refactor(data): update type annotations for damage files index and related classes. Need better implementation

This commit is contained in:
nuluh
2025-03-22 19:48:50 +07:00
parent 58a316d9c8
commit ff64f3a3ab
2 changed files with 34 additions and 21 deletions

View File

@@ -7,19 +7,31 @@ from typing import TypedDict, Dict, List
from joblib import load from joblib import load
from pprint import pprint from pprint import pprint
# class DamageFilesIndex(TypedDict): # class DamageFilesIndices(TypedDict):
# damage_index: int # damage_index: int
# files: list[int] # files: list[int]
DamageFilesIndex = Dict[int, List[str]] OriginalSingleDamageScenarioFilePath = str
DamageScenarioGroupIndex = int
OriginalSingleDamageScenario = pd.DataFrame
SensorIndex = int
VectorColumnIndex = List[SensorIndex]
VectorColumnIndices = List[VectorColumnIndex]
DamageScenarioGroup = List[OriginalSingleDamageScenario]
GroupDataset = List[DamageScenarioGroup]
def generate_damage_files_index(**kwargs) -> DamageFilesIndex: class DamageFilesIndices(TypedDict):
prefix = kwargs.get("prefix", "zzzAD") damage_index: int
extension = kwargs.get("extension", ".TXT") files: List[str]
num_damage = kwargs.get("num_damage")
file_index_start = kwargs.get("file_index_start")
col = kwargs.get("col") def generate_damage_files_index(**kwargs) -> DamageFilesIndices:
base_path = kwargs.get("base_path") prefix: str = kwargs.get("prefix", "zzzAD")
extension: str = kwargs.get("extension", ".TXT")
num_damage: int = kwargs.get("num_damage")
file_index_start: int = kwargs.get("file_index_start")
col: int = kwargs.get("col")
base_path: str = kwargs.get("base_path")
damage_scenarios = {} damage_scenarios = {}
a = file_index_start a = file_index_start
@@ -53,7 +65,7 @@ def generate_damage_files_index(**kwargs) -> DamageFilesIndex:
class DataProcessor: class DataProcessor:
def __init__(self, file_index: Dict[int, List[str]], cache_path: str = None): def __init__(self, file_index: DamageFilesIndices, cache_path: str = None):
self.file_index = file_index self.file_index = file_index
if cache_path: if cache_path:
self.data = load(cache_path) self.data = load(cache_path)
@@ -80,7 +92,7 @@ class DataProcessor:
return tokens # Prepend 'Time' column if applicable return tokens # Prepend 'Time' column if applicable
def _load_dataframe(self, file_path: str) -> pd.DataFrame: def _load_dataframe(self, file_path: str) -> OriginalSingleDamageScenario:
""" """
Loads a single data file into a pandas DataFrame. Loads a single data file into a pandas DataFrame.
@@ -94,7 +106,7 @@ class DataProcessor:
df.columns = col_names df.columns = col_names
return df return df
def _load_all_data(self) -> List[List[pd.DataFrame]]: def _load_all_data(self) -> GroupDataset:
""" """
Loads all data files based on the grouping dictionary and returns a nested list. Loads all data files based on the grouping dictionary and returns a nested list.
@@ -164,12 +176,12 @@ class DataProcessor:
else type(self.data).__name__ else type(self.data).__name__
) )
def _create_vector_column_index(self): def _create_vector_column_index(self) -> VectorColumnIndices:
vector_col_idx = [] vector_col_idx: VectorColumnIndices = []
y = 0 y = 0
for data_group in self.data: # len(data_group[i]) = 5 for data_group in self.data: # len(data_group[i]) = 5
for j in data_group: # len(j[i]) = for j in data_group: # len(j[i]) =
c = [] # column vector c_{j} c: VectorColumnIndex = [] # column vector c_{j}
x = 0 x = 0
for _ in range(6): # TODO: range(6) should be dynamic and parameterized for _ in range(6): # TODO: range(6) should be dynamic and parameterized
c.append(x + y) c.append(x + y)
@@ -178,7 +190,7 @@ class DataProcessor:
y += 1 y += 1
return vector_col_idx return vector_col_idx
def create_vector_column(self, overwrite=True): def create_vector_column(self, overwrite=True) -> List[List[List[pd.DataFrame]]]:
""" """
Create a vector column from the loaded data. Create a vector column from the loaded data.

View File

@@ -1,7 +1,8 @@
from convert import * from convert import *
from joblib import dump from joblib import dump, load
a = generate_damage_files_index( # a = generate_damage_files_index(
num_damage=6, file_index_start=1, col=5, base_path="D:/thesis/data/dataset_A" # num_damage=6, file_index_start=1, col=5, base_path="D:/thesis/data/dataset_A"
) # )
dump(DataProcessor(file_index=a), "D:/cache.joblib") # dump(DataProcessor(file_index=a), "D:/cache.joblib")
a = load("D:/cache.joblib")