diff --git a/code/src/ml/model_selection.py b/code/src/ml/model_selection.py index 3dc73ff..51c9f9b 100644 --- a/code/src/ml/model_selection.py +++ b/code/src/ml/model_selection.py @@ -1,8 +1,9 @@ import numpy as np import pandas as pd import os -from sklearn.model_selection import train_test_split as sklearn_split - +import matplotlib.pyplot as plt +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay +from joblib import load def create_ready_data( stft_data_path: str, @@ -121,6 +122,7 @@ def train_and_evaluate_model( try: y_pred = model.predict(x_test) + result["y_pred"] = y_pred # Convert to numpy array except Exception as e: result["error"] = f"Prediction error: {str(e)}" return result @@ -153,3 +155,45 @@ def train_and_evaluate_model( except Exception as e: result["error"] = f"Training error: {str(e)}" return result +def plot_confusion_matrix(results_sensor, x_test, y_test): + """ + Plot confusion matrices for each model in results_sensor1. + + Parameters: + ----------- + results_sensor1 : list + List of dictionaries containing model results. + x_test1 : array-like + Test input samples. + y_test : array-like + True labels for the test samples. + + Returns: + -------- + None + This function will display confusion matrices for each model in results_sensor1. + + Example + ------- + >>> results_sensor1 = [ + ... {'model': 'model1', 'accuracy': 95.0}, + ... {'model': 'model2', 'accuracy': 90.0} + ... ] + >>> x_test1 = np.random.rand(100, 10) # Example test data + >>> y_test = np.random.randint(0, 2, size=100) # Example true labels + >>> plot_confusion_matrix(results_sensor1, x_test1, y_test) + """ + # Iterate through each model result and plot confusion matrix + for i in results_sensor: + model = load(f"D:/thesis/models/{i['sensor']}/{i['model']}.joblib") + y_pred = model.predict(x_test) + cm = confusion_matrix(y_test, y_pred) # -> ndarray + + # get the class labels + labels = model.classes_ + + # Plot + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels) + disp.plot(cmap=plt.cm.Blues) # You can change colormap + plt.title(f"{i['model']} {i['sensor']} Test") + plt.show() \ No newline at end of file