-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 0715b36
Showing
39 changed files
with
319,659 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
*.mat | ||
|
||
# force add | ||
!.gitkeep | ||
|
||
# data file | ||
*.pickle | ||
# *.json | ||
*.jsonl | ||
*log.txt | ||
*.vect | ||
*.log | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
dataset/mimic3/ | ||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
.venv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
*.DS_Store | ||
.DS_Store | ||
.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
Copyright (c) 2016, mp2893 | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
* Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
* Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
* Neither the name of RETAIN nor the names of its | ||
contributors may be used to endorse or promote products derived from | ||
this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
RETAIN | ||
========================================= | ||
|
||
RETAIN is an interpretable predictive model for healthcare applications. Given patient records, it can make predictions while explaining how each medical code (diagnosis codes, medication codes, or procedure codes) at each visit contributes to the prediction. The interpretation is possible due to the use of neural attention mechanism. | ||
|
||
[](https://youtu.be/co3lTOSgFlA?t=1m46s "RETAIN Interpretation Demo - Click to Watch!") | ||
Using RETAIN, you can calculate how positively/negatively each medical code (diagnosis, medication, or procedure code) at different visits contributes to the final score. In this case, we are predicting whether the given patient will be diagnosed with Heart Failure (HF). You can see that the codes that are highly related to HF makes positive contributions. RETAIN also learns to pay more attention to new information than old information. You can see that Cardiac Dysrythmia (CD) makes a bigger contribution as it occurs in the more recent visit. | ||
|
||
#### Relevant Publications | ||
|
||
RETAIN implements an algorithm introduced in the following [paper](http://papers.nips.cc/paper/6321-retain-an-interpretable-predictive-model-for-healthcare-using-reverse-time-attention-mechanism): | ||
|
||
RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism | ||
Edward Choi, Mohammad Taha Bahadori, Joshua A. Kulas, Andy Schuetz, Walter F. Stewart, Jimeng Sun, | ||
NIPS 2016, pp.3504-3512 | ||
|
||
#### Notice | ||
|
||
The RETAIN paper formulates the model as being able to make prediction at each timestep (e.g. try to predict what diagnoses the patient will receive at each visit), and treats sequence classification (e.g. Given a patient record, will he be diagnosed with heart failure in the future?) as a special case, since sequence classification makes the prediction at the last timestep only. | ||
|
||
This code, however, is implemented to perform the sequence classification task. For example, you can use this code to predict whether the given patient is a heart failure patient or not. Or you can predict whether this patient will be readmitted in the future. The more general version of RETAIN will be released in the future. | ||
|
||
#### Running RETAIN | ||
|
||
**STEP 1: Installation** | ||
|
||
1. Install [python](https://www.python.org/), [Theano](http://deeplearning.net/software/theano/index.html). We use Python 2.7, Theano 0.8. Theano can be easily installed in Ubuntu as suggested [here](http://deeplearning.net/software/theano/install_ubuntu.html#install-ubuntu) | ||
|
||
2. If you plan to use GPU computation, install [CUDA](https://developer.nvidia.com/cuda-downloads) | ||
|
||
3. Download/clone the RETAIN code | ||
|
||
**STEP 2: Fast way to test RETAIN with MIMIC-III** | ||
This step describes how to train RETAIN, with minimum number of steps using MIMIC-III, to predict patients' mortality using their visit records. | ||
|
||
0. You will first need to request access for [MIMIC-III](https://mimic.physionet.org/gettingstarted/access/), a publicly avaiable electronic health records collected from ICU patients over 11 years. | ||
|
||
1. You can use "process_mimic.py" to process MIMIC-III dataset and generate a suitable training dataset for RETAIN. | ||
Place the script to the same location where the MIMIC-III CSV files are located, and run the script. | ||
The execution command is `python process_mimic.py ADMISSIONS.csv DIAGNOSES_ICD.csv PATIENTS.csv <output file>`. | ||
|
||
2. Run RETAIN using the ".seqs" and ".morts" file generated by process_mimic.py. | ||
The ".seqs" file contains the sequence of visits for each patient. Each visit consists of multiple diagnosis codes. | ||
However we recommend using ".3digitICD9.seqs" file instead, as the results will be much more interpretable. | ||
(Or you could use [Single-level Clical Classification Software for ICD9](https://www.hcup-us.ahrq.gov/toolssoftware/ccs/ccs.jsp#examples) to decrease the number of codes to a couple of hundreds, which will even more improve the performance) | ||
The ".morts" file contains the sequence of mortality labels for each patient. | ||
The command is `python retain.py <3digitICD9.seqs file> 942 <morts file> <output path> --simple_load --n_epochs 100 --keep_prob_context 0.8 --keep_prob_emb 0.5`. | ||
`942` is the number of the entire 3-digit ICD9 codes used in the dataset. | ||
|
||
3. To test the model for interpretation, please refer to Step 6. I personally found that _perinatal jaundice (ICD9 774)_ has high correlation with mortality. | ||
|
||
4. The model reaches AUC above 0.8 with the above command, but the interpretations are not super clear. | ||
You could tune the hyper-parameters, but I doubt things will dramatically improve. | ||
After all, only 7,500 patients made more than a single hospital visit, and most of them have only two visits. | ||
|
||
**STEP 3: How to prepare your own dataset** | ||
|
||
1. RETAIN's training dataset needs to be a Python cPickled list of list of list. The outermost list corresponds to patients, the intermediate to the visit sequence each patient made, and the innermost to the medical codes (e.g. diagnosis codes, medication codes, procedure codes, etc.) that occurred within each visit. | ||
First, medical codes need to be converted to an integer. Then a single visit can be seen as a list of integers. Then a patient can be seen as a list of visits. | ||
For example, [5,8,15] means the patient was assigned with code 5, 8, and 15 at a certain visit. | ||
If a patient made two visits [1,2,3] and [4,5,6,7], it can be converted to a list of list [[1,2,3], [4,5,6,7]]. | ||
Multiple patients can be represented as [[[1,2,3], [4,5,6,7]], [[2,4], [8,3,1], [3]]], which means there are two patients where the first patient made two visits and the second patient made three visits. | ||
This list of list of list needs to be pickled using cPickle. We will refer to this file as the "visit file". | ||
|
||
2. The total number of unique medical codes is required to run RETAIN. | ||
For example, if the dataset is using 14,000 diagnosis codes and 11,000 procedure codes, the total number is 25,000. | ||
|
||
3. The label dataset (let us call this "label file") needs to be a Python cPickled list. Each element corresponds to the true label of each patient. For example, 1 can be the case patient and 0 can be the control patient. If there are two patients where only the first patient is a case, then we should have [1,0]. | ||
|
||
4. The "visit file" and "label file" need to have 3 sets respectively: training set, validation set, and test set. | ||
The file extension must be ".train", ".valid", and ".test" respectivley. | ||
For example, if you want to use a file named "my_visit_sequences" as the "visit file", then RETAIN will try to load "my_visit_sequences.train", "my_visit_sequences.valid", and "my_visit_sequences.test". | ||
This is also true for the "label file" | ||
|
||
5. You can use the time information regarding the visits as an additional source of information. Let us call this "time file". | ||
Note that the time information could be anything: duration between consecutive visits, cumulative number of days since the first visit, etc. | ||
"time file" needs to be prepared as a Python cPickled list of list. The outermost list corresponds to patients, and the innermost to the time information of each visit. | ||
For example, given a "visit file" [[[1,2,3], [4,5,6,7]], [[2,4], [8,3,1], [3]]], its corresponding "time file" could look like [[0, 15], [0, 45, 23]], if we are using the duration between the consecutive visits. (of course the numbers are fake, and I've set the duration for the first visit to zero.) | ||
Use `--time_file <path to time file>` option to use "time file" | ||
Remember that the ".train", ".valid", ".test" rule also applies to the "time file" as well. | ||
|
||
**Additional: Using your own medical code representations** | ||
RETAIN internally learns the vector representation of medical codes while training. These vectors are initialized with random values of course. | ||
You can, however, also use your own medical code representations, if you have one. (They can be trained by using Skip-gram like algorithms. Refer to [Med2Vec](http://www.kdd.org/kdd2016/subtopic/view/multi-layer-representation-learning-for-medical-concepts) or [this](http://arxiv.org/abs/1602.03686) for further details.) | ||
If you want to provide the medical code representations, it has to be a list of list (basically a matrix) of N rows and M columns where N is the number of unique codes in your "visit file" and M is the size of the code representations. | ||
Specify the path to your code representation file using `--embed_file <path to embedding file>`. | ||
Additionally, even if you use your own medical code representations, you can re-train (a.k.a fine-tune) them as you train RETAIN. | ||
Use `--embed_finetune` option to do this. If you are not providing your own medical code representations, RETAIN will use randomly initialized one, which obviously requires this fine-tuning process. Since the default is to use the fine-tuning, you do not need to worry about this. | ||
|
||
**STEP 4: Running RETAIN** | ||
|
||
1. The minimum input you need to run RETAIN is the "visit file", the number of unique medical codes in the "visit file", | ||
the "label file", and the output path. The output path is where the learned weights and the log will be saved. | ||
`python retain.py <visit file> <# codes in the visit file> <label file> <output path>` | ||
|
||
2. Specifying `--verbose` option will print training process after each 10 mini-batches. | ||
|
||
3. You can specify the size of the embedding W_emb, the size of the hidden layer of the GRU that generates alpha, and the size of the hidden layer of the GRU that generates beta. | ||
The respective commands are `--embed_size <integer>`, `--alpha_hidden_dim_size <integer>`, and `--beta_hidden_dim_size <integer>`. | ||
For example `--alpha_hidden_dim_size 128` will tell RETAIN to use a GRU with 128-dimensional hidden layer for generating alpha. | ||
|
||
4. Dropouts are applied to two places: 1) to the input embedding, 2) to the context vector c_i. The respective dropout rates can be adjusted using `--keep_prob_embed {0.0, 1.0}` and `--keep_prob_context {0.0, 1.0}`. Dropout values affect the performance so it is recommended to tune them for your data. | ||
|
||
5. L2 regularizations can be applied to W_emb, w_alpha, W_beta, and w_output. | ||
|
||
6. Additional options can be specified such as the size of the batch size, the number of epochs, etc. Detailed information can be accessed by `python retain.py --help` | ||
|
||
7. My personal recommendation: use mild regularization (0.0001 ~ 0.001) on all four weights, and use moderate dropout on the context vector only. But this entirely depends on your data, so you should always tune the hyperparameters for yourself. | ||
|
||
**STEP 5: Getting your results** | ||
|
||
RETAIN checks the AUC of the validation set after each epoch, and if it is higher than all previous values, it will save the current model. The model file is generated by [numpy.savez_compressed](http://docs.scipy.org/doc/numpy-1.10.1/reference/generated/numpy.savez_compressed.html). | ||
|
||
**Step 6: Testing your model** | ||
|
||
1. Using the file "test_retain.py", you can calculate the contributions of each medical code at each visit. First you need to have a trained model that was saved by numpy.savez_compressed. Note that you need to know the configuration with which you trained RETAIN (e.g. use of `--time_file`, use of `--use_log_time`.) | ||
|
||
2. Again, you need the "visit file" and "label file" prepared in the same way. This time, however, you do not need to follow the ".train", ".valid", ".test" rule. The testing script will try to load the file name as given. | ||
|
||
3. You also need the mapping information between the actual string medical codes and their integer codes. | ||
(e.g. "Hypertension" is mapped to 24) | ||
This file (let's call this "mapping file") need to be a Python cPickled dictionary where the keys are the string medical codes and the values are the corresponding intergers. | ||
(e.g. The mapping file generated by process_mimic.py is the ".types" file) | ||
This file is required to print the contributions of each medical code in a user-friendly format. | ||
|
||
4. For the additional options such as `--time_file` or `--use_log_time`, you should use exactly the same configuration with which you trained the model. For more detailed information, use "--help" option. | ||
|
||
5. The minimum input to run the testing script is the "model file", "visit file", "label file", "mapping file", and "output file". "output file" is where the contributions will be stored. | ||
`python test_retain.py <model file> <visit file> <label file> <mapping file> <output file>` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.