From a8288b1426fe3a27fa947dc48a51b1e0070eebf2 Mon Sep 17 00:00:00 2001 From: nuluh Date: Mon, 11 Aug 2025 13:15:48 +0700 Subject: [PATCH] refactor(src): enhance compute_stft function with type hints, improved documentation by moving column renaming process from `process_damage_case` to `compute_stft` --- code/src/process_stft.py | 47 ++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/code/src/process_stft.py b/code/src/process_stft.py index e100a8d..113c977 100644 --- a/code/src/process_stft.py +++ b/code/src/process_stft.py @@ -5,6 +5,7 @@ from scipy.signal import stft from scipy.signal.windows import hann import glob import multiprocessing # Added import for multiprocessing +from typing import Union, Tuple # Define the base directory where DAMAGE_X folders are located damage_base_path = 'D:/thesis/data/converted/raw' @@ -22,10 +23,31 @@ for dir_path in output_dirs.values(): # Define STFT parameters # Number of damage cases (adjust as needed) -num_damage_cases = 0 # Change to 30 if you have 30 damage cases +num_damage_cases = 6 # Change to 30 if you have 30 damage cases # Function to perform STFT and return magnitude -def compute_stft(vibration_data, return_param=False): +def compute_stft(vibration_data: np.ndarray, return_param: bool = False) -> Union[pd.DataFrame, Tuple[pd.DataFrame, list[int, int, int]]]: + """ + Computes the Short-Time Fourier Transform (STFT) magnitude of the input vibration data. + + Parameters + ---------- + vibration_data : numpy.ndarray + The input vibration data as a 1D NumPy array. + return_param : bool, optional + If True, the function returns additional STFT parameters (window size, hop size, and sampling frequency). + Defaults to False. + + Returns + ------- + pd.DataFrame + The transposed STFT magnitude, with frequencies as columns, if `return_param` is False. + tuple + If `return_param` is True, returns a tuple containing: + - pd.DataFrame: The transposed STFT magnitude, with frequencies as columns. + - list[int, int, int]: A list of STFT parameters [window_size, hop_size, Fs]. + """ + window_size = 1024 hop_size = 512 window = hann(window_size) @@ -40,12 +62,18 @@ def compute_stft(vibration_data, return_param=False): ) stft_magnitude = np.abs(Zxx) + # Convert STFT result to DataFrame + df_stft = pd.DataFrame( + stft_magnitude.T, + columns=[f"Freq_{freq:.2f}" for freq in np.linspace(0, Fs/2, stft_magnitude.shape[1])] + ) + # breakpoint() if return_param: - return stft_magnitude.T, [window_size, hop_size, Fs] # Transpose to have frequencies as columns + return df_stft, [window_size, hop_size, Fs] else: - return stft_magnitude.T + return df_stft -def process_damage_case(damage_num, Fs=Fs,): +def process_damage_case(damage_num): damage_folder = os.path.join(damage_base_path, f'DAMAGE_{damage_num}') if damage_num == 0: # Number of test runs per damage case @@ -89,13 +117,8 @@ def process_damage_case(damage_num, Fs=Fs,): vibration_data = df.iloc[:, 1].values # Perform STFT - stft_magnitude, (window_size, hop_size, Fs) = compute_stft(vibration_data, return_param=True) + df_stft = compute_stft(vibration_data) - # Convert STFT result to DataFrame - df_stft = pd.DataFrame( - stft_magnitude, - columns=[f"Freq_{freq:.2f}" for freq in np.linspace(0, Fs/2, stft_magnitude.shape[1])] - ) # only inlcude 21 samples vector features for first 45 num_test_runs else include 22 samples vector features if damage_num == 0: print(f"Processing damage_num = 0, test_num = {test_num}") @@ -130,4 +153,4 @@ def process_damage_case(damage_num, Fs=Fs,): if __name__ == "__main__": # Added main guard for multiprocessing with multiprocessing.Pool() as pool: - pool.map(process_damage_case, range(0, num_damage_cases + 1)) + pool.map(process_damage_case, range(num_damage_cases + 1))