Skip to content

Commit

Permalink
Merge pull request #12 from ispras/issue11
Browse files Browse the repository at this point in the history
#11: add predict frequency resampling
  • Loading branch information
AvetisyanAram authored Jul 24, 2024
2 parents b7b9a5c + 8fdbe97 commit 21657c5
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 31 deletions.
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,24 +137,30 @@ augmented_record = Compose(transforms=[
```

### Predict
This module allows users to test trained model with the architecture from `ecglib`. You can get the prediction for the specific ECG record or the prediction for all the records in the directory.
This module allows users to test trained model with the architecture from `ecglib`. You can get the prediction for the specific ECG record or the prediction for all the records in the directory. For NPZ-typed records, the `ecg_frequency` parameter should be specified either as an integer for all records in the directory or as a dictionary with record names and their corresponding frequency.

```python
# Predict example
from ecglib.predict import Predict

ecg_signal = wfdb.rdsamp("wfdb_file")[0] # for example 00001_hr from PTB-XL dataset
ecg_signal, ann = wfdb.rdsamp("wfdb_file") # for example 00001_hr from PTB-XL dataset
ecg_frequency = ann["fs"]

predict = Predict(
weights_path="/path/to/model_weights",
model_name="densenet1d121",
pathologies=["AFIB"],
frequency=500,
model_frequency=500,
device="cuda:0",
threshold=0.5
)

result_df = predict.predict_directory(directory="path/to/data_to_predict",
file_type="wfdb")
print(predict.predict(ecg_signal, channels_first=False))
ecg_prediction = predict.predict(ecg_signal, ecg_frequency, channels_first=False)

result_df_wfdb = predict.predict_directory(directory="path/to/data_to_predict",
file_type="wfdb")

result_df_npz = predict.predict_directory(directory="path/to/data_to_predict",
file_type="npz",
ecg_frequency=1000)
```
8 changes: 5 additions & 3 deletions src/ecglib/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,17 @@ def read_ecg_record(
):
if data_type == "npz":
ecg_record = np.load(file_path)["arr_0"].astype("float64")
frequency = None
elif data_type == "wfdb":
ecg_record, _ = wfdb.rdsamp(file_path, channels=leads)
ecg_record, ann = wfdb.rdsamp(file_path, channels=leads)
ecg_record = ecg_record.T
ecg_record = ecg_record.astype("float64")
frequency = ann['fs']
else:
raise ValueError(
'data_type can have only values from the list ["npz", "wfdb"]'
)
return ecg_record
return ecg_record, frequency

def take_metadata(self, index: int):
"""
Expand Down Expand Up @@ -140,7 +142,7 @@ def __getitem__(self, index):
file_path = self.ecg_data.iloc[index]["fpath"]

# data standartization (scaling, resampling, cuts off, normalization and padding/truncation)
ecg_record = self.read_ecg_record(file_path, self.data_type, self.leads)
ecg_record, _ = self.read_ecg_record(file_path, self.data_type, self.leads)
full_ecg_record_info = EcgRecord(
signal=ecg_record[self.leads, :],
frequency=ecg_frequency,
Expand Down
63 changes: 41 additions & 22 deletions src/ecglib/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def tabular_metadata_handler(


def get_full_record(
frequency: int,
record: str,
ecg_frequency: int,
model_frequency: int,
record: Union[np.ndarray, torch.Tensor],
patient_meta: dict,
ecg_meta: dict,
normalization: str = "z_norm",
Expand All @@ -43,8 +44,9 @@ def get_full_record(
"""
Returns a full record from raw record, patient metadata, ECG metadata, and configuration.
:param frequency: int, frequency of the ECG record
:param record: str, path to ECG record
:param ecg_frequency: int, frequency of the ECG record
:param model_frequency: int, frequency of the trained model
:param record: Union[np.ndarray, torch.Tensor], ECG record
:param patient_meta: dict, patient metadata
:param ecg_meta: dict, ECG metadata
:param normalization: str, normalization type
Expand All @@ -55,16 +57,15 @@ def get_full_record(
"""

record = record[:,]
frequency = int(frequency)
patient_meta = patient_meta
trained_freq = frequency
if preprocess:
record_processed = P.Compose(transforms=preprocess, p=1.0)(record)
else:
record_processed = P.Compose(
transforms=[
P.FrequencyResample(
ecg_frequency=frequency, requested_frequency=trained_freq
ecg_frequency=int(ecg_frequency),
requested_frequency=int(model_frequency),
),
P.Normalization(norm_type=normalization),
],
Expand All @@ -89,7 +90,7 @@ def __init__(
weights_path: str,
model_name: str,
pathologies: list,
frequency: int,
model_frequency: int,
device: str,
threshold: float,
leads: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
Expand All @@ -106,7 +107,7 @@ def __init__(
:param weights_path: str, path to the model weights
:param model_name: str, name of the model
:param pathologies: list, list of pathologies
:param frequency: int, frequency of the ECG record
:param model_frequency: int, frequency of the trained model
:param device: str, device to be used for computations
:param threshold: float, threshold for the model
:param leads: list, list of leads
Expand All @@ -124,7 +125,7 @@ def __init__(
self.leads_num = len(leads)
self.leads = leads
self.device = device
self.frequency = frequency
self.model_frequency = model_frequency
self.use_sigmoid = use_sigmoid

if model is None:
Expand Down Expand Up @@ -154,15 +155,17 @@ def __init__(

def predict(
self,
record,
ecg_meta=None,
patient_meta=None,
channels_first=True,
record: Union[np.ndarray, torch.Tensor],
ecg_frequency: int,
ecg_meta: dict = None,
patient_meta: dict = None,
channels_first: bool = True,
):
"""
Function that evaluates the model on a single ECG record.
:param record: np.array or torch.tensor, ECG record
:param ecg_frequency: int, frequency of the ECG record
:param ecg_meta: dict, ECG metadata (default is None)
:param patient_meta: dict, patient metadata (default is None)
:param channels_first: bool, whether the channels are the first dimension in the input data (default is False)
Expand Down Expand Up @@ -190,7 +193,8 @@ def predict(
)

input_ = get_full_record(
self.frequency,
ecg_frequency,
self.model_frequency,
self.record,
self.patient_meta,
self.ecg_meta,
Expand Down Expand Up @@ -225,11 +229,12 @@ def predict(

def predict_directory(
self,
directory,
file_type,
write_to_file=None,
ecg_meta=None,
patient_meta=None,
directory:str,
file_type:str,
ecg_frequency:Union[dict, int, None]=None,
write_to_file:str=None,
ecg_meta:List[dict]=None,
patient_meta:List[dict]=None,
):
"""
Evaluates the model on all ECG records in a directory.
Expand All @@ -239,9 +244,11 @@ def predict_directory(
:param write_to_file: str, path to the file where the predictions will be written (default is None), or None if the predictions should not be written to a file
:param ecg_meta: list of dicts, each dict contains "filename" and "data" keys. ECG metadata (default is None)
:param patient_meta: list of dicts, each dict contains "filename" and "data" keys. Patient metadata (default is None)
:param ecg_frequency: the frequency of the ECG records
:return: pd.DataFrame, dataframe with the predictions
"""

if ecg_meta:
ecg_meta = sorted(ecg_meta, key=lambda k: k["filename"])
if patient_meta:
Expand All @@ -256,6 +263,14 @@ def predict_directory(
record_files = [file for file in all_files if file.endswith(file_type)]
record_files = sorted(record_files)

if ecg_frequency is None:
ecg_frequencies = {}
elif isinstance(ecg_frequency, int):
ecg_frequencies = {record_file: ecg_frequency for record_file in record_files}
elif isinstance(ecg_frequency, dict):
assert all(isinstance(value, int) for value in ecg_frequencies.values()), "All values in ecg_frequency should be integers."
ecg_frequencies = ecg_frequency

answer_df = pd.DataFrame(
columns=["filename", "raw_out", "prob_out", "label_out"]
)
Expand All @@ -276,10 +291,14 @@ def predict_directory(
patient_meta_counter += 1
patient_meta_ = patient_meta[patient_meta_counter]["data"]

record_ = EcgDataset.read_ecg_record(
record_, record_frequency = EcgDataset.read_ecg_record(
None, os.path.join(directory, record), file_type
)
record_answer = self.predict(record_, ecg_meta_, patient_meta_)

ecg_frequency = record_frequency if record_frequency is not None else ecg_frequencies.get(record)
assert ecg_frequency is not None, "The file should contain the record frequency or the ecg_frequency should be defined."

record_answer = self.predict(record_, ecg_frequency, ecg_meta_, patient_meta_)

answer_df_current = pd.DataFrame(
{
Expand Down

0 comments on commit 21657c5

Please sign in to comment.