From 2decff0cfb1317a4de92793b1802371285bfe994 Mon Sep 17 00:00:00 2001 From: nuluh Date: Fri, 13 Dec 2024 16:29:08 +0700 Subject: [PATCH] Closes #24 feat(stft): Implement STFT processing for vibration data with multiprocessing support to include all the data for training process instead of just using `TEST1` only --- code/src/process_stft.py | 115 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 code/src/process_stft.py diff --git a/code/src/process_stft.py b/code/src/process_stft.py new file mode 100644 index 0000000..1de44b4 --- /dev/null +++ b/code/src/process_stft.py @@ -0,0 +1,115 @@ +import os +import pandas as pd +import numpy as np +from scipy.signal import stft, hann +import glob +import multiprocessing # Added import for multiprocessing + +# Define the base directory where DAMAGE_X folders are located +damage_base_path = 'D:/thesis/data/converted/raw' + +# Define output directories for each sensor +output_dirs = { + 'sensor1': os.path.join(damage_base_path, 'sensor1'), + 'sensor2': os.path.join(damage_base_path, 'sensor2') +} + +# Create output directories if they don't exist +for dir_path in output_dirs.values(): + os.makedirs(dir_path, exist_ok=True) + +# Define STFT parameters +window_size = 1024 +hop_size = 512 +window = hann(window_size) +Fs = 1024 + +# Number of damage cases (adjust as needed) +num_damage_cases = 6 # Change to 30 if you have 30 damage cases + +# Number of test runs per damage case +num_test_runs = 5 + +# Function to perform STFT and return magnitude +def compute_stft(vibration_data): + frequencies, times, Zxx = stft( + vibration_data, + fs=Fs, + window=window, + nperseg=window_size, + noverlap=window_size - hop_size + ) + stft_magnitude = np.abs(Zxx) + return stft_magnitude.T # Transpose to have frequencies as columns + +def process_damage_case(damage_num): + damage_folder = os.path.join(damage_base_path, f'DAMAGE_{damage_num}') + + # Check if the damage folder exists + if not os.path.isdir(damage_folder): + print(f"Folder {damage_folder} does not exist. Skipping...") + return + + # Process Sensor 1 and Sensor 2 separately + for sensor_num in [1, 2]: + aggregated_stft = [] # List to hold STFTs from all test runs + + # Iterate over all test runs + for test_num in range(1, num_test_runs + 1): + # Construct the filename based on sensor number + # Sensor 1 corresponds to '_01', Sensor 2 corresponds to '_02' + sensor_suffix = f'_0{sensor_num}' + file_name = f'DAMAGE_{damage_num}_TEST{test_num}{sensor_suffix}.csv' + file_path = os.path.join(damage_folder, file_name) + + # Check if the file exists + if not os.path.isfile(file_path): + print(f"File {file_path} does not exist. Skipping...") + continue + + # Read the CSV file + try: + df = pd.read_csv(file_path) + except Exception as e: + print(f"Error reading {file_path}: {e}. Skipping...") + continue + + # Ensure the CSV has exactly two columns: 'Timestamp (s)' and 'Sensor X' + if df.shape[1] != 2: + print(f"Unexpected number of columns in {file_path}. Expected 2, got {df.shape[1]}. Skipping...") + continue + + # Extract vibration data (assuming the second column is sensor data) + vibration_data = df.iloc[:, 1].values + + # Perform STFT + stft_magnitude = compute_stft(vibration_data) + + # Convert STFT result to DataFrame + df_stft = pd.DataFrame( + stft_magnitude, + columns=[f"Freq_{freq:.2f}" for freq in np.linspace(0, Fs/2, stft_magnitude.shape[1])] + ) + + # Append to the aggregated list + aggregated_stft.append(df_stft) + + # Concatenate all STFT DataFrames vertically + if aggregated_stft: + df_aggregated = pd.concat(aggregated_stft, ignore_index=True) + + # Define output filename + output_file = os.path.join( + output_dirs[f'sensor{sensor_num}'], + f'stft_data{sensor_num}_{damage_num}.csv' + ) + + # Save the aggregated STFT to CSV + df_aggregated.to_csv(output_file, index=False) + print(f"Saved aggregated STFT for Sensor {sensor_num}, Damage {damage_num} to {output_file}") + else: + print(f"No STFT data aggregated for Sensor {sensor_num}, Damage {damage_num}.") + +if __name__ == "__main__": # Added main guard for multiprocessing + with multiprocessing.Pool() as pool: + pool.map(process_damage_case, range(1, num_damage_cases + 1))