From 296e87c762b20650358baa0b47103e719c60b1ad Mon Sep 17 00:00:00 2001 From: beyondguo Date: Thu, 17 Nov 2022 16:08:35 +0000 Subject: [PATCH] update --- README.md | 40 ++++--- playground.ipynb | 301 +++++++---------------------------------------- 2 files changed, 65 insertions(+), 276 deletions(-) diff --git a/README.md b/README.md index 1a46796..9ecc035 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,15 @@ -# SEGA: SkEtch-based Generative Augmentation +# GENIUS: Sketch-based Language Model Pre-training via Extreme and Selective Masking for Text Generation and Augmentation -**基于草稿的生成式增强模型** +**基于草稿的生成模型** -**SEGA** is a **general text augmentation model** that can be used for data augmentation for **various NLP tasks** (including sentiment analysis, topic classification, NER, and QA). SEGA uses an encoder-decoder structure (based on the BART architecture) and is pre-trained on the `C4-realnewslike` corpus. +**GENIUS** is a powerful conditional text generation model using sketches as input, which can fill in the missing contexts for a given **sketch** (key information consisting of textual spans, phrases, or words, concatenated by mask tokens). GENIUS uses an encoder-decoder structure (based on the BART architecture) and is pre-trained on the `C4-realnewslike` corpus. +**GENIUS** can also be used as a **general textual data augmentation tool** for **various NLP tasks** (including sentiment analysis, topic classification, NER, and QA). -![sega-illustration](https://cdn.jsdelivr.net/gh/beyondguo/mdnice_pictures/typora/sega-main-illustration.png) -- Paper: [SEGA: SkEtch-based Generative Augmentation (preprint)](https://github.com/beyondguo/SEGA/blob/master/SEGA_gby_preprint.pdf) +![genius-illustration](https://cdn.jsdelivr.net/gh/beyondguo/mdnice_pictures/typora/what-is-genius.png) + +- Paper: [genius: SkEtch-based Generative Augmentation (preprint)](https://github.com/beyondguo/SEGA/blob/master/SEGA_gby_preprint.pdf) - Models hosted in 🤗 Huggingface: @@ -15,15 +17,15 @@ | Model | #params | Language | comment| |------------------------|--------------------------------|-------|---------| -| [`sega-large`](https://huggingface.co/beyond/sega-large) | 406M | English | The version used in paper | -| [`sega-large-k2t`](https://huggingface.co/beyond/sega-large-k2t) | 406M | English | keywords-to-text | -| [`sega-base`](https://huggingface.co/beyond/sega-base) | 139M | English | smaller version | -| [`sega-base-ps`](https://huggingface.co/beyond/sega-base) | 139M | English | pre-trained both in paragraphs and short sentences | -| [`sega-base-chinese`](https://huggingface.co/beyond/sega-base-chinese) | 116M | 中文 | 在一千万纯净中文段落上预训练| +| [`genius-large`](https://huggingface.co/beyond/genius-large) | 406M | English | The version used in paper | +| [`genius-large-k2t`](https://huggingface.co/beyond/genius-large-k2t) | 406M | English | keywords-to-text | +| [`genius-base`](https://huggingface.co/beyond/genius-base) | 139M | English | smaller version | +| [`genius-base-ps`](https://huggingface.co/beyond/genius-base) | 139M | English | pre-trained both in paragraphs and short sentences | +| [`genius-base-chinese`](https://huggingface.co/beyond/genius-base-chinese) | 116M | 中文 | 在一千万纯净中文段落上预训练| -**SEGA** is able to write complete paragraphs given a sketch (or framework), which can be composed of: +**GENIUS** is able to write complete paragraphs given a sketch (or framework), which can be composed of: - keywords /key-phrases, like "––NLP––AI––computer––science––" - spans, like "Conference on Empirical Methods––submission of research papers––" - sentences, like "I really like machine learning––I work at Google since last year––" @@ -35,11 +37,11 @@ ```python from transformers import pipeline # 1. load the model with the huggingface `pipeline` -sega = pipeline("text2text-generation", model='beyond/sega-large', device=0) +genius = pipeline("text2text-generation", model='beyond/genius-large', device=0) # 2. provide a sketch (joint by tokens) sketch = " Conference on Empirical Methods submission of research papers Deep Learning " # 3. just do it! -generated_text = sega(sketch, num_beams=3, do_sample=True, max_length=200)[0]['generated_text'] +generated_text = genius(sketch, num_beams=3, do_sample=True, max_length=200)[0]['generated_text'] print(generated_text) ``` Output: @@ -48,14 +50,14 @@ Output: ``` #### 2. If you want to do **data augmentation** to generate new training samples -Please check [SEGA/augmentation_tools](https://github.com/beyondguo/SEGA/tree/master/augmentation_tools), where we provide ready-to-run scripts for data augmentation for text classification/NER/MRC tasks. +Please check [genius/augmentation_tools](https://github.com/beyondguo/genius/tree/master/augmentation_tools), where we provide ready-to-run scripts for data augmentation for text classification/NER/MRC tasks. --- -## SEGA as A Strong Data Augmentation Tool: +## GENIUS as A Strong Data Augmentation Tool: - Setting: Low-resource setting, where only n={50,100,200,500,1000} labeled samples are available for training. The below results are the average of all training sizes. - Text Classification Datasets: [HuffPost](https://huggingface.co/datasets/khalidalt/HuffPost), [BBC](https://huggingface.co/datasets/SetFit/bbc-news), [SST2](https://huggingface.co/datasets/glue), [IMDB](https://huggingface.co/datasets/imdb), [Yahoo](https://huggingface.co/datasets/yahoo_answers_topics), [20NG](https://huggingface.co/datasets/newsgroup). - Base classifier: [DistilBERT](https://huggingface.co/distilbert-base-cased) @@ -71,8 +73,8 @@ In-distribution (ID) evaluations: | C-MLM | 80.60 | 96.13 | 45.40 | 46.36 | 77.31 | 76.91 | 70.45 | | LAMBADA | 81.46 | 93.74 | 50.49 | 47.72 | 78.22 | 78.31 | 71.66 | | STA | 80.74 | 95.64 | 46.96 | 47.27 | 77.88 | 77.80 | 71.05 | -| **SEGA** | 81.43 | 95.74 | 49.60 | 50.38 | **80.16** | 78.82 | 72.68 | -| **SEGA-f** | **81.82** | 95.99 | **50.42** | **50.81** | 79.40 | **80.57** | **73.17** | +| **GeniusAug** | 81.43 | 95.74 | 49.60 | 50.38 | **80.16** | 78.82 | 72.68 | +| **GeniusAug-f** | **81.82** | 95.99 | **50.42** | **50.81** | 79.40 | **80.57** | **73.17** | Out-of-distribution (OOD) evaluations: | | Huff->BBC | BBC->Huff | IMDB->SST2 | SST2->IMDB | avg. | @@ -84,8 +86,8 @@ Out-of-distribution (OOD) evaluations: | C-MLM | 64.94 | **67.80** | 74.98 | 71.78 | 69.87 | | LAMBADA | 68.57 | 52.79 | 75.24 | 76.04 | 68.16 | | STA | 69.31 | 64.82 | 74.72 | 73.62 | 70.61 | -| **SEGA** | 74.87 | 66.85 | 76.02 | 74.76 | 73.13 | -| **SEGA-f** | **76.18** | 66.89 | **77.45** | **80.36** | **75.22** | +| **GeniusAug** | 74.87 | 66.85 | 76.02 | 74.76 | 73.13 | +| **GeniusAug-f** | **76.18** | 66.89 | **77.45** | **80.36** | **75.22** | ### BibTeX entry and citation info TBD diff --git a/playground.ipynb b/playground.ipynb index ad14634..3f40c9e 100644 --- a/playground.ipynb +++ b/playground.ipynb @@ -2,59 +2,56 @@ "cells": [ { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[CLS] Ng is an [MASK] and deeplearning.<|endoftext|> <|startofpiece|> assistant professor of computer science at Stanford University, where his research interests include machine learning, computer vision, natural language processing, <|endofpiece|>\n" - ] - } - ], + "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", - "# tokenizer = AutoTokenizer.from_pretrained(\"BAAI/glm-10b\", trust_remote_code=True)\n", - "# model = AutoModelForSeq2SeqLM.from_pretrained(\"BAAI/glm-10b\", trust_remote_code=True)\n", - "# model = model.half().cuda()\n", + "tokenizer = AutoTokenizer.from_pretrained(\"BAAI/glm-10b\", trust_remote_code=True)\n", + "model = AutoModelForSeq2SeqLM.from_pretrained(\"BAAI/glm-10b\", trust_remote_code=True)\n", + "model = model.half().cuda()\n", "\n", - "inputs = tokenizer(\"Ng is an [MASK] and deeplearning.\", return_tensors=\"pt\")\n", + "inputs = tokenizer(\"machine learning [MASK] research interest [MASK]\", return_tensors=\"pt\")\n", "inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=512, mask_id=tokenizer.mask_token_id)\n", "inputs = {key: value.cuda() for key, value in inputs.items()}\n", "inputs[\"generation_attention_mask\"] = inputs[\"generation_attention_mask\"].half()\n", "outputs = model.generate(**inputs, max_length=512, eos_token_id=tokenizer.eop_token_id, num_beams=4)\n", - "print(tokenizer.decode(outputs[0].tolist()))\n" + "print(tokenizer.decode(outputs[0].tolist()))\n", + "# inputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inputs['position_ids'].shape, inputs['generation_attention_mask'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline\n", + "# checkpoint = 'fnlp/bart-base-chinese'\n", + "# tokenizer = BertTokenizer.from_pretrained(checkpoint)\n", + "# sega_model = BartForConditionalGeneration.from_pretrained(checkpoint)\n", + "# sega_generator = Text2TextGenerationPipeline(sega_model, tokenizer, device=0)\n", + "# sega_generator\n", + "# sega_generator = pipeline('text2text-generation', model='facebook/bart-large', device=0)\n", + "sega_generator(' interview The Associated Press Trump announced another White House run, Pence declined former president . But he positioned himself potential alternative Republicans conservative leadership Trump era.',max_length=200,num_beams=3,do_sample=True)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "ValueError", - "evalue": "Loading BAAI/glm-large requires you to execute the configuration file in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set the option `trust_remote_code=True` to remove this error.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mglm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpipeline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'text2text-generation'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'BAAI/glm-large'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrust_remote_code\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/.local/lib/python3.6/site-packages/transformers/pipelines/__init__.py\u001b[0m in \u001b[0;36mpipeline\u001b[0;34m(task, model, config, tokenizer, feature_extractor, framework, revision, use_fast, use_auth_token, model_kwargs, pipeline_class, **kwargs)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAutoConfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrevision\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrevision\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_from_pipeline\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mconfig\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAutoConfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrevision\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrevision\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_from_pipeline\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0mmodel_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.local/lib/python3.6/site-packages/transformers/models/auto/configuration_auto.py\u001b[0m in \u001b[0;36mfrom_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 654\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtrust_remote_code\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 655\u001b[0m raise ValueError(\n\u001b[0;32m--> 656\u001b[0;31m \u001b[0;34mf\"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 657\u001b[0m \u001b[0;34m\"on your local machine. Make sure you have read the code there to avoid malicious use, then set \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 658\u001b[0m \u001b[0;34m\"the option `trust_remote_code=True` to remove this error.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: Loading BAAI/glm-large requires you to execute the configuration file in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set the option `trust_remote_code=True` to remove this error." - ] - } - ], + "outputs": [], "source": [ - "glm = pipeline('text2text-generation',model='BAAI/glm-large', trust_remote_code=True)" + "sega_generator.model.num_parameters()" ] }, { @@ -75,20 +72,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'generated_text': 'I am interested in machine learning, and my research interest is in data science. I am a graduate student at the University of California, Davis.'}]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from transformers import pipeline\n", "# sega = pipeline(\"text2text-generation\",model='saved_models/bart-base-c4-realnewslike-4templates-passage-max15sents_2-sketch4/checkpoint-129375', framework='pt')\n", @@ -98,17 +84,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2022-11-07 09:47:24.890192: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0\n" - ] - } - ], + "outputs": [], "source": [ "# sega-chinese\n", "from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline\n", @@ -123,40 +101,18 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'generated_text': '今 天 的 篮 球 是 上 海 财 经 大 学 篮 球'}]" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "bart_generator(\"今天[MASK]篮球[MASK]上海财经大学[MASK]\", max_length=50, do_sample=False)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'generated_text': '今 天 的 篮 球 比 赛 是 由 上 海 财 经 大 学 举 办 的 。'}]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "sega_generator(\"今天[MASK]篮球[MASK]上海财经大学[MASK]\", max_length=50, do_sample=False)" ] @@ -181,175 +137,6 @@ " print('SEGA-chinese output:\\n>>> ',sega_generator(sketch, max_length=100, do_sample=True, num_beams=3)[0]['generated_text'].replace(' ',''),'\\n')\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from datasets import load_dataset\n", - "dataset_name = 'conll2003'\n", - "# dataset_name = 'wikiann'\n", - "raw_datasets = load_dataset(dataset_name)\n", - "raw_datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import pipeline\n", - "sega = pipeline('text2text-generation', model='beyond/sega-base-ps')" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import defaultdict\n", - "import numpy as np\n", - "\n", - "# 合并多条样本\n", - "def concat_multiple_sequences(dataset, size=3, overlap=True):\n", - " # 传入正经的huggingface dataset格式\n", - " # 如果是子集的话,建议使用select方法来筛选\n", - " new_dataset = defaultdict(list)\n", - " l = len(dataset)\n", - " if overlap: # 连续窗口滑动\n", - " for i in range(l-size):\n", - " concat_tokens = np.concatenate(dataset[i:i+size]['tokens'])\n", - " concat_tags = np.concatenate(dataset[i:i+size]['ner_tags'])\n", - " new_dataset['tokens'].append(concat_tokens)\n", - " new_dataset['ner_tags'].append(concat_tags)\n", - " else: # 互相不重叠\n", - " for i in range(l//size):\n", - " concat_tokens = np.concatenate(dataset[i*size:(i+1)*size]['tokens'])\n", - " concat_tags = np.concatenate(dataset[i*size:(i+1)*size]['ner_tags'])\n", - " new_dataset['tokens'].append(concat_tokens)\n", - " new_dataset['ner_tags'].append(concat_tags)\n", - " return new_dataset\n", - "\n", - "\n", - "tag_names = raw_datasets['train'].features['ner_tags'].feature.names\n", - "\n", - "def get_mention_name(tag):\n", - " # tag: the number/index of the tag name\n", - " # tag_names: the list of tag names\n", - " # mention: ORG, LOC, etc.\n", - " return tag_names[tag].split('-')[-1]\n", - "\n", - "# 单独把实体抽出来\n", - "def extract_mentions(tokens, tags):\n", - " \"\"\"\n", - " return: \n", - " mentions: []\n", - " mention_dict: {'MISC': [], 'PER': [], 'LOC': [], 'ORG': []}\n", - " \"\"\"\n", - " mentions = []\n", - " mention_dict = {t:[] for t in list(set([t.split('-')[-1] for t in tag_names])) if t != 'O'}\n", - " for i in range(len(tokens)):\n", - " mention = get_mention_name(tags[i])\n", - " if mention == 'O':\n", - " continue\n", - " if tags[i] % 2 == 1:\n", - " # the start\n", - " mention_dict[mention].append([tokens[i]])\n", - " mentions.append([tokens[i]])\n", - " else:\n", - " # the remaining part\n", - " mention_dict[mention][-1].append(tokens[i])\n", - " mentions[-1].append(tokens[i])\n", - " for k in mention_dict:\n", - " if mention_dict[k]: # not empty\n", - " mention_dict[k] = [' '.join(items) for items in mention_dict[k]]\n", - " mentions = [' '.join(items) for items in mentions]\n", - " return mentions,mention_dict\n", - " \n", - "\n", - "def get_spans(tokens,window=3):\n", - " spans = []\n", - " for i in range(len(tokens) // window):\n", - " spans.append(' '.join(tokens[i*window:(i+1)*window]))\n", - " return spans\n", - "\n", - "def extract_mention_spans(tokens, tags):\n", - " \"\"\"\n", - " 把一个句子中\"\"\"\n", - " text = ' '.join(tokens)\n", - " kws = keynotes_yake(text, 2, 5)\n", - " spans = get_spans(tokens, window=3)\n", - " mentions, _ = extract_mentions(tokens, tags)\n", - " wanted_spans = []\n", - " for span in spans[1:-1]:\n", - " for w in mentions+[]:\n", - " if w in span:\n", - " wanted_spans.append(span)\n", - " break\n", - " m = ' ' + ' '.join(wanted_spans) + ' '\n", - " m = f\"{spans[0]} {m} {spans[-1]}\" # 相当于限制了边界\n", - " return wanted_spans, m\n", - "\n", - "from torch.utils.data import Dataset\n", - "class MDataset(Dataset):\n", - " def __init__(self, m_list):\n", - " self.masked_contents = m_list\n", - " def __len__(self):\n", - " return len(self.masked_contents)\n", - " def __getitem__(self, i):\n", - " return self.masked_contents[i]" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n", - " num_rows: 100\n", - "})" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "orig_data = raw_datasets['train'].select(range(100))\n", - "orig_data" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array(['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British',\n", - " 'lamb', '.', 'Peter', 'Blackburn', 'BRUSSELS', '1996-08-22'],\n", - " dtype='