Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Add support for any GPT-2 model hosted in Huggingface (#4360)
Browse files Browse the repository at this point in the history
  • Loading branch information
cahya-wirawan authored Feb 17, 2022
1 parent 7a28f15 commit 31e049d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 33 deletions.
5 changes: 4 additions & 1 deletion parlai/agents/hugging_face/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
We offer wrappers for generative transformers from [Hugging Face's transformers repository](https://github.com/huggingface/transformers) for fine-tuning and evaluating in ParlAI.

## GPT2
To use GPT2, run your command with the flag: `-m hugging_face/gpt2`.
To use GPT2, run your command with the flag: `-m hugging_face/gpt2`. And suppose you want to use another model other
than the default English GPT2 (small, medium, large and xl version), in that case, you can use `-m hugging_face/gpt2 --model_name <gpt2 model name>`,
where `<gpt2 model name>` can be any GPT2 model hosted in Huggingface such as **anonymous-german-nlp/german-gpt2**
or **indonesian-nlp/gpt2**

### Examples
**Talk to GPT2 large in interactive mode, with beam size 10, 3-gram beam blocking, and minimum beam length 25:**
Expand Down
35 changes: 19 additions & 16 deletions parlai/agents/hugging_face/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,26 @@ def get_tokenizer(self, opt):
"""
Instantiate tokenizer.
"""
model_sz = opt["gpt2_size"]
if model_sz == "small":
model_key = "gpt2"
elif model_sz == "distilgpt2":
model_key = "distilgpt2"
else:
model_key = f"gpt2-{model_sz}"
# check if datapath has the files that hugging face code looks for
hf_dir = os.path.join(opt["datapath"], "hf", model_key)
if all(
PathManager.exists(os.path.join(hf_dir, file_name))
for file_name in ["merges.txt", "vocab.json"]
):
fle_key = PathManager.get_local_path(hf_dir, recursive=True)

if opt["model_name"]:
fle_key = opt["model_name"]
else:
fle_key = model_key
model_sz = opt["gpt2_size"]
if model_sz == "small":
model_key = "gpt2"
elif model_sz == "distilgpt2":
model_key = "distilgpt2"
else:
model_key = f"gpt2-{model_sz}"
# check if datapath has the files that hugging face code looks for
hf_dir = os.path.join(opt["datapath"], "hf", model_key)
if all(
PathManager.exists(os.path.join(hf_dir, file_name))
for file_name in ["merges.txt", "vocab.json"]
):
fle_key = PathManager.get_local_path(hf_dir, recursive=True)

else:
fle_key = model_key
return GPT2Tokenizer.from_pretrained(fle_key)

def add_additional_special_tokens(self, additional_special_tokens: List[str]):
Expand Down
41 changes: 25 additions & 16 deletions parlai/agents/hugging_face/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,26 @@ def __init__(self, opt, dict):

def _init_from_pretrained(self, opt):
# load model
model_sz = opt["gpt2_size"]
if model_sz == "small":
model_key = "gpt2"
elif model_sz == "distilgpt2":
model_key = "distilgpt2"
if opt["model_name"]:
fle_key = opt["model_name"]
else:
model_key = f"gpt2-{model_sz}"

# check if datapath has the files that hugging face code looks for
hf_dir = os.path.join(opt["datapath"], "hf", model_key)
if all(
PathManager.exists(os.path.join(hf_dir, file_name))
for file_name in ["pytorch_model.bin", "config.json"]
):
fle_key = PathManager.get_local_path(hf_dir, recursive=True)
else:
fle_key = model_key
model_sz = opt["gpt2_size"]
if model_sz == "small":
model_key = "gpt2"
elif model_sz == "distilgpt2":
model_key = "distilgpt2"
else:
model_key = f"gpt2-{model_sz}"

# check if datapath has the files that hugging face code looks for
hf_dir = os.path.join(opt["datapath"], "hf", model_key)
if all(
PathManager.exists(os.path.join(hf_dir, file_name))
for file_name in ["pytorch_model.bin", "config.json"]
):
fle_key = PathManager.get_local_path(hf_dir, recursive=True)
else:
fle_key = model_key
return GPT2Model.from_pretrained(fle_key)

def forward(self, input, encoder_state, incr_state=None):
Expand Down Expand Up @@ -237,6 +240,12 @@ def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
agent = parser.add_argument_group("Gpt2 Args")
agent.add_argument(
"--model-name",
type=str,
default=None,
help="Any GPT-2 model names.",
)
agent.add_argument(
"--gpt2-size",
type=str,
Expand Down

0 comments on commit 31e049d

Please sign in to comment.