2 Star 7 Fork 1

DeepPSP / torch_ecg

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
MIT

torch_ecg

pytest codeql formatting codecov PyPI DOI zenodo downloads license

ECG Deep Learning Framework Implemented using PyTorch.

Documentation (under development):

The system design is depicted as follows

Installation

torch_ecg requires Python 3.6+ and is available through pip:

python -m pip install torch-ecg

One can download the development version hosted at GitHub via

git clone https://github.com/DeepPSP/torch_ecg.git
cd torch_ecg
python -m pip install .

or use pip directly via

python -m pip install git+https://github.com/DeepPSP/torch_ecg.git

Main Modules

Augmenters

Click to expand!

Augmenters are classes (subclasses of torch Module) that perform data augmentation in a uniform way and are managed by the AugmenterManager (also a subclass of torch Module). Augmenters and the manager share a common signature of the formward method:

forward(self, sig:Tensor, label:Optional[Tensor]=None, *extra_tensors:Sequence[Tensor], **kwargs:Any) -> Tuple[Tensor, ...]:

The following augmenters are implemented:

  1. baseline wander (adding sinusoidal and gaussian noises)
  2. cutmix
  3. mixup
  4. random flip
  5. random masking
  6. random renormalize
  7. stretch-or-compress (scaling)
  8. label smooth (not actually for data augmentation, but has simimlar behavior)

Usage example (this example uses all augmenters except cutmix, each with default config):

import torch
from torch_ecg.cfg import CFG
from torch_ecg.augmenters import AugmenterManager

config = CFG(
    random=False,
    fs=500,
    baseline_wander={},
    label_smooth={},
    mixup={},
    random_flip={},
    random_masking={},
    random_renormalize={},
    stretch_compress={},
)
am = AugmenterManager.from_config(config)
sig, label, mask = torch.rand(2,12,5000), torch.rand(2,26), torch.rand(2,5000,1)
sig, label, mask = am(sig, label, mask)

Augmenters can be stochastic along the batch dimension and (or) the channel dimension (ref. the get_indices method of the Augmenter base class).

:point_right: Back to TOC

Preprocessors

Click to expand!

Also preprecessors acting on numpy arrays. Similarly, preprocessors are monitored by a manager

import torch
from torch_ecg.cfg import CFG
from torch_ecg._preprocessors import PreprocManager

config = CFG(
    random=False,
    resample={"fs": 500},
    bandpass={},
    normalize={},
)
ppm = PreprocManager.from_config(config)
sig = torch.rand(12,80000).numpy()
sig, fs = ppm(sig, 200)

The following preprocessors are implemented

  1. baseline removal (detrend)
  2. normalize (z-score, min-max, naïve)
  3. bandpass
  4. resample

For more examples, see the README file) of the preprecessors module.

:point_right: Back to TOC

Databases

Click to expand!

This module include classes that manipulate the io of the ECG signals and labels in an ECG database, and maintains metadata (statistics, paths, plots, list of records, etc.) of it. This module is migrated and improved from DeepPSP/database_reader

After migration, all should be tested again, the progression:

Database Source Tested
AFDB PhysioNet :heavy_check_mark:
ApneaECG PhysioNet :x:
CinC2017 PhysioNet :x:
CinC2018 PhysioNet :x:
CinC2020 PhysioNet :heavy_check_mark:
CinC2021 PhysioNet :heavy_check_mark:
LTAFDB PhysioNet :x:
LUDB PhysioNet :heavy_check_mark:
MITDB PhysioNet :heavy_check_mark:
SHHS NSRR :x:
CPSC2018 CPSC :heavy_check_mark:
CPSC2019 CPSC :heavy_check_mark:
CPSC2020 CPSC :heavy_check_mark:
CPSC2021 CPSC :heavy_check_mark:
SPH Figshare :heavy_check_mark:

NOTE that these classes should not be confused with a torch Dataset, which is strongly related to the task (or the model). However, one can build Datasets based on these classes, for example the Dataset for the The 4th China Physiological Signal Challenge 2021 (CPSC2021).

One can use the built-in Datasets in torch_ecg.databases.datasets as follows

from torch_ecg.databases.datasets.cinc2021 import CINC2021Dataset, CINC2021TrainCfg
config = deepcopy(CINC2021TrainCfg)
config.db_dir = "some/path/to/db"
dataset = CINC2021Dataset(config, training=True, lazy=False)

:point_right: Back to TOC

Implemented Neural Network Architectures

Click to expand!
  1. CRNN, both for classification and sequence tagging (segmentation)
  2. U-Net
  3. RR-LSTM

A typical signature of the instantiation (__init__) function of a model is as follows

__init__(self, classes:Sequence[str], n_leads:int, config:Optional[CFG]=None, **kwargs:Any) -> None

if a config is not specified, then the default config will be used (stored in the model_configs module).

Quick Example

A quick example is as follows:

import torch
from torch_ecg.utils.utils_nn import adjust_cnn_filter_lengths
from torch_ecg.model_configs import ECG_CRNN_CONFIG
from torch_ecg.models.ecg_crnn import ECG_CRNN

config = adjust_cnn_filter_lengths(ECG_CRNN_CONFIG, fs=400)
# change the default CNN backbone
# bottleneck with global context attention variant of Nature Communications ResNet
config.cnn.name="resnet_nature_comm_bottle_neck_gc"

classes = ["NSR", "AF", "PVC", "SPB"]
n_leads = 12
model = ECG_CRNN(classes, n_leads, config)

model(torch.rand(2, 12, 4000))  # signal length 4000, batch size 2

Then a model for the classification of 4 classes, namely "NSR", "AF", "PVC", "SPB", on 12-lead ECGs is created. One can check the size of a model, in terms of the number of parameters via

model.module_size

or in terms of memory consumption via

model.module_size_

Custom Model

One can adjust the configs to create a custom model. For example, the building blocks of the 4 stages of a TResNet backbone are basic, basic, bottleneck, bottleneck. If one wants to change the second block to be a bottleneck block with sequeeze and excitation (SE) attention, then

from copy import deepcopy

from torch_ecg.models.ecg_crnn import ECG_CRNN
from torch_ecg.model_configs import (
    ECG_CRNN_CONFIG,
    tresnetF, resnet_bottle_neck_se,
)

my_resnet = deepcopy(tresnetP)
my_resnet.building_block[1] = "bottleneck"
my_resnet.block[1] = resnet_bottle_neck_se

The convolutions in a TResNet are anti-aliasing convolutions, if one wants further to change the convolutions to normal convolutions, then

for b in my_resnet.block:
    b.conv_type = None

or change them to separable convolutions via

for b in my_resnet.block:
    b.conv_type = "separable"

Finally, replace the default CNN backbone via

my_model_config = deepcopy(ECG_CRNN_CONFIG)
my_model_config.cnn.name = "my_resnet"
my_model_config.cnn.my_resnet = my_resnet

model = ECG_CRNN(["NSR", "AF", "PVC", "SPB"], 12, my_model_config)

:point_right: Back to TOC

CNN Backbones

Click to expand!

Implemented

  1. VGG
  2. ResNet (including vanilla ResNet, ResNet-B, ResNet-C, ResNet-D, ResNeXT, TResNet, Stanford ResNet, Nature Communications ResNet, etc.)
  3. MultiScopicNet (CPSC2019 SOTA)
  4. DenseNet (CPSC2020 SOTA)
  5. Xception

In general, variants of ResNet are the most commonly used architectures, as can be inferred from CinC2020 and CinC2021.

Ongoing

  1. MobileNet
  2. DarkNet
  3. EfficientNet

TODO

  1. HarDNet
  2. HO-ResNet
  3. U-Net++
  4. U-Squared Net
  5. etc.

More details and a list of references can be found in the README file of this module.

:point_right: Back to TOC

Components

Click to expand!

This module consists of frequently used components such as loggers, trainers, etc.

Loggers

Loggers including

  1. CSV logger
  2. text logger
  3. tensorboard logger are implemented and manipulated uniformly by a manager.

Outputs

The Output classes implemented in this module serve as containers for ECG downstream task model outputs, including

  • ClassificationOutput
  • MultiLabelClassificationOutput
  • SequenceTaggingOutput
  • WaveDelineationOutput
  • RPeaksDetectionOutput

each having some required fields (keys), and is able to hold an arbitrary number of custom fields. These classes are useful for the computation of metrics.

Metrics

This module has the following pre-defined (built-in) Metrics classes:

  • ClassificationMetrics
  • RPeaksDetectionMetrics
  • WaveDelineationMetrics

These metrics are computed according to either Wikipedia, or some published literatures.

Trainer

An abstract base class BaseTrainer is implemented, in which some common steps in building a training pipeline (workflow) are impemented. A few task specific methods are assigned as abstractmethods, for example the method

evaluate(self, data_loader:DataLoader) -> Dict[str, float]

for evaluation on the validation set during training and perhaps further for model selection and early stopping.

:point_right: Back to TOC

:point_right: Back to TOC

Other Useful Tools

Click to expand!

R peaks detection algorithms

This is a collection of traditional (non deep learning) algorithms for R peaks detection collected from WFDB and BioSPPy.

:point_right: Back to TOC

Usage Examples

Click to expand!

See case studies in the benchmarks folder.

a large part of the case studies are migrated from other DeepPSP repositories, some are implemented in the old fasion, being inconsistent with the new system architecture of torch_ecg, hence need updating and testing

Benchmark Architecture Source Finished Updated Tested
CinC2020 CRNN DeepPSP/cinc2020 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CinC2021 CRNN DeepPSP/cinc2021 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CinC2022[^1] Multi Task Learning (MTL) DeepPSP/cinc2022 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CPSC2019 SequenceTagging/U-Net NA :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CPSC2020 CRNN/SequenceTagging DeepPSP/cpsc2020 :heavy_check_mark: :x: :x:
CPSC2021 CRNN/SequenceTagging/LSTM DeepPSP/cpsc2021 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
LUDB U-Net NA :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:

[^1]: Although CinC2022 dealt with acoustic cardiac signals (phonocardiogram, PCG), the tasks and signals can be treated similarly.

Taking CPSC2021 for example, the steps are

  1. Write a Dataset to fit the training data for the model(s) and the training workflow. Or directly use the built-in Datasets in torch_ecg.databases.datasets. In this example, 3 tasks are considered, 2 of which use a MaskedBCEWithLogitsLoss function, hence the Dataset produces an extra tensor for these 2 tasks

    def __getitem__(self, index:int) -> Tuple[np.ndarray, ...]:
        if self.lazy:
            if self.task in ["qrs_detection"]:
                return self.fdr[index][:2]
            else:
                return self.fdr[index]
        else:
            if self.task in ["qrs_detection"]:
                return self._all_data[index], self._all_labels[index]
            else:
                return self._all_data[index], self._all_labels[index], self._all_masks[index]
  2. Inherit a base model to create task specific models, along with tailored model configs

  3. Inherit the BaseTrainer to build the training pipeline, with the abstractmethods (_setup_dataloaders, run_one_step, evaluate, batch_dim, etc.) implemented.

:point_right: Back to TOC

CAUTION

For the most of the time, but not always, after updates, I will run the notebooks in the benchmarks manually. If someone finds some bug, please raise an issue. The test workflow is to be enhanced and automated, see this project.

:point_right: Back to TOC

Work in progress

See the projects page.

:point_right: Back to TOC

Citation

@misc{torch_ecg,
      title = {{torch\_ecg: An ECG Deep Learning Framework Implemented using PyTorch}},
     author = {WEN, Hao and KANG, Jingsu},
        doi = {10.5281/ZENODO.6435048},
        url = {https://zenodo.org/record/6435048},
  publisher = {Zenodo},
       year = {2022},
  copyright = {{MIT License}}
}
@article{torch_ecg_paper,
      title = {{A Novel Deep Learning Package for Electrocardiography Research}},
     author = {Hao Wen and Jingsu Kang},
    journal = {{Physiological Measurement}},
        doi = {10.1088/1361-6579/ac9451},
       year = {2022},
      month = {11},
  publisher = {{IOP Publishing}},
     volume = {43},
     number = {11},
      pages = {115006}
}

:point_right: Back to TOC

Thanks

Much is learned, especially the modular design, from the adversarial NLP library TextAttack and from Hugging Face transformers.

:point_right: Back to TOC

MIT License Copyright (c) 2021 WEN Hao and KANG Jingsu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

简介

deep learning ecg models implemented using pytorch 展开 收起
MIT
取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
Python
1
https://gitee.com/deep-psp/torch_ecg.git
git@gitee.com:deep-psp/torch_ecg.git
deep-psp
torch_ecg
torch_ecg
master

搜索帮助