Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
beyondguo committed Nov 17, 2022
1 parent 5bd666c commit 296e87c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 276 deletions.
40 changes: 21 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
# SEGA: SkEtch-based Generative Augmentation
# GENIUS: Sketch-based Language Model Pre-training via Extreme and Selective Masking for Text Generation and Augmentation

**基于草稿的生成式增强模型**
**基于草稿的生成模型**

**SEGA** is a **general text augmentation model** that can be used for data augmentation for **various NLP tasks** (including sentiment analysis, topic classification, NER, and QA). SEGA uses an encoder-decoder structure (based on the BART architecture) and is pre-trained on the `C4-realnewslike` corpus.
**GENIUS** is a powerful conditional text generation model using sketches as input, which can fill in the missing contexts for a given **sketch** (key information consisting of textual spans, phrases, or words, concatenated by mask tokens). GENIUS uses an encoder-decoder structure (based on the BART architecture) and is pre-trained on the `C4-realnewslike` corpus.

**GENIUS** can also be used as a **general textual data augmentation tool** for **various NLP tasks** (including sentiment analysis, topic classification, NER, and QA).

![sega-illustration](https://cdn.jsdelivr.net/gh/beyondguo/mdnice_pictures/typora/sega-main-illustration.png)

- Paper: [SEGA: SkEtch-based Generative Augmentation (preprint)](https://github.com/beyondguo/SEGA/blob/master/SEGA_gby_preprint.pdf)
![genius-illustration](https://cdn.jsdelivr.net/gh/beyondguo/mdnice_pictures/typora/what-is-genius.png)

- Paper: [genius: SkEtch-based Generative Augmentation (preprint)](https://github.com/beyondguo/SEGA/blob/master/SEGA_gby_preprint.pdf)

- Models hosted in 🤗 Huggingface:

**Model variations:**

| Model | #params | Language | comment|
|------------------------|--------------------------------|-------|---------|
| [`sega-large`](https://huggingface.co/beyond/sega-large) | 406M | English | The version used in paper |
| [`sega-large-k2t`](https://huggingface.co/beyond/sega-large-k2t) | 406M | English | keywords-to-text |
| [`sega-base`](https://huggingface.co/beyond/sega-base) | 139M | English | smaller version |
| [`sega-base-ps`](https://huggingface.co/beyond/sega-base) | 139M | English | pre-trained both in paragraphs and short sentences |
| [`sega-base-chinese`](https://huggingface.co/beyond/sega-base-chinese) | 116M | 中文 | 在一千万纯净中文段落上预训练|
| [`genius-large`](https://huggingface.co/beyond/genius-large) | 406M | English | The version used in paper |
| [`genius-large-k2t`](https://huggingface.co/beyond/genius-large-k2t) | 406M | English | keywords-to-text |
| [`genius-base`](https://huggingface.co/beyond/genius-base) | 139M | English | smaller version |
| [`genius-base-ps`](https://huggingface.co/beyond/genius-base) | 139M | English | pre-trained both in paragraphs and short sentences |
| [`genius-base-chinese`](https://huggingface.co/beyond/genius-base-chinese) | 116M | 中文 | 在一千万纯净中文段落上预训练|

<img src="https://cdn.jsdelivr.net/gh/beyondguo/mdnice_pictures/typora/sega-hf-api.jpg" width="50%" />

**SEGA** is able to write complete paragraphs given a sketch (or framework), which can be composed of:
**GENIUS** is able to write complete paragraphs given a sketch (or framework), which can be composed of:
- keywords /key-phrases, like "––NLP––AI––computer––science––"
- spans, like "Conference on Empirical Methods––submission of research papers––"
- sentences, like "I really like machine learning––I work at Google since last year––"
Expand All @@ -35,11 +37,11 @@
```python
from transformers import pipeline
# 1. load the model with the huggingface `pipeline`
sega = pipeline("text2text-generation", model='beyond/sega-large', device=0)
genius = pipeline("text2text-generation", model='beyond/genius-large', device=0)
# 2. provide a sketch (joint by <mask> tokens)
sketch = "<mask> Conference on Empirical Methods <mask> submission of research papers <mask> Deep Learning <mask>"
# 3. just do it!
generated_text = sega(sketch, num_beams=3, do_sample=True, max_length=200)[0]['generated_text']
generated_text = genius(sketch, num_beams=3, do_sample=True, max_length=200)[0]['generated_text']
print(generated_text)
```
Output:
Expand All @@ -48,14 +50,14 @@ Output:
```

#### 2. If you want to do **data augmentation** to generate new training samples
Please check [SEGA/augmentation_tools](https://github.com/beyondguo/SEGA/tree/master/augmentation_tools), where we provide ready-to-run scripts for data augmentation for text classification/NER/MRC tasks.
Please check [genius/augmentation_tools](https://github.com/beyondguo/genius/tree/master/augmentation_tools), where we provide ready-to-run scripts for data augmentation for text classification/NER/MRC tasks.




---

## SEGA as A Strong Data Augmentation Tool:
## GENIUS as A Strong Data Augmentation Tool:
- Setting: Low-resource setting, where only n={50,100,200,500,1000} labeled samples are available for training. The below results are the average of all training sizes.
- Text Classification Datasets: [HuffPost](https://huggingface.co/datasets/khalidalt/HuffPost), [BBC](https://huggingface.co/datasets/SetFit/bbc-news), [SST2](https://huggingface.co/datasets/glue), [IMDB](https://huggingface.co/datasets/imdb), [Yahoo](https://huggingface.co/datasets/yahoo_answers_topics), [20NG](https://huggingface.co/datasets/newsgroup).
- Base classifier: [DistilBERT](https://huggingface.co/distilbert-base-cased)
Expand All @@ -71,8 +73,8 @@ In-distribution (ID) evaluations:
| C-MLM | 80.60 | 96.13 | 45.40 | 46.36 | 77.31 | 76.91 | 70.45 |
| LAMBADA | 81.46 | 93.74 | 50.49 | 47.72 | 78.22 | 78.31 | 71.66 |
| STA | 80.74 | 95.64 | 46.96 | 47.27 | 77.88 | 77.80 | 71.05 |
| **SEGA** | 81.43 | 95.74 | 49.60 | 50.38 | **80.16** | 78.82 | 72.68 |
| **SEGA-f** | **81.82** | 95.99 | **50.42** | **50.81** | 79.40 | **80.57** | **73.17** |
| **GeniusAug** | 81.43 | 95.74 | 49.60 | 50.38 | **80.16** | 78.82 | 72.68 |
| **GeniusAug-f** | **81.82** | 95.99 | **50.42** | **50.81** | 79.40 | **80.57** | **73.17** |

Out-of-distribution (OOD) evaluations:
| | Huff->BBC | BBC->Huff | IMDB->SST2 | SST2->IMDB | avg. |
Expand All @@ -84,8 +86,8 @@ Out-of-distribution (OOD) evaluations:
| C-MLM | 64.94 | **67.80** | 74.98 | 71.78 | 69.87 |
| LAMBADA | 68.57 | 52.79 | 75.24 | 76.04 | 68.16 |
| STA | 69.31 | 64.82 | 74.72 | 73.62 | 70.61 |
| **SEGA** | 74.87 | 66.85 | 76.02 | 74.76 | 73.13 |
| **SEGA-f** | **76.18** | 66.89 | **77.45** | **80.36** | **75.22** |
| **GeniusAug** | 74.87 | 66.85 | 76.02 | 74.76 | 73.13 |
| **GeniusAug-f** | **76.18** | 66.89 | **77.45** | **80.36** | **75.22** |

### BibTeX entry and citation info
TBD
Expand Down
Loading

0 comments on commit 296e87c

Please sign in to comment.