-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'sdavis_trn' into 'master'
transformer See merge request machine-learning/bonito!166
- Loading branch information
Showing
7 changed files
with
264 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
# Bonito | ||
|
||
[![PyPI version](https://badge.fury.io/py/ont-bonito.svg)](https://badge.fury.io/py/ont-bonito) | ||
[![PyPI version](https://badge.fury.io/py/ont-bonito.svg)](https://badge.fury.io/py/ont-bonito) | ||
[![py38](https://img.shields.io/badge/python-3.8-brightgreen.svg)](https://img.shields.io/badge/python-3.8-brightgreen.svg) | ||
[![py39](https://img.shields.io/badge/python-3.9-brightgreen.svg)](https://img.shields.io/badge/python-3.9-brightgreen.svg) | ||
[![py310](https://img.shields.io/badge/python-3.10-brightgreen.svg)](https://img.shields.io/badge/python-3.10-brightgreen.svg) | ||
[![py311](https://img.shields.io/badge/python-3.11-brightgreen.svg)](https://img.shields.io/badge/python-3.11-brightgreen.svg) | ||
[![cu117](https://img.shields.io/badge/cuda-11.7-blue.svg)](https://img.shields.io/badge/cuda-11.7-blue.svg) | ||
[![cu118](https://img.shields.io/badge/cuda-11.8-blue.svg)](https://img.shields.io/badge/cuda-11.8-blue.svg) | ||
|
||
Bonito is an open source research basecaller for Oxford Nanopore reads. | ||
|
||
|
@@ -30,6 +30,15 @@ $ bonito download --models --show # show all available models | |
$ bonito download --models # download all available models | ||
``` | ||
|
||
## Transformer Models | ||
|
||
The `bonito.transformer` package requires | ||
[flash-attn](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). | ||
|
||
This must be manually installed as the `flash-attn` packaging system prevents it from being listed as a normal dependency. | ||
|
||
Setting `CUDA_HOME` to the relevant library directory will help avoid CUDA version mismatches between packages. | ||
|
||
## Modified Bases | ||
|
||
Modified base calling is handled by [Remora](https://github.com/nanoporetech/remora). | ||
|
@@ -49,7 +58,7 @@ $ bonito basecaller dna_r9.4.1 --save-ctc --reference reference.mmi /data/reads | |
$ bonito train --directory /data/training/ctc-data /data/training/model-dir | ||
``` | ||
|
||
In addition to training a new model from scratch you can also easily fine tune one of the pretrained models. | ||
In addition to training a new model from scratch you can also easily fine tune one of the pretrained models. | ||
|
||
```bash | ||
bonito train --epochs 1 --lr 5e-4 --pretrained [email protected] --directory /data/training/ctc-data /data/training/fine-tuned-model | ||
|
@@ -62,7 +71,7 @@ $ bonito download --training | |
$ bonito train /data/training/model-dir | ||
``` | ||
|
||
All training calls use Automatic Mixed Precision to speed up training. To disable this, set the `--no-amp` flag to True. | ||
All training calls use Automatic Mixed Precision to speed up training. To disable this, set the `--no-amp` flag to True. | ||
|
||
## Developer Quickstart | ||
|
||
|
@@ -72,9 +81,13 @@ $ cd bonito | |
$ python3 -m venv venv3 | ||
$ source venv3/bin/activate | ||
(venv3) $ pip install --upgrade pip | ||
(venv3) $ pip install -e . | ||
(venv3) $ pip install -e .[cu118] --extra-index-url https://download.pytorch.org/whl/cu118 | ||
``` | ||
|
||
The `ont-bonito[cu118]` and `ont-bonito[cu121]` optional dependencies can be used, along | ||
with the corresponding `--extra-index-url`, to ensure the PyTorch package matches the | ||
local CUDA setup. | ||
|
||
## Interface | ||
|
||
- `bonito view` - view a model architecture for a given `.toml` file and the number of parameters in the network. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .model import Model | ||
from .basecall import basecall |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from bonito.crf import basecall |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import logging | ||
import types | ||
from functools import lru_cache | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
try: | ||
from flash_attn import flash_attn_qkvpacked_func | ||
from flash_attn.layers.rotary import RotaryEmbedding | ||
from flash_attn.modules.mlp import GatedMlp | ||
from flash_attn.ops.triton.layer_norm import RMSNorm | ||
except ImportError: | ||
logger.warning( | ||
"please install flash-attn to use the transformer module: " | ||
"`pip install flash-attn --no-build-isolation`" | ||
) | ||
|
||
from bonito.crf.model import SeqdistModel | ||
from bonito.nn import from_dict, register, LinearCRFEncoder, MakeContiguous, Module, Permute, Serial | ||
|
||
|
||
def deepnorm_params(depth): | ||
""" | ||
Returns the DeepNorm (https://arxiv.org/abs/2203.00555) alpha and beta parameters. | ||
""" | ||
alpha = round((2*depth)**0.25, 7) | ||
beta = round((8*depth)**(-1/4), 7) | ||
return alpha, beta | ||
|
||
|
||
@lru_cache(maxsize=2) | ||
def sliding_window_mask(seq_len, window, device): | ||
band = torch.full((seq_len, seq_len), fill_value=1.0) | ||
band = torch.triu(band, diagonal=-window[0]) | ||
band = band * torch.tril(band, diagonal=window[1]) | ||
band = band.to(torch.bool).to(device) | ||
return band | ||
|
||
|
||
class MultiHeadAttention(Module): | ||
def __init__(self, d_model, nhead, qkv_bias=False, out_bias=True, rotary_dim=None, attn_window=None): | ||
super().__init__() | ||
assert d_model % nhead == 0, "d_model must be divisible by nhead" | ||
|
||
self.d_model = d_model | ||
self.nhead = nhead | ||
self.head_dim = d_model // nhead | ||
self.rotary_dim = self.head_dim if rotary_dim is None else rotary_dim | ||
|
||
self.Wqkv = torch.nn.Linear(d_model, 3 * d_model, bias=qkv_bias) | ||
self.out_proj = torch.nn.Linear(d_model, d_model, bias=out_bias) | ||
|
||
self.rotary_emb = RotaryEmbedding(self.rotary_dim, interleaved=False) | ||
self.attn_window = (-1, -1) if attn_window is None else tuple(attn_window) | ||
|
||
def attn_func(self, qkv): | ||
if torch.cuda.get_device_capability(qkv.device)[0] >= 8 and (torch.is_autocast_enabled() or qkv.dtype == torch.half): | ||
attn_output = flash_attn_qkvpacked_func(qkv, window_size=self.attn_window) | ||
else: | ||
q, k, v = torch.chunk(qkv.permute(0, 2, 3, 1, 4), chunks=3, dim=1) | ||
mask = sliding_window_mask(qkv.shape[1], self.attn_window, q.device) | ||
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) | ||
attn_output = attn_output.permute(0, 1, 3, 2, 4) | ||
return attn_output | ||
|
||
def forward(self, x): | ||
N, T, _ = x.shape | ||
|
||
qkv = self.Wqkv(x).view(N, T, 3, self.nhead, self.head_dim) | ||
|
||
qkv = self.rotary_emb(qkv) | ||
|
||
attn_output = self.attn_func(qkv).reshape(N, T, self.d_model) | ||
|
||
out = self.out_proj(attn_output) | ||
|
||
return out | ||
|
||
|
||
@register | ||
class TransformerEncoderLayer(Module): | ||
def __init__(self, d_model, nhead, dim_feedforward, deepnorm_alpha, deepnorm_beta, attn_window=None): | ||
super().__init__() | ||
self.kwargs = { | ||
"d_model": d_model, | ||
"nhead": nhead, | ||
"dim_feedforward": dim_feedforward, | ||
"deepnorm_alpha": deepnorm_alpha, | ||
"deepnorm_beta": deepnorm_beta, | ||
"attn_window": attn_window | ||
} | ||
|
||
self.self_attn = MultiHeadAttention( | ||
d_model=d_model, | ||
nhead=nhead, | ||
qkv_bias=False, | ||
out_bias=True, | ||
attn_window=attn_window | ||
) | ||
self.ff = GatedMlp( | ||
d_model, | ||
hidden_features=dim_feedforward, | ||
activation=F.silu, | ||
bias1=False, | ||
bias2=False, | ||
multiple_of=1, | ||
) | ||
self.norm1 = RMSNorm(d_model) | ||
self.norm2 = RMSNorm(d_model) | ||
|
||
self.register_buffer("deepnorm_alpha", torch.tensor(deepnorm_alpha)) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
db = self.kwargs["deepnorm_beta"] | ||
d_model = self.kwargs["d_model"] | ||
torch.nn.init.xavier_normal_(self.ff.fc1.weight, gain=db) | ||
torch.nn.init.xavier_normal_(self.ff.fc2.weight, gain=db) | ||
torch.nn.init.xavier_normal_(self.self_attn.out_proj.weight, gain=db) | ||
torch.nn.init.xavier_normal_(self.self_attn.Wqkv.weight[2*d_model:], gain=db) | ||
torch.nn.init.xavier_normal_(self.self_attn.Wqkv.weight[:2*d_model], gain=1) | ||
|
||
def forward(self, x): | ||
x = self.norm1(self.self_attn(x), self.deepnorm_alpha*x) | ||
x = self.norm2(self.ff(x), self.deepnorm_alpha*x) | ||
return x | ||
|
||
def to_dict(self, include_weights=False): | ||
if include_weights: | ||
raise NotImplementedError | ||
return self.kwargs | ||
|
||
|
||
def use_koi(self, **kwargs): | ||
# koi needs modified LinearCRFLayer settings | ||
def _expand_blanks(m): | ||
if isinstance(m, LinearCRFEncoder): | ||
m.expand_blanks = False | ||
self.encoder.apply(_expand_blanks) | ||
self.encoder = Serial([ | ||
self.encoder, | ||
Permute([1, 0, 2]), | ||
MakeContiguous(), | ||
]) | ||
|
||
|
||
def Model(config): | ||
model_config = {k: v for k, v in config["model"].items() if k != "package"} | ||
model = from_dict(model_config) | ||
model.config = config | ||
model.use_koi = types.MethodType(use_koi, model) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
import os | ||
import re | ||
from setuptools import setup, find_packages | ||
from setuptools.command.install import install | ||
|
||
|
||
__pkg_name__ = 'bonito' | ||
|
@@ -35,12 +34,15 @@ | |
author='Oxford Nanopore Technologies, Ltd', | ||
author_email='[email protected]', | ||
url='https://github.com/nanoporetech/bonito', | ||
entry_points = { | ||
extras_require={ | ||
# --extra-index-url https://download.pytorch.org/whl/cu118 | ||
"cu118": ["torch==2.2.1+cu118"], | ||
# --extra-index-url https://download.pytorch.org/whl/cu121 | ||
"cu121": ["torch==2.2.1+cu121"], | ||
}, | ||
entry_points={ | ||
'console_scripts': [ | ||
'{0} = {0}:main'.format(__pkg_name__) | ||
] | ||
}, | ||
dependency_links=[ | ||
'https://download.pytorch.org/whl/cu113', | ||
] | ||
) |