Skip to content

Commit

Permalink
Dataset: Make PDMDataset to load multiple types of files
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianolorenti committed Feb 25, 2024
1 parent 2212495 commit a3cdc67
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
9 changes: 0 additions & 9 deletions ceruleo/dataset/catalog/PHMDataset2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,6 @@ def get_key_from_filename(filename: str) -> str:
)
)


def _load_life(self, filename: str) -> pd.DataFrame:
return pd.read_parquet(filename)

def get_time_series(self, i: int) -> pd.DataFrame:
df = self._load_life(self.cycles_metadata.iloc[i]["Filename"])
return df


def prepare_raw_dataset(self):
"""Download and unzip the raw files
Expand Down
25 changes: 22 additions & 3 deletions ceruleo/dataset/ts_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Iterable
from pathlib import Path



try:
from types import EllipsisType
except:
Expand All @@ -16,9 +18,10 @@
TENSORFLOW_ENABLED = True
except:
TENSORFLOW_ENABLED = False
from tqdm.auto import tqdm
from abc import ABC, abstractmethod, abstractproperty

from tqdm.auto import tqdm


class DatasetIterator:
def __init__(self, dataset):
Expand Down Expand Up @@ -393,8 +396,24 @@ def _prepare_dataset(self):
pass

def get_time_series(self, i: int) -> pd.DataFrame:
df = pd.read_csv(self.cycles_metadata.iloc[i]["Filename"])
return df
file_path = Path(self.cycles_metadata.iloc[i]["Filename"]).resolve()
file_extension = file_path.suffix.lower()
if file_extension == '.csv':
return pd.read_csv(file_path)
elif file_extension in ('.xls', '.xlsx'):
return pd.read_excel(file_path)
elif file_extension == '.parquet':
return pd.read_parquet(file_path)
elif file_extension == '.feather':
return pd.read_feather(file_path)
elif file_extension in ('.h5', '.hdf5'):
return pd.read_hdf(file_path)
elif file_extension in ('.pkl', '.pickle'):
return pd.read_pickle(file_path)
else:
raise ValueError(f"Unsupported file extension: {file_extension}")



@property
def n_time_series(self) -> int:
Expand Down

0 comments on commit a3cdc67

Please sign in to comment.