Skip to content

Commit

Permalink
Adding summarization tasks for model upload.
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jul 17, 2021
1 parent 6b005ca commit ce21d56
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 42 deletions.
2 changes: 1 addition & 1 deletion analysis/analyze_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, path, output_name, task, copy_to_tmp_path = False):
def open_model(self):
if self.copy_to_tmp_path:
self.dest_path_ = tempfile.TemporaryDirectory()
self.dest_path = Path(self.dest_path_.name)
self.dest_path = self.dest_path_.name
else:
self.dest_path = None

Expand Down
61 changes: 48 additions & 13 deletions analysis/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import sh
import shutil
from transformers import (
BertForQuestionAnswering,
BertForSequenceClassification,
TFBertForQuestionAnswering,
TFBertForSequenceClassification
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
TFAutoModelForQuestionAnswering,
TFAutoModelForSequenceClassification,
TFAutoModelForSeq2SeqLM
)
from tempfile import TemporaryDirectory
from nn_pruning.inference_model_patcher import optimize_model
Expand All @@ -23,6 +25,7 @@

from examples.question_answering.qa_sparse_xp import QASparseXP
from examples.text_classification.glue_sparse_xp import GlueSparseXP
from examples.seq2seq.summarization_sparse_xp import SummarizationSparseXP

def pretty_json(p):
return json.dumps(p, sort_keys=True, indent=indent, separators = [", ", ": "])
Expand Down Expand Up @@ -61,6 +64,7 @@ class Packager:
EVAL_DIR = "eval"
QA_TASKS = {"squadv1", "squadv2"}
GLUE_TASKS = {"mnli", "qqp", "sst2"}
SUMMARIZATION_TASKS = {"cnn_dailymail"}
TASK_INFO = {
"squadv1": {
"main_metric": "f1",
Expand Down Expand Up @@ -111,10 +115,22 @@ class Packager:
"speed_report.json"
],
},
"cnn_dailymail": {
"main_metric": "eval_rouge2",
"base_value": 30.13,
"base_model": "facebook/bart-large-cnn",
"eval_files": [
"eval_results.json",
"evaluate_timing.json",
"sparsity_report.json",
"speed_report.json"
],
},
}
METRIC_INFO = {
"eval_accuracy": {"name": "accuracy", "title_name": "acc"},
"f1": {"name": "F1", "title_name": "f"},
"eval_rouge2": {"name": "R2", "title_name": "r"},
}

def __init__(self,
Expand All @@ -137,7 +153,7 @@ def __init__(self,
def build_model_name_(cls, base_name, task, speedup, metric_name, metric_value, linear_sparsity, kind, is_ampere, version):

density = int(100 - linear_sparsity)

base_name = base_name.split("/")[-1]
name = f"{base_name}-{task}-x{speedup:.2f}-{metric_name}{metric_value:.1f}-d{density}-{kind}"
if is_ampere:
name += "-ampere"
Expand Down Expand Up @@ -239,6 +255,8 @@ def copy_model_files(self, force = False):
model = QASparseXP.compile_model(src_path, dest_path=d.name)
elif self.task in self.GLUE_TASKS:
model = GlueSparseXP.compile_model(src_path, dest_path=d.name)
elif self.task in self.SUMMARIZATION_TASKS:
model = SummarizationSparseXP.compile_model(src_path, dest_path=d.name)
else:
raise Exception(f"Unknown task {self.task}")

Expand All @@ -249,10 +267,13 @@ def copy_model_files(self, force = False):
with TemporaryDirectory() as d2:
if self.task in self.QA_TASKS:
QASparseXP.final_fine_tune_bertarize(src_path, d2, remove_head_pruning=True)
tf_model = TFBertForQuestionAnswering.from_pretrained(d2, from_pt=True)
tf_model = TFAutoModelForQuestionAnswering.from_pretrained(d2, from_pt=True)
elif self.task in self.GLUE_TASKS:
GlueSparseXP.final_fine_tune_bertarize(src_path, d2, remove_head_pruning=True)
tf_model = TFBertForSequenceClassification.from_pretrained(d2, from_pt=True)
tf_model = TFAutoModelForSequenceClassification.from_pretrained(d2, from_pt=True)
elif self.task in self.SUMMARIZATION_TASKS:
SummarizationSparseXP.final_fine_tune_bertarize(src_path, d2, remove_head_pruning=True)
tf_model = TFAutoModelForSeq2SeqLM.from_pretrained(d2, from_pt=True)
else:
raise Exception(f"Unknown task {self.task}")

Expand All @@ -261,9 +282,11 @@ def copy_model_files(self, force = False):

if force or not (self.git_path / "pytorch_model.bin").exists():
if self.task in self.QA_TASKS:
model = BertForQuestionAnswering.from_pretrained(src_path)
model = AutoModelForQuestionAnswering.from_pretrained(src_path)
elif self.task in self.GLUE_TASKS:
model = BertForSequenceClassification.from_pretrained(src_path)
model = AutoModelForSequenceClassification.from_pretrained(src_path)
elif self.task in self.SUMMARIZATION_TASKS:
model = AutoModelForSeq2SeqLM.from_pretrained(src_path)
else:
raise Exception(f"Unknown task {self.task}")
model.save_pretrained(self.git_path)
Expand Down Expand Up @@ -305,9 +328,13 @@ def create_graphics(self, url_base, model_card_path):
density_plotter = DensityBokehPlotter("density", self.JS_PATH)

if self.task in self.QA_TASKS:
model = BertForQuestionAnswering.from_pretrained(self.git_path)
model = AutoModelForQuestionAnswering.from_pretrained(self.git_path)
elif self.task in self.GLUE_TASKS:
model = BertForSequenceClassification.from_pretrained(self.git_path)
model = AutoModelForSequenceClassification.from_pretrained(self.git_path)
elif self.task in self.SUMMARIZATION_TASKS:
model = AutoModelForSeq2SeqLM.from_pretrained(self.git_path)
else:
raise Exception(f"Unknown task {self.task}")

fig, js, html = density_plotter.run(model=model,
dest_path=model_card_path / "images",
Expand Down Expand Up @@ -342,8 +369,16 @@ def create_readme(self):
template.undefined = jinja2.StrictUndefined

config = checkpoint_info["config"]
pruned_heads = sum([len(x) for x in config["pruned_heads"].values()])
total_heads = config["num_hidden_layers"] * config["num_attention_heads"]
pruned_heads = config.get("pruned_heads", {})
pruned_heads = sum([len(x) for x in pruned_heads.values()])
total_heads = config.get("heads_count")
if total_heads is None:
model_structure = name2struct.get(config["model_type"])
num_hidden_layers = config.get(model_structure.NAME_CONFIG["num_hidden_layers"])
num_attention_heads = config.get(model_structure.NAME_CONFIG["num_attention_heads"])
assert num_hidden_layers is not None
assert num_attention_heads is not None
total_heads = num_hidden_layers * num_attention_heads

sparsity_report = dict(linear_density = self.density,
total_density = self.total_density,
Expand Down
22 changes: 16 additions & 6 deletions analysis/model_card_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def add_info(self, traces_, kind, **kwargs):
def layer_short_name(self, name):
shortname = name
shortname = shortname.split(".")
shortname = shortname[3:]
shortname = shortname[2:]
shortname = ".".join(shortname)
shortname = ("decoder." if self.model_structure.is_decoder(name) else "encoder.") + shortname
shortname = self.replacements_apply(shortname, [".self", ".weight", ".dense", (".", ".")])

return shortname
Expand All @@ -207,7 +208,6 @@ def create_fig(self,
self.full_color = full_color
self.empty_color = empty_color
self.url_base = url_base
self.attention_size = model.config.hidden_size

self.model_structure = struct_from_config(model.config_class)
self.attention_size = getattr(model.config, self.model_structure.NAME_CONFIG["hidden_size"])
Expand All @@ -217,8 +217,9 @@ def create_fig(self,

traces = {}
part_index = 0
max_layer_encoder = -1
linear_per_layer = len(self.model_structure.LAYER_PATTERNS)
positions = []
len_layer_pattern = len(self.model_structure.LAYER_PATTERNS)
for layer in self.layers:
name = layer["name"]
density = layer["density"] * 100
Expand All @@ -230,17 +231,17 @@ def create_fig(self,
if v in name:
increment = 1
if k in self.model_structure.ATTENTION_LAYERS:
kind = k
kind = k.replace('encoder_decoder_', '')
else:
kind = "fully connected"
break

if kind is None:
print(name)
assert (False)

shortname = self.layer_short_name(name)

linear_per_layer = 6 if not self.model_structure.is_decoder(name) else len_layer_pattern
x = part_index / linear_per_layer + 1 / (linear_per_layer * 2)
url = f"{self.url_base}/{layer['filename']}"
img_height = str(int(layer["size"][0] / 8)) + "px"
Expand All @@ -258,7 +259,16 @@ def create_fig(self,
if x not in positions:
positions.append(x)

part_index += increment
layer_number = self.model_structure.layer_index(name)
is_decoder = self.model_structure.is_decoder(name)
if not is_decoder:
max_layer_encoder = max(max_layer_encoder, layer_number)
else:
layer_number += max_layer_encoder + 1
if self.model_structure.is_ffn(name) and self.model_structure.LAYER_PATTERNS["output_dense"] in name:
part_index = (layer_number + 1 ) * linear_per_layer
else:
part_index += increment

colors = ["#6573f7", "#ed5642", "#20cb97", "#aa69f7"]

Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq/seq2seq_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def checkpoint_dir(self):
def instrument_model(self, model):
if self.args.optimize_model_before_eval != "disabled":
model = optimize_model(self.model, self.args.optimize_model_before_eval)
return TimingModule(model, method_list=["generate"])
return TimingModule(model, method_list=["generate", "config"])

def run_dir(self):
return Path(self.args.output_dir)
Expand Down
48 changes: 29 additions & 19 deletions nn_pruning/inference_model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,62 @@
from transformers import BertConfig
from .model_patcher import ModelPatcher
from .model_structure import struct_from_config, count_num_heads

from collections import defaultdict

class BertHeadsPruner:
def __init__(self, model):
self.model = model
self.model_structure = struct_from_config(model.config_class)
attention_layers = [self.model_structure.LAYER_PATTERNS[i] for i in self.model_structure.ATTENTION_LAYERS]
if hasattr(self.model_structure, "ATTENTION_PREFIX"):
ATTENTION_PREFIX = self.model_structure.ATTENTION_PREFIX
self.attention_prefix = ATTENTION_PREFIX[0] if isinstance(ATTENTION_PREFIX, tuple) else ATTENTION_PREFIX
else:
self.attention_prefix = ".".join(attention_layers[0].split('.')[:-1])
self.attention_layers = [attention_layer[len(self.attention_prefix)+1:] for attention_layer in attention_layers]
self.attention_head_size = self.model_structure.NAME_CONFIG["attention_head_size"]
self.num_attention_heads = self.model_structure.NAME_CONFIG["num_attention_heads"]

def analyze_head(self, p, head_size):
p0 = (p != 0).reshape(p.shape[0] // head_size, head_size, p.shape[1]).any(-1).any(-1)
return p0

def get_pruned_heads(self):
heads_count = 0
to_prune = {}
to_prune = defaultdict(list)
max_layer_encoder = -1
layer_number = -1
shift = 0
for name, module in self.model.named_modules():
if name.endswith(self.attention_prefix):
if hasattr(module, self.attention_head_size):
head_size = getattr(module, self.attention_head_size)
num_heads = getattr(module, self.num_attention_heads)
prev_layer = layer_number
layer_number = int("".join(ch for ch in name if ch.isnumeric()))
is_decoder = self.model_structure.is_decoder(name)
if not is_decoder:
max_layer_encoder = max(max_layer_encoder, layer_number)
else:
layer_number += max_layer_encoder + 1
if prev_layer == layer_number:
shift += num_heads
else:
shift = 0
parts = []
for a in self.attention_layers:
p = self.analyze_head(getattr(module, a).weight, module.attention_head_size)
parts.append(p)
for k, v in module.named_children():
if self.model_structure.is_attention(".".join([name, k]), exclude_att_dense=True):
p = self.analyze_head(v.weight, head_size)
parts.append(p)
parts = list(torch.stack(parts, 0).all(0).cpu().detach().numpy())
heads_count += len(parts)

heads_to_prune = [i for i, p in enumerate(parts) if not p]
heads_to_prune = [i + shift for i, p in enumerate(parts) if not p]

# TEMPORARY : AT LEAST KEEP ONE HEAD
if len(heads_to_prune) == len(parts):
heads_to_prune.remove(0)

to_prune[layer_number] = heads_to_prune
to_prune[layer_number].extend(heads_to_prune)
return to_prune, heads_count

def run(self):
model = self.model
to_prune, heads_count = self.get_pruned_heads()
if isinstance(self.model.config, BertConfig):
to_prune, heads_count = self.get_pruned_heads()
model.prune_heads(to_prune)
return sum([len(p) for p in to_prune.values()]), heads_count
return 0, count_num_heads(model)
return sum([len(p) for p in to_prune.values()]), heads_count


class SparseDimensionsLinear(nn.Module):
Expand Down
5 changes: 4 additions & 1 deletion nn_pruning/model_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class BertStructure(ModelStructure):
intermediate_size="intermediate_size",
num_hidden_layers="num_hidden_layers",
num_attention_heads="num_attention_heads",
attention_head_size="attention_head_size",
)

class BartStructure(ModelStructure):
Expand All @@ -96,7 +97,8 @@ class BartStructure(ModelStructure):
hidden_size="d_model",
intermediate_size="encoder_ffn_dim",
num_hidden_layers="num_hidden_layers",
num_attention_heads="encoder_attention_heads",
num_attention_heads="num_heads",
attention_head_size = "head_dim",
)

class T5Structure(ModelStructure):
Expand All @@ -121,6 +123,7 @@ class T5Structure(ModelStructure):
intermediate_size="d_ff",
num_hidden_layers="num_layers",
num_attention_heads="num_heads",
attention_head_size="key_value_proj_dim",
)

config2struct = {
Expand Down
5 changes: 4 additions & 1 deletion nn_pruning/sparse_xp.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def final_fine_tune_bertarize(cls, src_path, dest_path, remove_head_pruning = Fa
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)

if remove_head_pruning:
del config["pruned_heads"]
try:
del config["pruned_heads"]
except KeyError:
pass
with (dest_path / "config.json").open("w") as f:
json.dump(config, f)

Expand Down

0 comments on commit ce21d56

Please sign in to comment.