refactor(src): enhance compute_stft function with type hints, improved documentation by moving column renaming process from process_damage_case to compute_stft
This commit is contained in:
@@ -5,6 +5,7 @@ from scipy.signal import stft
|
|||||||
from scipy.signal.windows import hann
|
from scipy.signal.windows import hann
|
||||||
import glob
|
import glob
|
||||||
import multiprocessing # Added import for multiprocessing
|
import multiprocessing # Added import for multiprocessing
|
||||||
|
from typing import Union, Tuple
|
||||||
|
|
||||||
# Define the base directory where DAMAGE_X folders are located
|
# Define the base directory where DAMAGE_X folders are located
|
||||||
damage_base_path = 'D:/thesis/data/converted/raw'
|
damage_base_path = 'D:/thesis/data/converted/raw'
|
||||||
@@ -22,10 +23,31 @@ for dir_path in output_dirs.values():
|
|||||||
# Define STFT parameters
|
# Define STFT parameters
|
||||||
|
|
||||||
# Number of damage cases (adjust as needed)
|
# Number of damage cases (adjust as needed)
|
||||||
num_damage_cases = 0 # Change to 30 if you have 30 damage cases
|
num_damage_cases = 6 # Change to 30 if you have 30 damage cases
|
||||||
|
|
||||||
# Function to perform STFT and return magnitude
|
# Function to perform STFT and return magnitude
|
||||||
def compute_stft(vibration_data, return_param=False):
|
def compute_stft(vibration_data: np.ndarray, return_param: bool = False) -> Union[pd.DataFrame, Tuple[pd.DataFrame, list[int, int, int]]]:
|
||||||
|
"""
|
||||||
|
Computes the Short-Time Fourier Transform (STFT) magnitude of the input vibration data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
vibration_data : numpy.ndarray
|
||||||
|
The input vibration data as a 1D NumPy array.
|
||||||
|
return_param : bool, optional
|
||||||
|
If True, the function returns additional STFT parameters (window size, hop size, and sampling frequency).
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pd.DataFrame
|
||||||
|
The transposed STFT magnitude, with frequencies as columns, if `return_param` is False.
|
||||||
|
tuple
|
||||||
|
If `return_param` is True, returns a tuple containing:
|
||||||
|
- pd.DataFrame: The transposed STFT magnitude, with frequencies as columns.
|
||||||
|
- list[int, int, int]: A list of STFT parameters [window_size, hop_size, Fs].
|
||||||
|
"""
|
||||||
|
|
||||||
window_size = 1024
|
window_size = 1024
|
||||||
hop_size = 512
|
hop_size = 512
|
||||||
window = hann(window_size)
|
window = hann(window_size)
|
||||||
@@ -40,12 +62,18 @@ def compute_stft(vibration_data, return_param=False):
|
|||||||
)
|
)
|
||||||
stft_magnitude = np.abs(Zxx)
|
stft_magnitude = np.abs(Zxx)
|
||||||
|
|
||||||
|
# Convert STFT result to DataFrame
|
||||||
|
df_stft = pd.DataFrame(
|
||||||
|
stft_magnitude.T,
|
||||||
|
columns=[f"Freq_{freq:.2f}" for freq in np.linspace(0, Fs/2, stft_magnitude.shape[1])]
|
||||||
|
)
|
||||||
|
# breakpoint()
|
||||||
if return_param:
|
if return_param:
|
||||||
return stft_magnitude.T, [window_size, hop_size, Fs] # Transpose to have frequencies as columns
|
return df_stft, [window_size, hop_size, Fs]
|
||||||
else:
|
else:
|
||||||
return stft_magnitude.T
|
return df_stft
|
||||||
|
|
||||||
def process_damage_case(damage_num, Fs=Fs,):
|
def process_damage_case(damage_num):
|
||||||
damage_folder = os.path.join(damage_base_path, f'DAMAGE_{damage_num}')
|
damage_folder = os.path.join(damage_base_path, f'DAMAGE_{damage_num}')
|
||||||
if damage_num == 0:
|
if damage_num == 0:
|
||||||
# Number of test runs per damage case
|
# Number of test runs per damage case
|
||||||
@@ -89,13 +117,8 @@ def process_damage_case(damage_num, Fs=Fs,):
|
|||||||
vibration_data = df.iloc[:, 1].values
|
vibration_data = df.iloc[:, 1].values
|
||||||
|
|
||||||
# Perform STFT
|
# Perform STFT
|
||||||
stft_magnitude, (window_size, hop_size, Fs) = compute_stft(vibration_data, return_param=True)
|
df_stft = compute_stft(vibration_data)
|
||||||
|
|
||||||
# Convert STFT result to DataFrame
|
|
||||||
df_stft = pd.DataFrame(
|
|
||||||
stft_magnitude,
|
|
||||||
columns=[f"Freq_{freq:.2f}" for freq in np.linspace(0, Fs/2, stft_magnitude.shape[1])]
|
|
||||||
)
|
|
||||||
# only inlcude 21 samples vector features for first 45 num_test_runs else include 22 samples vector features
|
# only inlcude 21 samples vector features for first 45 num_test_runs else include 22 samples vector features
|
||||||
if damage_num == 0:
|
if damage_num == 0:
|
||||||
print(f"Processing damage_num = 0, test_num = {test_num}")
|
print(f"Processing damage_num = 0, test_num = {test_num}")
|
||||||
@@ -130,4 +153,4 @@ def process_damage_case(damage_num, Fs=Fs,):
|
|||||||
|
|
||||||
if __name__ == "__main__": # Added main guard for multiprocessing
|
if __name__ == "__main__": # Added main guard for multiprocessing
|
||||||
with multiprocessing.Pool() as pool:
|
with multiprocessing.Pool() as pool:
|
||||||
pool.map(process_damage_case, range(0, num_damage_cases + 1))
|
pool.map(process_damage_case, range(num_damage_cases + 1))
|
||||||
|
|||||||
Reference in New Issue
Block a user