From 9b26e00afd7a2cf650863fb3654d1418382927d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Lagunas?= Date: Thu, 24 Dec 2020 16:54:52 +0100 Subject: [PATCH] Refactor to extend easily to other networks and tasks. --- .isort.cfg | 4 + Makefile | 3 + .../question_answering/qa_sparse_train.py | 83 +--- .../question_answering/qa_sparse_xp.py | 115 +---- .../examples/question_answering/qa_xp.py | 8 +- .../squad_v2_local/evaluate.py | 331 +++++++++++++ .../squad_v2_local/squad_v2_local.py | 137 ++++++ .../examples/text-classification/README.md | 289 ++++++++++++ .../text-classification/requirements.txt | 3 + .../examples/text-classification/run_glue.py | 443 ++++++++++++++++++ .../examples/{question_answering => }/xp.py | 2 + nn_pruning/hp_naming.py | 38 -- .../patch_coordinator.py} | 222 ++++++--- nn_pruning/modules/sparse_trainer.py | 95 ++++ nn_pruning/tests/test_convert.py | 19 + pyproject.toml | 3 + 16 files changed, 1507 insertions(+), 288 deletions(-) create mode 100644 .isort.cfg create mode 100644 Makefile create mode 100644 nn_pruning/examples/question_answering/squad_v2_local/evaluate.py create mode 100644 nn_pruning/examples/question_answering/squad_v2_local/squad_v2_local.py create mode 100644 nn_pruning/examples/text-classification/README.md create mode 100644 nn_pruning/examples/text-classification/requirements.txt create mode 100644 nn_pruning/examples/text-classification/run_glue.py rename nn_pruning/examples/{question_answering => }/xp.py (99%) rename nn_pruning/{examples/question_answering/qa_sparse_patch.py => modules/patch_coordinator.py} (73%) create mode 100644 nn_pruning/modules/sparse_trainer.py create mode 100644 nn_pruning/tests/test_convert.py create mode 100644 pyproject.toml diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 00000000..6e7207ec --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,4 @@ +[settings] +multi_line_output=3 +include_trailing_comma=True + diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..7ad30576 --- /dev/null +++ b/Makefile @@ -0,0 +1,3 @@ +style: + black . + isort . diff --git a/nn_pruning/examples/question_answering/qa_sparse_train.py b/nn_pruning/examples/question_answering/qa_sparse_train.py index c88a0c8b..eb9a3490 100644 --- a/nn_pruning/examples/question_answering/qa_sparse_train.py +++ b/nn_pruning/examples/question_answering/qa_sparse_train.py @@ -17,83 +17,12 @@ """ # You can also adapt this script on your own question answering task. Pointers for this are left as comments. -from typing import Dict -from transformers.optimization import AdamW, get_linear_schedule_with_warmup -from .qa_sparse_patch import QASparseModelPatchingCoordinator +from nn_pruning.modules.sparse_trainer import SparseTrainer from .qa_train import QATrainer - -class QASparseTrainer(QATrainer): +# SparseTrainer should appear first in the base classes, as its functions must override QATrainer and its base classes (Trainer) +class QASparseTrainer(SparseTrainer, QATrainer): def __init__(self, sparse_args, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.sparse_args = sparse_args - - def set_patch_coordinator(self, patch_coordinator: QASparseModelPatchingCoordinator): - self.patch_coordinator = patch_coordinator - - def log(self, logs: Dict[str, float]) -> None: - add = {self.log_prefix + k: v for k, v in self.patch_coordinator.log().items()} - - logs.update(add) - - return super().log(logs) - - def schedule_threshold(self, training: bool): - step = self.state.global_step - self.patch_coordinator.schedule_threshold(step, self.state.max_steps, self.args.warmup_steps, training) - - def training_step(self, *args, **kwargs): - self.schedule_threshold(True) - self.log_prefix = "" - return super().training_step(*args, **kwargs) - - def compute_loss(self, model, inputs): - """ - How the loss is computed by Trainer. By default, all models return the loss in the first element. - - Subclass and override for custom behavior. - """ - outputs = model(**inputs) - - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index] - - # We don't use .loss here since the model may return tuples instead of ModelOutput. - loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] - - loss = self.patch_coordinator.distil_loss_combine(loss, inputs, outputs) - loss += self.patch_coordinator.regularization_loss(model) - - return loss - - def evaluate(self, *args, **kwargs): - self.schedule_threshold(False) - self.log_prefix = "eval_" - ret = super().evaluate(*args, **kwargs) - return ret - - def create_optimizer_and_scheduler(self, num_training_steps: int): - assert self.optimizer is None - self.optimizer = self.create_optimizer(self.model) - - assert self.lr_scheduler is None - self.lr_scheduler = self.create_scheduler(num_training_steps) - - def create_optimizer(self, model): - args = self.args - - optimizer_grouped_parameters = self.patch_coordinator.create_optimizer_groups(model, self.args, self.sparse_args) - - optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - return optimizer - - def create_scheduler(self, num_training_steps: int): - scheduler = get_linear_schedule_with_warmup( - self.optimizer, - num_warmup_steps=self.args.warmup_steps, - num_training_steps=num_training_steps, - ) - return scheduler + print(QASparseTrainer.__mro__) + QATrainer.__init__(self, *args, **kwargs) + SparseTrainer.__init__(self, sparse_args) diff --git a/nn_pruning/examples/question_answering/qa_sparse_xp.py b/nn_pruning/examples/question_answering/qa_sparse_xp.py index 575359fc..711fad4f 100644 --- a/nn_pruning/examples/question_answering/qa_sparse_xp.py +++ b/nn_pruning/examples/question_answering/qa_sparse_xp.py @@ -20,12 +20,11 @@ import json import shutil -from dataclasses import dataclass, field from pathlib import Path from types import SimpleNamespace from nn_pruning.hp_naming import TrialShortNamer -from .qa_sparse_patch import QASparseModelPatchingCoordinator +from nn_pruning.modules.patch_coordinator import SparseTrainingArguments, ModelPatchingCoordinator from .qa_sparse_train import QASparseTrainer from .qa_xp import ( QAXP, @@ -35,114 +34,6 @@ ) -@dataclass -class SparseTrainingArguments: - """ - Sparse training specific arguments - """ - - mask_scores_learning_rate: float = field( - default=1e-2, metadata={"help": "The initial learning rate for mask_scores."} - ) - - dense_pruning_method: str = field(default="topk", metadata={"help": "Dense Layers pruning method."}) - - attention_pruning_method: str = field(default="topk", metadata={"help": "Dense Layers pruning method."}) - - ampere_pruning_method: str = field( - default="disabled", - metadata={"help": "Ampere sparse method ('disabled' for no ampere sparsity)"}, - ) - - mask_init: str = field(default="constant", metadata={"help": "Mask scores initialization method"}) - - mask_scale: float = field( - default=0.0, - metadata={"help": "Parameter to use with mask_init."}, - ) - - dense_block_rows: int = field( - default=1, - metadata={"help": "Block size in rows for dense layers."}, - ) - - dense_block_cols: int = field( - default=1, - metadata={"help": "Block size in cols for dense layers."}, - ) - - attention_block_rows: int = field( - default=1, - metadata={"help": "Block size in rows for attention."}, - ) - - attention_block_cols: int = field( - default=1, - metadata={"help": "Block size in cols for attention."}, - ) - - initial_threshold: float = field( - default=1.0, - metadata={"help": "Initial value of the threshold (for scheduling)."}, - ) - final_threshold: float = field( - default=0.5, - metadata={"help": "Final value of the threshold (for scheduling)."}, - ) - - initial_warmup: float = field( - default=1, - metadata={ - "help": "Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays at its `initial_threshold` value (sparsity schedule)." - }, - ) - final_warmup: float = field( - default=2, - metadata={ - "help": "Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays" - }, - ) - - initial_ampere_temperature: float = field( - default=0.0, - metadata={"help": "Initial value of the ampere temperature (for scheduling)."}, - ) - final_ampere_temperature: float = field( - default=20.0, - metadata={"help": "Final value of the ampere temperature (for scheduling)."}, - ) - - regularization: str = field( - default="disabled", - metadata={"help": "Add L0 or L1 regularization to the mask scores."}, - ) - - regularization_final_lambda: float = field( - default=0.0, - metadata={"help": "Regularization intensity (used in conjunction with `regularization`)."}, - ) - - distil_teacher_name_or_path: str = field( - default=None, - metadata={"help": "Path to the already SQuAD fine-tuned teacher model. Only for distillation."}, - ) - - distil_alpha_ce: float = field( - default=0.5, - metadata={"help": "Cross entropy loss linear weight. Only for distillation."}, - ) - - distil_alpha_teacher: float = field( - default=0.5, - metadata={"help": "Distillation loss linear weight. Only for distillation."}, - ) - - distil_temperature: float = field( - default=2.0, - metadata={"help": "Distillation temperature. Only for distillation."}, - ) - - class SparseQAShortNamer(TrialShortNamer): DEFAULTS = { "adam_beta1": 0.9, @@ -232,7 +123,7 @@ class QASparseXP(QAXP): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.patch_coordinator = QASparseModelPatchingCoordinator(self.sparse_args, self.training_args.device, self.model_args.cache_dir) + self.patch_coordinator = ModelPatchingCoordinator(self.sparse_args, self.training_args.device, self.model_args.cache_dir) def create_trainer(self, *args, **kwargs): super().create_trainer(*args, **kwargs) @@ -264,7 +155,7 @@ def load_json_to_obj(name): model_args.model_name_or_path = str(src_path) model = cls._model_init(model_args, model_config, trial=None) - patcher = QASparseModelPatcher(sparse_args) + patcher = ModelPatchingCoordinator(sparse_args) patcher.patch_model(model, trial=None) import torch diff --git a/nn_pruning/examples/question_answering/qa_xp.py b/nn_pruning/examples/question_answering/qa_xp.py index f4a19be1..329dcc82 100644 --- a/nn_pruning/examples/question_answering/qa_xp.py +++ b/nn_pruning/examples/question_answering/qa_xp.py @@ -22,7 +22,7 @@ from dataclasses import dataclass, field from pathlib import Path -from datasets import load_dataset, load_metric +from datasets import load_metric from transformers import ( AutoModelForQuestionAnswering, DataCollatorWithPadding, @@ -34,7 +34,7 @@ from .qa_train import QATrainer from .qa_utils import postprocess_qa_predictions -from .xp import XP, DataTrainingArguments, ModelArguments, TrainingArguments +from nn_pruning.examples.xp import XP, DataTrainingArguments, ModelArguments, TrainingArguments logger = logging.getLogger(__name__) @@ -80,7 +80,7 @@ class QAXP(XP): SHORT_NAMER = TrialShortNamer @classmethod - def _model_init(self, model_args, model_config, trial=None): + def _model_init(self, model_args, model_config): model = AutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -90,7 +90,7 @@ def _model_init(self, model_args, model_config, trial=None): return model def model_init(self, trial=None): - return self._model_init(self.model_args, self.config, trial) + return self._model_init(self.model_args, self.config) def prepare_column_names(self): training_args = self.training_args diff --git a/nn_pruning/examples/question_answering/squad_v2_local/evaluate.py b/nn_pruning/examples/question_answering/squad_v2_local/evaluate.py new file mode 100644 index 00000000..41b11811 --- /dev/null +++ b/nn_pruning/examples/question_answering/squad_v2_local/evaluate.py @@ -0,0 +1,331 @@ +"""Official evaluation script for SQuAD version 2.0. + +In addition to basic functionality, we also compute additional statistics and +plot precision-recall curves if an additional na_prob.json file is provided. +This file is expected to map question ID's to the model's predicted probability +that a question is unanswerable. +""" +import argparse +import collections +import json +import os +import re +import string +import sys + +import numpy as np + +OPTS = None + + +def parse_args(): + parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.") + parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.") + parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.") + parser.add_argument( + "--out-file", + "-o", + metavar="eval.json", + help="Write accuracy metrics to file (default is stdout).", + ) + parser.add_argument( + "--na-prob-file", + "-n", + metavar="na_prob.json", + help="Model estimates of probability of no answer.", + ) + parser.add_argument( + "--na-prob-thresh", + "-t", + type=float, + default=1.0, + help='Predict "" if no-answer probability exceeds this (default = 1.0).', + ) + parser.add_argument( + "--out-image-dir", + "-p", + metavar="out_images", + default=None, + help="Save precision-recall curves to directory.", + ) + parser.add_argument("--verbose", "-v", action="store_true") + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def make_qid_to_has_ans(dataset): + qid_to_has_ans = {} + for article in dataset: + for p in article["paragraphs"]: + for qa in p["qas"]: + qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"]) + return qid_to_has_ans + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + return re.sub(regex, " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def get_tokens(s): + if not s: + return [] + return normalize_answer(s).split() + + +def compute_exact(a_gold, a_pred): + return int(normalize_answer(a_gold) == normalize_answer(a_pred)) + + +def compute_f1(a_gold, a_pred): + gold_toks = get_tokens(a_gold) + pred_toks = get_tokens(a_pred) + common = collections.Counter(gold_toks) & collections.Counter(pred_toks) + num_same = sum(common.values()) + if len(gold_toks) == 0 or len(pred_toks) == 0: + # If either is no-answer, then F1 is 1 if they agree, 0 otherwise + return int(gold_toks == pred_toks) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(pred_toks) + recall = 1.0 * num_same / len(gold_toks) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def get_raw_scores(dataset, preds): + exact_scores = {} + f1_scores = {} + for article in dataset: + for p in article["paragraphs"]: + for qa in p["qas"]: + qid = qa["id"] + gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)] + if not gold_answers: + # For unanswerable questions, only correct answer is empty string + gold_answers = [""] + if qid not in preds: + print("Missing prediction for %s" % qid) + continue + a_pred = preds[qid] + # Take max over all gold answers + exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) + f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) + return exact_scores, f1_scores + + +def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): + new_scores = {} + for qid, s in scores.items(): + pred_na = na_probs[qid] > na_prob_thresh + if pred_na: + new_scores[qid] = float(not qid_to_has_ans[qid]) + else: + new_scores[qid] = s + return new_scores + + +def make_eval_dict(exact_scores, f1_scores, qid_list=None): + if not qid_list: + total = len(exact_scores) + return collections.OrderedDict( + [ + ("exact", 100.0 * sum(exact_scores.values()) / total), + ("f1", 100.0 * sum(f1_scores.values()) / total), + ("total", total), + ] + ) + else: + total = len(qid_list) + return collections.OrderedDict( + [ + ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total), + ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total), + ("total", total), + ] + ) + + +def merge_eval(main_eval, new_eval, prefix): + for k in new_eval: + main_eval["%s_%s" % (prefix, k)] = new_eval[k] + + +def plot_pr_curve(precisions, recalls, out_image, title): + plt.step(recalls, precisions, color="b", alpha=0.2, where="post") + plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b") + plt.xlabel("Recall") + plt.ylabel("Precision") + plt.xlim([0.0, 1.05]) + plt.ylim([0.0, 1.05]) + plt.title(title) + plt.savefig(out_image) + plt.clf() + + +def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None): + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + true_pos = 0.0 + cur_p = 1.0 + cur_r = 0.0 + precisions = [1.0] + recalls = [0.0] + avg_prec = 0.0 + for i, qid in enumerate(qid_list): + if qid_to_has_ans[qid]: + true_pos += scores[qid] + cur_p = true_pos / float(i + 1) + cur_r = true_pos / float(num_true_pos) + if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]: + # i.e., if we can put a threshold after this point + avg_prec += cur_p * (cur_r - recalls[-1]) + precisions.append(cur_p) + recalls.append(cur_r) + if out_image: + plot_pr_curve(precisions, recalls, out_image, title) + return {"ap": 100.0 * avg_prec} + + +def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir): + if out_image_dir and not os.path.exists(out_image_dir): + os.makedirs(out_image_dir) + num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) + if num_true_pos == 0: + return + pr_exact = make_precision_recall_eval( + exact_raw, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_exact.png"), + title="Precision-Recall curve for Exact Match score", + ) + pr_f1 = make_precision_recall_eval( + f1_raw, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_f1.png"), + title="Precision-Recall curve for F1 score", + ) + oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} + pr_oracle = make_precision_recall_eval( + oracle_scores, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_oracle.png"), + title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)", + ) + merge_eval(main_eval, pr_exact, "pr_exact") + merge_eval(main_eval, pr_f1, "pr_f1") + merge_eval(main_eval, pr_oracle, "pr_oracle") + + +def histogram_na_prob(na_probs, qid_list, image_dir, name): + if not qid_list: + return + x = [na_probs[k] for k in qid_list] + weights = np.ones_like(x) / float(len(x)) + plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) + plt.xlabel("Model probability of no-answer") + plt.ylabel("Proportion of dataset") + plt.title("Histogram of no-answer probability: %s" % name) + plt.savefig(os.path.join(image_dir, "na_prob_hist_%s.png" % name)) + plt.clf() + + +def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for i, qid in enumerate(qid_list): + if qid not in scores: + continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + return 100.0 * best_score / len(scores), best_thresh + + +def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) + main_eval["best_exact"] = best_exact + main_eval["best_exact_thresh"] = exact_thresh + main_eval["best_f1"] = best_f1 + main_eval["best_f1_thresh"] = f1_thresh + + +def main(): + with open(OPTS.data_file) as f: + dataset_json = json.load(f) + dataset = dataset_json["data"] + with open(OPTS.pred_file) as f: + preds = json.load(f) + if OPTS.na_prob_file: + with open(OPTS.na_prob_file) as f: + na_probs = json.load(f) + else: + na_probs = {k: 0.0 for k in preds} + qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False + has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] + no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] + exact_raw, f1_raw = get_raw_scores(dataset, preds) + exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) + f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) + out_eval = make_eval_dict(exact_thresh, f1_thresh) + if has_ans_qids: + has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) + merge_eval(out_eval, has_ans_eval, "HasAns") + if no_ans_qids: + no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) + merge_eval(out_eval, no_ans_eval, "NoAns") + if OPTS.na_prob_file: + find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) + if OPTS.na_prob_file and OPTS.out_image_dir: + run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir) + histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns") + histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns") + if OPTS.out_file: + with open(OPTS.out_file, "w") as f: + json.dump(out_eval, f) + else: + print(json.dumps(out_eval, indent=2)) + + +if __name__ == "__main__": + OPTS = parse_args() + if OPTS.out_image_dir: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + main() diff --git a/nn_pruning/examples/question_answering/squad_v2_local/squad_v2_local.py b/nn_pruning/examples/question_answering/squad_v2_local/squad_v2_local.py new file mode 100644 index 00000000..140fb5dc --- /dev/null +++ b/nn_pruning/examples/question_answering/squad_v2_local/squad_v2_local.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SQuAD v2 metric. """ + +import datasets + +from .evaluate import ( + apply_no_ans_threshold, + find_all_best_thresh, + get_raw_scores, + make_eval_dict, + make_qid_to_has_ans, + merge_eval, +) + +_CITATION = """\ +@inproceedings{Rajpurkar2016SQuAD10, + title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, + author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, + booktitle={EMNLP}, + year={2016} +} +""" + +_DESCRIPTION = """ +This metric wrap the official scoring script for version 2 of the Stanford Question +Answering Dataset (SQuAD). + +Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by +crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, +from the corresponding reading passage, or the question might be unanswerable. + +SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions +written adversarially by crowdworkers to look similar to answerable ones. +To do well on SQuAD2.0, systems must not only answer questions when possible, but also +determine when no answer is supported by the paragraph and abstain from answering. +""" + +_KWARGS_DESCRIPTION = """ +Computes SQuAD v2 scores (F1 and EM). +Args: + predictions: List of triple for question-answers to score with the following elements: + - the question-answer 'id' field as given in the references (see below) + - the text of the answer + - the probability that the question has no answer + references: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair (see above), + - 'answers': a list of Dict {'text': text of the answer as a string} + no_answer_threshold: float + Probability threshold to decide that a question has no answer. +Returns: + 'exact': Exact match (the normalized answer exactly match the gold answer) + 'f1': The F-score of predicted tokens versus the gold answer + 'total': Number of score considered + 'HasAns_exact': Exact match (the normalized answer exactly match the gold answer) + 'HasAns_f1': The F-score of predicted tokens versus the gold answer + 'HasAns_total': Number of score considered + 'NoAns_exact': Exact match (the normalized answer exactly match the gold answer) + 'NoAns_f1': The F-score of predicted tokens versus the gold answer + 'NoAns_total': Number of score considered + 'best_exact': Best exact match (with varying threshold) + 'best_exact_thresh': No-answer probability threshold associated to the best exact match + 'best_f1': Best F1 (with varying threshold) + 'best_f1_thresh': No-answer probability threshold associated to the best F1 +""" + + +class SquadV2(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": { + "id": datasets.Value("string"), + "prediction_text": datasets.Value("string"), + "no_answer_probability": datasets.Value("float32"), + }, + "references": { + "id": datasets.Value("string"), + "answers": datasets.features.Sequence( + { + "text": datasets.Value("string"), + "answer_start": datasets.Value("int32"), + } + ), + }, + } + ), + codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + ) + + def _compute(self, predictions, references, no_answer_threshold=1.0): + no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions) + dataset = [{"paragraphs": [{"qas": references}]}] + predictions = dict((p["id"], p["prediction_text"]) for p in predictions) + + qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False + has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] + no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] + + exact_raw, f1_raw = get_raw_scores(dataset, predictions) + exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) + f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) + out_eval = make_eval_dict(exact_thresh, f1_thresh) + + if has_ans_qids: + has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) + merge_eval(out_eval, has_ans_eval, "HasAns") + if no_ans_qids: + no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) + merge_eval(out_eval, no_ans_eval, "NoAns") + find_all_best_thresh( + out_eval, + predictions, + exact_raw, + f1_raw, + no_answer_probabilities, + qid_to_has_ans, + ) + + return out_eval diff --git a/nn_pruning/examples/text-classification/README.md b/nn_pruning/examples/text-classification/README.md new file mode 100644 index 00000000..99961399 --- /dev/null +++ b/nn_pruning/examples/text-classification/README.md @@ -0,0 +1,289 @@ + + +## GLUE Benchmark + +# Run TensorFlow 2.0 version + +Based on the script [`run_tf_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_tf_glue.py). + +Fine-tuning the library TensorFlow 2.0 Bert model for sequence classification on the MRPC task of the GLUE benchmark: [General Language Understanding Evaluation](https://gluebenchmark.com/). + +This script has an option for mixed precision (Automatic Mixed Precision / AMP) to run models on Tensor Cores (NVIDIA Volta/Turing GPUs) and future hardware and an option for XLA, which uses the XLA compiler to reduce model runtime. +Options are toggled using `USE_XLA` or `USE_AMP` variables in the script. +These options and the below benchmark are provided by @tlkh. + +Quick benchmarks from the script (no other modifications): + +| GPU | Mode | Time (2nd epoch) | Val Acc (3 runs) | +| --------- | -------- | ----------------------- | ----------------------| +| Titan V | FP32 | 41s | 0.8438/0.8281/0.8333 | +| Titan V | AMP | 26s | 0.8281/0.8568/0.8411 | +| V100 | FP32 | 35s | 0.8646/0.8359/0.8464 | +| V100 | AMP | 22s | 0.8646/0.8385/0.8411 | +| 1080 Ti | FP32 | 55s | - | + +Mixed precision (AMP) reduces the training time considerably for the same hardware and hyper-parameters (same batch size was used). + + +## Run generic text classification script in TensorFlow + +The script [run_tf_text_classification.py](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_tf_text_classification.py) allows users to run a text classification on their own CSV files. For now there are few restrictions, the CSV files must have a header corresponding to the column names and not more than three columns: one column for the id, one column for the text and another column for a second piece of text in case of an entailment classification for example. + +To use the script, one as to run the following command line: +```bash +python run_tf_text_classification.py \ + --train_file train.csv \ ### training dataset file location (mandatory if running with --do_train option) + --dev_file dev.csv \ ### development dataset file location (mandatory if running with --do_eval option) + --test_file test.csv \ ### test dataset file location (mandatory if running with --do_predict option) + --label_column_id 0 \ ### which column corresponds to the labels + --model_name_or_path bert-base-multilingual-uncased \ + --output_dir model \ + --num_train_epochs 4 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 32 \ + --do_train \ + --do_eval \ + --do_predict \ + --logging_steps 10 \ + --evaluation_strategy steps \ + --save_steps 10 \ + --overwrite_output_dir \ + --max_seq_length 128 +``` + +# Run PyTorch version + +Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py). + +Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding +Evaluation](https://gluebenchmark.com/). This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa. + +GLUE is made up of a total of 9 different tasks. We get the following results on the dev set of the benchmark with an +uncased BERT base model (the checkpoint `bert-base-uncased`). All experiments ran single V100 GPUs with a total train +batch sizes between 16 and 64. Some of these tasks have a small dataset and training can lead to high variance in the results +between different runs. We report the median on 5 runs (with different seeds) for each of the metrics. + +| Task | Metric | Result | +|-------|------------------------------|-------------| +| CoLA | Matthew's corr | 49.23 | +| SST-2 | Accuracy | 91.97 | +| MRPC | F1/Accuracy | 89.47/85.29 | +| STS-B | Person/Spearman corr. | 83.95/83.70 | +| QQP | Accuracy/F1 | 88.40/84.31 | +| MNLI | Matched acc./Mismatched acc. | 80.61/81.08 | +| QNLI | Accuracy | 87.46 | +| RTE | Accuracy | 61.73 | +| WNLI | Accuracy | 45.07 | + +Some of these results are significantly different from the ones reported on the test set +of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the +website. + +```bash +export TASK_NAME=MRPC + +python run_glue.py \ + --model_name_or_path bert-base-cased \ + --task_name $TASK_NAME \ + --do_train \ + --do_eval \ + --max_seq_length 128 \ + --per_device_train_batch_size 32 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir /tmp/$TASK_NAME/ +``` + +where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI. + +The dev set results will be present within the text file `eval_results.txt` in the specified output_dir. +In case of MNLI, since there are two separate dev sets (matched and mismatched), there will be a separate +output folder called `/tmp/MNLI-MM/` in addition to `/tmp/MNLI/`. + +The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI, +CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being +said, there shouldn’t be any issues in running half-precision training with the remaining GLUE tasks as well, +since the data processor for each task inherits from the base class DataProcessor. + +## Running on TPUs in PyTorch + +Even when running PyTorch, you can accelerate your workloads on Google's TPUs, using `pytorch/xla`. For information on +how to setup your TPU environment refer to the +[pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md). + +For running your GLUE task on MNLI dataset you can run something like the following form the root of the transformers +repo: + +``` +python examples/xla_spawn.py \ + --num_cores=8 \ + transformers/examples/text-classification/run_glue.py \ + --do_train \ + --do_eval \ + --task_name=mrpc \ + --num_train_epochs=3 \ + --max_seq_length=128 \ + --learning_rate=5e-5 \ + --output_dir=/tmp/mrpc \ + --overwrite_output_dir \ + --logging_steps=5 \ + --save_steps=5 \ + --tpu_metrics_debug \ + --model_name_or_path=bert-base-cased \ + --per_device_train_batch_size=64 \ + --per_device_eval_batch_size=64 +``` + + +#### Using Apex and mixed-precision + +Using Apex and 16 bit precision, the fine-tuning on MRPC only takes 27 seconds. First install +[apex](https://github.com/NVIDIA/apex), then run the following example: + +```bash + +python run_glue.py \ + --model_name_or_path bert-base-cased \ + --task_name MRPC \ + --do_train \ + --do_eval \ + --max_seq_length 128 \ + --per_device_train_batch_size 32 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir /tmp/mrpc_output/ \ + --fp16 +``` + +#### Distributed training + +Here is an example using distributed training on 8 V100 GPUs. The model used is the BERT whole-word-masking and it +reaches F1 > 92 on MRPC. + +```bash + +python -m torch.distributed.launch \ + --nproc_per_node 8 run_glue.py \ + --model_name_or_path bert-base-cased \ + --task_name mrpc \ + --do_train \ + --do_eval \ + --max_seq_length 128 \ + --per_device_train_batch_size 8 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir /tmp/mrpc_output/ +``` + +Training with these hyper-parameters gave us the following results: + +```bash +acc = 0.8823529411764706 +acc_and_f1 = 0.901702786377709 +eval_loss = 0.3418912578906332 +f1 = 0.9210526315789473 +global_step = 174 +loss = 0.07231863956341798 +``` + +### MNLI + +The following example uses the BERT-large, uncased, whole-word-masking model and fine-tunes it on the MNLI task. + +```bash +export GLUE_DIR=/path/to/glue + +python -m torch.distributed.launch \ + --nproc_per_node 8 run_glue.py \ + --model_name_or_path bert-base-cased \ + --task_name mnli \ + --do_train \ + --do_eval \ + --max_seq_length 128 \ + --per_device_train_batch_size 8 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir output_dir \ +``` + +The results are the following: + +```bash +***** Eval results ***** + acc = 0.8679706601466992 + eval_loss = 0.4911287787382479 + global_step = 18408 + loss = 0.04755385363816904 + +***** Eval results ***** + acc = 0.8747965825874695 + eval_loss = 0.45516540421714036 + global_step = 18408 + loss = 0.04755385363816904 +``` + +# Run PyTorch version using PyTorch-Lightning + +Run `bash run_pl.sh` from the `glue` directory. This will also install `pytorch-lightning` and the requirements in +`examples/requirements.txt`. It is a shell pipeline that will automatically download, preprocess the data and run the +specified models. Logs are saved in `lightning_logs` directory. + +Pass `--gpus` flag to change the number of GPUs. Default uses 1. At the end, the expected results are: + +``` +TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recall': 0.869537067011978, 'f1': 0.8608974358974358} +``` + + +# XNLI + +Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_xnli.py). + +[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is a crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili). + +#### Fine-tuning on XNLI + +This example code fine-tunes mBERT (multi-lingual BERT) on the XNLI dataset. It runs in 106 mins +on a single tesla V100 16GB. The data for XNLI can be downloaded with the following links and should be both saved (and un-zipped) in a +`$XNLI_DIR` directory. + +* [XNLI 1.0](https://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip) +* [XNLI-MT 1.0](https://dl.fbaipublicfiles.com/XNLI/XNLI-MT-1.0.zip) + +```bash +export XNLI_DIR=/path/to/XNLI + +python run_xnli.py \ + --model_name_or_path bert-base-multilingual-cased \ + --language de \ + --train_language en \ + --do_train \ + --do_eval \ + --data_dir $XNLI_DIR \ + --per_device_train_batch_size 32 \ + --learning_rate 5e-5 \ + --num_train_epochs 2.0 \ + --max_seq_length 128 \ + --output_dir /tmp/debug_xnli/ \ + --save_steps -1 +``` + +Training with the previously defined hyper-parameters yields the following results on the **test** set: + +```bash +acc = 0.7093812375249501 +``` diff --git a/nn_pruning/examples/text-classification/requirements.txt b/nn_pruning/examples/text-classification/requirements.txt new file mode 100644 index 00000000..0f5c38bd --- /dev/null +++ b/nn_pruning/examples/text-classification/requirements.txt @@ -0,0 +1,3 @@ +datasets >= 1.1.3 +sentencepiece != 0.1.92 +protobuf diff --git a/nn_pruning/examples/text-classification/run_glue.py b/nn_pruning/examples/text-classification/run_glue.py new file mode 100644 index 00000000..56858c4c --- /dev/null +++ b/nn_pruning/examples/text-classification/run_glue.py @@ -0,0 +1,443 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for sequence classification on GLUE.""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import transformers +from datasets import load_dataset, load_metric +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.trainer_utils import is_main_process + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + task_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached preprocessed datasets or not."}, + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + train_file: Optional[str] = field( + default=None, + metadata={"help": "A csv or a json file containing the training data."}, + ) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "A csv or a json file containing the validation data."}, + ) + + def __post_init__(self): + if self.task_name is not None: + self.task_name = self.task_name.lower() + if self.task_name not in task_to_keys.keys(): + raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) + elif self.train_file is None or self.validation_file is None: + raise ValueError("Need either a GLUE task or a training/validation file.") + else: + extension = self.train_file.split(".")[-1] + assert extension in [ + "csv", + "json", + ], "`train_file` should be a csv or a json file." + extension = self.validation_file.split(".")[-1] + assert extension in [ + "csv", + "json", + ], "`validation_file` should be a csv or a json file." + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"}, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, + ) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + # + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.task_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset("glue", data_args.task_name) + elif data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + datasets = load_dataset( + "csv", + data_files={ + "train": data_args.train_file, + "validation": data_args.validation_file, + }, + ) + else: + # Loading a dataset from local json files + datasets = load_dataset( + "json", + data_files={ + "train": data_args.train_file, + "validation": data_args.validation_file, + }, + ) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" + if not is_regression: + label_list = datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = datasets["train"].features["label"].dtype in [ + "float32", + "float64", + ] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + ) + + # Preprocessing the datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [name for name in datasets["train"].column_names if name != "label"] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Padding strategy + if data_args.pad_to_max_length: + padding = "max_length" + max_length = data_args.max_seq_length + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + padding = False + max_length = None + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and data_args.task_name is not None + and is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} + else: + logger.warn( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + elif data_args.task_name is None: + label_to_id = {v: i for i, v in enumerate(label_list)} + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_length, truncation=True) + + # Map labels to IDs (not necessary for GLUE tasks) + if label_to_id is not None and "label" in examples: + result["label"] = [label_to_id[l] for l in examples["label"]] + return result + + datasets = datasets.map( + preprocess_function, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + ) + + train_dataset = datasets["train"] + eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] + if data_args.task_name is not None: + test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Get the metric function + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) + # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from + # compute_metrics + + # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + if data_args.task_name is not None: + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=tokenizer, + # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. + data_collator=default_data_collator if data_args.pad_to_max_length else None, + ) + + # Training + if training_args.do_train: + trainer.train( + model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + ) + trainer.save_model() # Saves the tokenizer too for easy upload + + # Evaluation + eval_results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + eval_datasets = [eval_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + eval_datasets.append(datasets["validation_mismatched"]) + + for eval_dataset, task in zip(eval_datasets, tasks): + eval_result = trainer.evaluate(eval_dataset=eval_dataset) + + output_eval_file = os.path.join(training_args.output_dir, f"eval_results_{task}.txt") + if trainer.is_world_process_zero(): + with open(output_eval_file, "w") as writer: + logger.info(f"***** Eval results {task} *****") + for key, value in eval_result.items(): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + eval_results.update(eval_result) + + if training_args.do_predict: + logger.info("*** Test ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + test_datasets = [test_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + test_datasets.append(datasets["test_mismatched"]) + + for test_dataset, task in zip(test_datasets, tasks): + # Removing the `label` columns because it contains -1 and Trainer won't like that. + test_dataset.remove_columns_("label") + predictions = trainer.predict(test_dataset=test_dataset).predictions + predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) + + output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt") + if trainer.is_world_process_zero(): + with open(output_test_file, "w") as writer: + logger.info(f"***** Test results {task} *****") + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + if is_regression: + writer.write(f"{index}\t{item:3.3f}\n") + else: + item = label_list[item] + writer.write(f"{index}\t{item}\n") + return eval_results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/nn_pruning/examples/question_answering/xp.py b/nn_pruning/examples/xp.py similarity index 99% rename from nn_pruning/examples/question_answering/xp.py rename to nn_pruning/examples/xp.py index 48484ec0..2882f3c2 100644 --- a/nn_pruning/examples/question_answering/xp.py +++ b/nn_pruning/examples/xp.py @@ -155,6 +155,8 @@ def get_all_args(self, exclude_base=False): all_args[name] = getattr(self, name) return all_args + + @classmethod def run_from_json_file(cls, filename): json_file_name = Path(filename).resolve() diff --git a/nn_pruning/hp_naming.py b/nn_pruning/hp_naming.py index 466dcc64..d3123bd6 100644 --- a/nn_pruning/hp_naming.py +++ b/nn_pruning/hp_naming.py @@ -140,41 +140,3 @@ def parse_repr(cls, repr): parameters[k] = cls.DEFAULTS[k] return parameters - - -if False: - - class MyTrialShortNamer(TrialShortNamer): - DEFAULTS = {"a": 0, "b": 0} - - def hp_space(trial): - return {} - - def model_init(trial): - if trial is not None: - a = trial.suggest_int("a", -4, 4) - b = trial.suggest_int("b", -4, 4) - else: - a = 0 - b = 0 - config = RegressionModelConfig(a=a, b=b, double_output=False) - - return RegressionPreTrainedModel(config) - - def hp_name(trial): - return MyTrialShortNamer.shortname(trial.params) - - with tempfile.TemporaryDirectory() as tmp_dir: - trainer = get_regression_trainer( - output_dir=tmp_dir, - learning_rate=0.1, - logging_steps=1, - evaluation_strategy=EvaluationStrategy.EPOCH, - num_train_epochs=4, - disable_tqdm=True, - load_best_model_at_end=True, - logging_dir="runs", - run_name="test", - model_init=model_init, - ) - trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4) diff --git a/nn_pruning/examples/question_answering/qa_sparse_patch.py b/nn_pruning/modules/patch_coordinator.py similarity index 73% rename from nn_pruning/examples/question_answering/qa_sparse_patch.py rename to nn_pruning/modules/patch_coordinator.py index f60cd687..0c36fc23 100644 --- a/nn_pruning/examples/question_answering/qa_sparse_patch.py +++ b/nn_pruning/modules/patch_coordinator.py @@ -22,9 +22,10 @@ import torch.nn as nn import torch.nn.functional as nn_functional from transformers import AutoConfig, AutoModelForQuestionAnswering +from dataclasses import dataclass, field from nn_pruning.model_structure import BertStructure -from nn_pruning.modules.masked_nn import ( +from .masked_nn import ( ChannelPruningModulePatcher, JointPruningModulePatcher, LinearPruningModulePatcher, @@ -38,15 +39,114 @@ PatcherContextModule, ) -class ModelPatchingCoordinator: - def log(self): - logs = {} - for k, v in self.patcher_context.enumerate_context_data(): - logs[k] = v - - return logs +@dataclass +class SparseTrainingArguments: + """ + Sparse training specific arguments + """ + + mask_scores_learning_rate: float = field( + default=1e-2, metadata={"help": "The initial learning rate for mask_scores."} + ) + + dense_pruning_method: str = field(default="topk", metadata={"help": "Dense Layers pruning method."}) + + attention_pruning_method: str = field(default="topk", metadata={"help": "Dense Layers pruning method."}) + + ampere_pruning_method: str = field( + default="disabled", + metadata={"help": "Ampere sparse method ('disabled' for no ampere sparsity)"}, + ) + + mask_init: str = field(default="constant", metadata={"help": "Mask scores initialization method"}) + + mask_scale: float = field( + default=0.0, + metadata={"help": "Parameter to use with mask_init."}, + ) + + dense_block_rows: int = field( + default=1, + metadata={"help": "Block size in rows for dense layers."}, + ) + + dense_block_cols: int = field( + default=1, + metadata={"help": "Block size in cols for dense layers."}, + ) + + attention_block_rows: int = field( + default=1, + metadata={"help": "Block size in rows for attention."}, + ) + + attention_block_cols: int = field( + default=1, + metadata={"help": "Block size in cols for attention."}, + ) + + initial_threshold: float = field( + default=1.0, + metadata={"help": "Initial value of the threshold (for scheduling)."}, + ) + final_threshold: float = field( + default=0.5, + metadata={"help": "Final value of the threshold (for scheduling)."}, + ) + + initial_warmup: float = field( + default=1, + metadata={ + "help": "Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays at its `initial_threshold` value (sparsity schedule)." + }, + ) + final_warmup: float = field( + default=2, + metadata={ + "help": "Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays" + }, + ) + + initial_ampere_temperature: float = field( + default=0.0, + metadata={"help": "Initial value of the ampere temperature (for scheduling)."}, + ) + final_ampere_temperature: float = field( + default=20.0, + metadata={"help": "Final value of the ampere temperature (for scheduling)."}, + ) + + regularization: str = field( + default="disabled", + metadata={"help": "Add L0 or L1 regularization to the mask scores."}, + ) + + regularization_final_lambda: float = field( + default=0.0, + metadata={"help": "Regularization intensity (used in conjunction with `regularization`)."}, + ) + + distil_teacher_name_or_path: str = field( + default=None, + metadata={"help": "Path to the already SQuAD fine-tuned teacher model. Only for distillation."}, + ) + + distil_alpha_ce: float = field( + default=0.5, + metadata={"help": "Cross entropy loss linear weight. Only for distillation."}, + ) + + distil_alpha_teacher: float = field( + default=0.5, + metadata={"help": "Distillation loss linear weight. Only for distillation."}, + ) + + distil_temperature: float = field( + default=2.0, + metadata={"help": "Distillation temperature. Only for distillation."}, + ) -class QASparseModelPatchingCoordinator(ModelPatchingCoordinator): +class ModelPatchingCoordinator: MODEL_STRUCTURE = BertStructure def __init__(self, sparse_args, device, cache_dir): @@ -64,51 +164,14 @@ def parse_pruning_method(self, method): raise RuntimeError("Could not parse pruning method") def patch_model(self, model, trial): - assert trial is None or len(trial.params) == 0 - attention_pruning_method_parts = self.parse_pruning_method(self.sparse_args.attention_pruning_method) + raise NotImplementedError("Implement in subclass") - parameters_attention = LinearPruningParameters( - method=attention_pruning_method_parts[0], - submethod=attention_pruning_method_parts[1], - ampere_method=self.sparse_args.ampere_pruning_method, - block_rows=self.sparse_args.attention_block_rows, - block_cols=self.sparse_args.attention_block_cols, - ) - - dense_pruning_method_parts = self.parse_pruning_method(self.sparse_args.dense_pruning_method) - - parameters_dense = LinearPruningParameters( - method=dense_pruning_method_parts[0], - submethod=dense_pruning_method_parts[1], - ampere_method=self.sparse_args.ampere_pruning_method, - block_rows=self.sparse_args.dense_block_rows, - block_cols=self.sparse_args.dense_block_cols, - ) - - patcher_context = self.patcher_context - - p_attention = JointPruningModulePatcher(patcher_context, parameters_attention, suffix=".attention") - - if parameters_dense.submethod.startswith("1d"): - p_dense = ChannelPruningModulePatcher( - patcher_context, parameters_dense, self.MODEL_STRUCTURE, suffix="dense" - ) - else: - p_dense = LinearPruningModulePatcher(patcher_context, parameters_dense, suffix="dense") - - module_patchers = dict( - query=p_attention, - key=p_attention, - value=p_attention, - att_dense=p_dense, - interm_dense=p_dense, - output_dense=p_dense, - ) - - patcher = BertLinearModelPatcher(module_patchers) - patcher.patch(model) + def log(self): + logs = {} + for k, v in self.patcher_context.enumerate_context_data(): + logs[k] = v - assert patcher.stats["patched"] == 72 + return logs def create_teacher(self, device, cache_dir): sparse_args = self.sparse_args @@ -272,11 +335,6 @@ def create_optimizer_groups(self, model, args, sparse_args): "lr": args.learning_rate, "weight_decay": args.weight_decay, }, - { - "params": decay_params, - "lr": args.learning_rate, - "weight_decay": 0.0, - }, ] return optimizer_grouped_parameters @@ -285,3 +343,53 @@ def compile_model(self, model): self.schedule_threshold() compiler = MaskedLinearModelCompiler() compiler.patch(model) + + def patch_model(self, model, trial): + assert trial is None or len(trial.params) == 0 + attention_pruning_method_parts = self.parse_pruning_method(self.sparse_args.attention_pruning_method) + + parameters_attention = LinearPruningParameters( + method=attention_pruning_method_parts[0], + submethod=attention_pruning_method_parts[1], + ampere_method=self.sparse_args.ampere_pruning_method, + block_rows=self.sparse_args.attention_block_rows, + block_cols=self.sparse_args.attention_block_cols, + ) + + dense_pruning_method_parts = self.parse_pruning_method(self.sparse_args.dense_pruning_method) + + parameters_dense = LinearPruningParameters( + method=dense_pruning_method_parts[0], + submethod=dense_pruning_method_parts[1], + ampere_method=self.sparse_args.ampere_pruning_method, + block_rows=self.sparse_args.dense_block_rows, + block_cols=self.sparse_args.dense_block_cols, + ) + + patcher_context = self.patcher_context + + p_attention = JointPruningModulePatcher(patcher_context, parameters_attention, suffix=".attention") + + if parameters_dense.submethod.startswith("1d"): + p_dense = ChannelPruningModulePatcher( + patcher_context, parameters_dense, self.MODEL_STRUCTURE, suffix="dense" + ) + else: + p_dense = LinearPruningModulePatcher(patcher_context, parameters_dense, suffix="dense") + + module_patchers = dict( + query=p_attention, + key=p_attention, + value=p_attention, + att_dense=p_dense, + interm_dense=p_dense, + output_dense=p_dense, + ) + + patcher = BertLinearModelPatcher(module_patchers) + patcher.patch(model) + assert (patcher.stats["patched"] == 72) + + return patcher + + diff --git a/nn_pruning/modules/sparse_trainer.py b/nn_pruning/modules/sparse_trainer.py new file mode 100644 index 00000000..5c622789 --- /dev/null +++ b/nn_pruning/modules/sparse_trainer.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Sparse Fine-tuning the library models for question answering. +""" +# You can also adapt this script on your own question answering task. Pointers for this are left as comments. + +from typing import Dict +from transformers.optimization import AdamW, get_linear_schedule_with_warmup +from .patch_coordinator import ModelPatchingCoordinator + +class SparseTrainer: + def __init__(self, sparse_args): + self.sparse_args = sparse_args + + def set_patch_coordinator(self, patch_coordinator: ModelPatchingCoordinator): + self.patch_coordinator = patch_coordinator + + def log(self, logs: Dict[str, float]) -> None: + add = {self.log_prefix + k: v for k, v in self.patch_coordinator.log().items()} + + logs.update(add) + + return super().log(logs) + + def schedule_threshold(self, training: bool): + step = self.state.global_step + self.patch_coordinator.schedule_threshold(step, self.state.max_steps, self.args.warmup_steps, training) + + def training_step(self, *args, **kwargs): + self.schedule_threshold(True) + self.log_prefix = "" + return super().training_step(*args, **kwargs) + + def compute_loss(self, model, inputs): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + outputs = model(**inputs) + + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + loss = self.patch_coordinator.distil_loss_combine(loss, inputs, outputs) + loss += self.patch_coordinator.regularization_loss(model) + + return loss + + def evaluate(self, *args, **kwargs): + self.schedule_threshold(False) + self.log_prefix = "eval_" + ret = super().evaluate(*args, **kwargs) + return ret + + def create_optimizer_and_scheduler(self, num_training_steps: int): + assert self.optimizer is None + self.optimizer = self.create_optimizer(self.model) + + assert self.lr_scheduler is None + self.lr_scheduler = self.create_scheduler(num_training_steps) + + def create_optimizer(self, model): + args = self.args + + optimizer_grouped_parameters = self.patch_coordinator.create_optimizer_groups(model, self.args, self.sparse_args) + + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + return optimizer + + def create_scheduler(self, num_training_steps: int): + scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps=self.args.warmup_steps, + num_training_steps=num_training_steps, + ) + return scheduler diff --git a/nn_pruning/tests/test_convert.py b/nn_pruning/tests/test_convert.py new file mode 100644 index 00000000..c476b5a0 --- /dev/null +++ b/nn_pruning/tests/test_convert.py @@ -0,0 +1,19 @@ +import shutil +import unittest +from pathlib import Path +from unittest import TestCase + +import nn_pruning.examples.question_answering.run_qa_sparse as run_qa_sparse + + +class TestFun(TestCase): + def test_base(self): + checkpoint_path = Path("/home/lagunas/tmp/checkpoint-5") + dest_path = checkpoint_path.parent / (checkpoint_path.name + "-compiled") + shutil.rmtree(dest_path) + + model = run_qa_sparse.QASparseTraining.compile_model(checkpoint_path, str(dest_path)) + + +if __name__ == "__main__": + unittest.main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..291558c9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.black] +line-length = 119 +target-version = ['py35']