refactor(ml): clean up model_selection.py by removing unused code and improving function structure
This commit is contained in:
@@ -8,7 +8,7 @@ from joblib import load
|
|||||||
def create_ready_data(
|
def create_ready_data(
|
||||||
stft_data_path: str,
|
stft_data_path: str,
|
||||||
stratify: np.ndarray = None,
|
stratify: np.ndarray = None,
|
||||||
) -> tuple:
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Create a stratified train-test split from STFT data.
|
Create a stratified train-test split from STFT data.
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ def create_ready_data(
|
|||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
tuple
|
tuple
|
||||||
(X_train, X_test, y_train, y_test) - Split datasets
|
(pd.DataFrame, np.ndarray) - Combined data and corresponding labels
|
||||||
"""
|
"""
|
||||||
ready_data = []
|
ready_data = []
|
||||||
for file in os.listdir(stft_data_path):
|
for file in os.listdir(stft_data_path):
|
||||||
@@ -155,7 +155,7 @@ def train_and_evaluate_model(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["error"] = f"Training error: {str(e)}"
|
result["error"] = f"Training error: {str(e)}"
|
||||||
return result
|
return result
|
||||||
def plot_confusion_matrix(results_sensor, y_test):
|
def plot_confusion_matrix(results_sensor, y_test, title):
|
||||||
"""
|
"""
|
||||||
Plot confusion matrices for each model in results_sensor1.
|
Plot confusion matrices for each model in results_sensor1.
|
||||||
|
|
||||||
@@ -193,8 +193,7 @@ def plot_confusion_matrix(results_sensor, y_test):
|
|||||||
# Plot
|
# Plot
|
||||||
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
|
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
|
||||||
disp.plot(cmap=plt.cm.Blues) # You can change colormap
|
disp.plot(cmap=plt.cm.Blues) # You can change colormap
|
||||||
plt.title(f"{i['model']} {i['sensor']} Test")
|
plt.title(f"{title}")
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def calculate_label_percentages(labels):
|
def calculate_label_percentages(labels):
|
||||||
"""
|
"""
|
||||||
@@ -255,9 +254,9 @@ def inference_model(
|
|||||||
fs=1024,
|
fs=1024,
|
||||||
window=hann(1024),
|
window=hann(1024),
|
||||||
nperseg=1024,
|
nperseg=1024,
|
||||||
noverlap=512
|
noverlap=1024-512
|
||||||
)
|
)
|
||||||
data = pd.DataFrame(np.abs(Zxx).T, columns=[f"Freq_{freq:.2f}" for freq in np.linspace(0, 1024/2, Zxx.shape[1])])
|
data = pd.DataFrame(np.abs(Zxx).T, columns=[f"Freq_{freq:.2f}" for freq in np.linspace(0, 1024/2, Zxx.shape[1])])
|
||||||
data = data.rename(columns={"Freq_0.00": "00"}) # To match the model input format
|
data = data.rename(columns={"Freq_0.00": "00"}) # To match the model input format
|
||||||
model = load(models) # Load the model from the provided path
|
model = load(models) # Load the model from the provided path
|
||||||
return calculate_label_percentages(model.predict(data))
|
return calculate_label_percentages(model.predict(data.iloc[:21,:]))
|
||||||
Reference in New Issue
Block a user