diff --git a/.gitignore b/.gitignore index 6b727a1..2bb5cad 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,22 @@ __pycache__/ .pybuilder/ .ipynb_checkpoints + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg diff --git a/MANIFEST.in b/MANIFEST.in index 61d0546..4b4fd6f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -include ecglib/models/weights/*.yaml \ No newline at end of file +include src/ecglib/models/weights/*.yaml \ No newline at end of file diff --git a/README.md b/README.md index de7da8a..5c64225 100644 --- a/README.md +++ b/README.md @@ -2,56 +2,102 @@ ## Table of contents -- [Introduction](#datasets) -- [Datasets](#datasets) +- [Introduction](#introduction) +- [Credits](#credits) +- [Installation](#installation) +- [Data](#data) - [Models](#models) - [Preprocessing](#preprocessing) -- [ToDo](#todo) -- [Credits](#credits) +- [Predict](#predict) ### Introduction -**Ecg** **lib**rary (`ecglib`) is a tool for ECG signal analysis. The library helps with preprocessing ECG signals, downloading the datasets, creating Dataset classes to train models to detect ECG pathologies. The library allows researchers to use model architectures pretrained on more than 500,000 ECG records to fine-tune them on their own datasets. +**Ecg** **lib**rary (`ecglib`) is a tool for ECG signal analysis. The library helps with preprocessing ECG signals, downloading the datasets, creating Dataset classes to train models to detect ECG pathologies and EcgRecord classes to store their records. The library allows researchers to use model architectures pretrained on more than 500,000 ECG records to fine-tune them on their own datasets. + +### Credits + +If you find this tool useful in your research, please consider citing the paper: + +1) **Deep Neural Networks Generalization and Fine-Tuning for 12-lead ECG Classification** - We demonstrate that training deep neural networks on a large dataset and fine-tuning it on a small dataset from another domain outperforms the networks trained only on one of the datasets. + + @article{avetisyan2023deep, + title={Deep Neural Networks Generalization and Fine-Tuning for 12-lead ECG Classification}, + author={Avetisyan, Aram and Tigranyan, Shahane and Asatryan, Ariana and Mashkova, Olga and Skorik, Sergey and Ananev, Vladislav and Markin, Yury}, + journal={arXiv preprint arXiv:2305.18592}, + year={2023} + } -### Datasets -This module allows user to load and store ECG datasets in different formats and to extract meta-information about each single ECG signal (i.e. frequency, full path to file, scp-codes, patient age etc.). +### Installation + +To install the latest version from PyPI: + +``` +pip install ecglib +``` + +### Data +This module allows user to load and store ECG datasets and records in different formats and to extract meta-information about each single ECG signal (i.e. frequency, full path to file, scp-codes, patient age etc.). Via `load_datasets.py` one can download [PTB-XL ECG database](https://physionet.org/content/ptb-xl/1.0.2/) in its original *wfdb* format and to store information concerning each record in a *csv* file. ```python # download PTB-XL ECG dataset -from ecglib.datasets import load_ptb_xl +from ecglib.data import load_ptb_xl ptb_xl_info = load_ptb_xl(download=True) ``` -`datasets.py` script contains classes for storing ECG datasets. -- *ECGDataset* is a general class for storing main features of your ECG dataset such as number of leads, number of classes to predict, augmentation etc.. -- *PTBXLDataset* is a child class with respect to *ECGDataset*; one can load each record from *wfdb* or *npz* format and preprocess it before further utilization. It is also possible to create a *png* picture of each record using [ecg-plot](https://pypi.org/project/ecg-plot/). +Via `ecg_record.py` one can create class *EcgRecord* to store important information about an ECG record. ```python -# create PTBXLDataset class from PTB-XL map file +# creating EcgRecord class for example file + +from ecglib.data import EcgRecord +import wfdb + +ecg_signal = wfdb.rdsamp("wfdb_file")[0] # for example 00001_hr from PTB-XL dataset +ecg_record = EcgRecord(signal=ecg_signal.T, frequency=500, patient_id=1) + +``` + +Via `datasets.py` one can create class *EcgDataset* to store ECG datasets. It stores main features of your ECG dataset such as number of leads, number of classes to predict, augmentation etc. It is also possible to plot each record using [ecg-plot](https://pypi.org/project/ecg-plot/). + +```python +# create EcgDataset class from # fit targets for 'AFIB' binary classification -from ecglib.datasets import PTBXLDataset +from ecglib.data import EcgDataset targets = [[0.0] if 'AFIB' in eval(ptb_xl_info.iloc[i]['scp_codes']).keys() else [1.0] for i in range(ptb_xl_info.shape[0])] -ecg_data = PTBXLDataset(ecg_data=ptb_xl_info, target=targets) +ecg_data = EcgDataset(ecg_data=ptb_xl_info, target=targets) ``` ### Models This module comprises components of model architectures and open weights for models derived from binary classification experiments in several pathologies. -`create_model` function allows user to create a model from scratch (currently supported architectures include *densenet1d121*, *densenet1d201*) as well as load a pretrained model checkpoint from `weights` folder (currently supported architectures include *densenet1d121*). +`create_model` function allows user to create a model from scratch (supported architectures include *resnet1d18*, *resnet1d50*, *resnet1d101*, *densenet1d121*, *densenet1d201*) as well as load a pretrained model checkpoint from `weights` folder (supported architectures include *resnet1d18*, *resnet1d50*, *resnet1d101*, *densenet1d121*). `create_model` also allows to use both ECG record and metadata during training by concating FCN to the network that takes ECG record as an input. ```python # create 'densenet1d121' model from scratch for binary classification 12-lead experiment -from ecglib.models import create_model +from ecglib.models.model_builder import create_model + +model = create_model(model_name='densenet1d121', pathology='1AVB', pretrained=False) -model = create_model(model_name='densenet1d121', pathology='1AVB', pretrained=False, leads_num=12) +# create 'cnntabular' model with 'densenet1d121' architecture for ECG record and FCN for metadata. Number of input features is set to 5 by default and can be changed by adding config + +from ecglib.models.model_builder import Combination +from ecglib.models.config.model_configs import DenseNetConfig, TabularNetConfig + +densenet_config = DenseNetConfig() +tabular_config = TabularNetConfig(inp_features=50) +model = create_model(model_name=['densenet1d121', 'tabular'], + config=[densenet_config, tabular_config], + combine=Combination.CNNTAB, + pathology='1AVB', + pretrained=False) ``` ```python @@ -59,21 +105,19 @@ model = create_model(model_name='densenet1d121', pathology='1AVB', pretrained=Fa from ecglib.models import create_model -model = create_model(model_name='densenet1d121', pathology='AFIB', pretrained=True, leads_num=12) +model = create_model(model_name='densenet1d121', pathology='AFIB', pretrained=True) ``` `architectures` folder includes model architectures. -`config` folder contains default parameter dataclasses for building a model. - -In `weights` folder one can find file with paths to the models derived from the following binary classification 12-lead experiments. Currently avaliable pathologies (scp-codes): *AFIB*, *1AVB*, *STACH*, *SBRAD*, *RBBB*, *LBBB*, *PVC*, *LVH*. +In `weights` folder one can find file with paths to the models derived from the following binary classification 12-lead experiments. Available pathologies (scp-codes): *AFIB*, *1AVB*, *STACH*, *SBRAD*, *IRBBB*, *CRBBB*, *PVC*. ### Preprocessing This module includes framework inspired by [Albumentations](https://albumentations.ai/) Python library for preprocessing and augmenting ECG data. `composition.py` script contains *SomeOf*, *OneOf* and *Compose* structures for building your own preprocessing and augmentation pipeline. -`preprocess.py` and `functional.py` both comprise classes and functions respectively describing different preprocessing and augmentation techniques. For more information see code commentary. +`preprocess.py` and `functional.py` both comprise classes and functions respectively describing different preprocessing and augmentation techniques. You can preprocess either numpy data and EcgRecord data. For more information see code commentary. ```python # augmentation example @@ -81,31 +125,36 @@ import torch from ecglib.preprocessing.preprocess import * from ecglib.preprocessing.composition import * -# provide an ecg-record in `numpy.ndarray` form -ecg_record = read_any_ECG_ndarray_type +# provide an ecg record in a `numpy.ndarray` form +ecg_signal = wfdb.rdsamp("wfdb_file")[0] # for example 00001_hr from PTB-XL dataset +ecg_record = EcgRecord(signal=ecg_signal.T, frequency=500, patient_id=1) augmented_record = Compose(transforms=[ SumAug(leads=[0, 6, 11]), RandomConvexAug(n=4), OneOf(transforms=[ButterworthFilter(), IIRNotchFilter()], transform_prob=[0.8, 0.2]) -], p=0.5)(ecg_record) +], p=0.5)(ecg_record) # ecg_signal can be used instead of ecg_record ``` -### ToDo -**Next release in December 2022** -- **Datasets**: add support for more data formats and datasets. Change TisDataset/PTBXLDataset to remove duplicates -- **Models**: add more model architectures and weights of these models for different pathologies -- **Preprocessing**: add class ECGrecord and update preprocessing methods -- Add possibility to use metadata for analysis -- Add complex segmentation methods +### Predict +This module allows users to test trained model with the architecture from `ecglib`. You can get the prediction for the specific ECG record or the prediction for all the records in the directory. -### Credits -This project is made possible by: - -- [Aram Avetisyan](https://github.com/avetisyanaram) (a.a.avetisyan@gmail.com) -- [Olga Mashkova](https://github.com/omashkova) -- [Vladislav Ananev](https://github.com/Survial53) -- [Shahane Tigranyan](https://github.com/decoder-99) -- [Ariana Asatryan](https://github.com/arianasatryan) -- [Sergey Skorik](https://github.com/Skorik99) -- [Yury Markin](https://github.com/grandkarabas) +```python +# Predict example +from ecglib.predict import Predict + +ecg_signal = wfdb.rdsamp("wfdb_file")[0] # for example 00001_hr from PTB-XL dataset + +predict = Predict( + weights_path="/path/to/model_weights", + model_name="densenet1d121", + pathologies=["AFIB"], + frequency=500, + device="cuda:0", + threshold=0.5 +) + +result_df = predict.predict_directory(directory="path/to/data_to_predict", + file_type="wfdb") +print(predict.predict(ecg_signal, channels_first=False)) +``` diff --git a/ecglib/__init__.py b/ecglib/__init__.py deleted file mode 100644 index f550b72..0000000 --- a/ecglib/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from . import models -from . import preprocessing -from . import datasets - -__all__ = [ - 'models', - 'preprocessing', - 'datasets', -] diff --git a/ecglib/datasets/__init__.py b/ecglib/datasets/__init__.py deleted file mode 100644 index 5e95061..0000000 --- a/ecglib/datasets/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .load_datasets import * -from .datasets import * - -__all__ = [ - 'load_datasets', - 'datasets', -] \ No newline at end of file diff --git a/ecglib/datasets/datasets.py b/ecglib/datasets/datasets.py deleted file mode 100644 index 98ee6fa..0000000 --- a/ecglib/datasets/datasets.py +++ /dev/null @@ -1,393 +0,0 @@ -from enum import IntEnum -from pathlib import Path -from typing import Callable, Optional - -import ecg_plot -import numpy as np -import pandas as pd -import torch -import wfdb -from ecglib import preprocessing as P -from torch.utils.data import Dataset - -__all__ = [ - "EcgDataset", - "TisDataset", - "PTBXLDataset", -] - - -class EcgDataset(Dataset): - """ - EcgDataset - :param ecg_data: dataframe with ecg info - :param target: a list of targets - :param frequency: frequency for signal resampling - :param leads: a list of leads - :param ecg_length: length of ECG signal after padding / cropping - :param cut_range: cutting parameters - :param pad_mode: padding mode - :param classes: number of classes - :param use_meta: whether to use metadata or not - :param augmentation: a bunch of augmentations and other preprocessing techniques - """ - - def __init__( - self, - ecg_data: pd.DataFrame, - target: list, - frequency: int = 500, - leads: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ecg_length: float = 10, - cut_range: list = [0, 0], - pad_mode: str = "constant", - norm_type: str = "z_norm", - classes: int = 2, - use_meta: bool = False, - augmentation: Callable = None, - ): - super().__init__() - if "fpath" not in ecg_data.columns: - raise ValueError("column 'fpath' not in ecg_data") - self.ecg_data = ecg_data - self.target = target - self.frequency = frequency - self.leads = leads - self.ecg_length = ecg_length - self.cut_range = cut_range - self.pad_mode = pad_mode - self.norm_type = norm_type - self.classes = classes - if use_meta and "ecg_parameters" not in ecg_data.columns: - raise ValueError("metadata column 'ecg_parameters' not in ecg_data") - self.meta = use_meta - self.augmentation = augmentation - - def __len__(self): - return self.ecg_data.shape[0] - - def get_fpath(self, index: int) -> str: - """ - Returns path to file with ECG leads - :param index: Index of ECG in dataset - - :return: Path to ECG file - """ - - return self.ecg_data.iloc[index]["fpath"] - - def get_name(self, index: int) -> str: - """ - Returns name of ECG file - :param index: Index of ECG in dataset - - :return: ECG file name - """ - - return str(Path(self.get_fpath(index)).stem) - - -class TisDataset(EcgDataset): - """ - TisDataset - :param ecg_data: dataframe with ecg info - :param target: a list of targets - :param frequency: frequency for signal resampling - :param leads: a list of leads - :param ecg_length: length of ECG signal after padding / cropping - :param cut_range: cutting parameters - :param pad_mode: padding mode - :param classes: number of classes - :param use_meta: whether to use metadata or not - :param augmentation: a bunch of augmentations and other preprocessing techniques - """ - - def __init__( - self, - ecg_data: pd.DataFrame, - target: list, - frequency: int = 500, - leads: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ecg_length: float = 10, - cut_range: list = [0, 0], - pad_mode: str = "constant", - norm_type: str = "z_norm", - classes: int = 2, - use_meta: bool = False, - augmentation: Callable = None, - ): - super().__init__( - ecg_data, - target, - frequency, - leads, - ecg_length, - cut_range, - pad_mode, - norm_type, - classes, - use_meta, - augmentation, - ) - - def __getitem__(self, index): - - if "frequency" in self.ecg_data.columns: - ecg_frequency = float(self.ecg_data.iloc[index]["frequency"]) - else: - ecg_frequency = self.frequency - - # data standartization (scaling, resampling, cuts off, normalization and padding/truncation) - ecg_record = np.load(self.ecg_data.iloc[index]["fpath"])["arr_0"].astype( - "float64" - ) - ecg_record = P.Compose( - transforms=[ - P.FrequencyResample( - ecg_frequency=ecg_frequency, requested_frequency=self.frequency - ), - P.EdgeCut(cut_range=self.cut_range, frequency=self.frequency), - P.Normalization(norm_type=self.norm_type), - P.Padding( - observed_ecg_length=self.ecg_length, frequency=self.frequency - ), - ], - p=1.0, - )(ecg_record) - assert not np.any( - np.isnan(ecg_record) - ), f"ecg_record = {ecg_record}, index = {index}" - - # data preprocessing if specified (augmentation, filtering) - if self.augmentation is not None: - ecg_record = self.augmentation(ecg_record) - - target = self.target[index] - - result = [ - torch.tensor(ecg_record[self.leads, :], dtype=torch.float), - torch.tensor(target, dtype=torch.float), - ] - - if self.meta: - result.append( - torch.tensor( - self.ecg_data.iloc[index]["ecg_parameters"], dtype=torch.float - ) - ) - - return (index, result) - - def save_as_png( - self, index: int, dest_path: str, postfix: Optional[str] = None - ) -> None: - """ - Saves the image of ecg record - - :param index: Index of ECG - :param dest_path: Directory to save the image - :param postfix: Subdirectory where the image will be saved, defaults to None - """ - - ecg = (np.load(self.get_fpath(index))["arr_0"].astype("float64"),) - ecg = np.squeeze(ecg) - - if "frequency" in self.ecg_data.columns: - frequency = self.ecg_data.iloc[index]["frequency"] - else: - frequency = self.frequency - ecg = ecg / np.max( - ecg - ) # added to fit the record to the visible part of the plot - ecg_plot.plot(ecg, sample_rate=frequency) - ecg_fname = self.get_name(index) - - if postfix: - dest_path = str(Path(dest_path).joinpath(postfix)) - - dest_path = ( - "{}/".format(dest_path) if not dest_path.endswith("/") else dest_path - ) - - if not Path(dest_path).exists(): - Path(dest_path).mkdir(parents=True, exist_ok=True) - - ecg_plot.save_as_png(file_name=ecg_fname, path=dest_path) - - @staticmethod - def for_train_from_config( - data: pd.DataFrame, target: list, augmentation: Callable, config: dict, classes_num: int - ) -> EcgDataset: - """ - A wrapper with just four parameters to create `TisDataset` for training and validation - :param data: dataframe with ecg info - :param target: a list of targets - :param augmentation: a bunch of augmentations and other preprocessing techniques - :param config: config dictionary - :param classes_num: number of classes - - :return: TisDataset - """ - - return TisDataset( - data, - target, - frequency=config.ecg_record_params.resampled_frequency, - leads=config.ecg_record_params.leads, - ecg_length=config.ecg_record_params.observed_ecg_length, - norm_type=config.ecg_record_params.normalization, - classes=classes_num, - use_meta=config.ecg_metadata.use_metadata, - cut_range=config.ecg_record_params.ecg_cut_range, - augmentation=augmentation, - ) - - -class PTBXLDataset(EcgDataset): - """ - PTBXLDataset - :param ecg_data: dataframe with ecg info - :param target: a list of targets - :param frequency: frequency for signal resampling - :param leads: a list of leads - :param ecg_length: length of ECG signal after padding / cropping - :param cut_range: cutting parameters - :param pad_mode: padding mode - :param classes: number of classes - :param use_meta: whether to use metadata or not - :param augmentation: a bunch of augmentations and other preprocessing techniques - """ - - def __init__( - self, - ecg_data: pd.DataFrame, - target: list, - frequency: int = 500, - leads: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ecg_length: float = 10, - cut_range: list = [0, 0], - pad_mode: str = "constant", - norm_type: str = "z_norm", - classes: int = 2, - use_meta: bool = False, - augmentation: Callable = None, - ): - super().__init__( - ecg_data, - target, - 500, - leads, - 10, - cut_range, - pad_mode, - norm_type, - classes, - False, - augmentation, - ) - - def __getitem__(self, index): - if "frequency" in self.ecg_data.columns: - ecg_frequency = float(self.ecg_data.iloc[index]["frequency"]) - else: - ecg_frequency = self.frequency - - ecg_record, _ = wfdb.rdsamp( - self.ecg_data.iloc[index]["fpath"], channels=self.leads - ) - ecg_record = ecg_record.T - ecg_record = ecg_record.astype("float64") - - # data standartization (scaling, resampling, cuts off, normalization and padding/truncation) - ecg_record = P.Compose( - transforms=[ - P.FrequencyResample( - ecg_frequency=ecg_frequency, requested_frequency=self.frequency - ), - P.EdgeCut(cut_range=self.cut_range, frequency=self.frequency), - P.Normalization(norm_type=self.norm_type), - P.Padding( - observed_ecg_length=self.ecg_length, frequency=self.frequency - ), - ], - p=1.0, - )(ecg_record) - - # data preprocessing if specified (augmentation, filtering) - if self.augmentation is not None: - ecg_record = self.augmentation(ecg_record) - - target = self.target[index] - - result = [ - torch.tensor(ecg_record[self.leads, :], dtype=torch.float), - torch.tensor(target, dtype=torch.float), - ] - - return (index, result) - - def save_as_png( - self, index: int, dest_path: str, postfix: Optional[str] = None - ) -> None: - """ - Saves the image of ecg record - :param index: Index of ECG - :param dest_path: Directory to save the image - :param postfix: Subdirectory where the image will be saved, defaults to None - """ - - if "frequency" in self.ecg_data.columns: - frequency = self.ecg_data.iloc[index]["frequency"] - else: - frequency = self.frequency - - ecg, _ = wfdb.rdsamp(self.ecg_data.iloc[index]["fpath"], channels=self.leads) - ecg = ecg.T - ecg = ecg.astype("float64") - - ecg = ecg / np.max( - ecg - ) # added to fit the record to the visible part of the plot - ecg_plot.plot(ecg, sample_rate=frequency) - ecg_fname = self.get_name(index) - - if postfix: - dest_path = str(Path(dest_path).joinpath(postfix)) - - dest_path = ( - "{}/".format(dest_path) if not dest_path.endswith("/") else dest_path - ) - - if not Path(dest_path).exists(): - Path(dest_path).mkdir(parents=True, exist_ok=True) - - ecg_plot.save_as_png(file_name=ecg_fname, path=dest_path) - - @staticmethod - def for_train_from_config( - data: pd.DataFrame, target: list, augmentation: Callable, config: dict, classes_num: int - ) -> EcgDataset: - """ - A wrapper with just four parameters to create `PTBXLDataset` for training and validation - :param data: dataframe with ecg info - :param target: a list of targets - :param augmentation: a bunch of augmentations and other preprocessing techniques - :param config: config dictionary - :param classes_num: number of classes - - :return: PTBXLDataset - """ - - return PTBXLDataset( - data, - target, - frequency=config.ecg_record_params.resampled_frequency, - leads=config.ecg_record_params.leads, - ecg_length=config.ecg_record_params.observed_ecg_length, - norm_type=config.ecg_record_params.normalization, - classes=classes_num, - use_meta=config.ecg_metadata.use_metadata, - cut_range=config.ecg_record_params.edge_cut.ecg_cut_range, - augmentation=augmentation, - ) - diff --git a/ecglib/datasets/load_datasets.py b/ecglib/datasets/load_datasets.py deleted file mode 100644 index 5a626a7..0000000 --- a/ecglib/datasets/load_datasets.py +++ /dev/null @@ -1,63 +0,0 @@ -import requests -from zipfile import ZipFile -import os - -from tqdm import tqdm -import pandas as pd - - -__all__ = [ - "load_ptb_xl", -] - - -def load_ptb_xl(download: bool = False, - path_to_zip: str = "./", - path_to_unzip: str = "./", - delete_zip: bool = True, - ) -> pd.DataFrame: - ''' - Load PTB-XL dataset - :param download: whether to download PTB-XL from Physionet - :param path_to_zip: path where to store PTB-XL .zip file - :param path_to_unzip: path where to unarchive PTB-XL .zip file - :param delete_zip: whether to delete PTB-XL .zip file after unarchiving - - :return: dataframe with PTB-XL dataset info - ''' - - if download: - - url = 'https://physionet.org/static/published-projects/ptb-xl/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.2.zip' - ptb_xl_zip = os.path.join(path_to_zip, 'ptb_xl.zip') - response = requests.get(url, stream=True) - total_size_in_bytes= int(response.headers.get('content-length', 0)) - print('Loading PTB-XL file...') - with open(ptb_xl_zip, 'wb') as f: - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) - for chunk in response.iter_content(chunk_size=512): - if chunk: - progress_bar.update(len(chunk)) - f.write(chunk) - progress_bar.close() - print('Loading completed!') - f.close() - - print('Unzipping PTB-XL file...') - with ZipFile(ptb_xl_zip, 'r') as zip_ref: - for member in tqdm(zip_ref.infolist(), desc=''): - try: - zip_ref.extract(member, path_to_unzip) - except zipfile.error as e: - pass - print('Unzipping completed!') - - if delete_zip: - print(f'Deleting {ptb_xl_zip} file...') - os.remove(ptb_xl_zip) - print('Deleting completed!') - - ptb_xl_info = pd.read_csv(os.path.join(path_to_unzip, 'ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.2', 'ptbxl_database.csv')) - ptb_xl_info['fpath'] = [os.path.join(path_to_unzip, 'ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.2', ptb_xl_info.iloc[i]['filename_hr']) for i in range(len(ptb_xl_info['filename_hr']))] - - return ptb_xl_info \ No newline at end of file diff --git a/ecglib/models/__init__.py b/ecglib/models/__init__.py deleted file mode 100644 index c3757b8..0000000 --- a/ecglib/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .create_model import create_model \ No newline at end of file diff --git a/ecglib/models/config/model_configs.py b/ecglib/models/config/model_configs.py deleted file mode 100644 index 6b4bdb2..0000000 --- a/ecglib/models/config/model_configs.py +++ /dev/null @@ -1,21 +0,0 @@ -import imp -from dataclasses import dataclass -from typing import Optional, Type - -import torch.nn as nn - - -@dataclass(repr=True, eq=True) -class DenseNetConfig: - """ - Default parameters correspond DenseNet121_1d model - """ - - growth_rate: int = 32 - block_config: tuple = (6, 12, 24, 16) - num_init_features: int = 64 - bottleneck_size: int = 4 - kernel_size: int = 3 - input_channels: int = 12 - num_classes: int = 1 - reinit: bool = True \ No newline at end of file diff --git a/ecglib/models/create_model.py b/ecglib/models/create_model.py deleted file mode 100644 index e219aa8..0000000 --- a/ecglib/models/create_model.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -from collections import OrderedDict -from dataclasses import asdict -from operator import mod - -import hydra -import torch -from omegaconf import DictConfig -import yaml - -from .architectures.densenet1d import densenet121_1d, densenet201_1d -from .config.model_configs import DenseNetConfig - -resource_package = __name__ - -pathologies = ["AFIB", "STACH", "SBRAD", "RBBB", "LBBB", "PVC", "1AVB"] - -arch_map = { - "densenet1d121": densenet121_1d, - "densenet1d201": densenet201_1d, -} - - -def get_config(model_name: str, config: dict = None): - if config: - model_name = config.model_name - return hydra.utils.instantiate(config.config) - - if "densenet" in model_name: - return DenseNetConfig() - else: - raise Exception("Unknown model type") - - -def get_model( - model_name: str, - leads_num: int, - model_cfg: DictConfig, - num_classes: int, -) -> torch.nn.Module: - if model_cfg: - model_name = model_cfg.model_name - - assert ( - model_name in arch_map - ), "Model name must be one of ['densenet1d121', 'densenet1d201']" - - model_config = get_config(model_name=model_name, config=model_cfg) - model_config.input_channels = leads_num - model_config.num_classes = num_classes - - # Note: Overloads config params according to input args values - if isinstance(model_config, DenseNetConfig): - return arch_map[model_name](**asdict(model_config)) - - -def create_model( - model_name: str, - pathology: str, - model_cfg: dict = None, - pretrained: bool = False, - leads_num: int = 12, -): - model = get_model( - model_name=model_name, - leads_num=leads_num, - model_cfg=model_cfg, - num_classes=1, - # meta=meta, - ) - - if pretrained: - if pathology not in pathologies: - raise KeyError( - "pathology must be one of ['AFIB', 'STACH', 'SBRAD', 'RBBB', 'LBBB', 'PVC', '1AVB']" - ) - - dirname = os.path.dirname(__file__) - weights_path = os.path.join(dirname, 'weights/model_weights_paths.yaml') - with open(weights_path, 'r') as file: - weights_config = yaml.safe_load(file) - if model_name not in weights_config[f"{leads_num}_leads"][pathology]: - raise KeyError( - "the weights are currently available for the following architectures ['densenet1d121']" - ) - weights_path = weights_config[f"{leads_num}_leads"][pathology][model_name] - - model_info = torch.hub.load_state_dict_from_url(weights_path, progress=True, check_hash=True, file_name=f"{leads_num}_leads_{model_name}_{pathology}.pt") - - print('{} model trained on {}-lead {} second ECG records with {} frequency to detect {}. {} normalization was applied to all the records during preprocessing.'.format( - model_name, - leads_num, - model_info['config_file']['ecg_record_params']['observed_ecg_length'], - model_info['config_file']['ecg_record_params']['resampled_frequency'], - pathology, - model_info['config_file']['ecg_record_params']['normalization'], - )) - - - model_state_dict = model_info['model'] - model.load_state_dict(model_state_dict) - return model diff --git a/ecglib/models/weights/model_weights_paths.yaml b/ecglib/models/weights/model_weights_paths.yaml deleted file mode 100644 index 53cae38..0000000 --- a/ecglib/models/weights/model_weights_paths.yaml +++ /dev/null @@ -1,16 +0,0 @@ -12_leads: - AFIB: - densenet1d121: "https://drive.google.com/uc?export=download&id=16DLKW4UjgfMxKfmETZNHQwOxVJ_KUphj" - 1AVB: - densenet1d121: "https://drive.google.com/uc?export=download&id=1pBrxx-I3kz91DIlkns5Aujr7nU8N2Ohr" - SBRAD: - densenet1d121: "https://drive.google.com/uc?export=download&id=1efH77HPHKxGHklU-yhaFB6zd98ymjz4z" - STACH: - densenet1d121: "https://drive.google.com/uc?export=download&id=1xv99PTimid0I-fadsla1SOlCBGvKrUVx" - PVC: - densenet1d121: "https://drive.google.com/uc?export=download&id=1j30e6jwDd02q08P0A_1X8pI5_x_S_9Qw" - RBBB: - densenet1d121: "https://drive.google.com/uc?export=download&id=1OtMUJ0eDTZm3u0rbpKsQxE2x62waHZGG" - LBBB: - densenet1d121: "https://drive.google.com/uc?export=download&id=1Z7xZe5UhXA2BOK_l1lnYmHopPnMmmxNU" - \ No newline at end of file diff --git a/ecglib/preprocessing/preprocess.py b/ecglib/preprocessing/preprocess.py deleted file mode 100644 index abca775..0000000 --- a/ecglib/preprocessing/preprocess.py +++ /dev/null @@ -1,508 +0,0 @@ -import copy -from typing import Union - -import numpy as np -import pandas as pd -import pywt -from scipy import signal -from scipy.stats import zscore - -from . import functional as F - - -__all__ = [ - "FrequencyResample", - "Padding", - "EdgeCut", - "Normalization", - "ButterworthFilter", - "IIRNotchFilter", - "BaselineWanderRemoval", - "WaveletTransform", - "LeadCrop", - "RandomLeadCrop", - "TimeCrop", - "RandomTimeCrop", - "SumAug", - "RandomSumAug", - "ReflectAug", - "ConvexAug", - "RandomConvexAug", -] - - -class FrequencyResample: - """ - Apply frequency resample - :param ecg_frequency: sampling frequency of a signal - :param requested_frequency: sampling frequency of a preprocessed signal - - :return: preprocessed data - """ - - def __init__( - self, - ecg_frequency: int, - requested_frequency: int = 500, - ): - if isinstance(ecg_frequency, (int, float)): - self.ecg_frequency = ecg_frequency - else: - raise ValueError('ecg_frequency must be scalar') - if isinstance(requested_frequency, (int, float)): - self.requested_frequency = requested_frequency - else: - raise ValueError('requested_frequency must be scalar') - self.func = F.ecg_to_one_frequency - - def __call__(self, x): - - x = self.func(x, int(self.ecg_frequency), int(self.requested_frequency)) - - return x - - -class Padding: - """ - Apply padding. If ECG is longer than the observed_ecg_length the record is cut. - :param observed_ecg_length: length of padded signal in seconds - :param frequency: sampling frequency of a signal - :param pad_mode: padding mode - - :return: preprocessed data - """ - - def __init__( - self, - observed_ecg_length: float = 10, - frequency: int = 500, - pad_mode: str = "constant", - ): - self.observed_ecg_length = observed_ecg_length - self.frequency = frequency - self.pad_mode = pad_mode - - def __call__(self, x): - - if self.observed_ecg_length*self.frequency - x.shape[1] > 0: - x = np.pad(x, ((0, 0), (0, self.observed_ecg_length*self.frequency - x.shape[1])), mode=self.pad_mode) - else: - x = x[:, :self.observed_ecg_length*self.frequency] - - return x - - -class EdgeCut: - """ - Cut signal edges - :param cut_range: cutting parameters - :param frequency: sampling frequency of a signal - - :return: preprocessed data - """ - - def __init__( - self, - cut_range: list = [0, 0], - frequency: int = 500, - ): - self.cut_range = cut_range - self.frequency = frequency - self.func = F.cut_ecg - - def __call__(self, x): - - x = self.func(x, self.cut_range, self.frequency) - - return x - - -class Normalization: - """ - Apply normalization - :param norm_type: type of normalization ('z_norm' or 'cycle') - :param leads: leads to be normalized - - :return: preprocessed data - """ - - def __init__( - self, - norm_type: str = "z_norm", - leads: list = None, - ): - if leads is None: - self.leads = list(range(12)) - elif isinstance(leads, list): - self.leads = leads - else: - raise ValueError('leads must be list type') - self.norm_type = norm_type - if norm_type == "cycle": - self.func = F.cycle_normalization - elif norm_type == "min_max": - self.func = F.minmax_normalization - elif norm_type == "z_norm": - self.func = F.z_normalization - else: - raise ValueError('norm_type must be one of [cycle, min_max, z_norm]') - - def __call__(self, x): - - if self.norm_type is not None: - return self.func(x[self.leads, :]) - else: - return x[self.leads, :] - - -class ButterworthFilter: - """ - Apply Butterworth filter augmentation - :param filter_type: type of Butterworth filter ('bandpass', 'lowpass' or 'highpass') - :param leads: leads to be filtered - :param n: filter order - :param Wn: cutoff frequency(ies) - :param fs: filtered signal frequency - - :return: preprocessed data - """ - - def __init__( - self, - filter_type: str = "bandpass", - leads: list = None, - n: int = 10, - Wn: Union[float,list] = [3, 30], - fs: int = 500, - ): - if leads is None: - self.leads = list(range(12)) - elif isinstance(leads, list): - self.leads = leads - else: - raise ValueError('leads must be list type') - self.filter_type = filter_type - if filter_type == "bandpass": - self.func = F.butterworth_bandpass_filter - elif filter_type == "lowpass": - self.func = F.butterworth_lowpass_filter - elif filter_type == "highpass": - self.func = F.butterworth_highpass_filter - else: - raise ValueError("Filter type must be one of [bandpass, lowpass, highpass]") - self.n = n - if filter_type == "bandpass" and not isinstance(Wn, list): - raise ValueError('Wn must be list type in case of bandpass filter') - elif (filter_type == "highpass" or filter_type == "lowpass") and not isinstance(Wn, (int, float)): - raise ValueError(f'Wn must be a scalar in case of {filter_type} filter') - self.Wn = Wn - self.fs = fs - - def __call__(self, x): - - return np.apply_along_axis(self.func, axis=1, arr=x[self.leads, :], n=self.n, Wn=self.Wn, fs=self.fs) - - -class IIRNotchFilter: - """ - Apply IIR notch filter augmentation - :param leads: leads to be filtered - :param w0: frequency to remove from a signal - :param Q: quality factor - :param fs: sampling frequency of a signal - - :return: preprocessed data - """ - - def __init__( - self, - leads: list = None, - w0: float = 50, - Q: float = 30, - fs: int = 500, - ): - if leads is None: - self.leads = list(range(12)) - elif isinstance(leads, list): - self.leads = leads - else: - raise ValueError('leads must be list type') - self.w0 = w0 - self.Q = Q - self.fs = fs - self.func = F.IIR_notch_filter - - def __call__(self, x): - - return np.apply_along_axis(self.func, axis=1, arr=x[self.leads, :], w0=self.w0, Q=self.Q, fs=self.fs) - - -class BaselineWanderRemoval: - """ - Remove baseline wander using wavelets (see article https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.308.6789&rep=rep1&type=pdf) - :param leads: leads to be preprocessed - :param wavelet: wavelet name - - :return: preprocessed data - """ - - def __init__( - self, - leads: list = None, - wavelet: str = 'db4', - ): - if leads is None: - self.leads = list(range(12)) - elif isinstance(leads, list): - self.leads = leads - else: - raise ValueError('leads must be list type') - self.wavelet = wavelet - self.func = F.DWT_BW - - def __call__(self, x): - - return np.apply_along_axis(self.func, axis=1, arr=x[self.leads, :], wavelet=self.wavelet) - - -class WaveletTransform: - """ - Apply wavelet transform augmentation - :param wt_type: type of wavelet transform ('DWT' with soft thresholding or 'SWT') - :param leads: leads to be transformed - :param wavelet: wavelet name - :param level: decomposition level - :param threshold: thresholding value for all coefficients except the first one (only for DWT) - :param low: thresholding value for the first coefficient (only for DWT) - - :return: preprocessed data - """ - - def __init__( - self, - wt_type: str = "DWT", - leads: list = None, - wavelet: str = 'db4', - level: int = 3, - threshold: float = 2, - low: float = 1e6, - ): - if leads is None: - self.leads = list(range(12)) - elif isinstance(leads, list): - self.leads = leads - else: - raise ValueError('leads must be list type') - self.wt_type = wt_type - self.wavelet = wavelet - self.level = level - if wt_type == "DWT": - self.threshold = threshold - self.low = low - self.func = F.DWT_filter - elif wt_type == "SWT": - self.threshold = None - self.low = None - self.func = F.SWT_filter - else: - raise ValueError('wt_type must be one of [DWT, SWT]') - - def __call__(self, x): - - if self.wt_type == "DWT": - return np.apply_along_axis(self.func, axis=1, arr=x[self.leads, :], wavelet=self.wavelet, - level=self.level, threshold=self.threshold, low=self.low) - elif self.wt_type == "SWT": - return np.apply_along_axis(self.func, axis=1, arr=x[self.leads, :], wavelet=self.wavelet, - level=self.level) - - -class LeadCrop: - """ - Apply lead crop augmentation - :param leads: leads to be cropped - - :return: preprocessed data - """ - - def __init__( - self, - leads: list = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ): - self.leads = leads - self.func = F.lead_crop - - def __call__(self, x): - - return self.func(x, leads=self.leads) - - -class RandomLeadCrop(LeadCrop): - """ - Apply random lead crop augmentation - :param n: number of leads to be cropped (chosen randomly) - - :return: preprocessed data - """ - - def __init__( - self, - n: int = 11, - ): - ls = np.arange(12, dtype='int') - leads_to_remove = np.random.choice(ls, size=n, replace=False) - super().__init__(leads_to_remove) - self.n = n - - def __call__(self, x): - - return self.func(x, leads=self.leads) - - -class TimeCrop: - """ - Apply time crop augmentation - :param time: length of time segment to be cropped (the same units as signal) - :param leads: leads to be cropped - - :return: preprocessed data - """ - - def __init__( - self, - time: int = 100, - leads: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ): - self.time = time - self.leads = leads - self.func = F.time_crop - - def __call__(self, x): - - return self.func(x, time=self.time, leads=self.leads) - - -class RandomTimeCrop(TimeCrop): - """ - Apply random time crop augmentation - :param time: length of time segment to be cropped (the same units as signal) - :param n: number of leads to be cropped (chosen randomly) - - :return: preprocessed data - """ - - def __init__( - self, - time: int = 100, - n: int = 12, - ): - ls = np.arange(12, dtype='int') - leads_to_modify = np.random.choice(ls, size=n, replace=False) - super().__init__(time, leads_to_modify) - self.time = time - self.n = n - - def __call__(self, x): - - return self.func(x, time=self.time, leads=self.leads) - - -class SumAug: - """ - Apply sum augmentation to selected leads - :param leads: leads to be replaced by sum of all leads - - :return: preprocessed data - """ - - def __init__( - self, - leads: list = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ): - self.leads = leads - self.func = F.sum_augmentation - - def __call__(self, x): - - return self.func(x, leads=self.leads) - - -class RandomSumAug(SumAug): - """ - Apply random sum augmentation - :param n: number of leads to be replaced by sum of all leads (chosen randomly) - - :return: preprocessed data - """ - - def __init__( - self, - n: int = 11, - ): - ls = np.arange(12, dtype='int') - leads_to_remove = np.random.choice(ls, size=n, replace=False) - super().__init__(leads_to_remove) - self.n = n - - def __call__(self, x): - - return self.func(x, leads=self.leads) - - -class ReflectAug: - """ - Apply reflection augmentation - - :return: preprocessed data - """ - - def __init__( - self, - ): - self.func = F.reflect_augmentation - - def __call__(self, x): - - return self.func(x) - - -class ConvexAug: - """ - Apply convex augmentation - :param leads: leads to be replaced by convex combination of some leads (chosen randomly) - - :return: preprocessed data - """ - - def __init__( - self, - leads: list = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ): - self.leads = leads - self.func = F.convex_augmentation - - def __call__(self, x): - - return self.func(x, leads=self.leads) - - -class RandomConvexAug(ConvexAug): - """ - Apply random convex augmentation - :param n: number of leads (chosen randomly) to be replaced by convex combination of some leads (chosen randomly) - - :return: preprocessed data - """ - - def __init__( - self, - n: int = 11, - ): - ls = np.arange(12, dtype='int') - leads_to_remove = np.random.choice(ls, size=n, replace=False) - super().__init__(leads_to_remove) - self.n = n - - def __call__(self, x): - - return self.func(x, leads=self.leads) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 2853508..0000000 --- a/requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -pandas -ecg_plot -tqdm -numpy -torch -fastai -scipy -ipywidgets -pyyaml -PyWavelets -wfdb -hydra-core -omegaconf \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8960de6 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +[metadata] +name = ecglib +version = 1.1.0 +description = ECG library with pretrained models and tools for ECG analysis +license = Apache License 2.0 +author = Aram Avetisyan +author_email = a.a.avetisyan@gmail.com +keywords = {"ecg analysis", "pytorch", "pretrained models", "ecg preprocessing", "ecg datasets"} +url = https://github.com/ispras/EcgLib + +[options] +include_package_data = True +packages = find: +package_dir = + =src +install_requires = + pandas>=1.5.2 + ecg_plot>=0.0.1 + tqdm>=4.64.0 + numpy>=1.2 + torch>=1.13.1 + fastai>=2.7.10 + neurokit2>=0.2.2 + scipy>=1.7.3 + ipywidgets>=7.6.5 + pyyaml>=6.0 + PyWavelets>=1.3.0 + wfdb>=4.0.0 + omegaconf>=2.3.0 +python_requires = >=3.6 + +[options.packages.find] +where=src \ No newline at end of file diff --git a/setup.py b/setup.py index 3924bf5..fc1f76c 100644 --- a/setup.py +++ b/setup.py @@ -1,41 +1,3 @@ -from setuptools import setup, find_packages -from os import path +from setuptools import setup -here = path.abspath(path.dirname(__file__)) - -# Get the long description from the README file -with open(path.join(here, 'README.md'), encoding='utf-8') as f: - long_description = f.read() - -setup_args = dict( - name='ecglib', - version='1.0.1', - description='ECG library with pretrained models and tools for ECG analysis', - long_description=long_description, - long_description_content_type='text/markdown', - license='Apache License 2.0', - packages=find_packages(), - include_package_data=True, - author='Aram Avetisyan', - author_email='a.a.avetisyan@gmail.com', - install_requires=['pandas', - 'ecg_plot', - 'tqdm', - 'numpy', - 'torch', - 'scipy', - 'ipywidgets', - 'pyyaml', - 'PyWavelets', - 'wfdb', - 'hydra-core', - 'omegaconf', - ], - keywords=['ecg analysis', 'pytorch', 'pretrained models', 'ecg preprocessing', 'ecg datasets'], - url = 'https://github.com/ispras/EcgLib', - python_requires='>=3.6', -) - - -if __name__ == '__main__': - setup(**setup_args) +setup() \ No newline at end of file diff --git a/src/ecglib/__init__.py b/src/ecglib/__init__.py new file mode 100644 index 0000000..cc29d17 --- /dev/null +++ b/src/ecglib/__init__.py @@ -0,0 +1,11 @@ +from . import models +from . import preprocessing +from . import data +from . import predict + +__all__ = [ + "models", + "preprocessing", + "data", + "predict", +] diff --git a/src/ecglib/data/__init__.py b/src/ecglib/data/__init__.py new file mode 100644 index 0000000..64bc3bb --- /dev/null +++ b/src/ecglib/data/__init__.py @@ -0,0 +1,9 @@ +from .load_datasets import * +from .datasets import * +from .ecg_record import * + +__all__ = [ + "load_datasets", + "datasets", + "ecg_record", +] diff --git a/src/ecglib/data/datasets.py b/src/ecglib/data/datasets.py new file mode 100644 index 0000000..2fdce91 --- /dev/null +++ b/src/ecglib/data/datasets.py @@ -0,0 +1,236 @@ +from pathlib import Path +from typing import Callable, Optional, Union + +import ecg_plot +import numpy as np +import pandas as pd +import torch +import wfdb +from ecglib import preprocessing as P +from .ecg_record import EcgRecord +from torch.utils.data import Dataset + + +__all__ = [ + "EcgDataset", +] + + +class EcgDataset(Dataset): + """ + EcgDataset + :param ecg_data: dataframe with ecg info + :param target: a list of targets + :param frequency: frequency for signal resampling + :param leads: a list of leads + :param ecg_length: length of ECG signal after padding / cropping + :param cut_range: cutting parameters + :param pad_mode: padding mode + :param classes: number of classes + :param use_meta: whether to use metadata or not + :param augmentation: a bunch of augmentations and other preprocessing techniques + """ + + def __init__( + self, + ecg_data: pd.DataFrame, + target: list, + frequency: int = 500, + leads: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + data_type: str = "wfdb", + ecg_length: Union[int, float] = 10, + cut_range: list = [0, 0], + pad_mode: str = "constant", + norm_type: str = "z_norm", + classes: int = 2, + augmentation: Callable = None, + ): + super().__init__() + if "fpath" not in ecg_data.columns: + raise ValueError("column 'fpath' not in ecg_data") + self.ecg_data = ecg_data + self.target = target + self.frequency = frequency + self.leads = leads + self.data_type = data_type + self.ecg_length = ecg_length + self.cut_range = cut_range + self.pad_mode = pad_mode + self.norm_type = norm_type + self.classes = classes + self.augmentation = augmentation + + def __len__(self): + return self.ecg_data.shape[0] + + def get_fpath(self, index: int) -> str: + """ + Returns path to file with ECG leads + :param index: Index of ECG in dataset + + :return: Path to ECG file + """ + + return self.ecg_data.iloc[index]["fpath"] + + def get_name(self, index: int) -> str: + """ + Returns name of ECG file + :param index: Index of ECG in dataset + + :return: ECG file name + """ + + return str(Path(self.get_fpath(index)).stem) + + def read_ecg_record( + self, file_path, data_type, leads=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + ): + if data_type == "npz": + ecg_record = np.load(file_path)["arr_0"].astype("float64") + elif data_type == "wfdb": + ecg_record, _ = wfdb.rdsamp(file_path, channels=leads) + ecg_record = ecg_record.T + ecg_record = ecg_record.astype("float64") + else: + raise ValueError( + 'data_type can have only values from the list ["npz", "wfdb"]' + ) + return ecg_record + + def __getitem__(self, index): + ecg_frequency = float(self.ecg_data.iloc[index]["frequency"]) + patient_meta = ( + self.ecg_data.iloc[index]["patient_metadata"] + if "patient_metadata" in self.ecg_data.iloc[index] + else dict() + ) + ecg_record_meta = ( + self.ecg_data.iloc[index]["ecg_metadata"] + if "ecg_metadata" in self.ecg_data.iloc[index] + else dict() + ) + file_path = self.ecg_data.iloc[index]["fpath"] + + # data standartization (scaling, resampling, cuts off, normalization and padding/truncation) + ecg_record = self.read_ecg_record(file_path, self.data_type, self.leads) + ecg_record = P.Compose( + transforms=[ + P.FrequencyResample( + ecg_frequency=ecg_frequency, requested_frequency=self.frequency + ), + P.EdgeCut(cut_range=self.cut_range, frequency=self.frequency), + P.Normalization(norm_type=self.norm_type), + P.Padding( + observed_ecg_length=self.ecg_length, frequency=self.frequency + ), + ], + p=1.0, + )(ecg_record) + assert not np.any( + np.isnan(ecg_record) + ), f"ecg_record = {ecg_record}, index = {index}" + + # data preprocessing if specified (augmentation, filtering) + if self.augmentation is not None: + ecg_record = self.augmentation(ecg_record) + + target = self.target[index] + + patient_meta = { + key: patient_meta[key] + if isinstance(patient_meta[key], list) + else [patient_meta[key]] + for key in patient_meta + } + + ecg_record_meta = { + key: ecg_record_meta[key] + if isinstance(ecg_record_meta[key], list) + else [ecg_record_meta[key]] + for key in ecg_record_meta + } + + full_ecg_record_info = EcgRecord( + signal=ecg_record[self.leads, :], + frequency=ecg_frequency, + name=file_path, + lead_order=self.leads, + ecg_metadata=ecg_record_meta, + patient_metadata=patient_meta, + ) + + result = [ + full_ecg_record_info.to_tensor(), + torch.tensor(target, dtype=torch.float), + ] + + return (index, result) + + def save_as_png( + self, index: int, dest_path: str, postfix: Optional[str] = None + ) -> None: + """ + Saves the image of ecg record + + :param index: Index of ECG + :param dest_path: Directory to save the image + :param postfix: Subdirectory where the image will be saved, defaults to None + """ + + ecg = (np.load(self.get_fpath(index))["arr_0"].astype("float64"),) + ecg = np.squeeze(ecg) + + if "frequency" in self.ecg_data.columns: + frequency = self.ecg_data.iloc[index]["frequency"] + else: + frequency = self.frequency + ecg = ecg / np.max( + ecg + ) # added to fit the record to the visible part of the plot + ecg_plot.plot(ecg, sample_rate=frequency) + ecg_fname = self.get_name(index) + + if postfix: + dest_path = str(Path(dest_path).joinpath(postfix)) + + dest_path = ( + "{}/".format(dest_path) if not dest_path.endswith("/") else dest_path + ) + + if not Path(dest_path).exists(): + Path(dest_path).mkdir(parents=True, exist_ok=True) + + ecg_plot.save_as_png(file_name=ecg_fname, path=dest_path) + + @staticmethod + def for_train_from_config( + data: pd.DataFrame, + target: list, + augmentation: Callable, + config: dict, + classes_num: int, + ): + """ + A wrapper with just four parameters to create `TisDataset` for training and validation + :param data: dataframe with ecg info + :param target: a list of targets + :param augmentation: a bunch of augmentations and other preprocessing techniques + :param config: config dictionary + :param classes_num: number of classes + + :return: EcgDataset + """ + + return EcgDataset( + data, + target, + frequency=config.ecg_record_params.resampled_frequency, + leads=config.ecg_record_params.leads, + data_type=config.ecg_record_params.data_type, + ecg_length=config.ecg_record_params.observed_ecg_length, + norm_type=config.ecg_record_params.normalization, + classes=classes_num, + cut_range=config.ecg_record_params.ecg_cut_range, + augmentation=augmentation, + ) diff --git a/src/ecglib/data/ecg_record.py b/src/ecglib/data/ecg_record.py new file mode 100644 index 0000000..2a8be3d --- /dev/null +++ b/src/ecglib/data/ecg_record.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass, field +from typing import Optional, Dict + +import ecg_plot +import numpy as np +import torch + +from ..preprocessing import preprocess + + +@dataclass(repr=True, eq=True) +class EcgRecord: + """ + Class that describe ECG record + :param signal: ECG signal + :param frequency: ECG record frequency + :param name: ECG name + :param lead_order: order of ECG leads in signal + :param duration: ECG record length + :param leads_num: number of leads + :param patient_id: patient_id + :param ecg_segment_info: location of ECG record peaks, intervals and segments + :param ecg_metadata: ECG signal metadata + :param patient_metadata: patient's metadata + :param annotation_info: annotation of the record + :param preprocessing_info: a list of preprocessing techniques applied to the signal + """ + + signal: np.ndarray + frequency: int + name: str = "ecg_record" + lead_order: list = field( + default_factory=lambda: [ + "I", + "II", + "III", + "AVR", + "AVL", + "AVF", + "V1", + "V2", + "V3", + "V4", + "V5", + "V6", + ] + ) + duration: float = 0 + leads_num: Optional[int] = 12 + patient_id: Optional[str] = "" + ecg_segment_info: dict = field(default_factory=lambda: {}) + ecg_metadata: Optional[Dict[str, float]] = field(default_factory=lambda: {}) + patient_metadata: Optional[Dict[str, float]] = field( + default_factory=lambda: { + "age": None, + "sex": None, + "weight": None, + "height": None, + } + ) + annotation_info: Optional[Dict[str, list]] = field(default_factory=lambda: {}) + preprocessing_info: list = field(default_factory=lambda: []) + + def __post_init__(self): + self.duration = len(self.signal[0]) / self.frequency + self.leads_num = len(self.signal) + + def ecg_plot(self, path="./", save_img=False): + ecg = self.signal / np.max( + self.signal + ) + ecg_plot.plot(ecg, sample_rate=self.frequency) + if save_img: + ecg_plot.save_as_png(path + self.name) + + def to_tensor(self): + ecg_tensor = torch.tensor(self.signal, dtype=torch.float) + patient_tensor_values = [] + for meta in self.patient_metadata.values(): + if meta is not None: + patient_tensor_values += [float(param) for param in meta] + else: + patient_tensor_values += [None] + ecg_metadata_tensor_values = [] + for meta in self.ecg_metadata.values(): + if meta is not None: + ecg_metadata_tensor_values += [float(param) for param in meta] + else: + ecg_metadata_tensor_values += [None] + patient_tensor_values = np.array(patient_tensor_values, dtype=float) + np.nan_to_num(patient_tensor_values, copy=False) + patient_tensor = torch.tensor(patient_tensor_values) + ecg_metadata_tensor_values = np.array(ecg_metadata_tensor_values, dtype=float) + np.nan_to_num(ecg_metadata_tensor_values, copy=False) + ecg_metadata_tensor = torch.tensor(ecg_metadata_tensor_values) + return list([ecg_tensor, patient_tensor, ecg_metadata_tensor]) + + def frequency_resample(self, requested_frequency=500): + self.signal = preprocess.FrequencyResample( + ecg_frequency=self.frequency, requested_frequency=requested_frequency + )(self.signal) + self.frequency = requested_frequency + self.preprocessing_info.append( + f"changed frequency from {self.frequency} to {requested_frequency}" + ) + + def cut_ecg(self, cut_range=[0, 0]): + self.signal = preprocess.EdgeCut(cut_range=cut_range, frequency=self.frequency)( + self.signal + ) + self.duration = self.duration - cut_range[0] - cut_range[1] + self.preprocessing_info.append( + f"cut ecg record {cut_range[0]} seconds from the beginning" + "and {cut_range[1]} seconds from the end leaving {self.duration} seconds" + ) + + def get_fixed_length(self, requested_length=10): + self.signal = preprocess.Padding( + observed_ecg_length=requested_length, frequency=self.frequency + )(self.signal) + self.duration = requested_length + self.preprocessing_info.append( + f"changed length of ecg record from {self.duration} seconds" + "to {requested_length} seconds" + ) + + def normalize(self, norm_type="z_norm"): + self.signal = preprocess.Normalization(norm_type=norm_type)(self.signal) + self.preprocessing_info.append(f"applied {norm_type} normalization") + + def remove_baselinewander(self, wavelet="db4"): + self.signal = preprocess.BaselineWanderRemoval(wavelet=wavelet)(self.signal) + self.preprocessing_info.append( + f"removed baseline wander with wavelet {wavelet}" + ) diff --git a/src/ecglib/data/load_datasets.py b/src/ecglib/data/load_datasets.py new file mode 100644 index 0000000..02483ba --- /dev/null +++ b/src/ecglib/data/load_datasets.py @@ -0,0 +1,199 @@ +import requests +import zipfile +import os + +import wfdb +from tqdm import tqdm +import pandas as pd + + +__all__ = [ + "load_ptb_xl", + "load_physionet2020", +] + + +def load_ptb_xl( + download: bool = False, + path_to_zip: str = "./", + path_to_unzip: str = "./", + delete_zip: bool = True, +) -> pd.DataFrame: + """ + Load PTB-XL dataset + :param download: whether to download PTB-XL from Physionet + :param path_to_zip: path where to store PTB-XL .zip file + :param path_to_unzip: path where to unarchive PTB-XL .zip file + :param delete_zip: whether to delete PTB-XL .zip file after unarchiving + + :return: dataframe with PTB-XL dataset info + """ + + if download: + url = "https://physionet.org/static/published-projects/ptb-xl/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.2.zip" + ptb_xl_zip = os.path.join(path_to_zip, "ptb_xl.zip") + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + print("Loading PTB-XL file...") + with open(ptb_xl_zip, "wb") as f: + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + for chunk in response.iter_content(chunk_size=1024): + if chunk: + progress_bar.update(len(chunk)) + f.write(chunk) + progress_bar.close() + print("Loading completed!") + f.close() + + print("Unzipping PTB-XL file...") + with zipfile.ZipFile(ptb_xl_zip, "r") as zip_ref: + for member in tqdm(zip_ref.infolist(), desc=""): + try: + zip_ref.extract(member, path_to_unzip) + except zipfile.error as e: + pass + print("Unzipping completed!") + + if delete_zip: + print(f"Deleting {ptb_xl_zip} file...") + os.remove(ptb_xl_zip) + print("Deleting completed!") + + ptb_xl_info = pd.read_csv( + os.path.join( + path_to_unzip, + "ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.2", + "ptbxl_database.csv", + ) + ) + ptb_xl_info["fpath"] = [ + os.path.join( + path_to_unzip, + "ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.2", + ptb_xl_info.iloc[i]["filename_hr"], + ) + for i in range(len(ptb_xl_info["filename_hr"])) + ] + + return ptb_xl_info + + +def load_physionet2020( + download: bool = False, + path_to_zip: str = "./", + path_to_unzip: str = "./", + delete_zip: bool = True, + selected_datasets=None, +) -> pd.DataFrame: + """ + Load physionet2020 challange datasets + :param download: whether to download physionet2020 + :param path_to_zip: path where to store archive file + :param path_to_unzip: path where to unarchive selected datasets + :param delete_zip: whether to delete archive file after unarchiving selected datasets + :param selected_datasets: list of the dataset names to extract from ['georgia','st_petersburg_incart','cpsc_2018','ptb-xl','ptb',cpsc_2018_extra'] + """ + + if download: + url = "https://physionet.org/static/published-projects/challenge-2020/classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2.zip" + print(f"Loading started...") + with open( + os.path.join( + path_to_zip, + "classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2.zip", + ), + "wb", + ) as f: + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + for chunk in response.iter_content(chunk_size=512): + if chunk: + progress_bar.update(len(chunk)) + f.write(chunk) + progress_bar.close() + print("Loading completed!") + + if not os.path.exists(path_to_unzip): + os.makedirs(path_to_unzip) + + if not selected_datasets: + selected_datasets = [ + "georgia", + "st_petersburg_incart", + "cpsc_2018", + "ptb-xl", + "ptb", + "cpsc_2018_extra", + ] + + selected_datasets_paths = tuple( + [ + f"classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2/training/{dataset}/" + for dataset in selected_datasets + ] + ) + print(f"Unzipping started...") + with zipfile.ZipFile( + os.path.join( + path_to_zip, + "classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2.zip", + ), + "r", + ) as zip_ref: + for member in tqdm(zip_ref.infolist(), desc=""): + if member.filename.startswith(selected_datasets_paths): + filename = os.path.basename(member.filename) + if not filename: + continue + member.filename = member.filename.replace( + "classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2/training/", + "", + ) + zip_ref.extract(member, path_to_unzip) + print(f"Unzipping completed!") + + if delete_zip: + print(f"Deleting zipfile started...") + os.remove( + os.path.join( + path_to_zip, + "classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2.zip", + ) + ) + print("Deleting completed!") + + for dataset in selected_datasets: + dataset_meta = [] + print(f"Collecting {dataset} information...") + path_to_dataset = os.path.join(path_to_unzip, dataset) + folders = os.listdir(path_to_dataset) + for folder in tqdm(folders): + for filename in os.listdir(os.path.join(path_to_dataset, folder)): + if not filename.endswith("hea"): + continue + + file_info = {} + record, metadata = wfdb.rdsamp( + os.path.join(path_to_dataset, folder, filename[:-4]) + ) + file_info["file_name"] = filename[:-4] + file_info["fpath"] = os.path.join( + path_to_dataset, folder, filename[:-4] + ) + file_info["ecg_shape"] = record.shape + file_info["frequency"] = metadata["fs"] + file_info["ecg_duration"] = metadata["sig_len"] / metadata["fs"] + + for comment in metadata["comments"]: + key, val = tuple(comment.strip().replace(" ", "").split(":")) + file_info[key] = val + dataset_meta.append(file_info) + + dataset_meta = pd.DataFrame(dataset_meta) + dataset_meta["Dx"] = dataset_meta["Dx"].apply(lambda x: x.split(",")) + dataset_meta.to_csv( + os.path.join(path_to_unzip, f"{dataset}_dataset.csv"), index=False + ) + print("Information is collected!") + print("Loading completed!") diff --git a/ecglib/models/architectures/__init__.py b/src/ecglib/models/__init__.py similarity index 100% rename from ecglib/models/architectures/__init__.py rename to src/ecglib/models/__init__.py diff --git a/ecglib/models/config/__init__.py b/src/ecglib/models/architectures/__init__.py similarity index 100% rename from ecglib/models/config/__init__.py rename to src/ecglib/models/architectures/__init__.py diff --git a/src/ecglib/models/architectures/cnn_tabular.py b/src/ecglib/models/architectures/cnn_tabular.py new file mode 100644 index 0000000..298de71 --- /dev/null +++ b/src/ecglib/models/architectures/cnn_tabular.py @@ -0,0 +1,74 @@ +import torch +from fastai.layers import AdaptiveConcatPool1d, LinBnDrop + + +class CnnTabular(torch.nn.Module): + def __init__( + self, + cnn_backbone, + cnn_out_features, + tabular_model, + tabular_out_features, + classes, + head_ftrs=[512], + head_drop_prob=0.2, + ): + super(CnnTabular, self).__init__() + + self.cnn_backbone = cnn_backbone + self.tabular = tabular_model + self.head_layers = head_ftrs + + self.head_inp_features = cnn_out_features * 2 + tabular_out_features + + self.cnn_pooling = torch.nn.Sequential( + AdaptiveConcatPool1d(), + torch.nn.Flatten(), + ) + + self.head = torch.nn.Sequential() + for i, f in enumerate(self.head_layers): + if not self.head: + self.head.add_module( + "input", + LinBnDrop( + n_in=self.head_inp_features, + n_out=f, + bn=True, + p=head_drop_prob, + act=torch.nn.ReLU(), + lin_first=False, + ), + ) + + self.head.add_module( + "hidden{}".format(i), + LinBnDrop( + n_in=f, + n_out=f, + bn=True, + p=head_drop_prob, + act=torch.nn.ReLU(), + lin_first=False, + ), + ) + + self.head.add_module( + "output", + LinBnDrop( + n_in=self.head_layers[-1], + n_out=classes, + bn=True, + p=head_drop_prob, + act=None, + lin_first=False, + ), + ) + + def forward(self, input): + y_cnn = self.cnn_backbone(input[0]) + y_cnn = self.cnn_pooling(y_cnn) + y_tabular = self.tabular(input[1]) + y = torch.cat((y_cnn, y_tabular), dim=-1) + y = self.head(y) + return y diff --git a/ecglib/models/architectures/densenet1d.py b/src/ecglib/models/architectures/densenet1d.py similarity index 65% rename from ecglib/models/architectures/densenet1d.py rename to src/ecglib/models/architectures/densenet1d.py index 52932af..0e4ab08 100644 --- a/ecglib/models/architectures/densenet1d.py +++ b/src/ecglib/models/architectures/densenet1d.py @@ -3,7 +3,8 @@ class DenseLayer(nn.Module): - ''' Paper: https://arxiv.org/pdf/1608.06993v5.pdf ''' + """Paper: https://arxiv.org/pdf/1608.06993v5.pdf""" + def __init__(self, input_channels, growth_rate, bottleneck_size, kernel_size): super().__init__() self.use_bottleneck = bottleneck_size > 0 @@ -15,7 +16,8 @@ def __init__(self, input_channels, growth_rate, bottleneck_size, kernel_size): input_channels, self.num_bottleneck_output_filters, kernel_size=1, - stride=1) + stride=1, + ) self.bn1 = nn.BatchNorm1d(self.num_bottleneck_output_filters) self.act1 = nn.ReLU(inplace=True) self.conv1 = nn.Conv1d( @@ -24,7 +26,8 @@ def __init__(self, input_channels, growth_rate, bottleneck_size, kernel_size): kernel_size=kernel_size, stride=1, dilation=1, - padding=kernel_size // 2) + padding=kernel_size // 2, + ) def forward(self, x): if self.use_bottleneck: @@ -38,15 +41,21 @@ def forward(self, x): class DenseBlock(nn.ModuleDict): - def __init__(self, num_layers, input_channels, growth_rate, kernel_size, bottleneck_size): + def __init__( + self, num_layers, input_channels, growth_rate, kernel_size, bottleneck_size + ): super().__init__() self.num_layers = num_layers for i in range(self.num_layers): - self.add_module(f'denselayer{i}', - DenseLayer(input_channels + i * growth_rate, - growth_rate, - bottleneck_size, - kernel_size)) + self.add_module( + f"denselayer{i}", + DenseLayer( + input_channels + i * growth_rate, + growth_rate, + bottleneck_size, + kernel_size, + ), + ) def forward(self, x): layer_outputs = [x] @@ -62,7 +71,9 @@ def __init__(self, input_channels, out_channels): super().__init__() self.bn = nn.BatchNorm1d(input_channels) self.act = nn.ReLU(inplace=True) - self.conv = nn.Conv1d(input_channels, out_channels, kernel_size=1, stride=1, dilation=1) + self.conv = nn.Conv1d( + input_channels, out_channels, kernel_size=1, stride=1, dilation=1 + ) self.pool = nn.AvgPool1d(kernel_size=2, stride=2) def forward(self, x): @@ -75,27 +86,38 @@ def forward(self, x): class DenseNet1d(nn.Module): def __init__( - self, - growth_rate: int = 32, - block_config: tuple = (6, 12, 24, 16), - num_init_features: int = 64, - bottleneck_size: int = 4, - kernel_size: int = 3, - input_channels: int = 3, - num_classes: int = 1, - reinit: bool = True, + self, + growth_rate: int = 32, + block_config: tuple = (6, 12, 24, 16), + num_init_features: int = 64, + bottleneck_size: int = 4, + kernel_size: int = 3, + input_channels: int = 3, + num_classes: int = 1, + reinit: bool = True, ): super().__init__() + self.stem = None + self.backbone = None + self.head = None - self.features = nn.Sequential( + # make stem + self.stem = nn.Sequential( nn.Conv1d( - input_channels, num_init_features, - kernel_size=7, stride=2, padding=3, dilation=1), + input_channels, + num_init_features, + kernel_size=7, + stride=2, + padding=3, + dilation=1, + ), nn.BatchNorm1d(num_init_features), nn.ReLU(inplace=True), nn.MaxPool1d(kernel_size=3, stride=2, padding=1), ) + # make backbone + self.backbone = nn.Sequential() num_features = num_init_features for i, num_layers in enumerate(block_config): block = DenseBlock( @@ -105,15 +127,17 @@ def __init__( kernel_size=kernel_size, bottleneck_size=bottleneck_size, ) - self.features.add_module(f'denseblock{i}', block) + self.backbone.add_module(f"denseblock{i}", block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = TransitionBlock( - input_channels=num_features, - out_channels=num_features // 2) - self.features.add_module(f'transition{i}', trans) + input_channels=num_features, out_channels=num_features // 2 + ) + self.backbone.add_module(f"transition{i}", trans) num_features = num_features // 2 + self.backbone_out_features = num_features + # make head self.final_bn = nn.BatchNorm1d(num_features) self.final_act = nn.ReLU(inplace=True) self.final_pool = nn.AdaptiveAvgPool1d(1) @@ -130,15 +154,15 @@ def __init__( elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) - def forward_features(self, x): - out = self.features(x) - out = self.final_bn(out) - out = self.final_act(out) - out = self.final_pool(out) - return out - def forward(self, x): - features = self.forward_features(x) + # stem + features = self.stem(x) + # backbone + features = self.backbone(features) + # head + features = self.final_bn(features) + features = self.final_act(features) + features = self.final_pool(features) features = features.squeeze(-1) out = self.classifier(features) return out @@ -149,6 +173,9 @@ def reset_classifier(self): def get_classifier(self): return self.classifier + def get_cnn(self): + return (nn.Sequential(self.stem, self.backbone), self.backbone_out_features) + def densenet121_1d(**kwargs): kwargs["block_config"] = (6, 12, 24, 16) diff --git a/src/ecglib/models/architectures/model_types.py b/src/ecglib/models/architectures/model_types.py new file mode 100644 index 0000000..b102e44 --- /dev/null +++ b/src/ecglib/models/architectures/model_types.py @@ -0,0 +1,24 @@ +from enum import IntEnum + +__all__ = ["MType"] + + +class MType(IntEnum): + RESNET = 0 + DENSENET = 1 + TABULAR = 2 + OTHER = 5 # use to sign custom models + + @staticmethod + def from_string(label: str) -> IntEnum: + label = label.lower() + if "resnet" in label: + return MType.RESNET + elif "densenet" in label: + return MType.DENSENET + elif "tabular" in label: + return MType.TABULAR + elif "other" in label: + return MType.OTHER + else: + raise ValueError diff --git a/src/ecglib/models/architectures/registred_models.py b/src/ecglib/models/architectures/registred_models.py new file mode 100644 index 0000000..6ba07be --- /dev/null +++ b/src/ecglib/models/architectures/registred_models.py @@ -0,0 +1,68 @@ +from typing import List, Optional, Callable + +from torch.nn import Module + +from .resnet1d import resnet1d18, resnet1d50, resnet1d101 +from .densenet1d import densenet121_1d, densenet201_1d +from .tabular import tabular + +__all__ = ["register_model", "registred_models", "get_builder", "is_model_registred"] + + +# extensible model's storage +BUILTIN_MODELS = { + "densenet1d121": densenet121_1d, + "densenet1d201": densenet201_1d, + "resnet1d18": resnet1d18, + "resnet1d50": resnet1d50, + "resnet1d101": resnet1d101, + "tabular": tabular, +} + + +def register_model( + name: Optional[str] = None, +) -> Callable[[Callable[..., Module]], Callable[..., Module]]: + """ + Function decorator which helps to register new models + """ + + def wrapper(fn: Callable[..., Module]) -> Callable[..., Module]: + key = name if name is not None else fn.__name__ + if key in BUILTIN_MODELS: + raise ValueError(f"An entry is already registered under the name '{key}'.") + BUILTIN_MODELS[key] = fn + return fn + + return wrapper + + +def registered_models() -> List: + """ + Returns a list with the names of registered models. + """ + return list(BUILTIN_MODELS.keys()) + + +def get_builder(name: str) -> Callable[[Callable[..., Module]], Callable[..., Module]]: + """ + Returns a model builder callable object. + + param: name (str): Model name. + return: Callable + """ + if name not in BUILTIN_MODELS: + raise ValueError( + f"An entry is not registered in `BUILTIN_MODELS`. Available models: {registered_models()}." + ) + return BUILTIN_MODELS[name] + + +def is_model_registred(name: str) -> bool: + """ + Checks is model was registered. + + param: name (str): Model name. + return: Boolean flag. + """ + return name in BUILTIN_MODELS diff --git a/src/ecglib/models/architectures/resnet1d.py b/src/ecglib/models/architectures/resnet1d.py new file mode 100644 index 0000000..de3b9c1 --- /dev/null +++ b/src/ecglib/models/architectures/resnet1d.py @@ -0,0 +1,273 @@ +import torch.nn as nn +from fastai.layers import AdaptiveConcatPool1d, LinBnDrop +from fastcore.basics import listify + + +class ResidualBlock1d(nn.Module): + expansion = 4 + + def __init__(self, inplanes, out_ftrs, stride=1, kernel_size=3, downsample=None): + super().__init__() + + self.conv1 = nn.Conv1d(inplanes, out_ftrs, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(out_ftrs) + + self.conv2 = nn.Conv1d( + out_ftrs, + out_ftrs, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + bias=False, + ) + self.bn2 = nn.BatchNorm1d(out_ftrs) + + self.conv3 = nn.Conv1d(out_ftrs, out_ftrs * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm1d(out_ftrs * 4) + + self.out_features = out_ftrs * 4 + + self.stride = stride + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet1d(nn.Module): + """Paper: https://arxiv.org/pdf/1512.03385.pdf""" + + def __init__( + self, + block, + layers, + kernel_size=3, + num_classes=2, + input_channels=3, + inplanes=64, + fix_feature_dim=False, + kernel_size_stem=None, + stride_stem=2, + pooling_stem=True, + stride=2, + lin_ftrs_head=None, + ps_head=0.5, + bn_final_head=False, + bn_head=True, + act_head="relu", + concat_pooling=True, + ): + super(ResNet1d, self).__init__() + + self.stem = None + self.backbone = None + self.pooling_adapter = None + self.head = None + self.inplanes = inplanes + + self.kernel_size_stem = ( + kernel_size if kernel_size_stem is None else kernel_size_stem + ) + + # stem + self.stem = self._make_stem( + in_channels=input_channels, + inplanes=inplanes, + kernel_size=self.kernel_size_stem, + stride=stride_stem, + pooling=pooling_stem, + ) + # backbone + self.backbone = self._make_backbone( + inplanes=inplanes, + bb_layers=layers, + bb_block=block, + feature_dim=fix_feature_dim, + bb_kernel_size=kernel_size, + bb_stride=stride, + ) + # head + head_ftrs = ( + inplanes if fix_feature_dim else (2 ** len(layers) * inplanes) + ) * block.expansion + self.head = self._make_head( + n_features=head_ftrs, + n_classes=num_classes, + lin_ftrs=lin_ftrs_head, + ps=ps_head, + bn_final=bn_final_head, + bn=bn_head, + act=act_head, + concat_pooling=concat_pooling, + ) + + def _make_stem(self, in_channels, inplanes, kernel_size, stride, pooling): + stem = nn.Sequential( + nn.Conv1d( + in_channels, + inplanes, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + bias=False, + ), + nn.BatchNorm1d(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool1d(kernel_size=3, stride=2, padding=1) if pooling else None, + ) + return stem + + def _make_backbone( + self, bb_layers, bb_block, inplanes, feature_dim, bb_kernel_size, bb_stride + ): + backbone = nn.Sequential() + + for i, blocks in enumerate(bb_layers): + if not backbone: + out_ftrs = inplanes + else: + out_ftrs = inplanes if feature_dim else (2**i) * inplanes + + backbone.add_module( + "hidden_block{}".format(i), + self._make_block( + bb_block=bb_block, + out_ftrs=out_ftrs, + blocks=bb_layers[i], + stride=bb_stride, + kernel_size=bb_kernel_size, + ), + ) + + return backbone + + def _make_block(self, bb_block, out_ftrs, blocks, stride=1, kernel_size=3): + downsample = None + if stride != 1 or self.inplanes != out_ftrs * bb_block.expansion: + downsample = self._perform_downsample(bb_block, out_ftrs, stride) + + block_layers = nn.Sequential() + block_layers.add_module( + f"{bb_block.__name__}_layer0", + bb_block(self.inplanes, out_ftrs, stride, kernel_size, downsample), + ) + + self.inplanes = out_ftrs * bb_block.expansion + + for i in range(1, blocks): + block_layers.add_module( + f"{bb_block.__name__}_layer{i}", bb_block(self.inplanes, out_ftrs) + ) + + return block_layers + + def _perform_downsample(self, block, out_ftrs, stride): + downsample = nn.Sequential() + + downsample.add_module( + f"{block.__name__}_downsample", + nn.Conv1d( + self.inplanes, + out_ftrs * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + ) + downsample.add_module( + f"{block.__name__}_normalize", + nn.BatchNorm1d(out_ftrs * block.expansion), + ) + return downsample + + def _make_head( + self, + n_features, + n_classes, + lin_ftrs=None, + ps=0.5, + bn_final=False, + bn=True, + act="relu", + concat_pooling=True, + ): + lin_ftrs = ( + [n_features if concat_pooling else n_features, n_classes] + if lin_ftrs is None + else [2 * n_features if concat_pooling else n_features] + + lin_ftrs + + [n_classes] + ) + + probs = listify(ps) + if len(probs) == 1: + probs = [probs[0] / 2] * (len(lin_ftrs) - 2) + probs + + actns = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * ( + len(lin_ftrs) - 2 + ) + [None] + + pooling_adapter = nn.Sequential( + AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), + nn.Flatten(), + ) + layers = nn.Sequential() + layers.add_module("pooling_adapter", pooling_adapter) + + for ni, no, p, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], probs, actns): + layers.add_module("lin_bn_drop", LinBnDrop(ni, no, bn, p, actn)) + + if bn_final: + layers.add_module("bn_final", nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01)) + + return layers + + def get_cnn(self): + return ( + nn.Sequential(self.stem, self.backbone), + self.backbone[-1][-1].out_features, + ) + + def forward(self, x): + y = self.stem(x) + y = self.backbone(y) + y = self.head(y) + return y + + +def resnet1d18(**kwargs): + kwargs["block"] = ResidualBlock1d + kwargs["layers"] = [1, 2, 2, 1] + return ResNet1d(**kwargs) + + +def resnet1d50(**kwargs): + kwargs["block"] = ResidualBlock1d + kwargs["layers"] = [3, 4, 6, 3] + return ResNet1d(**kwargs) + + +def resnet1d101(**kwargs): + kwargs["block"] = ResidualBlock1d + kwargs["layers"] = [3, 4, 23, 3] + return ResNet1d(**kwargs) diff --git a/src/ecglib/models/architectures/tabular.py b/src/ecglib/models/architectures/tabular.py new file mode 100644 index 0000000..2a44e96 --- /dev/null +++ b/src/ecglib/models/architectures/tabular.py @@ -0,0 +1,71 @@ +import torch +from fastai.layers import LinBnDrop +from fastcore.basics import listify + + +class TabularNet(torch.nn.Module): + def __init__( + self, + inp_features, + lin_ftrs=[10, 10, 10, 10, 10, 8], + drop=[0.5], + act_fn="relu", + bn_last=True, + act_last=True, + drop_last=True, + ): + super(TabularNet, self).__init__() + + criterion = ( + torch.nn.ReLU(inplace=True) + if act_fn == "relu" + else torch.nn.ELU(inplace=True) + ) + actns = [criterion] * (len(lin_ftrs) - 2) + bns = [True] * (len(lin_ftrs) - 2) + + actns.append(criterion) if act_last else actns.append(None) + bns.append(True) if bn_last else bns.append(False) + + if isinstance(drop, int): + drop = listify(drop) + + if len(drop) == 1: + drop_ps = [drop[0] / 2] * (len(lin_ftrs) - 2) + drop_ps.append(drop[0]) if drop_last else drop_ps.append(0.0) + elif len(drop) != len(lin_ftrs): + raise + + self.mlp = torch.nn.Sequential() + for i, (n_inp, n_out, drop_p, bn, act_fn) in enumerate( + zip(lin_ftrs[:-1], lin_ftrs[1:], drop_ps, bns, actns) + ): + if not self.mlp: + self.mlp.add_module( + "input", + LinBnDrop( + n_in=inp_features, + n_out=n_inp, + bn=bn, + p=drop_p, + act=act_fn, + lin_first=True, + ), + ) + + self.mlp.add_module( + "hidden{}".format(i), + LinBnDrop( + n_in=n_inp, n_out=n_out, bn=bn, p=drop_p, act=act_fn, lin_first=True + ), + ) + + self.out_size = lin_ftrs[-1] + + def forward(self, x): + return self.mlp(x) + + +def tabular(**kwargs): + """Constructs an TabularNet model""" + return TabularNet(**kwargs) diff --git a/ecglib/models/weights/__init__.py b/src/ecglib/models/config/__init__.py similarity index 100% rename from ecglib/models/weights/__init__.py rename to src/ecglib/models/config/__init__.py diff --git a/src/ecglib/models/config/model_configs.py b/src/ecglib/models/config/model_configs.py new file mode 100644 index 0000000..002fb9e --- /dev/null +++ b/src/ecglib/models/config/model_configs.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass, field +from typing import Optional, Union, Tuple, Dict, Any + +import torch + +__all__ = [ + "BaseConfig", + "ResNetConfig", + "TabularNetConfig", + "DenseNetConfig", +] + + +@dataclass +class BaseConfig: + """ + Base congiguration class + """ + + checkpoint_url: Optional[Union[str, Tuple[str, str]]] = None + checkpoint_file: Optional[str] = None + configs_hub: Optional[str] = None + + @property + def has_checkpoint(self) -> bool: + return self.checkpoint_url or self.checkpoint_file or self.configs_hub + + def dict(self) -> Dict: + return {f: getattr(self, f) for f in self.__annotations__} + + def load_checkpoint(self) -> Any: + assert self.has_checkpoint() + if self.checkpoint_file: + return torch.load(self.checkpoint_file) + elif self.checkpoint_url: + raise NotImplementedError + else: + raise FileNotFoundError + + def load_from_hub(self): + # Must be implemented in child classes to load checkpoints from `configs_hub` + raise NotImplementedError + + +@dataclass(repr=True, eq=True) +class ResNetConfig(BaseConfig): + """ + Default parameters correspond Resnet1d50 model + """ + + layers: list = field(default_factory=lambda: [3, 4, 6, 3]) + kernel_size: int = 3 + num_classes: int = 1 + input_channels: int = 12 + inplanes: int = 64 + fix_feature_dim: bool = False + kernel_size_stem: Optional[int] = None + stride_stem: int = 2 + pooling_stem: bool = True + stride: int = 2 + lin_ftrs_head: Optional[list] = None + ps_head: float = 0.5 + bn_final_head: bool = False + bn_head: bool = True + act_head: str = "relu" + concat_pooling: bool = True + + +@dataclass(repr=True, eq=True) +class TabularNetConfig(BaseConfig): + """ + Default parameters correspond TabularNet model + """ + + inp_features: int = 5 + lin_ftrs: list = field( + default_factory=lambda: [10, 10, 10, 10, 10, 8], + ) + drop: Union[int, str] = field(default_factory=lambda: [0.5]) + act_fn: str = field(default_factory=lambda: "relu") + bn_last: bool = field(default_factory=lambda: True) + act_last: bool = field(default_factory=lambda: True) + drop_last: bool = field(default_factory=lambda: True) + + +@dataclass(repr=True, eq=True) +class DenseNetConfig(BaseConfig): + """ + Default parameters correspond DenseNet121_1d model + """ + + growth_rate: int = 32 + block_config: tuple = (6, 12, 24, 16) + num_init_features: int = 64 + bottleneck_size: int = 4 + kernel_size: int = 3 + input_channels: int = 12 + num_classes: int = 1 + reinit: bool = True diff --git a/src/ecglib/models/config/registred_configs.py b/src/ecglib/models/config/registred_configs.py new file mode 100644 index 0000000..c0224d1 --- /dev/null +++ b/src/ecglib/models/config/registred_configs.py @@ -0,0 +1,70 @@ +from .model_configs import ( + BaseConfig, + ResNetConfig, + DenseNetConfig, + TabularNetConfig, +) + +from ..architectures.model_types import MType + +from typing import List, Any + +__all__ = ["register_config", "registred_configs", "config", "is_conf_registred"] + +# extensible config's storage +BUILTIN_CONFIGS = { + MType.RESNET: ResNetConfig, + MType.DENSENET: DenseNetConfig, + MType.TABULAR: TabularNetConfig, +} + + +def register_config( + model_type: MType, +) -> Any: + """ + Function decorator which helps to register new configs + """ + + def wrapper(config_obj: Any) -> Any: + # key = model_type if model_type is not None else config_obj.__name__ + if model_type in BUILTIN_CONFIGS: + raise ValueError( + f"An entry is already registered under the name '{model_type}'." + ) + BUILTIN_CONFIGS[model_type] = config_obj + return config_obj + + return wrapper + + +def registred_configs() -> List[str]: + """ + Returns a list with the names of registered configs. + """ + return list(BUILTIN_CONFIGS.keys()) + + +def config(model_type: MType) -> BaseConfig: + """ + Returns config object class. + + param: name (str): Model name. + return: config (BaseConfig) object + """ + + if model_type not in BUILTIN_CONFIGS: + raise ValueError( + f"An entry is not registered in `BUILTIN_CONFIGS`. Available configs: {registred_configs()}." + ) + return BUILTIN_CONFIGS[model_type] + + +def is_conf_registred(model_type: MType) -> bool: + """ + Checks is model was registered. + + param: name (str): Model name. + return: Boolean flag. + """ + return model_type in BUILTIN_CONFIGS diff --git a/src/ecglib/models/model_builder.py b/src/ecglib/models/model_builder.py new file mode 100644 index 0000000..f981c1d --- /dev/null +++ b/src/ecglib/models/model_builder.py @@ -0,0 +1,269 @@ +import os +import yaml +from enum import IntEnum +from urllib.parse import urlparse +from typing import Callable, Dict, List, Union, Optional, Tuple +import dataclasses + +import torch +from torch.nn import Module + +from .architectures.registred_models import get_builder +from .config.registred_configs import ( + registred_configs, + config, + MType, +) +from .config.model_configs import BaseConfig +from .architectures.cnn_tabular import CnnTabular +from .weights.checkpoint import ModelChekpoint + +resource_package = __name__ + +__all__ = [ + "Combination", + "create_model", + "get_ecglib_url", + "get_model", + "get_config", + "weights_from_checkpoint", + "save_checkpoint", +] + + +class Combination(IntEnum): + SINGLE = 1 + CNNTAB = 2 + + @staticmethod + def from_string(label: str) -> IntEnum: + label = label.lower() + if "single" in label: + return Combination.SINGLE + elif "cnntab" in label: + return Combination.CNNTAB + else: + raise ValueError(f"label for combination must be one of [\'single\', \'cnntab\']. Recieved {label}") + + +def create_model( + model_name: Union[str, List[str]], + config: Union[BaseConfig, List[BaseConfig]] = None, + combine: Union[Combination, str] = Combination.SINGLE, + pretrained: bool = False, + pretrained_path: str = "ecglib", + pathology: Union[str, List[str]] = "AFIB", + leads_count: int = 12, + num_classes: int = 1, +) -> Module: + weights = None + configs = [config] if isinstance(config, BaseConfig) else config + model_name = [model_name] if isinstance(model_name, str) else model_name + combine = Combination.from_string(combine) if isinstance(combine, str) else combine + + if pretrained: # Currently working only for Combination.SINGLE + assert ( + combine is Combination.SINGLE + ), "pretrained is currently working only for Combination.SINGLE" + + assert ( + num_classes == 1 + ), "pretrained is currently working only for binary classification" + + if pretrained_path == "ecglib": + pretrained_path = get_ecglib_url( + model_name=model_name[0], pathology=pathology, leads_count=leads_count + ) + + weights = weights_from_checkpoint( + pretrained_path, + meta_info={ + "pathology": pathology, + "leads_count": str(leads_count), + "model_name": model_name[0], + }, + ) + + if combine is Combination.SINGLE: + assert ( + len(model_name) == 1 + ), "For Combination.SINGLE case `model_name` must contain only one model name" + + if configs: + configs = configs[0] + + return get_model(name=model_name[0], config=configs, weights=weights) + elif combine is Combination.CNNTAB: + assert ( + len(model_name) == 2 + ), f"For Combination.CNNTAB case `model_name` must contain 2 model names (N models in future releases)" + + assert ( + model_name[0] != 'tabular' and model_name[1] == 'tabular' + ), f"Combination.CNNTAB suggest using a cnn-like architecture as a cnn_backbone part and using TabularNet class as tabular model. Recieved: {model_name[0]}; {model_name[1]}" + + cnn_conf, tab_conf = ( + (None, None) if configs == None else (configs[0], configs[1]) + ) + + cnn, cnn_out = get_model(name=model_name[0], config=cnn_conf).get_cnn() + tab = get_model(name=model_name[1], config=tab_conf) + tab_out = tab.out_size + + model = CnnTabular( + cnn_backbone=cnn, + cnn_out_features=cnn_out, + tabular_model=tab, + tabular_out_features=tab_out, + classes=num_classes, + head_ftrs=[512], + head_drop_prob=0.2, + ) + + return model + else: + raise ValueError + + +def get_ecglib_url(model_name: str, pathology: str, leads_count: int): + try: + dirname = os.path.dirname(__file__) + weights_path = os.path.join(dirname, "weights/model_weights_paths.yaml") + with open(weights_path, "r") as file: + url = yaml.safe_load(file)[f"{leads_count}_leads"][pathology][model_name] + return url + except KeyError as e: + raise KeyError(f"Key {str(e)} doesn't exist. Check 'ecglib/model/weights/model_weights_paths.yaml\' to see all option for pretrained models.") + + +def _get_model_builder(name: str) -> Callable[..., Module]: + """ + Gets the model name and returns the model builder method. + + param: name (str): The name under which the model is registered. + param: fn (Callable): The model builder method. + """ + name = name.lower() + try: + fn = get_builder(name) + except KeyError: + raise ValueError(f"Unknown model {name}") + return fn + + +def _get_config_stub(model_type: MType) -> BaseConfig: + """ + Gets the model name and returns the model builder method. + + param: model_name (str): The name under which the model is registered. + return: config object (BaseConfig) + """ + try: + conf_obj = config(model_type) + except KeyError: + raise ValueError( + f"Unknown model type {model_type}. It must be in {registred_configs()}" + ) + return conf_obj + + +def get_model( + name: str, config: Optional[BaseConfig] = None, weights: dict = None +) -> Module: + """ + Gets the model name and configuration and returns an instantiated model. + + param: name (str): The name under which the model is registered. + param: config (BaseConfig): Object which contains parameters for the model builder method. + param: weights (dict): model weights. + return: model (nn.Module): The initialized model. + """ + builder = _get_model_builder(name) + if not config: + config = get_config(MType.from_string(name)) + model = builder(**config.dict()) + if weights: + assert isinstance(weights, dict) + model.load_state_dict(weights, strict=True) + return model + return model + + +def get_config( + model_type: MType, + params_overlay: Optional[Dict] = None, +) -> BaseConfig: + """ + Returns Config class according the model type + + param: model_type(MType): The type of model. + param: params_overlay (dict): Replace key-values in BaseConfig object with these (NOTE: Only identical keys for `params_overlay` and `BaseConfig instance` will be replaced). + rerurn config (BaseConfig): BaseConfig instance + """ + builder = _get_config_stub(model_type) # get config object + if params_overlay: + s1 = set(params_overlay.keys()) + s2 = set(f.name for f in dataclasses.fields(builder)) + identical_keys = s1 & s2 + params_overlay = dict([(key, params_overlay[key]) for key in identical_keys]) + return builder(**params_overlay) + return builder() + + +def _checkpoint_from_local(checkpoint_path: str) -> ModelChekpoint: + """ + Load model info from local checkpoint file. + + param: checkpoint_path(str): Path to checkpoint file. + return: Dictionary (Dict) which contains information about model checkpoint. + """ + model_info = torch.load(checkpoint_path, map_location=torch.device("cpu")) + return ModelChekpoint(model_info) + + +def _checkpoint_from_remote(url: str, meta_info: dict) -> ModelChekpoint: + """ + Load model info using remote link. + + param: url(str): Link to remote checkpoint file. + return: Dictionary (Dict) which contains information about model checkpoint. + """ + model_info = torch.hub.load_state_dict_from_url( + url=url, + map_location='cpu', + progress=False, + check_hash=False, + file_name=f"{meta_info['leads_count']}_leads_{meta_info['model_name']}_{meta_info['pathology']}_1_1_0.pt", + ) + return ModelChekpoint(model_info) + + +def weights_from_checkpoint( + checkpoint_path: str, + meta_info: dict, +) -> Tuple: + """ + Return BaseConfig instance from checkpoint file. + + param: checkpoint_path (str): Path to checkpoint file. + + return: config: Tuple which contains models weights, models configs, and experiment information. + """ + + model_info = None + if urlparse(checkpoint_path).scheme: + model_info = _checkpoint_from_remote(url=checkpoint_path, meta_info=meta_info) + else: + model_info = _checkpoint_from_local(checkpoint_path=checkpoint_path) + + return model_info + + +def save_checkpoint(path: str, model: Module, info: Dict, exclude: List = None) -> None: + """ + Save model weights and experiment info as checpoint file + + """ + info["model"] = model.state_dict + checkpoint = ModelChekpoint.make_checkpoint(model_info=info, exclude_keys=exclude) + torch.save(checkpoint, path) diff --git a/src/ecglib/models/weights/__init__.py b/src/ecglib/models/weights/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ecglib/models/weights/checkpoint.py b/src/ecglib/models/weights/checkpoint.py new file mode 100644 index 0000000..49d4650 --- /dev/null +++ b/src/ecglib/models/weights/checkpoint.py @@ -0,0 +1,21 @@ +from typing import Dict, List + + +class ModelChekpoint(dict): + @classmethod + def make_checkpoint(cls, model_info: Dict, exclude_keys: List = None): + """Make ModelChekpoint instance from input dictionary + + :param model_info: _description_ + :type model_info: Dict + :param exclude_keys: _description_, defaults to None + :type exclude_keys: List, optional + :return: _description_ + :rtype: _type_ + """ + if exclude_keys: + assert isinstance(exclude_keys, list) + for key in exclude_keys: + del model_info[key] + + return cls(model_info) diff --git a/src/ecglib/models/weights/model_weights_paths.yaml b/src/ecglib/models/weights/model_weights_paths.yaml new file mode 100644 index 0000000..400811b --- /dev/null +++ b/src/ecglib/models/weights/model_weights_paths.yaml @@ -0,0 +1,36 @@ +12_leads: + AFIB: + densenet1d121: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_densenet1d121_AFIB.pt" + resnet1d18: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d18_AFIB.pt" + resnet1d50: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d50_AFIB.pt" + resnet1d101: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d101_AFIB.pt" + 1AVB: + densenet1d121: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_densenet1d121_1AVB.pt" + resnet1d18: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d18_1AVB.pt" + resnet1d50: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d50_1AVB.pt" + resnet1d101: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d101_1AVB.pt" + SBRAD: + densenet1d121: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_densenet1d121_SBRAD.pt" + resnet1d18: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d18_SBRAD.pt" + resnet1d50: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d50_SBRAD.pt" + resnet1d101: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d101_SBRAD.pt" + STACH: + densenet1d121: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_densenet1d121_STACH.pt" + resnet1d18: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d18_STACH.pt" + resnet1d50: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d50_STACH.pt" + resnet1d101: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d101_STACH.pt" + PVC: + densenet1d121: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_densenet1d121_PVC.pt" + resnet1d18: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d18_PVC.pt" + resnet1d50: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d50_PVC.pt" + resnet1d101: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d101_PVC.pt" + CRBBB: + densenet1d121: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_densenet1d121_CRBBB.pt" + resnet1d18: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d18_CRBBB.pt" + resnet1d50: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d50_CRBBB.pt" + resnet1d101: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d101_CRBBB.pt" + IRBBB: + densenet1d121: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_densenet1d121_IRBBB.pt" + resnet1d18: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d18_IRBBB.pt" + resnet1d50: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d50_IRBBB.pt" + resnet1d101: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d101_IRBBB.pt" diff --git a/src/ecglib/predict/__init__.py b/src/ecglib/predict/__init__.py new file mode 100644 index 0000000..c83e04d --- /dev/null +++ b/src/ecglib/predict/__init__.py @@ -0,0 +1,5 @@ +from .predict import * + +__all__ = [ + "predict", +] diff --git a/src/ecglib/predict/predict.py b/src/ecglib/predict/predict.py new file mode 100644 index 0000000..478db1c --- /dev/null +++ b/src/ecglib/predict/predict.py @@ -0,0 +1,298 @@ +import os +from typing import List, Union + +import pandas as pd +import numpy as np +import torch + +from ecglib.models.model_builder import ( + create_model, + Combination, +) +from ecglib import preprocessing as P +from ecglib.models.config.model_configs import BaseConfig +from ecglib.data.datasets import EcgDataset + + +def tabular_metadata_handler( + patient_metadata: torch.Tensor, ecg_metadata: torch.Tensor, axis: int = 1 +) -> torch.Tensor: + """ + Concatenates patient and ECG metadata along a specified axis. + + :param patient_metadata: torch.Tensor, patient metadata + :param ecg_metadata: torch.Tensor, ECG metadata + :param axis: int, axis along which the tensors will be concatenated (default is 1) + + :return: torch.Tensor, concatenated patient and ECG metadata + """ + assert isinstance(patient_metadata, torch.Tensor) + assert isinstance(ecg_metadata, torch.Tensor) + return torch.concat((patient_metadata, ecg_metadata), axis) + + +def get_full_record( + frequency: int, + record: str, + patient_meta: dict, + ecg_meta: dict, + normalization: str = "z_norm", + use_metadata: bool = False, + preprocess: list = None, +) -> list: + """ + Returns a full record from raw record, patient metadata, ECG metadata, and configuration. + + :param frequency: int, frequency of the ECG record + :param record: str, path to ECG record + :param patient_meta: dict, patient metadata + :param ecg_meta: dict, ECG metadata + :param normalization: str, normalization type + :param use_metadata: bool, whether to use metadata or not + :param preprocess: list of preprocessing methods applied to ECG record + + :return: list, processed ECG record along with its metadata + """ + + record = record[:,] + frequency = int(frequency) + patient_meta = patient_meta + trained_freq = frequency + if preprocess: + record_processed = P.Compose(transforms=preprocess, p=1.0)(record) + else: + record_processed = P.Compose( + transforms=[ + P.FrequencyResample( + ecg_frequency=frequency, requested_frequency=trained_freq + ), + P.Normalization(norm_type=normalization), + ], + p=1.0, + )(record) + + assert not np.isnan(record_processed).any(), "NaN values in record" + + ecg_tensor = torch.tensor(record_processed, dtype=torch.float) + if use_metadata: + patient_meta_list = [float(param) for param in patient_meta.values()] + ecg_meta_list = [float(param) for param in ecg_meta.values()] + else: + patient_meta_list = [] + ecg_meta_list = [] + return [ecg_tensor, patient_meta_list, ecg_meta_list] + + +class Predict: + def __init__( + self, + weights_path: str, + model_name: str, + pathologies: list, + frequency: int, + device: str, + threshold: float, + leads: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + model: torch.nn.Module = None, + model_config: Union[BaseConfig, List[BaseConfig]] = None, + combine: Combination = Combination.SINGLE, + use_metadata: bool = False, + use_sigmoid: bool = True, + normalization: str = "z_norm", + ): + """ + Class for making predictions using a trained model. + + :param weights_path: str, path to the model weights + :param model_name: str, name of the model + :param pathologies: list, list of pathologies + :param frequency: int, frequency of the ECG record + :param device: str, device to be used for computations + :param threshold: float, threshold for the model + :param leads: list, list of leads + :param model: torch.nn.Module, model to be used for predictions + :param model_config: config with parameters of trained model + :param combine: Combination to select which type of model to use + :param use_metadata: bool, whether to use metadata or not + :param use_sigmoid: bool, whether to apply sigmoid after model output + :param normalization: str, normalization type + """ + + if not isinstance(pathologies, list): + pathologies = [pathologies] + + self.leads_num = len(leads) + self.leads = leads + self.device = device + self.frequency = frequency + self.use_sigmoid = use_sigmoid + + if model is None: + self.model = create_model( + model_name=model_name, + pretrained=True, + pretrained_path=weights_path, + pathology=pathologies, + leads_count=self.leads_num, + config=model_config, + combine=combine, + ) + else: + self.model = model + + self.model.to(self.device) + self.model.eval() + + self.handler = None + + self.use_metadata = use_metadata + self.normalization = normalization + + if use_metadata: + self.handler = tabular_metadata_handler + self.threshold = threshold + + def predict( + self, + record, + ecg_meta=None, + patient_meta=None, + channels_first=True, + ): + """ + Function that evaluates the model on a single ECG record. + + :param record: np.array or torch.tensor, ECG record + :param ecg_meta: dict, ECG metadata (default is None) + :param patient_meta: dict, patient metadata (default is None) + :param channels_first: bool, whether the channels are the first dimension in the input data (default is False) + + :return: dict, predicted probability, raw output, and labels + """ + + if patient_meta is None and ecg_meta is None and self.use_metadata: + raise ValueError("Patient or ECG metadata is required") + self.patient_meta = patient_meta + self.ecg_meta = ecg_meta + self.record = record + self.channels_first = channels_first + + if not self.channels_first: + if isinstance(self.record, torch.Tensor): + self.record = self.record.permute(1, 0) + elif isinstance(self.record, np.ndarray): + self.record = self.record.transpose(1, 0) + else: + raise ValueError( + "Record type must be either torch.tensor or np.array. Given type: {}".format( + type(self.record) + ) + ) + + input_ = get_full_record( + self.frequency, + self.record, + self.patient_meta, + self.ecg_meta, + self.normalization, + self.use_metadata, + ) + + ecg_signal = input_[0] + patient_meta = input_[1] + ecg_meta = input_[2] + inp = ( + ecg_signal + if not self.handler + else [ecg_signal, self.handler(patient_meta, ecg_meta)] + ) + + inp = ( + [item.to(self.device) for item in inp] + if isinstance(inp, list) + else inp.to(self.device) + ) + inp = inp.unsqueeze(0) + with torch.no_grad(): + outputs = self.model(inp) + if self.use_sigmoid: + probability = torch.nn.Sigmoid()(outputs) + else: + probability = outputs + label = (torch.nn.Sigmoid()(outputs) > self.threshold).float() + + return {"raw_out": outputs, "prob_out": probability, "label_out": label} + + def predict_directory( + self, + directory, + file_type, + write_to_file=None, + ecg_meta=None, + patient_meta=None, + ): + """ + Evaluates the model on all ECG records in a directory. + + :param directory: str, path to the directory with ECG records + :param file_type: str, file type of the ECG records + :param write_to_file: str, path to the file where the predictions will be written (default is None), or None if the predictions should not be written to a file + :param ecg_meta: list of dicts, each dict contains "filename" and "data" keys. ECG metadata (default is None) + :param patient_meta: list of dicts, each dict contains "filename" and "data" keys. Patient metadata (default is None) + + :return: pd.DataFrame, dataframe with the predictions + """ + if ecg_meta: + ecg_meta = sorted(ecg_meta, key=lambda k: k["filename"]) + if patient_meta: + patient_meta = sorted(patient_meta, key=lambda k: k["filename"]) + + all_files = os.listdir(directory) + + # filter by file_type + if file_type == "wfdb": + record_files = [file[:-4] for file in all_files if file.endswith(".dat")] + else: + record_files = [file for file in all_files if file.endswith(file_type)] + record_files = sorted(record_files) + + answer_df = pd.DataFrame( + columns=["filename", "raw_out", "prob_out", "label_out"] + ) + + ecg_meta_counter = 0 + patient_meta_counter = 0 + for record in record_files: + ecg_meta_ = None + patient_meta_ = None + + if ecg_meta: + if ecg_meta[ecg_meta_counter]["filename"] == record: + ecg_meta_counter += 1 + ecg_meta_ = ecg_meta[ecg_meta_counter]["data"] + + if patient_meta: + if patient_meta[patient_meta_counter]["filename"] == record: + patient_meta_counter += 1 + patient_meta_ = patient_meta[patient_meta_counter]["data"] + + record_ = EcgDataset.read_ecg_record( + None, os.path.join(directory, record), file_type + ) + record_answer = self.predict(record_, ecg_meta_, patient_meta_) + + answer_df_current = pd.DataFrame( + { + "filename": record, + "raw_out": record_answer["raw_out"].cpu().numpy().item(), + "prob_out": record_answer["prob_out"].cpu().numpy().item(), + "label_out": record_answer["label_out"].cpu().numpy().item(), + }, + index=[0], + ) + answer_df = pd.concat([answer_df_current, answer_df], ignore_index=True) + + if write_to_file: + answer_df.to_csv(write_to_file) + + return answer_df diff --git a/ecglib/preprocessing/__init__.py b/src/ecglib/preprocessing/__init__.py similarity index 62% rename from ecglib/preprocessing/__init__.py rename to src/ecglib/preprocessing/__init__.py index c55f1aa..8331647 100644 --- a/ecglib/preprocessing/__init__.py +++ b/src/ecglib/preprocessing/__init__.py @@ -4,7 +4,7 @@ __all__ = [ - 'functional', - 'preprocess', - 'composition', -] \ No newline at end of file + "functional", + "preprocess", + "composition", +] diff --git a/ecglib/preprocessing/composition.py b/src/ecglib/preprocessing/composition.py similarity index 85% rename from ecglib/preprocessing/composition.py rename to src/ecglib/preprocessing/composition.py index 89eeaac..0fe7aef 100644 --- a/ecglib/preprocessing/composition.py +++ b/src/ecglib/preprocessing/composition.py @@ -31,8 +31,9 @@ def __init__( self.p = p def __call__(self, x): - - idx = np.random.RandomState(random.randint(0, (1 << 32) - 1)).choice(2, p=[self.p, 1.0-self.p]) + idx = np.random.RandomState(random.randint(0, (1 << 32) - 1)).choice( + 2, p=[self.p, 1.0 - self.p] + ) if idx == 0: for t in self.transforms: x = t(x) @@ -59,7 +60,7 @@ def __init__( self.transforms = transforms self.n = n if transform_prob is None: - self.transform_prob = [1 / len(self.transforms)]*len(self.transforms) + self.transform_prob = [1 / len(self.transforms)] * len(self.transforms) else: if sum(transform_prob) > 1.0: raise ValueError("Sum of probabilities should be equal to 1.0") @@ -67,14 +68,15 @@ def __init__( self.transform_prob = transform_prob def __call__(self, x): - - idx = np.random.RandomState(random.randint(0, (1 << 32) - 1)).choice(len(self.transforms), size=self.n, p=self.transform_prob, replace=False) + idx = np.random.RandomState(random.randint(0, (1 << 32) - 1)).choice( + len(self.transforms), size=self.n, p=self.transform_prob, replace=False + ) for i in idx: t = self.transforms[i] data = t(x) - + return data - + class OneOf: """ @@ -84,7 +86,7 @@ class OneOf: :return: preprocessed data """ - + def __init__( self, transforms: list, @@ -92,7 +94,7 @@ def __init__( ): self.transforms = transforms if transform_prob is None: - self.transform_prob = [1 / len(self.transforms)]*len(self.transforms) + self.transform_prob = [1 / len(self.transforms)] * len(self.transforms) else: if sum(transform_prob) > 1.0: raise ValueError("Sum of probabilities should be equal to 1.0") @@ -100,9 +102,10 @@ def __init__( self.transform_prob = transform_prob def __call__(self, x): - - idx = np.random.RandomState(random.randint(0, (1 << 32) - 1)).choice(len(self.transforms), p=self.transform_prob) + idx = np.random.RandomState(random.randint(0, (1 << 32) - 1)).choice( + len(self.transforms), p=self.transform_prob + ) t = self.transforms[idx] x = t(x) - - return x \ No newline at end of file + + return x diff --git a/ecglib/preprocessing/functional.py b/src/ecglib/preprocessing/functional.py similarity index 53% rename from ecglib/preprocessing/functional.py rename to src/ecglib/preprocessing/functional.py index f829746..b522e2e 100644 --- a/ecglib/preprocessing/functional.py +++ b/src/ecglib/preprocessing/functional.py @@ -1,4 +1,5 @@ import copy +from typing import Union import numpy as np import pywt @@ -7,134 +8,154 @@ __all__ = [ - "butterworth_bandpass_filter", - "butterworth_highpass_filter", - "butterworth_lowpass_filter", + "frequency_resample", + "cut_ecg", + "butterworth_filter", "IIR_notch_filter", - "elliptic_bandpass_filter", + "elliptic_filter", "minmax_normalization", "z_normalization", - "cycle_normalization", "DWT_filter", "SWT_filter", "lead_crop", "time_crop", "sum_augmentation", "convex_augmentation", - "reflect_augmentation", - "ecg_to_one_frequency", "DWT_BW", - "cut_ecg", ] -def butterworth_bandpass_filter( - s: np.ndarray, - n: int = 10, - Wn: list = [3, 30], - fs: float = 500, +def frequency_resample( + ecg_record: np.ndarray, + ecg_frequency: int, + requested_frequency: int, ) -> np.ndarray: """ - Butterworth bandpass filter augmentation - :param s: one lead signal - :param n: filter order - :param Wn: cutoff frequencies - :param fs: filtered signal frequency + Frequency resample + :param record: signal + :param ecg_frequency: sampling frequency of a signal + :param requested_frequency: sampling frequency of a preprocessed signal :return: preprocessed data """ - - sos = signal.butter(N=n, btype='bandpass', Wn=Wn, fs=fs, output='sos') - filtered_signal = signal.sosfiltfilt(sos, s) - return filtered_signal + + if ecg_frequency == requested_frequency: + return ecg_record + ecg_record = signal.resample( + ecg_record, + int(ecg_record.shape[1] * requested_frequency / ecg_frequency), + axis=1, + ) + return ecg_record -def butterworth_highpass_filter( - s: np.ndarray, - n: int = 7, - Wn: float = 0.5, - fs: int = 500, +def cut_ecg( + data: np.ndarray, + cut_range: list, + frequency: int, ) -> np.ndarray: """ - Butterworth highpass filter augmentation - :param s: one lead signal - :param n: filter order - :param Wn: cutoff frequency - :param fs: filtered signal frequency + Cut signal edges + :param data: signal + :param cut_range: cutting parameters + :param frequency: sampling frequency of a signal :return: preprocessed data """ - - sos = signal.butter(N=n, btype='highpass', Wn=Wn, fs=fs, output='sos') - filtered_signal = signal.sosfiltfilt(sos, s) - return filtered_signal + cut_data = [] + start = int(cut_range[0] * frequency) + for rec in data: + end = -int(cut_range[1] * frequency) if cut_range[1] != 0 else len(rec) + cut_data.append(rec[start:end]) -def butterworth_lowpass_filter( - s: np.ndarray, - n: int = 6, - Wn: float = 20, - fs: int = 500, + return np.array(cut_data) + + +def butterworth_filter( + s: np.ndarray, + leads: list, + btype: str = "bandpass", + n: int = 10, + Wn: Union[float, int, list] = [3, 30], + fs: float = 500, ) -> np.ndarray: """ - Butterworth lowpass filter augmentation - :param s: one lead signal + Butterworth bandpass filter augmentation + :param s: ECG signal + :param leads: leads to be filtered + :param btype: type of Butterworth filter ('bandpass', 'lowpass' or 'highpass') :param n: filter order - :param Wn: cutoff frequency + :param Wn: cutoff frequency(ies) :param fs: filtered signal frequency :return: preprocessed data """ - - sos = signal.butter(N=n, btype='lowpass', Wn=Wn, fs=fs, output='sos') - filtered_signal = signal.sosfiltfilt(sos, s) - return filtered_signal + if btype == "bandpass" and not isinstance(Wn, list): + raise ValueError("Wn must be list type in case of bandpass filter") + elif (btype == "highpass" or btype == "lowpass") and not isinstance( + Wn, (int, float) + ): + raise ValueError(f"Wn must be a scalar in case of {btype} filter") + sos = signal.butter(N=n, btype=btype, Wn=Wn, fs=fs, output="sos") + s[leads, :] = signal.sosfiltfilt(sos, s[leads, :]) + return s def IIR_notch_filter( - s: np.ndarray, - w0: float = 50, - Q: float = 30, + s: np.ndarray, + leads: list, + w0: float = 50, + Q: float = 30, fs: int = 500, ) -> np.ndarray: """ IIR notch filter augmentation - :param s: one lead signal + :param s: ECG signal + :param leads: leads to be filtered :param w0: frequency to remove from a signal :param Q: quality factor :param fs: sampling frequency of a signal :return: preprocessed data """ - b, a = signal.iirnotch(w0=w0, Q=Q, fs=fs) - filtered_signal = signal.filtfilt(b, a, s) - return filtered_signal + s[leads, :] = signal.filtfilt(b, a, s[leads, :]) + return s -def elliptic_bandpass_filter( - s: np.ndarray, - n: int = 10, - rp: float = 4, - rs: float = 5, - Wn: list = [0.5, 50], +def elliptic_filter( + s: np.ndarray, + leads: list, + btype: str = "bandpass", + n: int = 10, + rp: float = 4, + rs: float = 5, + Wn: Union[float, int, list] = [0.5, 50], fs: int = 500, ) -> np.ndarray: """ - Elliptic bandpass filter - :param s: one lead signal + Elliptic filter + :param s: ECG signal + :param leads: leads to be filtered + :param btype: type of elliptic filter ('bandpass', 'lowpass' or 'highpass') :param n: filter order :param rp: maximum ripple allowed below unity gain in the passband :param rs: minimum attenuation required in the stop band - :param Wn: cutoff frequencies + :param Wn: cutoff frequency(ies) :param fs: filtered signal frequency - + :return: preprocessed data """ - - sos = signal.ellip(N=n, btype='bandpass', rp=rp, rs=rs, Wn=Wn, fs=fs, output='sos') - filtered_signal = signal.sosfiltfilt(sos, s) - return filtered_signal + if btype == "bandpass" and not isinstance(Wn, list): + raise ValueError("Wn must be list type in case of bandpass filter") + elif (btype == "highpass" or btype == "lowpass") and not isinstance( + Wn, (int, float) + ): + raise ValueError(f"Wn must be a scalar in case of {btype} filter") + sos = signal.ellip(N=n, btype=btype, rp=rp, rs=rs, Wn=Wn, fs=fs, output="sos") + s[leads, :] = signal.sosfiltfilt(sos, s[leads, :]) + return s def minmax_normalization( @@ -146,45 +167,35 @@ def minmax_normalization( :return: preprocessed data """ - - return (s-np.min(s))/(np.max(s)-np.min(s)) + smin = np.min(s) + smax = np.max(s) + s = (s - smin) / (smax - smin) + return s def z_normalization( s: np.ndarray, + handle_constant_axis: bool=False, ) -> np.ndarray: """ Z-normalization :param s: signal + :param handle_constant_axis: Flag indicating whether to handle constant values in the signal. :return: preprocessed data """ - - return zscore(s, axis=1, nan_policy='raise') - - -def cycle_normalization( - s: np.ndarray, -) -> np.ndarray: - """ - Cycle normalization - :param s: one lead signal - - :return: preprocessed data - """ - - smin = np.min(s) - smax = np.max(s) - n = len(s) - 1 - i = np.arange(len(s)) - return ((n - i) / n) * ((s - smin) / (smax - smin)) + (i / n) * ((s - smin) / (smax - smin)) + s_norm = zscore(s, axis=1, nan_policy="raise") + if handle_constant_axis: + same_values = np.all(s == s[:, 0][:, np.newaxis], axis=1) + s_norm[same_values] = 0 + return s_norm def DWT_filter( - s: np.ndarray, - wavelet: str = 'db4', - level: int = 3, - threshold: float = 2, + s: np.ndarray, + wavelet: str = "db4", + level: int = 3, + threshold: float = 2, low: float = 1e6, ) -> np.ndarray: """ @@ -197,7 +208,7 @@ def DWT_filter( :return: preprocessed data """ - + w = pywt.Wavelet(wavelet) maxlev = pywt.dwt_max_level(len(s), w.dec_len) @@ -205,14 +216,16 @@ def DWT_filter( coeffs = pywt.wavedec(s, w, level=level) for i in range(1, len(coeffs)): - coeffs[i] = pywt.threshold(coeffs[i], threshold * np.sqrt(np.log2(len(coeffs[i]))), mode='soft') - coeffs[0] = pywt.threshold(coeffs[0], low, mode='less') - return pywt.waverec(coeffs, wavelet, mode='periodic') + coeffs[i] = pywt.threshold( + coeffs[i], threshold * np.sqrt(np.log2(len(coeffs[i]))), mode="soft" + ) + coeffs[0] = pywt.threshold(coeffs[0], low, mode="less") + return pywt.waverec(coeffs, wavelet, mode="periodic") def SWT_filter( - s: np.ndarray, - wavelet: str = 'db4', + s: np.ndarray, + wavelet: str = "db4", level: int = 6, ) -> np.ndarray: """ @@ -223,15 +236,15 @@ def SWT_filter( :return: preprocessed data """ - + if len(s) % 2 == 0: - width = (2**int(np.ceil(np.log2(len(s)))) - len(s)) // 2 - s_padded = np.pad(s, pad_width=width, mode='symmetric') + width = (2 ** int(np.ceil(np.log2(len(s)))) - len(s)) // 2 + s_padded = np.pad(s, pad_width=width, mode="symmetric") else: - width1 = (2**int(np.ceil(np.log2(len(s)))) - len(s)) // 2 - width2 = (2**int(np.ceil(np.log2(len(s)))) - len(s)) // 2 + 1 - s_padded = np.pad(s, pad_width=(width1, width2), mode='symmetric') - + width1 = (2 ** int(np.ceil(np.log2(len(s)))) - len(s)) // 2 + width2 = (2 ** int(np.ceil(np.log2(len(s)))) - len(s)) // 2 + 1 + s_padded = np.pad(s, pad_width=(width1, width2), mode="symmetric") + w = pywt.Wavelet(wavelet) maxlev = pywt.swt_max_level(len(s_padded)) @@ -242,7 +255,7 @@ def SWT_filter( def lead_crop( - record: np.ndarray, + record: np.ndarray, leads: list, ) -> np.ndarray: """ @@ -252,14 +265,14 @@ def lead_crop( :return: preprocessed data """ - + record[leads, :] = np.zeros(record.shape[1]) return record def time_crop( - record: np.ndarray, - time: int, + record: np.ndarray, + time: int, leads: list, ) -> np.ndarray: """ @@ -270,12 +283,12 @@ def time_crop( :return: preprocessed data """ - + assert time <= record.shape[1] for lead in leads: - ls = np.arange(0, len(record[lead])-time, dtype="int") + ls = np.arange(0, len(record[lead]) - time, dtype="int") start = np.random.choice(ls) - record[lead, start:start+time] = 0 + record[lead, start : start + time] = 0 return record @@ -285,49 +298,34 @@ def sum_augmentation( ) -> np.ndarray: """ Signal summation augmentation - + :param record: signal :param leads: leads to be replaced by sum of all leads :return: preprocessed data """ - + record[leads, :] = np.sum(record, axis=0) return record -def reflect_augmentation( - record: np.ndarray, -) -> np.ndarray: - """ - Reflection augmentation - - :param record: signal - - :return: preprocessed data - :rtype: numpy 2d array - """ - - return -record - - def convex_augmentation( record: np.ndarray, leads: list, ) -> np.ndarray: """ Convex augmentation - + :param record: signal :param leads: leads to be replaced by convex combination of some leads (chosen randomly) :return: preprocessed data """ - + result = copy.deepcopy(record) - ls = np.arange(12, dtype='int') + ls = np.arange(12, dtype="int") for ltr in leads: - ln = np.random.randint(12, size=1)[0] + ln = np.random.randint(1, 13, size=1)[0] leads_to_sum = np.random.choice(ls, size=ln, replace=False) convex_coeffs = np.random.dirichlet(np.ones(ln), size=1)[0] if ln != 0: @@ -335,30 +333,10 @@ def convex_augmentation( return result -def ecg_to_one_frequency( - ecg_record: np.ndarray, - ecg_frequency: int, - requested_frequency: int, -) -> np.ndarray: - """ - Frequency resample - :param record: signal - :param ecg_frequency: sampling frequency of a signal - :param requested_frequency: sampling frequency of a preprocessed signal - - :return: preprocessed data - """ - - if ecg_frequency == requested_frequency: - return ecg_record - ecg_record = signal.resample(ecg_record, int(ecg_record.shape[1] * requested_frequency / ecg_frequency), axis=1) - return ecg_record - - def DWT_BW( - s: np.ndarray, - wavelet: str = 'db4', -)-> np.ndarray: + s: np.ndarray, + wavelet: str = "db4", +) -> np.ndarray: """ Remove baseline wander using wavelets (see article https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.308.6789&rep=rep1&type=pdf) :param s: one lead signal @@ -366,47 +344,26 @@ def DWT_BW( :return: preprocessed data """ - + w = pywt.Wavelet(wavelet) maxlev = pywt.dwt_max_level(len(s), w.dec_len) - + diffs = [] for i in range(1, maxlev + 1): - coeffs = pywt.wavedec(s, w, level=i, mode='periodic') - diff = np.sum(np.square(coeffs[0])) - np.sum([np.sum(np.square(coeffs[j])) for j in range(1, i)]) + coeffs = pywt.wavedec(s, w, level=i, mode="periodic") + diff = np.sum(np.square(coeffs[0])) - np.sum( + [np.sum(np.square(coeffs[j])) for j in range(1, i)] + ) diffs.append(diff) diffs = np.array(diffs) if np.max(diffs) > 6500: ixs = np.where(diffs > 6500)[0] ix = ixs[np.argmin(diffs[ixs])] - coeffs = pywt.wavedec(s, w, level=ix, mode='periodic') - coeffs[0] = np.array([0]* len(coeffs[0])) + coeffs = pywt.wavedec(s, w, level=ix, mode="periodic") + coeffs[0] = np.array([0] * len(coeffs[0])) else: ix = np.argmin(diffs[-3:]) ix += len(diffs) - 3 - coeffs = pywt.wavedec(s, w, level=ix, mode='periodic') - coeffs[0] = np.array([0]* len(coeffs[0])) - return pywt.waverec(coeffs, wavelet, mode='periodic') - - -def cut_ecg( - data: np.ndarray, - cut_range: list, - frequency: int, -) -> np.ndarray: - """ - Cut signal edges - :param data: signal - :param cut_range: cutting parameters - :param frequency: sampling frequency of a signal - - :return: preprocessed data - """ - - cut_data = [] - start = int(cut_range[0]*frequency) - for rec in data: - end = -int(cut_range[1]*frequency) if cut_range[1] != 0 else len(rec) - cut_data.append(rec[start: end]) - - return np.array(cut_data) \ No newline at end of file + coeffs = pywt.wavedec(s, w, level=ix, mode="periodic") + coeffs[0] = np.array([0] * len(coeffs[0])) + return pywt.waverec(coeffs, wavelet, mode="periodic") diff --git a/src/ecglib/preprocessing/preprocess.py b/src/ecglib/preprocessing/preprocess.py new file mode 100644 index 0000000..0c7d050 --- /dev/null +++ b/src/ecglib/preprocessing/preprocess.py @@ -0,0 +1,816 @@ +from typing import Union + +import numpy as np + +from . import functional as F +import ecglib + + +__all__ = [ + "FrequencyResample", + "Padding", + "EdgeCut", + "Normalization", + "ButterworthFilter", + "IIRNotchFilter", + "EllipticFilter", + "BaselineWanderRemoval", + "WaveletTransform", + "LeadCrop", + "RandomLeadCrop", + "TimeCrop", + "RandomTimeCrop", + "SumAug", + "RandomSumAug", + "ConvexAug", + "RandomConvexAug", +] + + +class FrequencyResample: + """ + Apply frequency resample + :param ecg_frequency: sampling frequency of a signal + :param requested_frequency: sampling frequency of a preprocessed signal + + :return: preprocessed data + """ + + def __init__( + self, + ecg_frequency: int, + requested_frequency: int = 500, + ): + if isinstance(ecg_frequency, (int, float)): + self.ecg_frequency = ecg_frequency + else: + raise ValueError("ecg_frequency must be scalar") + if isinstance(requested_frequency, (int, float)): + self.requested_frequency = requested_frequency + else: + raise ValueError("requested_frequency must be scalar") + self.func = F.frequency_resample + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.func( + x.signal, int(x.frequency), int(self.requested_frequency) + ) + x.preprocessing_info.append( + f"applied FrequencyResample to change the frequency from {x.frequency} to {self.requested_frequency}" + ) + x.frequency = self.requested_frequency + else: + x = self.func(x, int(self.ecg_frequency), int(self.requested_frequency)) + + return x + + +class Padding: + """ + Apply padding. If ECG is longer than the observed_ecg_length the record is cut. + :param observed_ecg_length: length of padded signal in seconds + :param frequency: sampling frequency of a signal + :param pad_mode: padding mode + + :return: preprocessed data + """ + + def __init__( + self, + observed_ecg_length: float = 10, + frequency: int = 500, + pad_mode: str = "constant", + ): + self.observed_ecg_length = observed_ecg_length + self.frequency = frequency + self.pad_mode = pad_mode + + def apply_pad(self, x, frequency): + if self.observed_ecg_length * frequency - x.shape[1] > 0: + x = np.pad( + x, + ((0, 0), (0, int(self.observed_ecg_length * frequency - x.shape[1]))), + mode=self.pad_mode, + ) + else: + x = x[:, : int(self.observed_ecg_length * frequency)] + return x + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_pad(x.signal, x.frequency) + x.duration = self.observed_ecg_length + x.preprocessing_info.append( + f"applied Padding with a length of {self.observed_ecg_length}" + ) + else: + x = self.apply_pad(x, self.frequency) + + return x + + +class EdgeCut: + """ + Cut signal edges + :param cut_range: cutting parameters + :param frequency: sampling frequency of a signal + + :return: preprocessed data + """ + + def __init__( + self, + cut_range: list = [0, 0], + frequency: int = 500, + ): + self.cut_range = cut_range + self.frequency = frequency + self.func = F.cut_ecg + + def apply_edge_cut(self, x, frequency): + if x.shape[1] / frequency <= sum(self.cut_range): + raise ValueError( + f"cut_range must be < length of the input signal ({x.shape[1]/frequency})" + ) + + x = self.func(x, self.cut_range, frequency) + + return x + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_edge_cut(x.signal, x.frequency) + x.duration = x.signal.shape[1] / x.frequency + x.preprocessing_info.append( + f"applied EdgeCut with a cut_range of {self.cut_range}" + ) + else: + x = self.apply_edge_cut(x, self.frequency) + + return x + + +class Normalization: + """ + Apply normalization + :param norm_type: type of normalization ('z_norm', 'z_norm_constant_handle', and 'min_max') + + :return: preprocessed data + """ + + def __init__( + self, + norm_type: str = "z_norm", + ): + self.norm_type = norm_type + if norm_type == "min_max": + self.func = F.minmax_normalization + elif norm_type == "z_norm" or norm_type == "z_norm_constant_handle": + self.func = F.z_normalization + else: + raise ValueError( + "norm_type must be one of [min_max, z_norm, z_norm_constant_handle]" + ) + + def apply_normalization(self, x): + if self.norm_type is not None: + if self.norm_type == "z_norm_constant_handle": + return self.func(x, handle_constant_axis=True) + return self.func(x) + else: + return x + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_normalization(x.signal) + x.preprocessing_info.append( + f"applied Normalization with norm_type {self.norm_type}" + ) + else: + x = self.apply_normalization(x) + return x + + +class ButterworthFilter: + """ + Apply Butterworth filter augmentation + :param filter_type: type of Butterworth filter ('bandpass', 'lowpass' or 'highpass') + :param leads: leads to be filtered + :param n: filter order + :param Wn: cutoff frequency(ies) + :param fs: filtered signal frequency + + :return: preprocessed data + """ + + def __init__( + self, + filter_type: str = "bandpass", + leads: list = None, + n: int = 10, + Wn: Union[float, int, list] = [3, 30], + fs: int = 500, + ): + self.leads = leads + self.filter_type = filter_type + self.func = F.butterworth_filter + if filter_type not in ["bandpass", "lowpass", "highpass"]: + raise ValueError("Filter type must be one of [bandpass, lowpass, highpass]") + self.n = n + if filter_type == "bandpass" and not isinstance(Wn, list): + raise ValueError("Wn must be list type in case of bandpass filter") + elif (filter_type == "highpass" or filter_type == "lowpass") and not isinstance( + Wn, (int, float) + ): + raise ValueError(f"Wn must be a scalar in case of {filter_type} filter") + self.Wn = Wn + self.fs = fs + + def apply_butterworth(self, x, frequency): + if self.leads is None: + self.leads = np.arange(x.shape[0]) + + return self.func( + x, + leads=self.leads, + btype=self.filter_type, + n=self.n, + Wn=self.Wn, + fs=frequency, + ) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_butterworth(x.signal, x.frequency) + x.preprocessing_info.append( + f"applied Butterworth filter with filter_type {self.filter_type} " + f"on leads {self.leads} with parameters Wn={self.Wn} and n={self.n}" + ) + else: + x = self.apply_butterworth(x, self.fs) + + return x + + +class IIRNotchFilter: + """ + Apply IIR notch filter augmentation + :param leads: leads to be filtered + :param w0: frequency to remove from a signal + :param Q: quality factor + :param fs: sampling frequency of a signal + + :return: preprocessed data + """ + + def __init__( + self, + leads: list = None, + w0: float = 50, + Q: float = 30, + fs: int = 500, + ): + self.leads = leads + self.w0 = w0 + self.Q = Q + self.fs = fs + self.func = F.IIR_notch_filter + + def apply_iirnotch(self, x, frequency): + if self.leads is None: + self.leads = np.arange(x.shape[0]) + + return self.func(x, leads=self.leads, w0=self.w0, Q=self.Q, fs=frequency) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_iirnotch(x.signal, x.frequency) + x.preprocessing_info.append( + f"applied IIRNotch filter " + f"on leads {self.leads} with parameters w0={self.w0} and Q={self.Q}" + ) + else: + x = self.apply_iirnotch(x, self.fs) + + return x + + +class EllipticFilter: + """ + Apply elliptic filter augmentation + :param filter_type: type of elliptic filter ('bandpass', 'lowpass' or 'highpass') + :param leads: leads to be filtered + :param n: filter order + :param rp: maximum ripple allowed below unity gain in the passband + :param rs: minimum attenuation required in the stop band + :param Wn: cutoff frequency(ies) + :param fs: filtered signal frequency + + :return: preprocessed data + """ + + def __init__( + self, + filter_type: str = "bandpass", + leads: list = None, + n: int = 10, + rp: float = 4, + rs: float = 5, + Wn: Union[float, int, list] = [0.5, 50], + fs: int = 500, + ): + self.leads = leads + self.filter_type = filter_type + self.func = F.elliptic_filter + if filter_type not in ["bandpass", "lowpass", "highpass"]: + raise ValueError("Filter type must be one of [bandpass, lowpass, highpass]") + self.n = n + if filter_type == "bandpass" and not isinstance(Wn, list): + raise ValueError("Wn must be list type in case of bandpass filter") + elif (filter_type == "highpass" or filter_type == "lowpass") and not isinstance( + Wn, (int, float) + ): + raise ValueError(f"Wn must be a scalar in case of {filter_type} filter") + self.rp = rp + self.rs = rs + self.Wn = Wn + self.fs = fs + + def apply_elliptic(self, x, frequency): + if self.leads is None: + self.leads = np.arange(x.shape[0]) + + return self.func( + x, + leads=self.leads, + btype=self.filter_type, + n=self.n, + rp=self.rp, + rs=self.rs, + Wn=self.Wn, + fs=frequency, + ) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_elliptic(x.signal, x.frequency) + x.preprocessing_info.append( + f"applied Elliptic filter with filter_type {self.filter_type}" + f"on leads {self.leads} with parameters n={self.n}, rp={self.rp}, rs={self.rs}, Wn={self.Wn}" + ) + else: + x = self.apply_elliptic(x, self.fs) + + return x + + +class BaselineWanderRemoval: + """ + Remove baseline wander using wavelets + (see article https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.308.6789&rep=rep1&type=pdf) + :param leads: leads to be processed + :param wavelet: wavelet name + + :return: preprocessed data + """ + + def __init__( + self, + leads: list = None, + wavelet: str = "db4", + ): + self.leads = leads + self.wavelet = wavelet + self.func = F.DWT_BW + + def apply_bas_wander(self, x): + if self.leads is None: + self.leads = np.arange(x.shape[0]) + for lead in self.leads: + func_result = self.func(x[lead, :], wavelet=self.wavelet) + if len(func_result) >= x.shape[1]: + x[lead, :] = func_result[: x.shape[1]] + else: + x[lead, :] = np.pad(func_result, (0, x.shape[1] - len(func_result))) + return x + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_bas_wander(x.signal) + x.preprocessing_info.append( + f"applied BaselineWanderRemoval filter with wavelet {self.wavelet} on leads {self.leads}" + ) + else: + x = self.apply_bas_wander(x) + + return x + + +class WaveletTransform: + """ + Apply wavelet transform augmentation + :param wt_type: type of wavelet transform ('DWT' with soft thresholding or 'SWT') + :param leads: leads to be transformed + :param wavelet: wavelet name + :param level: decomposition level + :param threshold: thresholding value for all coefficients except the first one (only for DWT) + :param low: thresholding value for the first coefficient (only for DWT) + + :return: preprocessed data + """ + + def __init__( + self, + wt_type: str = "DWT", + leads: list = None, + wavelet: str = "db4", + level: int = 3, + threshold: float = 2, + low: float = 1e6, + ): + self.leads = leads + self.wt_type = wt_type + self.wavelet = wavelet + self.level = level + if wt_type == "DWT": + self.threshold = threshold + self.low = low + self.func = F.DWT_filter + elif wt_type == "SWT": + self.threshold = None + self.low = None + self.func = F.SWT_filter + else: + raise ValueError("wt_type must be one of [DWT, SWT]") + + def apply_wavelet_transform(self, x): + if self.leads is None: + self.leads = np.arange(x.shape[0]) + for lead in self.leads: + if self.wt_type == "DWT": + func_result = self.func( + x[lead, :], + wavelet=self.wavelet, + level=self.level, + threshold=self.threshold, + low=self.low, + ) + elif self.wt_type == "SWT": + func_result = self.func( + x[lead, :], wavelet=self.wavelet, level=self.level + ) + if len(func_result) >= x.shape[1]: + x[lead, :] = func_result[: x.shape[1]] + else: + x[lead, :] = np.pad(func_result, (0, x.shape[1] - len(func_result))) + return x + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_wavelet_transform(x.signal) + x.preprocessing_info.append( + f"applied WaveletTransform filter with wavelet {self.wavelet}, level {self.level}," + f"and wt_type {self.wt_type} on leads {self.leads}" + ) + else: + x = self.apply_wavelet_transform(x) + + return x + + +class LeadCrop: + """ + Apply lead crop augmentation + :param leads: leads to be cropped + + :return: preprocessed data + """ + + def __init__( + self, + leads: list = None, + ): + self.leads = leads + self.func = F.lead_crop + + def apply_lead_crop(self, x): + if self.leads is None: + self.leads = [0] + return self.func(x, leads=self.leads) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_lead_crop(x.signal) + x.preprocessing_info.append( + f"applied LeadCrop filter on leads {self.leads}" + ) + else: + x = self.apply_lead_crop(x) + + return x + + +class RandomLeadCrop(LeadCrop): + """ + Apply random lead crop augmentation + :param leads: leads to be potentially cropped + :param n: number of leads to be cropped (chosen randomly) + + :return: preprocessed data + """ + + def __init__( + self, + leads: list = None, + n: int = None, + ): + super().__init__(leads=leads) + self.n = n + + def apply_random_lead_crop(self, x): + if isinstance(self.leads, list): + if self.n is None: + self.n = 1 + if isinstance(self.n, int) and self.n > len(self.leads): + raise ValueError(f"n must be <= {len(self.leads)}") + leads_to_crop = list( + np.random.choice(self.leads, size=self.n, replace=False) + ) + self.leads = leads_to_crop + else: + self.leads = np.arange(x.shape[0]) + if self.n is None: + self.n = 1 + leads_to_crop = list( + np.random.choice(self.leads, size=self.n, replace=False) + ) + self.leads = leads_to_crop + + return self.func(x, leads=self.leads) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_random_lead_crop(x.signal) + x.preprocessing_info.append( + f"applied RandomLeadCrop filter for {self.n} leads from the leads {self.leads}" + ) + else: + x = self.apply_random_lead_crop(x) + + return x + + +class TimeCrop: + """ + Apply time crop augmentation + :param time: length of time segment to be cropped (the same units as signal) + :param leads: leads to be cropped + + :return: preprocessed data + """ + + def __init__( + self, + time: int = 100, + leads: list = None, + ): + self.time = time + self.leads = leads + self.func = F.time_crop + + def apply_time_crop(self, x): + if self.leads is None: + self.leads = np.arange(x.shape[0]) + return self.func(x, time=self.time, leads=self.leads) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_time_crop(x.signal) + x.preprocessing_info.append( + f"applied TimeCrop filter with a size of {self.time} from the leads {self.leads}" + ) + else: + x = self.apply_time_crop(x) + + return x + + +class RandomTimeCrop(TimeCrop): + """ + Apply random time crop augmentation + :param time: length of time segment to be cropped (the same units as signal) + :param leads: leads to be potentially cropped + :param n: number of leads to be cropped (chosen randomly) + + :return: preprocessed data + """ + + def __init__( + self, + time: int = 100, + leads: list = None, + n: int = None, + ): + super().__init__(time=time, leads=leads) + self.n = n + + def apply_random_time_crop(self, x): + if isinstance(self.leads, list): + if self.n is None: + self.n = len(self.leads) + if isinstance(self.n, int) and self.n > len(self.leads): + raise ValueError(f"n must be <= {len(self.leads)}") + leads_to_crop = list( + np.random.choice(self.leads, size=self.n, replace=False) + ) + self.leads = leads_to_crop + else: + self.leads = np.arange(x.shape[0]) + if self.n is None: + self.n = len(self.leads) + else: + leads_to_crop = list( + np.random.choice(self.leads, size=self.n, replace=False) + ) + self.leads = leads_to_crop + + return self.func(x, time=self.time, leads=self.leads) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_random_time_crop(x.signal) + x.preprocessing_info.append( + f"applied RandomTimeCrop filter with a size of {self.time} to {self.n} leads " + f"from the leads {self.leads}" + ) + else: + x = self.apply_random_time_crop(x) + + return x + + +class SumAug: + """ + Apply sum augmentation to selected leads + :param leads: leads to be replaced by sum of all leads + + :return: preprocessed data + """ + + def __init__( + self, + leads: list = None, + ): + self.leads = leads + self.func = F.sum_augmentation + + def apply_sum_aug(self, x): + if self.leads is None: + self.leads = np.arange(x.shape[0]) + return self.func(x, leads=self.leads) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_sum_aug(x.signal) + x.preprocessing_info.append( + f"applied SumAug filter on the leads {self.leads}" + ) + else: + x = self.apply_sum_aug(x) + + return x + + +class RandomSumAug(SumAug): + """ + Apply random sum augmentation + :param leads: leads to be potentially modified + :param n: number of leads to be replaced by sum of all leads (chosen randomly) + + :return: preprocessed data + """ + + def __init__( + self, + leads: list = None, + n: int = None, + ): + super().__init__(leads=leads) + self.n = n + + def apply_random_sum_aug(self, x): + if isinstance(self.leads, list): + if self.n is None: + self.n = len(self.leads) + if isinstance(self.n, int) and self.n > len(self.leads): + raise ValueError(f"n must be <= {len(self.leads)}") + leads_to_sum = list( + np.random.choice(self.leads, size=self.n, replace=False) + ) + self.leads = leads_to_sum + else: + self.leads = np.arange(x.shape[0]) + if self.n is None: + self.n = len(self.leads) + else: + leads_to_sum = list( + np.random.choice(self.leads, size=self.n, replace=False) + ) + self.leads = leads_to_sum + + return self.func(x, leads=self.leads) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_random_sum_aug(x.signal) + x.preprocessing_info.append( + f"applied RandomSumAug filter to {self.n} leads from the leads {self.leads}" + ) + else: + x = self.apply_random_sum_aug(x) + + return x + + +class ConvexAug: + """ + Apply convex augmentation + :param leads: leads to be replaced by convex combination of some leads (chosen randomly) + + :return: preprocessed data + """ + + def __init__( + self, + leads: list = None, + ): + self.leads = leads + self.func = F.convex_augmentation + + def apply_convex_aug(self, x): + if self.leads is None: + self.leads = np.arange(x.shape[0]) + return self.func(x, leads=self.leads) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_convex_aug(x.signal) + x.preprocessing_info.append( + f"applied ConvexAug filter on the leads {self.leads}" + ) + else: + x = self.apply_convex_aug(x) + + return x + + +class RandomConvexAug(ConvexAug): + """ + Apply random convex augmentation + :param leads: leads to be returned + :param n: number of leads (chosen randomly) to be replaced by convex combination of some leads (chosen randomly) + + :return: preprocessed data + """ + + def __init__( + self, + leads: list = None, + n: int = None, + ): + super().__init__(leads=leads) + self.n = n + + def apply_random_convex_aug(self, x): + if isinstance(self.leads, list): + if self.n is None: + self.n = len(self.leads) + if isinstance(self.n, int) and self.n > len(self.leads): + raise ValueError(f"n must be <= {len(self.leads)}") + leads_to_convex = list( + np.random.choice(self.leads, size=self.n, replace=False) + ) + self.leads = leads_to_convex + else: + self.leads = np.arange(x.shape[0]) + if self.n is None: + self.n = len(self.leads) + else: + leads_to_convex = list( + np.random.choice(self.leads, size=self.n, replace=False) + ) + self.leads = leads_to_convex + + return self.func(x, leads=self.leads) + + def __call__(self, x): + if isinstance(x, ecglib.data.ecg_record.EcgRecord): + x.signal = self.apply_random_convex_aug(x.signal) + x.preprocessing_info.append( + f"applied RandomConvexAug filter to {self.n} leads from the leads {self.leads}" + ) + else: + x = self.apply_random_convex_aug(x) + + return x