From 90a5a76609b2686f7441d657b92d31d388b6b689 Mon Sep 17 00:00:00 2001 From: nuluh Date: Wed, 23 Apr 2025 12:48:15 +0700 Subject: [PATCH] wip: add function to create stratified train-test split from STFT data --- code/src/ml/model_selection.py | 50 ++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 code/src/ml/model_selection.py diff --git a/code/src/ml/model_selection.py b/code/src/ml/model_selection.py new file mode 100644 index 0000000..7c97fce --- /dev/null +++ b/code/src/ml/model_selection.py @@ -0,0 +1,50 @@ +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split as sklearn_split + + +def create_train_test_split( + ready_data: pd.DataFrame, + test_size: float = 0.2, + random_state: int = 42, + stratify: np.ndarray = None, +) -> tuple: + """ + Create a stratified train-test split from STFT data. + + Parameters: + ----------- + data : pd.DataFrame + The input DataFrame containing STFT data + test_size : float + Proportion of data to use for testing (default: 0.2) + random_state : int + Random seed for reproducibility (default: 42) + stratify : np.ndarray, optional + Labels to use for stratified sampling + + Returns: + -------- + tuple + (X_train, X_test, y_train, y_test) - Split datasets + """ + y_data = [i for i in range(len(ready_data))] + + for i in range(len(y_data)): + y_data[i] = [y_data[i]] * ready_data[i].shape[0] + y_data[i] = np.array(y_data[i]) + + # Extract features and labels + X = ( + ready_data.drop("label_column", axis=1) + if "label_column" in ready_data.columns + else ready_data + ) + y = ready_data["label_column"] if "label_column" in ready_data.columns else stratify + + # Create split + X_train, X_test, y_train, y_test = sklearn_split( + X, y, test_size=test_size, random_state=random_state, stratify=stratify + ) + + return X_train, X_test, y_train, y_test