refactor(ml): clean up model_selection.py by removing unused code and improving function structure

This commit is contained in:
nuluh
2025-07-18 19:27:46 +07:00
parent 18824e05c0
commit f6c71739df

View File

@@ -8,7 +8,7 @@ from joblib import load
def create_ready_data(
stft_data_path: str,
stratify: np.ndarray = None,
) -> tuple:
) -> tuple[pd.DataFrame, np.ndarray]:
"""
Create a stratified train-test split from STFT data.
@@ -22,7 +22,7 @@ def create_ready_data(
Returns:
--------
tuple
(X_train, X_test, y_train, y_test) - Split datasets
(pd.DataFrame, np.ndarray) - Combined data and corresponding labels
"""
ready_data = []
for file in os.listdir(stft_data_path):
@@ -155,7 +155,7 @@ def train_and_evaluate_model(
except Exception as e:
result["error"] = f"Training error: {str(e)}"
return result
def plot_confusion_matrix(results_sensor, y_test):
def plot_confusion_matrix(results_sensor, y_test, title):
"""
Plot confusion matrices for each model in results_sensor1.
@@ -193,8 +193,7 @@ def plot_confusion_matrix(results_sensor, y_test):
# Plot
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap=plt.cm.Blues) # You can change colormap
plt.title(f"{i['model']} {i['sensor']} Test")
plt.show()
plt.title(f"{title}")
def calculate_label_percentages(labels):
"""
@@ -255,9 +254,9 @@ def inference_model(
fs=1024,
window=hann(1024),
nperseg=1024,
noverlap=512
noverlap=1024-512
)
data = pd.DataFrame(np.abs(Zxx).T, columns=[f"Freq_{freq:.2f}" for freq in np.linspace(0, 1024/2, Zxx.shape[1])])
data = data.rename(columns={"Freq_0.00": "00"}) # To match the model input format
model = load(models) # Load the model from the provided path
return calculate_label_percentages(model.predict(data))
return calculate_label_percentages(model.predict(data.iloc[:21,:]))