from joblib import load import pandas as pd from src.data_preprocessing import * from src.process_stft import compute_stft from typing import List, Tuple from sklearn.base import BaseEstimator import json def probability_damage(pred: Tuple[np.ndarray, np.ndarray], model_classes: BaseEstimator, percentage=False) -> Dict[str, int]: """ Process the prediction output to return unique labels and their counts. """ labels, counts = np.unique(pred, return_counts=True) label_counts = dict(zip(labels, counts)) # init all models classes probability of damage with 0 in dictionary pod: Dict[np.ndarray, int] = dict.fromkeys(model_classes.classes_, 0) # update corresponding data pod.update(label_counts) # turn the value into ratio instead of prediction counts for label, count in pod.items(): ratio: float = count/np.sum(counts) if percentage: pod[label] = ratio * 100 else: pod[label] = ratio return pod def convert_keys_to_strings(obj): """ Recursively convert all dictionary keys to strings. """ if isinstance(obj, dict): return {str(key): convert_keys_to_strings(value) for key, value in obj["data"].items()} elif isinstance(obj, list): return [convert_keys_to_strings(item) for item in obj["data"]] else: return obj def inference(model_sensor_A_path: str, model_sensor_B_path: str, file_path: str): # Generate column indices column_index: List[Tuple[int, int]] = [ (i + 1, i + 26) for i in range(5) ] # Load a single case data df: pd.DataFrame = pd.read_csv(file_path, delim_whitespace=True, skiprows=10, header=0, memory_map=True) # Take case name case_name: str = file_path.split("/")[-1].split(".")[0] # Extract relevant columns for each sensor column_data: List[Tuple[pd.Series[float], pd.Series[float]]] = [ (df.iloc[:, i[0]], df.iloc[:, i[1]]) for i in column_index ] column_data_stft: List[Tuple[pd.DataFrame, pd.DataFrame]] = [ (compute_stft(sensor_A), compute_stft(sensor_B)) for (sensor_A, sensor_B) in column_data ] # Load the model model_sensor_A = load(model_sensor_A_path) model_sensor_B = load(model_sensor_B_path) res = {} for i, (stft_A, stft_B) in enumerate(column_data_stft): # Make predictions using the model pred_A: list[int] = model_sensor_A.predict(stft_A) pred_B: list[int] = model_sensor_B.predict(stft_B) percentage_A = probability_damage(pred_A, model_sensor_A) percentage_B = probability_damage(pred_B, model_sensor_B) res[f"Column_{i+1}"] = { "Sensor_A": { # "Predictions": pred_A, "PoD": percentage_A }, "Sensor_B": { # "Predictions": pred_B, "PoD": percentage_B } } final_res = {"data": res, "case": case_name} return final_res def heatmap(result, damage_classes: list[int] = [1, 2, 3, 4, 5, 6]): from scipy.interpolate import RectBivariateSpline resolution = 300 y = list(range(1, len(damage_classes)+1)) # length of column x = list(range(len(result["data"]))) # X, Y = np.meshgrid(x, y) Z = [] for _, column_data in result["data"].items(): sensor_a_pod = column_data['Sensor_A']['PoD'] Z.append([sensor_a_pod.get(cls, 0) for cls in damage_classes]) Z = np.array(Z).T y2 = np.linspace(1, len(damage_classes), resolution) x2 = np.linspace(0,4,resolution) f = RectBivariateSpline(x, y, Z.T, kx=2, ky=2) # 2nd degree quadratic spline interpolation Z2 = f(x2, y2).T.clip(0, 1) # clip to ignores negative values from cubic interpolation X2, Y2 = np.meshgrid(x2, y2) # breakpoint() c = plt.pcolormesh(X2, Y2, Z2, cmap='jet', shading='auto') # Add a colorbar plt.colorbar(c, label='Probability of Damage (PoD)') plt.gca().invert_xaxis() plt.grid(True, linestyle='-', alpha=0.7) plt.xticks(np.arange(int(X2.min()), int(X2.max())+1, 1)) plt.xlabel("Column Index") plt.ylabel("Damage Index") plt.title(result["case"]) # plt.xticks(ticks=x2, labels=[f'Col_{i+1}' for i in range(len(result))]) # plt.gca().xaxis.set_major_locator(MultipleLocator(65/4)) plt.show() if __name__ == "__main__": import matplotlib.pyplot as plt import json from scipy.interpolate import UnivariateSpline result = inference( "D:/thesis/models/Sensor A/SVM with StandardScaler and PCA.joblib", "D:/thesis/models/Sensor B/SVM with StandardScaler and PCA.joblib", "D:/thesis/data/dataset_B/zzzBD19.TXT" ) # heatmap(result) # Convert all keys to strings before dumping to JSON # result_with_string_keys = convert_keys_to_strings(result) # print(json.dumps(result_with_string_keys, indent=4)) # Create a 5x2 subplot grid (5 rows for each column, 2 columns for sensors) fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(5, 50)) # # Define damage class labels for x-axis damage_classes = [1, 2, 3, 4, 5, 6] # # Loop through each column in the data for row_idx, (column_name, column_data) in enumerate(result['data'].items()): # Plot Sensor A in the first column of subplots sensor_a_pod = column_data['Sensor_A']['PoD'] x_values = list(range(len(damage_classes))) y_values = [sensor_a_pod.get(cls, 0) for cls in damage_classes] # x2 = np.linspace(1, 6, 100) # interp = UnivariateSpline(x_values, y_values, s=0) axes[row_idx, 0].plot(x_values, y_values, '-', linewidth=2, markersize=8) axes[row_idx, 0].set_title(f"{column_name} - Sensor A", fontsize=10) axes[row_idx, 0].set_xticks(x_values) axes[row_idx, 0].set_xticklabels(damage_classes) axes[row_idx, 0].set_ylim(0, 1.05) axes[row_idx, 0].set_ylabel('Probability') axes[row_idx, 0].set_xlabel('Damage Class') axes[row_idx, 0].grid(True, linestyle='-', alpha=0.5) # Plot Sensor B in the second column of subplots sensor_b_pod = column_data['Sensor_B']['PoD'] y_values = [sensor_b_pod.get(cls, 0) for cls in damage_classes] axes[row_idx, 1].plot(x_values, y_values, '-', linewidth=2, markersize=8) axes[row_idx, 1].set_title(f"{column_name} - Sensor B", fontsize=10) axes[row_idx, 1].set_xticks(x_values) axes[row_idx, 1].set_xticklabels(damage_classes) axes[row_idx, 1].set_ylim(0, 1.05) axes[row_idx, 1].set_ylabel('Probability') axes[row_idx, 1].set_xlabel('Damage Class') axes[row_idx, 1].grid(True, linestyle='-', alpha=0.5) # Adjust layout to prevent overlap fig.tight_layout(rect=[0, 0, 1, 0.96]) # Leave space for suptitle plt.subplots_adjust(hspace=1, wspace=0.3) # Adjust spacing between subplots plt.suptitle(f"Case {result['case']}", fontsize=16, y=0.98) # Adjust suptitle position plt.show()