-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodels.py
261 lines (197 loc) · 7.78 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
"""Functions for loading and interacting with pretrained language models.
This module is designed to house all the annoying branching logic
that comes with supporting analysis of many slightly different model
implementations.
"""
import argparse
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Iterator, Literal, Optional, Sequence, overload
from remedi.utils import tokenizer_utils
from remedi.utils.typing import Device, Model, Tokenizer
import torch
import transformers
logger = logging.getLogger(__name__)
GPT_J_NAME_SHORT = "gptj" # A useful alias for the CLI.
GPT_J_NAME = "EleutherAI/gpt-j-6B"
GPT_NEO_X_NAME_SHORT = "neox"
GPT_NEO_X_NAME = "EleutherAI/gpt-neox-20b"
@dataclass(frozen=True)
class ModelAndTokenizer:
"""A pretrained model and its tokenizer."""
model: Model
tokenizer: Tokenizer
def to_(self, device: Optional[Device]) -> None:
"""Send model to the device."""
self.model.to(device)
def eval_(self) -> None:
"""Set model to eval mode."""
self.model.eval()
def unwrap_model(value: Model | ModelAndTokenizer) -> Model:
"""Unwrap the model if necessary."""
if isinstance(value, ModelAndTokenizer):
return value.model
return value
def unwrap_tokenizer(tokenizer: ModelAndTokenizer | Tokenizer) -> Tokenizer:
"""Unwrap the tokenizer."""
if isinstance(tokenizer, ModelAndTokenizer):
return tokenizer.tokenizer
return tokenizer
def determine_layers(model: ModelAndTokenizer | Model) -> tuple[int, ...]:
"""Return all hidden layer names for the given model."""
model = unwrap_model(model)
assert isinstance(model, Model)
if isinstance(model, transformers.GPTNeoXForCausalLM):
n_layer = model.config.num_hidden_layers
else:
n_layer = model.config.n_layer
return (*range(n_layer),)
@overload
def determine_layer_paths(
model: ModelAndTokenizer | Model,
layers: Optional[Sequence[int]] = ...,
*,
return_dict: Literal[False] = ...,
) -> Sequence[str]:
"""Determine layer path for each layer."""
...
@overload
def determine_layer_paths(
model: ModelAndTokenizer | Model,
layers: Optional[Sequence[int]] = ...,
*,
return_dict: Literal[True],
) -> dict[int, str]:
"""Determine mapping from layer to layer path."""
...
def determine_layer_paths(
model: ModelAndTokenizer | Model,
layers: Optional[Sequence[int]] = None,
*,
return_dict: bool = False,
) -> Sequence[str] | dict[int, str]:
"""Determine the absolute paths to the given layers in the model.
Args:
model: The model.
layers: The specific layer (numbers) to look at. Defaults to all of them.
return_dict: If True, return mapping from layer to layer path,
otherwise just return list of layer paths in same order as `layers`.
Returns:
Mapping from layer number to layer path.
"""
model = unwrap_model(model)
if layers is None:
layers = determine_layers(model)
assert isinstance(model, Model), type(model)
layer_paths = {}
for layer in layers:
if isinstance(model, transformers.GPTNeoXForCausalLM):
layer_path = f"gpt_neox.layers.{layer}"
else:
layer_path = f"transformer.h.{layer}"
layer_paths[layer] = layer_path
return layer_paths if return_dict else tuple(layer_paths[la] for la in layers)
def determine_hidden_size(model: ModelAndTokenizer | Model) -> int:
"""Determine hidden rep size for the model."""
model = unwrap_model(model)
return model.config.hidden_size
def determine_device(model: ModelAndTokenizer | Model) -> torch.device | None:
"""Determine device model is running on."""
parameter = any_parameter(model)
return parameter.device if parameter is not None else None
def determine_dtype(model: ModelAndTokenizer | Model) -> torch.dtype | None:
"""Determine dtype of model."""
parameter = any_parameter(model)
return parameter.dtype if parameter is not None else None
def any_parameter(model: ModelAndTokenizer | Model) -> torch.nn.Parameter | None:
"""Get any example parameter for the model."""
model = unwrap_model(model)
return next(iter(model.parameters()), None)
@contextmanager
def set_padding_side(
tokenizer: Tokenizer | ModelAndTokenizer, padding_side: str = "right"
) -> Iterator[None]:
"""Wrap `tokenizer_utils.set_padding_side`."""
tokenizer = unwrap_tokenizer(tokenizer)
with tokenizer_utils.set_padding_side(tokenizer, padding_side=padding_side):
yield
def map_to(
orig: Any, device: Device | None = None, dtype: torch.dtype | None = None
) -> Any:
"""Map all tensors in the given value to the device.
Args:
orig: Any sequence of or mapping to tensors, or just a tensor.
device: Device to send to.
Returns:
Same value, but with all tensors moved to the device.
"""
if device is None and dtype is None:
return orig
result = orig
if isinstance(orig, torch.Tensor):
result = orig.to(device=device, dtype=dtype)
elif isinstance(orig, dict):
result = {
key: map_to(value, device=device, dtype=dtype)
for key, value in orig.items()
}
elif isinstance(orig, (list, tuple)):
result = orig.__class__(
map_to(value, device=device, dtype=dtype) for value in orig
)
assert isinstance(result, orig.__class__), f"{type(result)}/{type(orig)}"
return result
def load_model(
name: str, device: Optional[Device] = None, fp16: Optional[bool] = None
) -> ModelAndTokenizer:
"""Load the model given its string name.
Args:
name: Name of the model.
device: If set, send model to this device. Defaults to CPU.
fp16: Whether to use half precision. If not set, depends on model.
Returns:
ModelAndTokenizer: Loaded model and its tokenizer.
"""
if name == GPT_J_NAME_SHORT:
name = GPT_J_NAME
elif name == GPT_NEO_X_NAME_SHORT:
name = GPT_NEO_X_NAME
# I usually save randomly initialized variants under the short name of the
# corresponding real model (e.g. gptj_random, neox_random), so check here
# if we are dealing with *any* variant of the big model.
is_gpt_j_variant = name == GPT_J_NAME or GPT_J_NAME_SHORT in name
is_neo_x_variant = name == GPT_NEO_X_NAME or GPT_NEO_X_NAME_SHORT in name
if fp16 is None:
fp16 = is_gpt_j_variant or is_neo_x_variant
torch_dtype = torch.float16 if fp16 else None
kwargs: dict = dict(torch_dtype=torch_dtype)
if name == GPT_J_NAME or GPT_J_NAME_SHORT in name:
kwargs["low_cpu_mem_usage"] = True
if fp16:
kwargs["revision"] = "float16"
logger.info(f"loading {name} (device={device}, fp16={fp16})")
model = transformers.AutoModelForCausalLM.from_pretrained(name, **kwargs)
if is_neo_x_variant:
model.to(torch_dtype)
model.to(device).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
tokenizer.pad_token = tokenizer.eos_token
return ModelAndTokenizer(model, tokenizer)
def add_model_args(parser: argparse.ArgumentParser) -> None:
"""Add args needed to load a model.
The args include:
--model: The language model to load, defaulting to GPT-J.
--device: The device to send model and inputs to.
--fp16: Whether to use half precision version of the model.
Note this is used as `--fp16 False` since default value depends on
which model we are loading.
"""
parser.add_argument(
"--model",
"-m",
default=GPT_J_NAME_SHORT,
help="model to edit",
)
parser.add_argument("--device", help="device to train on")
parser.add_argument("--fp16", type=bool, help="set whether to use fp16")