Skip to content

Commit

Permalink
Refactor to extend easily to other networks and tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
madlag committed Dec 24, 2020
1 parent 7a12870 commit 9b26e00
Show file tree
Hide file tree
Showing 16 changed files with 1,507 additions and 288 deletions.
4 changes: 4 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[settings]
multi_line_output=3
include_trailing_comma=True

3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
style:
black .
isort .
83 changes: 6 additions & 77 deletions nn_pruning/examples/question_answering/qa_sparse_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,83 +17,12 @@
"""
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.

from typing import Dict
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from .qa_sparse_patch import QASparseModelPatchingCoordinator
from nn_pruning.modules.sparse_trainer import SparseTrainer
from .qa_train import QATrainer


class QASparseTrainer(QATrainer):
# SparseTrainer should appear first in the base classes, as its functions must override QATrainer and its base classes (Trainer)
class QASparseTrainer(SparseTrainer, QATrainer):
def __init__(self, sparse_args, *args, **kwargs):
super().__init__(*args, **kwargs)

self.sparse_args = sparse_args

def set_patch_coordinator(self, patch_coordinator: QASparseModelPatchingCoordinator):
self.patch_coordinator = patch_coordinator

def log(self, logs: Dict[str, float]) -> None:
add = {self.log_prefix + k: v for k, v in self.patch_coordinator.log().items()}

logs.update(add)

return super().log(logs)

def schedule_threshold(self, training: bool):
step = self.state.global_step
self.patch_coordinator.schedule_threshold(step, self.state.max_steps, self.args.warmup_steps, training)

def training_step(self, *args, **kwargs):
self.schedule_threshold(True)
self.log_prefix = ""
return super().training_step(*args, **kwargs)

def compute_loss(self, model, inputs):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
outputs = model(**inputs)

# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]

# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

loss = self.patch_coordinator.distil_loss_combine(loss, inputs, outputs)
loss += self.patch_coordinator.regularization_loss(model)

return loss

def evaluate(self, *args, **kwargs):
self.schedule_threshold(False)
self.log_prefix = "eval_"
ret = super().evaluate(*args, **kwargs)
return ret

def create_optimizer_and_scheduler(self, num_training_steps: int):
assert self.optimizer is None
self.optimizer = self.create_optimizer(self.model)

assert self.lr_scheduler is None
self.lr_scheduler = self.create_scheduler(num_training_steps)

def create_optimizer(self, model):
args = self.args

optimizer_grouped_parameters = self.patch_coordinator.create_optimizer_groups(model, self.args, self.sparse_args)

optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
return optimizer

def create_scheduler(self, num_training_steps: int):
scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=num_training_steps,
)
return scheduler
print(QASparseTrainer.__mro__)
QATrainer.__init__(self, *args, **kwargs)
SparseTrainer.__init__(self, sparse_args)
115 changes: 3 additions & 112 deletions nn_pruning/examples/question_answering/qa_sparse_xp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@

import json
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from types import SimpleNamespace

from nn_pruning.hp_naming import TrialShortNamer
from .qa_sparse_patch import QASparseModelPatchingCoordinator
from nn_pruning.modules.patch_coordinator import SparseTrainingArguments, ModelPatchingCoordinator
from .qa_sparse_train import QASparseTrainer
from .qa_xp import (
QAXP,
Expand All @@ -35,114 +34,6 @@
)


@dataclass
class SparseTrainingArguments:
"""
Sparse training specific arguments
"""

mask_scores_learning_rate: float = field(
default=1e-2, metadata={"help": "The initial learning rate for mask_scores."}
)

dense_pruning_method: str = field(default="topk", metadata={"help": "Dense Layers pruning method."})

attention_pruning_method: str = field(default="topk", metadata={"help": "Dense Layers pruning method."})

ampere_pruning_method: str = field(
default="disabled",
metadata={"help": "Ampere sparse method ('disabled' for no ampere sparsity)"},
)

mask_init: str = field(default="constant", metadata={"help": "Mask scores initialization method"})

mask_scale: float = field(
default=0.0,
metadata={"help": "Parameter to use with mask_init."},
)

dense_block_rows: int = field(
default=1,
metadata={"help": "Block size in rows for dense layers."},
)

dense_block_cols: int = field(
default=1,
metadata={"help": "Block size in cols for dense layers."},
)

attention_block_rows: int = field(
default=1,
metadata={"help": "Block size in rows for attention."},
)

attention_block_cols: int = field(
default=1,
metadata={"help": "Block size in cols for attention."},
)

initial_threshold: float = field(
default=1.0,
metadata={"help": "Initial value of the threshold (for scheduling)."},
)
final_threshold: float = field(
default=0.5,
metadata={"help": "Final value of the threshold (for scheduling)."},
)

initial_warmup: float = field(
default=1,
metadata={
"help": "Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays at its `initial_threshold` value (sparsity schedule)."
},
)
final_warmup: float = field(
default=2,
metadata={
"help": "Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
},
)

initial_ampere_temperature: float = field(
default=0.0,
metadata={"help": "Initial value of the ampere temperature (for scheduling)."},
)
final_ampere_temperature: float = field(
default=20.0,
metadata={"help": "Final value of the ampere temperature (for scheduling)."},
)

regularization: str = field(
default="disabled",
metadata={"help": "Add L0 or L1 regularization to the mask scores."},
)

regularization_final_lambda: float = field(
default=0.0,
metadata={"help": "Regularization intensity (used in conjunction with `regularization`)."},
)

distil_teacher_name_or_path: str = field(
default=None,
metadata={"help": "Path to the already SQuAD fine-tuned teacher model. Only for distillation."},
)

distil_alpha_ce: float = field(
default=0.5,
metadata={"help": "Cross entropy loss linear weight. Only for distillation."},
)

distil_alpha_teacher: float = field(
default=0.5,
metadata={"help": "Distillation loss linear weight. Only for distillation."},
)

distil_temperature: float = field(
default=2.0,
metadata={"help": "Distillation temperature. Only for distillation."},
)


class SparseQAShortNamer(TrialShortNamer):
DEFAULTS = {
"adam_beta1": 0.9,
Expand Down Expand Up @@ -232,7 +123,7 @@ class QASparseXP(QAXP):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.patch_coordinator = QASparseModelPatchingCoordinator(self.sparse_args, self.training_args.device, self.model_args.cache_dir)
self.patch_coordinator = ModelPatchingCoordinator(self.sparse_args, self.training_args.device, self.model_args.cache_dir)

def create_trainer(self, *args, **kwargs):
super().create_trainer(*args, **kwargs)
Expand Down Expand Up @@ -264,7 +155,7 @@ def load_json_to_obj(name):
model_args.model_name_or_path = str(src_path)

model = cls._model_init(model_args, model_config, trial=None)
patcher = QASparseModelPatcher(sparse_args)
patcher = ModelPatchingCoordinator(sparse_args)
patcher.patch_model(model, trial=None)
import torch

Expand Down
8 changes: 4 additions & 4 deletions nn_pruning/examples/question_answering/qa_xp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dataclasses import dataclass, field
from pathlib import Path

from datasets import load_dataset, load_metric
from datasets import load_metric
from transformers import (
AutoModelForQuestionAnswering,
DataCollatorWithPadding,
Expand All @@ -34,7 +34,7 @@

from .qa_train import QATrainer
from .qa_utils import postprocess_qa_predictions
from .xp import XP, DataTrainingArguments, ModelArguments, TrainingArguments
from nn_pruning.examples.xp import XP, DataTrainingArguments, ModelArguments, TrainingArguments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,7 +80,7 @@ class QAXP(XP):
SHORT_NAMER = TrialShortNamer

@classmethod
def _model_init(self, model_args, model_config, trial=None):
def _model_init(self, model_args, model_config):
model = AutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand All @@ -90,7 +90,7 @@ def _model_init(self, model_args, model_config, trial=None):
return model

def model_init(self, trial=None):
return self._model_init(self.model_args, self.config, trial)
return self._model_init(self.model_args, self.config)

def prepare_column_names(self):
training_args = self.training_args
Expand Down
Loading

0 comments on commit 9b26e00

Please sign in to comment.