Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
beyondguo committed Nov 19, 2022
1 parent 296e87c commit 77941c6
Show file tree
Hide file tree
Showing 28 changed files with 1,356 additions and 324 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,4 @@ playground.ipynb
other_nlg/
other_gen/
other_aug/
nlg_eval.ipynb
nlg_eval/
Binary file renamed SEGA_gby_preprint.pdf → GENIUS_gby_arxiv.pdf
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
226 changes: 226 additions & 0 deletions _backup_scripts/model_upload.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Create Repo\n",
"https://huggingface.co/docs/huggingface_hub/how-to-manage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# go to the terminal, run:\n",
"# `huggingface-cli login` to login"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import create_repo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"create_repo(\"beyond/sega-large\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Upload files to the Hub\n",
"https://huggingface.co/docs/huggingface_hub/how-to-upstream\n",
"\n",
"first download git-lfs:\n",
"```shell\n",
"curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash\n",
"\n",
"sudo apt-get install git-lfs\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # saved_models/bart-large-c4-l_50_200-d_13799838-yake_mask-t_3900800/checkpoint-152375\n",
"# model_path = \"saved_models/bart-large-c4-l_50_200-d_13799838-yake_mask-t_3900800/checkpoint-152375\"\n",
"# from huggingface_hub import HfApi\n",
"# api = HfApi()\n",
"# api.upload_folder(\n",
"# folder_path=model_path,\n",
"# path_in_repo=\"/\",\n",
"# repo_id=\"beyond/sega-large\",\n",
"# repo_type=\"model\",\n",
"# ignore_patterns=\"**/logs/*.txt\",\n",
"# )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Or, directly use `push_to_hub` to upload model and files:\n",
"https://huggingface.co/docs/transformers/model_sharing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForSeq2SeqLM\n",
"model_path = \"saved_models/bart-large-c4-l_50_200-d_13799838-yake_mask-t_3900800/checkpoint-152375\"\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.push_to_hub(\"beyond/sega-large\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"model_path = \"saved_models/bart-large-c4-l_50_200-d_13799838-yake_mask-t_3900800/checkpoint-152375\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"tokenizer.push_to_hub(\"beyond/sega-large\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Card\n",
"1. directly modify it on the web page\n",
"2. for the API widgets: https://huggingface.co/docs/hub/models-widgets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import create_repo\n",
"from transformers import AutoModelForSeq2SeqLM\n",
"from transformers import AutoTokenizer\n",
"\n",
"repo = \"beyond/sega-large-k2t\"\n",
"model_path = 'saved_models/bart-large-cnn-wikipedia-paras-yake-importance-1000000d-final'\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_path)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"model.push_to_hub(repo)\n",
"tokenizer.push_to_hub(repo)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# sega-chinese is slight different\n",
"from huggingface_hub import create_repo\n",
"\n",
"from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline\n",
"model_path = '../saved_models/bart-base-chinese-chinese_clean_passages_80m_with_sketch-10000000/checkpoint-93750'\n",
"tokenizer = BertTokenizer.from_pretrained(model_path)\n",
"sega_model = BartForConditionalGeneration.from_pretrained(model_path)\n",
"\n",
"\n",
"sega_model.push_to_hub('beyond/sega-base-chinese')\n",
"tokenizer.push_to_hub('beyond/sega-base-chinese')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import create_repo\n",
"from transformers import AutoModelForSeq2SeqLM\n",
"from transformers import AutoTokenizer\n",
"\n",
"repo = \"beyond/sega-base\"\n",
"model_path = '../saved_models/bart-base-c4-realnewslike-4templates-passage-max15sents_2-sketch4/checkpoint-129375'\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_path)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"# if see: ValueError: If not specifying clone_from, you need to pass Repository a valid git clone.\n",
"# add `use_temp_dir=True`:\n",
"model.push_to_hub(repo,use_temp_dir=True)\n",
"tokenizer.push_to_hub(repo,use_temp_dir=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForSeq2SeqLM.from_pretrained('beyond/sega-base')\n",
"model.num_parameters()/1000000"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
},
"kernelspec": {
"display_name": "Python 3.6.10 64-bit ('conda': virtualenv)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 3 additions & 3 deletions my_dataset.py → _backup_scripts/my_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
class MyDataset(Dataset):
def __init__(self, tokenizer, texts, labels, label2idx, maxlen):
self.tokenizer = tokenizer
# 我先不用padding,后面通过data_collator来做dynamic padding
# no padding now,use data_collator for dynamic padding later
texts = [t if (t != None and str(t) != 'nan') else '' for t in texts]
self.encodings = tokenizer(texts, truncation=True, max_length=maxlen)
self.labels = labels
self.label2idx = label2idx
def __getitem__(self, idx):
item = {k:torch.tensor(v[idx]) for k,v in self.encodings.items()}
item['labels'] = torch.tensor(self.label2idx[self.labels[idx]]) # labels字段应该保存label的idx,而不是具体label名
item['labels'] = torch.tensor(self.label2idx[self.labels[idx]]) # 'labels' column should contain the idx of label, instead of the label string
return item
def __len__(self):
return len(self.labels)


def get_dataloader(file_path, tokenizer, label2idx, maxlen, bsz, collate_fn, shuffle=True):
# 单纯地给一个csv文件,然后返回一个dataloader
# input a csv file, return a dataloader
df = pd.read_csv(file_path)
texts, labels = list(df['content']), list(df['label'])
dataset = MyDataset(tokenizer, texts, labels, label2idx, maxlen)
Expand Down
File renamed without changes.
Loading

0 comments on commit 77941c6

Please sign in to comment.