refactor(data): update type annotations for damage files index and related classes. Need better implementation
This commit is contained in:
@@ -7,19 +7,31 @@ from typing import TypedDict, Dict, List
|
||||
from joblib import load
|
||||
from pprint import pprint
|
||||
|
||||
# class DamageFilesIndex(TypedDict):
|
||||
# class DamageFilesIndices(TypedDict):
|
||||
# damage_index: int
|
||||
# files: list[int]
|
||||
DamageFilesIndex = Dict[int, List[str]]
|
||||
OriginalSingleDamageScenarioFilePath = str
|
||||
DamageScenarioGroupIndex = int
|
||||
OriginalSingleDamageScenario = pd.DataFrame
|
||||
SensorIndex = int
|
||||
VectorColumnIndex = List[SensorIndex]
|
||||
VectorColumnIndices = List[VectorColumnIndex]
|
||||
DamageScenarioGroup = List[OriginalSingleDamageScenario]
|
||||
GroupDataset = List[DamageScenarioGroup]
|
||||
|
||||
|
||||
def generate_damage_files_index(**kwargs) -> DamageFilesIndex:
|
||||
prefix = kwargs.get("prefix", "zzzAD")
|
||||
extension = kwargs.get("extension", ".TXT")
|
||||
num_damage = kwargs.get("num_damage")
|
||||
file_index_start = kwargs.get("file_index_start")
|
||||
col = kwargs.get("col")
|
||||
base_path = kwargs.get("base_path")
|
||||
class DamageFilesIndices(TypedDict):
|
||||
damage_index: int
|
||||
files: List[str]
|
||||
|
||||
|
||||
def generate_damage_files_index(**kwargs) -> DamageFilesIndices:
|
||||
prefix: str = kwargs.get("prefix", "zzzAD")
|
||||
extension: str = kwargs.get("extension", ".TXT")
|
||||
num_damage: int = kwargs.get("num_damage")
|
||||
file_index_start: int = kwargs.get("file_index_start")
|
||||
col: int = kwargs.get("col")
|
||||
base_path: str = kwargs.get("base_path")
|
||||
|
||||
damage_scenarios = {}
|
||||
a = file_index_start
|
||||
@@ -53,7 +65,7 @@ def generate_damage_files_index(**kwargs) -> DamageFilesIndex:
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
def __init__(self, file_index: Dict[int, List[str]], cache_path: str = None):
|
||||
def __init__(self, file_index: DamageFilesIndices, cache_path: str = None):
|
||||
self.file_index = file_index
|
||||
if cache_path:
|
||||
self.data = load(cache_path)
|
||||
@@ -80,7 +92,7 @@ class DataProcessor:
|
||||
|
||||
return tokens # Prepend 'Time' column if applicable
|
||||
|
||||
def _load_dataframe(self, file_path: str) -> pd.DataFrame:
|
||||
def _load_dataframe(self, file_path: str) -> OriginalSingleDamageScenario:
|
||||
"""
|
||||
Loads a single data file into a pandas DataFrame.
|
||||
|
||||
@@ -94,7 +106,7 @@ class DataProcessor:
|
||||
df.columns = col_names
|
||||
return df
|
||||
|
||||
def _load_all_data(self) -> List[List[pd.DataFrame]]:
|
||||
def _load_all_data(self) -> GroupDataset:
|
||||
"""
|
||||
Loads all data files based on the grouping dictionary and returns a nested list.
|
||||
|
||||
@@ -164,12 +176,12 @@ class DataProcessor:
|
||||
else type(self.data).__name__
|
||||
)
|
||||
|
||||
def _create_vector_column_index(self):
|
||||
vector_col_idx = []
|
||||
def _create_vector_column_index(self) -> VectorColumnIndices:
|
||||
vector_col_idx: VectorColumnIndices = []
|
||||
y = 0
|
||||
for data_group in self.data: # len(data_group[i]) = 5
|
||||
for j in data_group: # len(j[i]) =
|
||||
c = [] # column vector c_{j}
|
||||
c: VectorColumnIndex = [] # column vector c_{j}
|
||||
x = 0
|
||||
for _ in range(6): # TODO: range(6) should be dynamic and parameterized
|
||||
c.append(x + y)
|
||||
@@ -178,7 +190,7 @@ class DataProcessor:
|
||||
y += 1
|
||||
return vector_col_idx
|
||||
|
||||
def create_vector_column(self, overwrite=True):
|
||||
def create_vector_column(self, overwrite=True) -> List[List[List[pd.DataFrame]]]:
|
||||
"""
|
||||
Create a vector column from the loaded data.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user