Skip to content

Commit

Permalink
New feature - LR Finder
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jul 9, 2020
1 parent 746496f commit 24e9c16
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 50 deletions.
29 changes: 18 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,13 @@
[![PyPI version](https://badge.fury.io/py/fast-bert.svg)](https://badge.fury.io/py/fast-bert)
![Python 3.6, 3.7](https://img.shields.io/badge/python-3.6%20%7C%203.7-green.svg)

**New Includes Summarisation using BERT Seq2Seq**
**New - Learning Rate Finder for Text Classification Training (borrowed with thanks from https://github.com/davidtvs/pytorch-lr-finder)**

**New model architectures: ALBERT, CamemBERT, DistilRoberta**

**DistilBERT (from HuggingFace), Smaller, faster, cheaper, lighter**

**RoBERTa model support added to Fastbert**

**Now supports LAMB optimizer for faster training.**
**Supports LAMB optimizer for faster training.**
Please refer to https://arxiv.org/abs/1904.00962 for the paper on LAMB optimizer.

**Now supports BERT and XLNet for both Multi-Class and Multi-Label text classification.**
**Supports BERT and XLNet for both Multi-Class and Multi-Label text classification.**

Fast-Bert is the deep learning library that allows developers and data scientists to train and deploy BERT and XLNet based models for natural language processing tasks beginning with Text Classification.

Expand Down Expand Up @@ -199,7 +194,19 @@ learner = BertLearner.from_pretrained_model(
| multi_label | multilabel classification |
| logging_steps | number of steps between each tensorboard metrics calculation. Set it to 0 to disable tensor flow logging. Keeping this value too low will lower the training speed as model will be evaluated each time the metrics are logged |

### 3. Train the model
### 3. Find the optimal learning rate

The learning rate is one of the most important hyperparameters for model training. We have incorporated the learining rate finder that was proposed by Leslie Smith and then built into the fastai library.

```python
learner.lr_find(start_lr=1e-5,optimizer_type='lamb')
```

The code is heavily borrowed from David Silva's [pytorch-lr-finder library](https://github.com/davidtvs/pytorch-lr-finder).

![Learning rate range test](images/lr_finder.png)

### 4. Train the model

```python
learner.fit(epochs=6,
Expand All @@ -211,7 +218,7 @@ learner.fit(epochs=6,

Fast-Bert now supports LAMB optmizer. Due to the speed of training, we have set LAMB as the default optimizer. You can switch back to AdamW by setting optimizer_type to 'adamw'.

### 4. Save trained model artifacts
### 5. Save trained model artifacts

```python
learner.save_model()
Expand All @@ -230,7 +237,7 @@ Model artefacts will be persisted in the output_dir/'model_out' path provided to

As the model artefacts are all stored in the same folder, you will be able to instantiate the learner object to run inference by pointing pretrained_path to this location.

### 5. Model Inference
### 6. Model Inference

If you already have a Learner object with trained model instantiated, just call predict_batch method on the learner object with the list of text data:

Expand Down
Binary file added images/lr_finder.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 1 addition & 39 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,12 @@
import subprocess


# from pip.req import parse_requirements

# install requirements for mixed precision training
# try:
# import torch

# TORCH_MAJOR = int(torch.__version__.split(".")[0])

# if TORCH_MAJOR == 0:
# subprocess.run(
# [
# sys.executable,
# "-m",
# "pip",
# "install",
# "git+https://github.com/NVIDIA/apex",
# "-v",
# "--no-cache-dir",
# ]
# )
# else:
# subprocess.run(
# [
# sys.executable,
# "-m",
# "pip",
# "install",
# "git+https://github.com/NVIDIA/apex",
# "-v",
# "--no-cache-dir",
# "--global-option=--cpp_ext",
# "--global-option=--cuda_ext",
# ]
# )
# except Exception:
# pass


with open("requirements.txt") as f:
install_requires = f.read().strip().split("\n")

setup(
name="fast_bert",
version="1.7.2",
version="1.8.0",
description="AI Library using BERT",
author="Kaushal Trivedi",
author_email="[email protected]",
Expand Down

0 comments on commit 24e9c16

Please sign in to comment.