feat(src): Add confusion matrix plotting function for model evaluation

This commit is contained in:
nuluh
2025-06-24 00:27:15 +07:00
parent 6196523ea0
commit 114ab849b9

View File

@@ -1,8 +1,9 @@
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split as sklearn_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from joblib import load
def create_ready_data(
stft_data_path: str,
@@ -121,6 +122,7 @@ def train_and_evaluate_model(
try:
y_pred = model.predict(x_test)
result["y_pred"] = y_pred # Convert to numpy array
except Exception as e:
result["error"] = f"Prediction error: {str(e)}"
return result
@@ -153,3 +155,45 @@ def train_and_evaluate_model(
except Exception as e:
result["error"] = f"Training error: {str(e)}"
return result
def plot_confusion_matrix(results_sensor, x_test, y_test):
"""
Plot confusion matrices for each model in results_sensor1.
Parameters:
-----------
results_sensor1 : list
List of dictionaries containing model results.
x_test1 : array-like
Test input samples.
y_test : array-like
True labels for the test samples.
Returns:
--------
None
This function will display confusion matrices for each model in results_sensor1.
Example
-------
>>> results_sensor1 = [
... {'model': 'model1', 'accuracy': 95.0},
... {'model': 'model2', 'accuracy': 90.0}
... ]
>>> x_test1 = np.random.rand(100, 10) # Example test data
>>> y_test = np.random.randint(0, 2, size=100) # Example true labels
>>> plot_confusion_matrix(results_sensor1, x_test1, y_test)
"""
# Iterate through each model result and plot confusion matrix
for i in results_sensor:
model = load(f"D:/thesis/models/{i['sensor']}/{i['model']}.joblib")
y_pred = model.predict(x_test)
cm = confusion_matrix(y_test, y_pred) # -> ndarray
# get the class labels
labels = model.classes_
# Plot
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap=plt.cm.Blues) # You can change colormap
plt.title(f"{i['model']} {i['sensor']} Test")
plt.show()