Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add support for any GPT-2 model hosted in Huggingface #4360

Merged
merged 3 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion parlai/agents/hugging_face/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
We offer wrappers for generative transformers from [Hugging Face's transformers repository](https://github.com/huggingface/transformers) for fine-tuning and evaluating in ParlAI.

## GPT2
To use GPT2, run your command with the flag: `-m hugging_face/gpt2`.
To use GPT2, run your command with the flag: `-m hugging_face/gpt2`. And suppose you want to use another model other
than the default English GPT2 (small, medium, large and xl version), in that case, you can use `-m hugging_face/gpt2 --model_name <gpt2 model name>`,
where `<gpt2 model name>` can be any GPT2 model hosted in Huggingface such as **anonymous-german-nlp/german-gpt2**
or **indonesian-nlp/gpt2**

### Examples
**Talk to GPT2 large in interactive mode, with beam size 10, 3-gram beam blocking, and minimum beam length 25:**
Expand Down
35 changes: 19 additions & 16 deletions parlai/agents/hugging_face/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,26 @@ def get_tokenizer(self, opt):
"""
Instantiate tokenizer.
"""
model_sz = opt["gpt2_size"]
if model_sz == "small":
model_key = "gpt2"
elif model_sz == "distilgpt2":
model_key = "distilgpt2"
else:
model_key = f"gpt2-{model_sz}"
# check if datapath has the files that hugging face code looks for
hf_dir = os.path.join(opt["datapath"], "hf", model_key)
if all(
PathManager.exists(os.path.join(hf_dir, file_name))
for file_name in ["merges.txt", "vocab.json"]
):
fle_key = PathManager.get_local_path(hf_dir, recursive=True)

if opt["model_name"]:
fle_key = opt["model_name"]
else:
fle_key = model_key
model_sz = opt["gpt2_size"]
if model_sz == "small":
model_key = "gpt2"
elif model_sz == "distilgpt2":
model_key = "distilgpt2"
else:
model_key = f"gpt2-{model_sz}"
# check if datapath has the files that hugging face code looks for
hf_dir = os.path.join(opt["datapath"], "hf", model_key)
if all(
PathManager.exists(os.path.join(hf_dir, file_name))
for file_name in ["merges.txt", "vocab.json"]
):
fle_key = PathManager.get_local_path(hf_dir, recursive=True)

else:
fle_key = model_key
return GPT2Tokenizer.from_pretrained(fle_key)

def add_additional_special_tokens(self, additional_special_tokens: List[str]):
Expand Down
41 changes: 25 additions & 16 deletions parlai/agents/hugging_face/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,26 @@ def __init__(self, opt, dict):

def _init_from_pretrained(self, opt):
# load model
model_sz = opt["gpt2_size"]
if model_sz == "small":
model_key = "gpt2"
elif model_sz == "distilgpt2":
model_key = "distilgpt2"
if opt["model_name"]:
fle_key = opt["model_name"]
else:
model_key = f"gpt2-{model_sz}"

# check if datapath has the files that hugging face code looks for
hf_dir = os.path.join(opt["datapath"], "hf", model_key)
if all(
PathManager.exists(os.path.join(hf_dir, file_name))
for file_name in ["pytorch_model.bin", "config.json"]
):
fle_key = PathManager.get_local_path(hf_dir, recursive=True)
else:
fle_key = model_key
model_sz = opt["gpt2_size"]
if model_sz == "small":
model_key = "gpt2"
elif model_sz == "distilgpt2":
model_key = "distilgpt2"
else:
model_key = f"gpt2-{model_sz}"

# check if datapath has the files that hugging face code looks for
hf_dir = os.path.join(opt["datapath"], "hf", model_key)
if all(
PathManager.exists(os.path.join(hf_dir, file_name))
for file_name in ["pytorch_model.bin", "config.json"]
):
fle_key = PathManager.get_local_path(hf_dir, recursive=True)
else:
fle_key = model_key
return GPT2Model.from_pretrained(fle_key)

def forward(self, input, encoder_state, incr_state=None):
Expand Down Expand Up @@ -237,6 +240,12 @@ def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
agent = parser.add_argument_group("Gpt2 Args")
agent.add_argument(
"--model_name",
cahya-wirawan marked this conversation as resolved.
Show resolved Hide resolved
type=str,
default=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we provide a default here that would fall back to the existing behavior?

Copy link
Contributor Author

@cahya-wirawan cahya-wirawan Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it is already the case. If we don't use the argument "--model_name" at all, it will fall back to the old behavior. This is because the default value for this argument is None

help="Any GPT-2 model names.",
)
agent.add_argument(
"--gpt2-size",
type=str,
Expand Down