From 7da3179d0863c827fd92f01cff360b12ecc852a5 Mon Sep 17 00:00:00 2001 From: nuluh Date: Thu, 29 May 2025 22:57:28 +0700 Subject: [PATCH] refactor(nb): Create and implement helper function `train_and_evaluate_model` --- code/notebooks/stft.ipynb | 267 +++++++++++++++----------------------- 1 file changed, 104 insertions(+), 163 deletions(-) diff --git a/code/notebooks/stft.ipynb b/code/notebooks/stft.ipynb index b018033..e32eda5 100644 --- a/code/notebooks/stft.ipynb +++ b/code/notebooks/stft.ipynb @@ -537,8 +537,8 @@ "metadata": {}, "outputs": [], "source": [ - "# len(y_data[0])\n", - "y_data" + "len(y_data[0])\n", + "# y_data" ] }, { @@ -621,137 +621,15 @@ "metadata": {}, "outputs": [], "source": [ - "accuracies1 = []\n", - "accuracies2 = []\n", - "\n", - "\n", - "# 1. Random Forest\n", - "rf_model1 = RandomForestClassifier()\n", - "rf_model1.fit(x_train1, y_train)\n", - "rf_pred1 = rf_model1.predict(x_test1)\n", - "acc1 = accuracy_score(y_test, rf_pred1) * 100\n", - "accuracies1.append(acc1)\n", - "# format with color coded if acc1 > 90\n", - "acc1 = f\"\\033[92m{acc1:.2f}\\033[00m\" if acc1 > 90 else f\"{acc1:.2f}\"\n", - "print(\"Random Forest Accuracy for sensor 1:\", acc1)\n", - "rf_model2 = RandomForestClassifier()\n", - "rf_model2.fit(x_train2, y_train)\n", - "rf_pred2 = rf_model2.predict(x_test2)\n", - "acc2 = accuracy_score(y_test, rf_pred2) * 100\n", - "accuracies2.append(acc2)\n", - "# format with color coded if acc2 > 90\n", - "acc2 = f\"\\033[92m{acc2:.2f}\\033[00m\" if acc2 > 90 else f\"{acc2:.2f}\"\n", - "print(\"Random Forest Accuracy for sensor 2:\", acc2)\n", - "# print(rf_pred)\n", - "# print(y_test)\n", - "\n", - "# 2. Bagged Trees\n", - "bagged_model1 = BaggingClassifier(estimator=DecisionTreeClassifier(), n_estimators=10)\n", - "bagged_model1.fit(x_train1, y_train)\n", - "bagged_pred1 = bagged_model1.predict(x_test1)\n", - "acc1 = accuracy_score(y_test, bagged_pred1) * 100\n", - "accuracies1.append(acc1)\n", - "# format with color coded if acc1 > 90\n", - "acc1 = f\"\\033[92m{acc1:.2f}\\033[00m\" if acc1 > 90 else f\"{acc1:.2f}\"\n", - "print(\"Bagged Trees Accuracy for sensor 1:\", acc1)\n", - "bagged_model2 = BaggingClassifier(estimator=DecisionTreeClassifier(), n_estimators=10)\n", - "bagged_model2.fit(x_train2, y_train)\n", - "bagged_pred2 = bagged_model2.predict(x_test2)\n", - "acc2 = accuracy_score(y_test, bagged_pred2) * 100\n", - "accuracies2.append(acc2)\n", - "# format with color coded if acc2 > 90\n", - "acc2 = f\"\\033[92m{acc2:.2f}\\033[00m\" if acc2 > 90 else f\"{acc2:.2f}\"\n", - "print(\"Bagged Trees Accuracy for sensor 2:\", acc2)\n", - "\n", - "# 3. Decision Tree\n", - "dt_model = DecisionTreeClassifier()\n", - "dt_model.fit(x_train1, y_train)\n", - "dt_pred1 = dt_model.predict(x_test1)\n", - "acc1 = accuracy_score(y_test, dt_pred1) * 100\n", - "accuracies1.append(acc1)\n", - "# format with color coded if acc1 > 90\n", - "acc1 = f\"\\033[92m{acc1:.2f}\\033[00m\" if acc1 > 90 else f\"{acc1:.2f}\"\n", - "print(\"Decision Tree Accuracy for sensor 1:\", acc1)\n", - "dt_model2 = DecisionTreeClassifier()\n", - "dt_model2.fit(x_train2, y_train)\n", - "dt_pred2 = dt_model2.predict(x_test2)\n", - "acc2 = accuracy_score(y_test, dt_pred2) * 100\n", - "accuracies2.append(acc2)\n", - "# format with color coded if acc2 > 90\n", - "acc2 = f\"\\033[92m{acc2:.2f}\\033[00m\" if acc2 > 90 else f\"{acc2:.2f}\"\n", - "print(\"Decision Tree Accuracy for sensor 2:\", acc2)\n", - "\n", - "# 4. KNeighbors\n", - "knn_model = KNeighborsClassifier()\n", - "knn_model.fit(x_train1, y_train)\n", - "knn_pred1 = knn_model.predict(x_test1)\n", - "acc1 = accuracy_score(y_test, knn_pred1) * 100\n", - "accuracies1.append(acc1)\n", - "# format with color coded if acc1 > 90\n", - "acc1 = f\"\\033[92m{acc1:.2f}\\033[00m\" if acc1 > 90 else f\"{acc1:.2f}\"\n", - "print(\"KNeighbors Accuracy for sensor 1:\", acc1)\n", - "knn_model2 = KNeighborsClassifier()\n", - "knn_model2.fit(x_train2, y_train)\n", - "knn_pred2 = knn_model2.predict(x_test2)\n", - "acc2 = accuracy_score(y_test, knn_pred2) * 100\n", - "accuracies2.append(acc2)\n", - "# format with color coded if acc2 > 90\n", - "acc2 = f\"\\033[92m{acc2:.2f}\\033[00m\" if acc2 > 90 else f\"{acc2:.2f}\"\n", - "print(\"KNeighbors Accuracy for sensor 2:\", acc2)\n", - "\n", - "# 5. Linear Discriminant Analysis\n", - "lda_model = LinearDiscriminantAnalysis()\n", - "lda_model.fit(x_train1, y_train)\n", - "lda_pred1 = lda_model.predict(x_test1)\n", - "acc1 = accuracy_score(y_test, lda_pred1) * 100\n", - "accuracies1.append(acc1)\n", - "# format with color coded if acc1 > 90\n", - "acc1 = f\"\\033[92m{acc1:.2f}\\033[00m\" if acc1 > 90 else f\"{acc1:.2f}\"\n", - "print(\"Linear Discriminant Analysis Accuracy for sensor 1:\", acc1)\n", - "lda_model2 = LinearDiscriminantAnalysis()\n", - "lda_model2.fit(x_train2, y_train)\n", - "lda_pred2 = lda_model2.predict(x_test2)\n", - "acc2 = accuracy_score(y_test, lda_pred2) * 100\n", - "accuracies2.append(acc2)\n", - "# format with color coded if acc2 > 90\n", - "acc2 = f\"\\033[92m{acc2:.2f}\\033[00m\" if acc2 > 90 else f\"{acc2:.2f}\"\n", - "print(\"Linear Discriminant Analysis Accuracy for sensor 2:\", acc2)\n", - "\n", - "# 6. Support Vector Machine\n", - "svm_model = SVC()\n", - "svm_model.fit(x_train1, y_train)\n", - "svm_pred1 = svm_model.predict(x_test1)\n", - "acc1 = accuracy_score(y_test, svm_pred1) * 100\n", - "accuracies1.append(acc1)\n", - "# format with color coded if acc1 > 90\n", - "acc1 = f\"\\033[92m{acc1:.2f}\\033[00m\" if acc1 > 90 else f\"{acc1:.2f}\"\n", - "print(\"Support Vector Machine Accuracy for sensor 1:\", acc1)\n", - "svm_model2 = SVC()\n", - "svm_model2.fit(x_train2, y_train)\n", - "svm_pred2 = svm_model2.predict(x_test2)\n", - "acc2 = accuracy_score(y_test, svm_pred2) * 100\n", - "accuracies2.append(acc2)\n", - "# format with color coded if acc2 > 90\n", - "acc2 = f\"\\033[92m{acc2:.2f}\\033[00m\" if acc2 > 90 else f\"{acc2:.2f}\"\n", - "print(\"Support Vector Machine Accuracy for sensor 2:\", acc2)\n", - "\n", - "# 7. XGBoost\n", - "xgboost_model = XGBClassifier()\n", - "xgboost_model.fit(x_train1, y_train)\n", - "xgboost_pred1 = xgboost_model.predict(x_test1)\n", - "acc1 = accuracy_score(y_test, xgboost_pred1) * 100\n", - "accuracies1.append(acc1)\n", - "# format with color coded if acc1 > 90\n", - "acc1 = f\"\\033[92m{acc1:.2f}\\033[00m\" if acc1 > 90 else f\"{acc1:.2f}\"\n", - "print(\"XGBoost Accuracy:\", acc1)\n", - "xgboost_model2 = XGBClassifier()\n", - "xgboost_model2.fit(x_train2, y_train)\n", - "xgboost_pred2 = xgboost_model2.predict(x_test2)\n", - "acc2 = accuracy_score(y_test, xgboost_pred2) * 100\n", - "accuracies2.append(acc2)\n", - "# format with color coded if acc2 > 90\n", - "acc2 = f\"\\033[92m{acc2:.2f}\\033[00m\" if acc2 > 90 else f\"{acc2:.2f}\"\n", - "print(\"XGBoost Accuracy:\", acc2)" + "def train_and_evaluate_model(model, model_name, sensor_label, x_train, y_train, x_test, y_test):\n", + " model.fit(x_train, y_train)\n", + " y_pred = model.predict(x_test)\n", + " accuracy = accuracy_score(y_test, y_pred) * 100\n", + " return {\n", + " \"model\": model_name,\n", + " \"sensor\": sensor_label,\n", + " \"accuracy\": accuracy\n", + " }" ] }, { @@ -760,8 +638,59 @@ "metadata": {}, "outputs": [], "source": [ - "print(accuracies1)\n", - "print(accuracies2)" + "# Define models for sensor1\n", + "models_sensor1 = {\n", + " # \"Random Forest\": RandomForestClassifier(),\n", + " # \"Bagged Trees\": BaggingClassifier(estimator=DecisionTreeClassifier(), n_estimators=10),\n", + " # \"Decision Tree\": DecisionTreeClassifier(),\n", + " # \"KNN\": KNeighborsClassifier(),\n", + " # \"LDA\": LinearDiscriminantAnalysis(),\n", + " \"SVM\": SVC(),\n", + " \"XGBoost\": XGBClassifier()\n", + "}\n", + "\n", + "results_sensor1 = []\n", + "for name, model in models_sensor1.items():\n", + " res = train_and_evaluate_model(model, name, \"sensor1\", x_train1, y_train, x_test1, y_test)\n", + " results_sensor1.append(res)\n", + " print(f\"{name} on sensor1: Accuracy = {res['accuracy']:.2f}%\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "models_sensor2 = {\n", + " # \"Random Forest\": RandomForestClassifier(),\n", + " # \"Bagged Trees\": BaggingClassifier(estimator=DecisionTreeClassifier(), n_estimators=10),\n", + " # \"Decision Tree\": DecisionTreeClassifier(),\n", + " # \"KNN\": KNeighborsClassifier(),\n", + " # \"LDA\": LinearDiscriminantAnalysis(),\n", + " \"SVM\": SVC(),\n", + " \"XGBoost\": XGBClassifier()\n", + "}\n", + "\n", + "results_sensor2 = []\n", + "for name, model in models_sensor2.items():\n", + " res = train_and_evaluate_model(model, name, \"sensor2\", x_train2, y_train, x_test2, y_test)\n", + " results_sensor2.append(res)\n", + " print(f\"{name} on sensor2: Accuracy = {res['accuracy']:.2f}%\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_results = {\n", + " \"sensor1\": results_sensor1,\n", + " \"sensor2\": results_sensor2\n", + "}\n", + "\n", + "print(all_results)" ] }, { @@ -773,36 +702,48 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", - "models = [rf_model, bagged_model, dt_model, knn_model, lda_model, svm_model, xgboost_model]\n", - "model_names = [\"Random Forest\", \"Bagged Trees\", \"Decision Tree\", \"KNN\", \"LDA\", \"SVM\", \"XGBoost\"]\n", + "def prepare_plot_data(results_dict):\n", + " # Gather unique model names\n", + " models_set = {entry['model'] for sensor in results_dict.values() for entry in sensor}\n", + " models = sorted(list(models_set))\n", + " \n", + " # Create dictionaries mapping sensor -> accuracy list ordered by model name\n", + " sensor_accuracies = {}\n", + " for sensor, entries in results_dict.items():\n", + " # Build a mapping: model -> accuracy for the given sensor\n", + " mapping = {entry['model']: entry['accuracy'] for entry in entries}\n", + " # Order the accuracies consistent with the sorted model names\n", + " sensor_accuracies[sensor] = [mapping.get(model, 0) for model in models]\n", + " \n", + " return models, sensor_accuracies\n", "\n", - "bar_width = 0.35 # Width of each bar\n", - "index = np.arange(len(model_names)) # Index for the bars\n", + "def plot_accuracies(models, sensor_accuracies):\n", + " bar_width = 0.35\n", + " x = np.arange(len(models))\n", + " sensors = list(sensor_accuracies.keys())\n", + " \n", + " plt.figure(figsize=(10, 6))\n", + " # Assume two sensors for plotting grouped bars\n", + " plt.bar(x - bar_width/2, sensor_accuracies[sensors[0]], width=bar_width, color='blue', label=sensors[0])\n", + " plt.bar(x + bar_width/2, sensor_accuracies[sensors[1]], width=bar_width, color='orange', label=sensors[1])\n", + " \n", + " # Add text labels on top of bars\n", + " for i, (a1, a2) in enumerate(zip(sensor_accuracies[sensors[0]], sensor_accuracies[sensors[1]])):\n", + " plt.text(x[i] - bar_width/2, a1 + 0.1, f\"{a1:.2f}%\", ha='center', va='bottom', color='black')\n", + " plt.text(x[i] + bar_width/2, a2 + 0.1, f\"{a2:.2f}%\", ha='center', va='bottom', color='black')\n", + " \n", + " plt.xlabel('Model Name')\n", + " plt.ylabel('Accuracy (%)')\n", + " plt.title('Accuracy of Classifiers for Each Sensor')\n", + " plt.xticks(x, models)\n", + " plt.legend()\n", + " plt.ylim(0, 105)\n", + " plt.tight_layout()\n", + " plt.show()\n", "\n", - "# Plotting the bar graph\n", - "plt.figure(figsize=(14, 8))\n", - "\n", - "# Bar plot for Sensor 1\n", - "plt.bar(index, accuracies1, width=bar_width, color='blue', label='Sensor 1')\n", - "\n", - "# Bar plot for Sensor 2\n", - "plt.bar(index + bar_width, accuracies2, width=bar_width, color='orange', label='Sensor 2')\n", - "\n", - "# Add values on top of each bar\n", - "for i, acc1, acc2 in zip(index, accuracies1, accuracies2):\n", - " plt.text(i, acc1 + .1, f'{acc1:.2f}%', ha='center', va='bottom', color='black')\n", - " plt.text(i + bar_width, acc2 + 1, f'{acc2:.2f}%', ha='center', va='bottom', color='black')\n", - "\n", - "# Customize the plot\n", - "plt.xlabel('Model Name →')\n", - "plt.ylabel('Accuracy →')\n", - "plt.title('Accuracy of classifiers for Sensors 1 and 2 with 513 features')\n", - "plt.xticks(index + bar_width / 2, model_names) # Set x-tick positions\n", - "plt.legend()\n", - "plt.ylim(0, 100)\n", - "\n", - "# Show the plot\n", - "plt.show()\n" + "# Use the functions\n", + "models, sensor_accuracies = prepare_plot_data(all_results)\n", + "plot_accuracies(models, sensor_accuracies)\n" ] }, {