feat(src): Add confusion matrix plotting function for model evaluation
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
from sklearn.model_selection import train_test_split as sklearn_split
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
||||
from joblib import load
|
||||
|
||||
def create_ready_data(
|
||||
stft_data_path: str,
|
||||
@@ -121,6 +122,7 @@ def train_and_evaluate_model(
|
||||
|
||||
try:
|
||||
y_pred = model.predict(x_test)
|
||||
result["y_pred"] = y_pred # Convert to numpy array
|
||||
except Exception as e:
|
||||
result["error"] = f"Prediction error: {str(e)}"
|
||||
return result
|
||||
@@ -153,3 +155,45 @@ def train_and_evaluate_model(
|
||||
except Exception as e:
|
||||
result["error"] = f"Training error: {str(e)}"
|
||||
return result
|
||||
def plot_confusion_matrix(results_sensor, x_test, y_test):
|
||||
"""
|
||||
Plot confusion matrices for each model in results_sensor1.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
results_sensor1 : list
|
||||
List of dictionaries containing model results.
|
||||
x_test1 : array-like
|
||||
Test input samples.
|
||||
y_test : array-like
|
||||
True labels for the test samples.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
This function will display confusion matrices for each model in results_sensor1.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> results_sensor1 = [
|
||||
... {'model': 'model1', 'accuracy': 95.0},
|
||||
... {'model': 'model2', 'accuracy': 90.0}
|
||||
... ]
|
||||
>>> x_test1 = np.random.rand(100, 10) # Example test data
|
||||
>>> y_test = np.random.randint(0, 2, size=100) # Example true labels
|
||||
>>> plot_confusion_matrix(results_sensor1, x_test1, y_test)
|
||||
"""
|
||||
# Iterate through each model result and plot confusion matrix
|
||||
for i in results_sensor:
|
||||
model = load(f"D:/thesis/models/{i['sensor']}/{i['model']}.joblib")
|
||||
y_pred = model.predict(x_test)
|
||||
cm = confusion_matrix(y_test, y_pred) # -> ndarray
|
||||
|
||||
# get the class labels
|
||||
labels = model.classes_
|
||||
|
||||
# Plot
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
|
||||
disp.plot(cmap=plt.cm.Blues) # You can change colormap
|
||||
plt.title(f"{i['model']} {i['sensor']} Test")
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user