wip: add function to create stratified train-test split from STFT data
This commit is contained in:
50
code/src/ml/model_selection.py
Normal file
50
code/src/ml/model_selection.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split as sklearn_split
|
||||
|
||||
|
||||
def create_train_test_split(
|
||||
ready_data: pd.DataFrame,
|
||||
test_size: float = 0.2,
|
||||
random_state: int = 42,
|
||||
stratify: np.ndarray = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Create a stratified train-test split from STFT data.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
data : pd.DataFrame
|
||||
The input DataFrame containing STFT data
|
||||
test_size : float
|
||||
Proportion of data to use for testing (default: 0.2)
|
||||
random_state : int
|
||||
Random seed for reproducibility (default: 42)
|
||||
stratify : np.ndarray, optional
|
||||
Labels to use for stratified sampling
|
||||
|
||||
Returns:
|
||||
--------
|
||||
tuple
|
||||
(X_train, X_test, y_train, y_test) - Split datasets
|
||||
"""
|
||||
y_data = [i for i in range(len(ready_data))]
|
||||
|
||||
for i in range(len(y_data)):
|
||||
y_data[i] = [y_data[i]] * ready_data[i].shape[0]
|
||||
y_data[i] = np.array(y_data[i])
|
||||
|
||||
# Extract features and labels
|
||||
X = (
|
||||
ready_data.drop("label_column", axis=1)
|
||||
if "label_column" in ready_data.columns
|
||||
else ready_data
|
||||
)
|
||||
y = ready_data["label_column"] if "label_column" in ready_data.columns else stratify
|
||||
|
||||
# Create split
|
||||
X_train, X_test, y_train, y_test = sklearn_split(
|
||||
X, y, test_size=test_size, random_state=random_state, stratify=stratify
|
||||
)
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||
Reference in New Issue
Block a user