feat(notebooks): Enhance STFT notebook and model selection functionality
- Updated paths in the STFT notebook to reflect new data files. - Improved plotting aesthetics for combined plots and added grid lines. - Introduced a 3D spectrogram visualization for better data representation. - Refactored model training function to include error handling and model export functionality. - Adjusted model training calls to include export paths for saved models. Closes #90 - Added additional markdown cells for better documentation and clarity in the notebook.
This commit is contained in:
@@ -55,3 +55,101 @@ def create_ready_data(
|
||||
y = np.array([])
|
||||
|
||||
return X, y
|
||||
|
||||
|
||||
def train_and_evaluate_model(
|
||||
model, model_name, sensor_label, x_train, y_train, x_test, y_test, export=None
|
||||
):
|
||||
"""
|
||||
Train a machine learning model, evaluate its performance, and optionally export it.
|
||||
|
||||
This function trains the provided model on the training data, evaluates its
|
||||
performance on test data using accuracy score, and can save the trained model
|
||||
to disk if an export path is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : estimator object
|
||||
The machine learning model to train.
|
||||
model_name : str
|
||||
Name of the model, used for the export filename and in the returned results.
|
||||
sensor_label : str
|
||||
Label identifying which sensor's data the model is being trained on.
|
||||
x_train : array-like or pandas.DataFrame
|
||||
The training input samples.
|
||||
y_train : array-like
|
||||
The target values for training.
|
||||
x_test : array-like or pandas.DataFrame
|
||||
The test input samples.
|
||||
y_test : array-like
|
||||
The target values for testing.
|
||||
export : str, optional
|
||||
Directory path where the trained model should be saved. If None, model won't be saved.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary containing:
|
||||
- 'model': model_name (str)
|
||||
- 'sensor': sensor_label (str)
|
||||
- 'accuracy': accuracy percentage (float)
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2)
|
||||
>>> result = train_and_evaluate_model(
|
||||
... SVC(),
|
||||
... "SVM",
|
||||
... "sensor1",
|
||||
... X_train,
|
||||
... y_train,
|
||||
... X_test,
|
||||
... y_test,
|
||||
... export="models/sensor1"
|
||||
... )
|
||||
>>> print(f"Model accuracy: {result['accuracy']:.2f}%")
|
||||
"""
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
result = {"model": model_name, "sensor": sensor_label, "success": False}
|
||||
|
||||
try:
|
||||
# Train the model
|
||||
model.fit(x_train, y_train)
|
||||
|
||||
try:
|
||||
y_pred = model.predict(x_test)
|
||||
except Exception as e:
|
||||
result["error"] = f"Prediction error: {str(e)}"
|
||||
return result
|
||||
|
||||
# Calculate accuracy
|
||||
try:
|
||||
accuracy = accuracy_score(y_test, y_pred) * 100
|
||||
result["accuracy"] = accuracy
|
||||
except Exception as e:
|
||||
result["error"] = f"Accuracy calculation error: {str(e)}"
|
||||
return result
|
||||
|
||||
# Export model if requested
|
||||
if export:
|
||||
try:
|
||||
import joblib
|
||||
|
||||
full_path = os.path.join(export, f"{model_name}.joblib")
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
joblib.dump(model, full_path)
|
||||
print(f"Model saved to {full_path}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to export model to {export}: {str(e)}")
|
||||
result["export_error"] = str(e)
|
||||
# Continue despite export error
|
||||
|
||||
result["success"] = True
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = f"Training error: {str(e)}"
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user