wip: add function to create stratified train-test split from STFT data

This commit is contained in:
nuluh
2025-04-23 12:48:15 +07:00
parent c8509aa728
commit 90a5a76609

View File

@@ -0,0 +1,50 @@
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split as sklearn_split
def create_train_test_split(
ready_data: pd.DataFrame,
test_size: float = 0.2,
random_state: int = 42,
stratify: np.ndarray = None,
) -> tuple:
"""
Create a stratified train-test split from STFT data.
Parameters:
-----------
data : pd.DataFrame
The input DataFrame containing STFT data
test_size : float
Proportion of data to use for testing (default: 0.2)
random_state : int
Random seed for reproducibility (default: 42)
stratify : np.ndarray, optional
Labels to use for stratified sampling
Returns:
--------
tuple
(X_train, X_test, y_train, y_test) - Split datasets
"""
y_data = [i for i in range(len(ready_data))]
for i in range(len(y_data)):
y_data[i] = [y_data[i]] * ready_data[i].shape[0]
y_data[i] = np.array(y_data[i])
# Extract features and labels
X = (
ready_data.drop("label_column", axis=1)
if "label_column" in ready_data.columns
else ready_data
)
y = ready_data["label_column"] if "label_column" in ready_data.columns else stratify
# Create split
X_train, X_test, y_train, y_test = sklearn_split(
X, y, test_size=test_size, random_state=random_state, stratify=stratify
)
return X_train, X_test, y_train, y_test