feat(src): Add confusion matrix plotting function for model evaluation
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import os
|
import os
|
||||||
from sklearn.model_selection import train_test_split as sklearn_split
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
||||||
|
from joblib import load
|
||||||
|
|
||||||
def create_ready_data(
|
def create_ready_data(
|
||||||
stft_data_path: str,
|
stft_data_path: str,
|
||||||
@@ -121,6 +122,7 @@ def train_and_evaluate_model(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
y_pred = model.predict(x_test)
|
y_pred = model.predict(x_test)
|
||||||
|
result["y_pred"] = y_pred # Convert to numpy array
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["error"] = f"Prediction error: {str(e)}"
|
result["error"] = f"Prediction error: {str(e)}"
|
||||||
return result
|
return result
|
||||||
@@ -153,3 +155,45 @@ def train_and_evaluate_model(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["error"] = f"Training error: {str(e)}"
|
result["error"] = f"Training error: {str(e)}"
|
||||||
return result
|
return result
|
||||||
|
def plot_confusion_matrix(results_sensor, x_test, y_test):
|
||||||
|
"""
|
||||||
|
Plot confusion matrices for each model in results_sensor1.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
results_sensor1 : list
|
||||||
|
List of dictionaries containing model results.
|
||||||
|
x_test1 : array-like
|
||||||
|
Test input samples.
|
||||||
|
y_test : array-like
|
||||||
|
True labels for the test samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
None
|
||||||
|
This function will display confusion matrices for each model in results_sensor1.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> results_sensor1 = [
|
||||||
|
... {'model': 'model1', 'accuracy': 95.0},
|
||||||
|
... {'model': 'model2', 'accuracy': 90.0}
|
||||||
|
... ]
|
||||||
|
>>> x_test1 = np.random.rand(100, 10) # Example test data
|
||||||
|
>>> y_test = np.random.randint(0, 2, size=100) # Example true labels
|
||||||
|
>>> plot_confusion_matrix(results_sensor1, x_test1, y_test)
|
||||||
|
"""
|
||||||
|
# Iterate through each model result and plot confusion matrix
|
||||||
|
for i in results_sensor:
|
||||||
|
model = load(f"D:/thesis/models/{i['sensor']}/{i['model']}.joblib")
|
||||||
|
y_pred = model.predict(x_test)
|
||||||
|
cm = confusion_matrix(y_test, y_pred) # -> ndarray
|
||||||
|
|
||||||
|
# get the class labels
|
||||||
|
labels = model.classes_
|
||||||
|
|
||||||
|
# Plot
|
||||||
|
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
|
||||||
|
disp.plot(cmap=plt.cm.Blues) # You can change colormap
|
||||||
|
plt.title(f"{i['model']} {i['sensor']} Test")
|
||||||
|
plt.show()
|
||||||
Reference in New Issue
Block a user