-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
ZihanLiu
authored and
ZihanLiu
committed
Apr 23, 2020
1 parent
69cf3f2
commit 0a1548a
Showing
34 changed files
with
4,778 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
.DS_Store | ||
data/ | ||
experiments/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,96 @@ | ||
# coach | ||
# Coach: A Coarse-to-Fine Approach for Cross-domain Slot Filling | ||
|
||
<img src="imgs/pytorch-logo-dark.png" width="10%"> [](https://opensource.org/licenses/MIT) | ||
|
||
<img align="right" src="imgs/HKUST.jpg" width="15%"> | ||
|
||
This repository is for the ACL-2020 paper: [Coach: A Coarse-to-Fine Approach for Cross-domain Slot Filling](path_to_the_arxiv_link). It contains the scripts for the coach framework and the baseline models [CT](https://arxiv.org/pdf/1707.02363.pdf) and [RZT](https://arxiv.org/pdf/1906.06870.pdf). | ||
|
||
This code has been written using PyTorch. If you use any source codes or ideas included in this repository for your work, please cite the following paper. | ||
<pre> | ||
bibtext | ||
</pre> | ||
|
||
## Abstract | ||
As an essential task in task-oriented dialog systems, slot filling requires extensive training data in a certain domain. However such data are not always available. | ||
Hence, cross-domain slot filling naturally arises to cope with this data scarcity problem. | ||
In this paper, we propose a Coarse-to-fine approach (Coach) for cross-domain slot filling. Our model first learns the general pattern of slot entities by detecting whether the tokens are slot entities or not. It then predicts the specific types for the slot entities. In addition, we propose a template regularization approach to improve the adaptation robustness by regularizing the representation of utterances based on utterance templates. | ||
Experimental results show that our model significantly outperforms state-of-the-art approaches in slot filling. Furthermore, our model can also be applied to the cross-domain named entity recognition task, and it achieves better adaptation performance than other existing baselines. | ||
|
||
## Coach Framework | ||
<img src="imgs/coach_framework.jpg" width=100%/> | ||
|
||
## Data | ||
- ```Cross-domain Slot Filling:``` Evaluated on [SNIPS](https://arxiv.org/pdf/1805.10190.pdf) dataset, which contains 39 slot types across seven domains (intents) and ~2000 training samples per domain. | ||
|
||
- ```Cross-domain named entity recognition (NER):``` We take CoNLL-2003 English named entity recognition (NER) dataset as the source domain and the [CBS SciTech News NER](https://github.com/jiachenwestlake/Cross-Domain_NER/tree/master/unsupervised_domain_adaptation/data/news_tech) dataset as the target domain. | ||
|
||
## Preprocessing | ||
Preprocessing scripts for cross-domain slot filling and NER are in the preprocess folder. We list the preprocess details for each file as follows: | ||
- ```slu_preprocess.py:``` The script for preprocessing the data in slot filling task. | ||
- ```gen_embeddings_for_slu.py:``` The script for the slot filling task. It contains domain descriptions, slot descriptions, and the functions for generating embeddings for the slot filling task vocabulary and the embeddings for domains and slots. | ||
- ```gen_example_emb_for_slu_baseline.py:``` The script for RZT baseline in the slot filling task. It contains slot examples and the function for generating embeddings for them. | ||
- ```gen_embeddings_for_ner.py:``` The script for the NER task. It contains entity descriptions, entity examples (for the RZT baseline), and the functions for generating embeddings for the NER task vocabulary as well as the embeddings for entities and entity examples. | ||
|
||
### Notes | ||
- We utilize [fastText](https://fasttext.cc/docs/en/pretrained-vectors.html) to get the word-level embeddings including the embeddings for those out of vocabulary (oov) words. | ||
- The original datasets, the preprocessed data and the preprocessed embeddings can be downloaded [here](LINK_TO_THE_DATA) (All in the data folder). | ||
|
||
## How to run | ||
### Configuration | ||
- ```--tgt_dm:``` Target domain | ||
- ```--n_samples:``` Number of samples used in the target domain | ||
- ```--tr:``` Use template regularization | ||
- ```--enc_type:``` Encoder type for encoding entity tokens in the Step Two | ||
- ```--model_path:``` Saved model path | ||
|
||
### Cross-domain Slot Filling | ||
Train Coach model for 50-shot adaptation to AddToPlaylist domain | ||
``` | ||
python slu_main.py --exp_name coach_lstmenc --exp_id atp_50 --bidirection --freeze_emb --tgt_dm PlayMusic --n_samples 50 | ||
``` | ||
|
||
Train Coach + Template Regularization (TR) for 50-shot adaptation to AddToPlaylist domain | ||
``` | ||
python slu_main.py --exp_name coach_tr_lstmenc --exp_id atp_50 --bidirection --freeze_emb --tr --tgt_dm AddToPlaylist --emb_file ./data/snips/emb/slu_word_char_embs_with_slotembs.npy --n_samples 50 | ||
``` | ||
|
||
Train CT model (baseline) for 50-shot adaptation to AddToPlaylist target domain | ||
``` | ||
python slu_baseline.py --exp_name ct --exp_id atp_50 --bidirection --freeze_emb --lr 1e-4 --hidden_dim 300 --tgt_dm AddToPlaylist --n_samples 50 | ||
``` | ||
|
||
Train RZT model (baseline) for 50-shot adaptation to AddToPlaylist target domain | ||
``` | ||
python slu_baseline.py --exp_name rzt --exp_id atp_50 --bidirection --freeze_emb --lr 1e-4 --hidden_dim 200 --use_example --tgt_dm AddToPlaylist --n_samples 50 | ||
``` | ||
|
||
Test Coach model on the AddToPlaylist target domain | ||
``` | ||
python slu_test.py --model_path ./experiments/coach/atp_50/best_model.pth --model_type coach --n_samples 50 --tgt_dm AddToPlaylist | ||
``` | ||
|
||
Test Coach model on seen and unseen slots for the AddToPlaylist target domain | ||
``` | ||
python slu_test.py --model_path ./experiments/coach_lstmenc/atp_50/best_model.pth --model_type coach --n_samples 50 --tgt_dm AddToPlaylist --test_mode seen_unseen | ||
``` | ||
|
||
### Cross-domain NER | ||
Train Coach model for zero-shot adaptation | ||
``` | ||
python ner_main.py --exp_name coach --exp_id ner_0 --bidirection --emb_file ./data/ner/emb/ner_embs.npy --emb_dim 300 --trs_hidden_dim 300 --lr 1e-4 | ||
``` | ||
|
||
Train CT model for zero-shot adaptation | ||
``` | ||
python ner_baseline.py --exp_name ct --exp_id ner_0 --bidirection --emb_file ./data/ner/emb/ner_embs.npy --emb_dim 300 --lr 1e-4 | ||
``` | ||
|
||
Train RZT model for zero-shot adaptation | ||
``` | ||
python ner_baseline.py --exp_name rzt --exp_id ner_0 --bidirection --emb_file ./data/ner/emb/ner_embs.npy --emb_dim 300 --lr 1e-4 --hidden_dim 150 --use_example | ||
``` | ||
|
||
### Notes | ||
- More running commands can be found in run.sh | ||
- All the models can be downloaded [here](LINK_TO_THE_EXPERIMENTS) (in the experiments folder) to reproduce our results. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import argparse | ||
|
||
def get_params(): | ||
# parse parameters | ||
parser = argparse.ArgumentParser(description="Cross-domain SLU") | ||
parser.add_argument("--exp_name", type=str, default="default", help="Experiment name") | ||
parser.add_argument("--logger_filename", type=str, default="cross-domain-slu.log") | ||
parser.add_argument("--dump_path", type=str, default="experiments", help="Experiment saved root path") | ||
parser.add_argument("--exp_id", type=str, default="1", help="Experiment id") | ||
|
||
# adaptation parameters | ||
parser.add_argument("--epoch", type=int, default=300, help="number of maximum epoch") | ||
parser.add_argument("--tgt_dm", type=str, default="", help="target_domain") | ||
parser.add_argument("--emb_file", type=str, default="./data/snips/emb/slu_word_char_embs.npy", help="embeddings file") # slu_embs.npy w/o char embeddings slu_word_char_embs.npy w/ char embeddings slu_word_char_embs_with_slotembs.npy w/ char and slot embs | ||
parser.add_argument("--emb_dim", type=int, default=400, help="embedding dimension") | ||
parser.add_argument("--batch_size", type=int, default=32, help="batch size") | ||
parser.add_argument("--num_binslot", type=int, default=3, help="number of binary slot O,B,I") | ||
parser.add_argument("--num_slot", type=int, default=72, help="number of slot types") | ||
parser.add_argument("--num_domain", type=int, default=7, help="number of domain") | ||
parser.add_argument("--freeze_emb", default=False, action="store_true", help="Freeze embeddings") | ||
|
||
parser.add_argument("--slot_emb_file", type=str, default="./data/snips/emb/slot_word_char_embs_based_on_each_domain.dict", help="dictionary type: slot embeddings based on each domain") # slot_embs_based_on_each_domain.dict w/o char embeddings slot_word_char_embs_based_on_each_domain.dict w/ char embeddings | ||
parser.add_argument("--bidirection", default=False, action="store_true", help="Bidirectional lstm") | ||
parser.add_argument("--lr", type=float, default=5e-4, help="learning rate") | ||
parser.add_argument("--dropout", type=float, default=0.3, help="dropout rate") | ||
parser.add_argument("--hidden_dim", type=int, default=200, help="hidden dimension for LSTM") | ||
parser.add_argument("--n_layer", type=int, default=2, help="number of layers for LSTM") | ||
parser.add_argument("--early_stop", type=int, default=5, help="No improvement after several epoch, we stop training") | ||
parser.add_argument("--binary", default=False, action="store_true", help="conduct binary training only") | ||
|
||
# add label_encoder | ||
parser.add_argument("--tr", default=False, action="store_true", help="use template regularization") | ||
|
||
# few shot learning | ||
parser.add_argument("--n_samples", type=int, default=0, help="number of samples for few shot learning") | ||
|
||
# encoder type for encoding entity tokens in the Step Two | ||
parser.add_argument("--enc_type", type=str, default="lstm", help="encoder type for encoding entity tokens (e.g., trs, lstm, none)") | ||
|
||
# transformer parameters | ||
parser.add_argument("--num_heads", type=int, default=4, help="Number of heads for transformer") | ||
parser.add_argument("--trs_hidden_dim", type=int, default=400, help="Dimension after combined into word level") | ||
parser.add_argument("--filter_size", type=int, default=64, help="Hidden size of the middle layer in FFN") | ||
parser.add_argument("--dim_key", type=int, default=0, help="Key dimension in transformer (if 0, then would be the same as hidden_size)") | ||
parser.add_argument("--dim_value", type=int, default=0, help="Value dimension in transformer (if 0, then would be the same as hidden_size)") | ||
parser.add_argument("--trs_layers", type=int, default=1, help="Number of layers for transformer") | ||
|
||
# baseline | ||
parser.add_argument("--use_example", default=False, action="store_true", help="use example value") | ||
parser.add_argument("--example_emb_file", type=str, default="./data/snips/emb/example_embs_based_on_each_domain.dict", help="dictionary type: example embeddings based on each domain") | ||
|
||
# test model | ||
parser.add_argument("--model_path", type=str, default="", help="Saved model path") | ||
parser.add_argument("--model_type", type=str, default="", help="Saved model type (e.g., coach, ct, rzt)") | ||
parser.add_argument("--test_mode", type=str, default="testset", help="Choose mode to test the model (e.g., testset, seen_unseen)") | ||
|
||
# NER | ||
parser.add_argument("--ner_entity_type_emb_file", type=str, default="./data/ner/emb/entity_type_embs.npy", help="entity type embeddings file path") | ||
parser.add_argument("--ner_example_emb_file", type=str, default="./data/ner/emb/example_embs.npy", help="entity example embeddings file path") | ||
parser.add_argument("--bilstmcrf", default=False, action="store_true", help="use BiLSTM-CRF baseline") | ||
parser.add_argument("--num_entity_label", type=int, default=9, help="number of entity label") | ||
|
||
params = parser.parse_args() | ||
|
||
return params |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
|
||
|
||
from src.utils import init_experiment | ||
from src.ner.baseline_loader import get_dataloader | ||
from src.ner.baseline_model import ConceptTagger, BiLSTMCRFTagger | ||
from src.ner.baseline_trainer import BaselineTrainer, BiLSTMCRFTrainer | ||
from config import get_params | ||
|
||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
def run_baseline(params): | ||
# initialize experiment | ||
logger = init_experiment(params, logger_filename=params.logger_filename) | ||
|
||
# get dataloader | ||
dataloader_tr, dataloader_val, dataloader_test, vocab = get_dataloader(params.batch_size, bilstmcrf=params.bilstmcrf, n_samples=params.n_samples) | ||
|
||
# build model | ||
if params.bilstmcrf: | ||
ner_tagger = BiLSTMCRFTagger(params, vocab) | ||
ner_tagger.cuda() | ||
baseline_trainer = BiLSTMCRFTrainer(params, ner_tagger) | ||
else: | ||
concept_tagger = ConceptTagger(params, vocab) | ||
concept_tagger.cuda() | ||
baseline_trainer = BaselineTrainer(params, concept_tagger) | ||
|
||
for e in range(params.epoch): | ||
logger.info("============== epoch {} ==============".format(e+1)) | ||
loss_list = [] | ||
pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr)) | ||
for i, (X, lengths, y) in pbar: | ||
X, lengths = X.cuda(), lengths.cuda() | ||
|
||
loss = baseline_trainer.train_step(X, lengths, y) | ||
loss_list.append(loss) | ||
pbar.set_description("(Epoch {}) LOSS:{:.4f}".format((e+1), np.mean(loss_list))) | ||
|
||
logger.info("Finish training epoch {}. LOSS:{:.4f}".format((e+1), np.mean(loss_list))) | ||
|
||
logger.info("============== Evaluate Epoch {} ==============".format(e+1)) | ||
f1_score, stop_training_flag = baseline_trainer.evaluate(dataloader_val, istestset=False) | ||
logger.info("Eval on dev set. Entity-F1: {:.4f}.".format(f1_score)) | ||
|
||
f1_score, stop_training_flag = baseline_trainer.evaluate(dataloader_test, istestset=True) | ||
logger.info("Eval on test set. Entity-F1: {:.4f}.".format(f1_score)) | ||
|
||
if stop_training_flag == True: | ||
break | ||
|
||
if __name__ == "__main__": | ||
params = get_params() | ||
run_baseline(params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
|
||
from src.utils import init_experiment | ||
from src.ner.dataloader import get_dataloader | ||
from src.ner.trainer import NERTrainer | ||
from src.ner.model import BinaryNERagger, EntityNamePredictor, SentRepreGenerator | ||
from config import get_params | ||
|
||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
def main(params): | ||
# initialize experiment | ||
logger = init_experiment(params, logger_filename=params.logger_filename) | ||
|
||
# get dataloader | ||
dataloader_tr, dataloader_val, dataloader_test, vocab = get_dataloader(params.batch_size, use_label_encoder=params.tr, n_samples=params.n_samples) | ||
|
||
# build model | ||
binary_nertagger = BinaryNERagger(params, vocab) | ||
entityname_predictor = EntityNamePredictor(params) | ||
binary_nertagger, entityname_predictor = binary_nertagger.cuda(), entityname_predictor.cuda() | ||
|
||
if params.tr: | ||
sent_repre_generator = SentRepreGenerator(params, vocab) | ||
sent_repre_generator = sent_repre_generator.cuda() | ||
ner_trainer = NERTrainer(params, binary_nertagger, entityname_predictor, sent_repre_generator) | ||
else: | ||
ner_trainer = NERTrainer(params, binary_nertagger, entityname_predictor) | ||
|
||
for e in range(params.epoch): | ||
logger.info("============== epoch {} ==============".format(e+1)) | ||
loss_bin_list, loss_entityname_list = [], [] | ||
if params.tr: | ||
loss_tem0_list, loss_tem1_list = [], [] | ||
pbar = tqdm(enumerate(dataloader_tr), total=len(dataloader_tr)) | ||
if params.tr: | ||
for i, (X, lengths, y_bin, y_final, templates, tem_lengths) in pbar: | ||
X, lengths, templates, tem_lengths = X.cuda(), lengths.cuda(), templates.cuda(), tem_lengths.cuda() | ||
loss_bin, loss_entityname, loss_tem0, loss_tem1 = ner_trainer.train_step(X, lengths, y_bin, y_final, templates=templates, tem_lengths=tem_lengths, epoch=e) | ||
loss_bin_list.append(loss_bin) | ||
loss_entityname_list.append(loss_entityname) | ||
loss_tem0_list.append(loss_tem0) | ||
loss_tem1_list.append(loss_tem1) | ||
|
||
pbar.set_description("(Epoch {}) LOSS BIN:{:.4f} LOSS entity:{:.4f} LOSS TEM0:{:.4f} LOSS TEM1:{:.4f}".format((e+1), np.mean(loss_bin_list), np.mean(loss_entityname_list), np.mean(loss_tem0_list), np.mean(loss_tem1_list))) | ||
else: | ||
for i, (X, lengths, y_bin, y_final) in pbar: | ||
X, lengths = X.cuda(), lengths.cuda() | ||
loss_bin, loss_entityname = ner_trainer.train_step(X, lengths, y_bin, y_final) | ||
loss_bin_list.append(loss_bin) | ||
loss_entityname_list.append(loss_entityname) | ||
|
||
pbar.set_description("(Epoch {}) LOSS BIN:{:.4f} LOSS entity:{:.4f}".format((e+1), np.mean(loss_bin_list), np.mean(loss_entityname_list))) | ||
|
||
logger.info("Finish training epoch {}. LOSS BIN:{:.4f} LOSS entity:{:.4f}".format((e+1), np.mean(loss_bin_list), np.mean(loss_entityname_list))) | ||
|
||
logger.info("============== Evaluate Epoch {} ==============".format(e+1)) | ||
bin_f1, final_f1, stop_training_flag = ner_trainer.evaluate(dataloader_val, istestset=False) | ||
logger.info("Eval on dev set. Binary entity-F1: {:.4f}. Final entity-F1: {:.4f}.".format(bin_f1, final_f1)) | ||
|
||
bin_f1, final_f1, stop_training_flag = ner_trainer.evaluate(dataloader_test, istestset=True) | ||
logger.info("Eval on test set. Binary entity-F1: {:.4f}. Final entity-F1: {:.4f}.".format(bin_f1, final_f1)) | ||
|
||
if stop_training_flag == True: | ||
break | ||
|
||
if __name__ == "__main__": | ||
params = get_params() | ||
main(params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
|
||
import numpy as np | ||
import pickle | ||
|
||
from src.ner.datareader import datareader | ||
from src.utils import load_embedding | ||
|
||
# entity_list = ["LOC", "PER", "ORG", "MISC"] | ||
entity_types = ["location", "person", "organization", "miscellaneous"] # entity descriptions | ||
|
||
example_dict = { | ||
"location": ["france", "russia"], | ||
"person": ["quigley", "samokhalova"], | ||
"organization": ["aberdeen", "nantes"], | ||
"miscellaneous": ["english", "eu-wide"] | ||
} # entity examples | ||
|
||
def get_oov_words(): | ||
_, _, _, vocab = datareader() | ||
_ = load_embedding(vocab, 300, "PATH_OF_THE_WIKI_EN_VEC") | ||
|
||
def gen_embs_for_vocab(): | ||
_, _, _, vocab = datareader() | ||
embedding = load_embedding(vocab, 300, "PATH_OF_THE_WIKI_EN_VEC", "../data/ner/emb/oov_embs.txt") | ||
|
||
np.save("../data/ner/emb/ner_embs.npy", embedding) | ||
|
||
def gen_embs_for_entity_types(emb_file, emb_dim): | ||
embedding = np.zeros((len(entity_types), emb_dim)) | ||
print("loading embeddings from %s" % emb_file) | ||
embedded_words = [] | ||
with open(emb_file, "r") as ef: | ||
pre_trained = 0 | ||
for i, line in enumerate(ef): | ||
if i == 0: continue # first line would be "num of words and dimention" | ||
line = line.strip() | ||
sp = line.split() | ||
try: | ||
assert len(sp) == emb_dim + 1 | ||
except: | ||
continue | ||
if sp[0] in entity_types and sp[0] not in embedded_words: | ||
pre_trained += 1 | ||
embedding[entity_types.index(sp[0])] = [float(x) for x in sp[1:]] | ||
embedded_words.append(sp[0]) | ||
print("Pre-train: %d / %d (%.2f)" % (pre_trained, len(entity_types), pre_trained / len(entity_types))) | ||
|
||
np.save("../data/ner/emb/entity_type_embs.npy", embedding) | ||
|
||
def gen_example_embs_for_entity_types(emb_file, emb_dim): | ||
ner_embs = np.load(emb_file) | ||
_, _, _, vocab = datareader() | ||
|
||
example_embs = np.zeros((len(entity_types), emb_dim, 2)) | ||
for i, entity_type in enumerate(entity_types): | ||
examples = example_dict[entity_type] | ||
for j, example in enumerate(examples): | ||
index = vocab.word2index[example] | ||
example_embs[i, :, j] = ner_embs[index] | ||
|
||
print("saving example embeddings") | ||
np.save("../data/ner/emb/example_embs.npy", example_embs) | ||
|
||
if __name__ == "__main__": | ||
# get_oov_words() | ||
# gen_embs_for_vocab() | ||
# gen_embs_for_entity_types("PATH_OF_THE_WIKI_EN_VEC", 300) | ||
gen_example_embs_for_entity_types("../data/ner/emb/ner_embs.npy", 300) | ||
|
Oops, something went wrong.