diff --git a/README.md b/README.md index 5c64225..1e530b9 100644 --- a/README.md +++ b/README.md @@ -137,24 +137,30 @@ augmented_record = Compose(transforms=[ ``` ### Predict -This module allows users to test trained model with the architecture from `ecglib`. You can get the prediction for the specific ECG record or the prediction for all the records in the directory. +This module allows users to test trained model with the architecture from `ecglib`. You can get the prediction for the specific ECG record or the prediction for all the records in the directory. For NPZ-typed records, the `ecg_frequency` parameter should be specified either as an integer for all records in the directory or as a dictionary with record names and their corresponding frequency. ```python # Predict example from ecglib.predict import Predict -ecg_signal = wfdb.rdsamp("wfdb_file")[0] # for example 00001_hr from PTB-XL dataset +ecg_signal, ann = wfdb.rdsamp("wfdb_file") # for example 00001_hr from PTB-XL dataset +ecg_frequency = ann["fs"] predict = Predict( weights_path="/path/to/model_weights", model_name="densenet1d121", pathologies=["AFIB"], - frequency=500, + model_frequency=500, device="cuda:0", threshold=0.5 ) -result_df = predict.predict_directory(directory="path/to/data_to_predict", - file_type="wfdb") -print(predict.predict(ecg_signal, channels_first=False)) +ecg_prediction = predict.predict(ecg_signal, ecg_frequency, channels_first=False) + +result_df_wfdb = predict.predict_directory(directory="path/to/data_to_predict", + file_type="wfdb") + +result_df_npz = predict.predict_directory(directory="path/to/data_to_predict", + file_type="npz", + ecg_frequency=1000) ``` diff --git a/src/ecglib/data/datasets.py b/src/ecglib/data/datasets.py index acb69f6..381d035 100644 --- a/src/ecglib/data/datasets.py +++ b/src/ecglib/data/datasets.py @@ -88,15 +88,17 @@ def read_ecg_record( ): if data_type == "npz": ecg_record = np.load(file_path)["arr_0"].astype("float64") + frequency = None elif data_type == "wfdb": - ecg_record, _ = wfdb.rdsamp(file_path, channels=leads) + ecg_record, ann = wfdb.rdsamp(file_path, channels=leads) ecg_record = ecg_record.T ecg_record = ecg_record.astype("float64") + frequency = ann['fs'] else: raise ValueError( 'data_type can have only values from the list ["npz", "wfdb"]' ) - return ecg_record + return ecg_record, frequency def take_metadata(self, index: int): """ @@ -140,7 +142,7 @@ def __getitem__(self, index): file_path = self.ecg_data.iloc[index]["fpath"] # data standartization (scaling, resampling, cuts off, normalization and padding/truncation) - ecg_record = self.read_ecg_record(file_path, self.data_type, self.leads) + ecg_record, _ = self.read_ecg_record(file_path, self.data_type, self.leads) full_ecg_record_info = EcgRecord( signal=ecg_record[self.leads, :], frequency=ecg_frequency, diff --git a/src/ecglib/predict/predict.py b/src/ecglib/predict/predict.py index 478db1c..4170bab 100644 --- a/src/ecglib/predict/predict.py +++ b/src/ecglib/predict/predict.py @@ -32,8 +32,9 @@ def tabular_metadata_handler( def get_full_record( - frequency: int, - record: str, + ecg_frequency: int, + model_frequency: int, + record: Union[np.ndarray, torch.Tensor], patient_meta: dict, ecg_meta: dict, normalization: str = "z_norm", @@ -43,8 +44,9 @@ def get_full_record( """ Returns a full record from raw record, patient metadata, ECG metadata, and configuration. - :param frequency: int, frequency of the ECG record - :param record: str, path to ECG record + :param ecg_frequency: int, frequency of the ECG record + :param model_frequency: int, frequency of the trained model + :param record: Union[np.ndarray, torch.Tensor], ECG record :param patient_meta: dict, patient metadata :param ecg_meta: dict, ECG metadata :param normalization: str, normalization type @@ -55,16 +57,15 @@ def get_full_record( """ record = record[:,] - frequency = int(frequency) patient_meta = patient_meta - trained_freq = frequency if preprocess: record_processed = P.Compose(transforms=preprocess, p=1.0)(record) else: record_processed = P.Compose( transforms=[ P.FrequencyResample( - ecg_frequency=frequency, requested_frequency=trained_freq + ecg_frequency=int(ecg_frequency), + requested_frequency=int(model_frequency), ), P.Normalization(norm_type=normalization), ], @@ -89,7 +90,7 @@ def __init__( weights_path: str, model_name: str, pathologies: list, - frequency: int, + model_frequency: int, device: str, threshold: float, leads: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], @@ -106,7 +107,7 @@ def __init__( :param weights_path: str, path to the model weights :param model_name: str, name of the model :param pathologies: list, list of pathologies - :param frequency: int, frequency of the ECG record + :param model_frequency: int, frequency of the trained model :param device: str, device to be used for computations :param threshold: float, threshold for the model :param leads: list, list of leads @@ -124,7 +125,7 @@ def __init__( self.leads_num = len(leads) self.leads = leads self.device = device - self.frequency = frequency + self.model_frequency = model_frequency self.use_sigmoid = use_sigmoid if model is None: @@ -154,15 +155,17 @@ def __init__( def predict( self, - record, - ecg_meta=None, - patient_meta=None, - channels_first=True, + record: Union[np.ndarray, torch.Tensor], + ecg_frequency: int, + ecg_meta: dict = None, + patient_meta: dict = None, + channels_first: bool = True, ): """ Function that evaluates the model on a single ECG record. :param record: np.array or torch.tensor, ECG record + :param ecg_frequency: int, frequency of the ECG record :param ecg_meta: dict, ECG metadata (default is None) :param patient_meta: dict, patient metadata (default is None) :param channels_first: bool, whether the channels are the first dimension in the input data (default is False) @@ -190,7 +193,8 @@ def predict( ) input_ = get_full_record( - self.frequency, + ecg_frequency, + self.model_frequency, self.record, self.patient_meta, self.ecg_meta, @@ -225,11 +229,12 @@ def predict( def predict_directory( self, - directory, - file_type, - write_to_file=None, - ecg_meta=None, - patient_meta=None, + directory:str, + file_type:str, + ecg_frequency:Union[dict, int, None]=None, + write_to_file:str=None, + ecg_meta:List[dict]=None, + patient_meta:List[dict]=None, ): """ Evaluates the model on all ECG records in a directory. @@ -239,9 +244,11 @@ def predict_directory( :param write_to_file: str, path to the file where the predictions will be written (default is None), or None if the predictions should not be written to a file :param ecg_meta: list of dicts, each dict contains "filename" and "data" keys. ECG metadata (default is None) :param patient_meta: list of dicts, each dict contains "filename" and "data" keys. Patient metadata (default is None) + :param ecg_frequency: the frequency of the ECG records :return: pd.DataFrame, dataframe with the predictions """ + if ecg_meta: ecg_meta = sorted(ecg_meta, key=lambda k: k["filename"]) if patient_meta: @@ -256,6 +263,14 @@ def predict_directory( record_files = [file for file in all_files if file.endswith(file_type)] record_files = sorted(record_files) + if ecg_frequency is None: + ecg_frequencies = {} + elif isinstance(ecg_frequency, int): + ecg_frequencies = {record_file: ecg_frequency for record_file in record_files} + elif isinstance(ecg_frequency, dict): + assert all(isinstance(value, int) for value in ecg_frequencies.values()), "All values in ecg_frequency should be integers." + ecg_frequencies = ecg_frequency + answer_df = pd.DataFrame( columns=["filename", "raw_out", "prob_out", "label_out"] ) @@ -276,10 +291,14 @@ def predict_directory( patient_meta_counter += 1 patient_meta_ = patient_meta[patient_meta_counter]["data"] - record_ = EcgDataset.read_ecg_record( + record_, record_frequency = EcgDataset.read_ecg_record( None, os.path.join(directory, record), file_type ) - record_answer = self.predict(record_, ecg_meta_, patient_meta_) + + ecg_frequency = record_frequency if record_frequency is not None else ecg_frequencies.get(record) + assert ecg_frequency is not None, "The file should contain the record frequency or the ecg_frequency should be defined." + + record_answer = self.predict(record_, ecg_frequency, ecg_meta_, patient_meta_) answer_df_current = pd.DataFrame( {