-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathopenai_completions.py
481 lines (402 loc) · 17 KB
/
openai_completions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
import copy
import os
from collections import defaultdict
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple
from tqdm import tqdm
import lm_eval.models.utils
from lm_eval import utils
from lm_eval.api.model import LM, TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import retry_on_specific_exceptions
from lm_eval.utils import eval_logger
def get_result(response, ctxlen: int) -> Tuple[float, bool]:
"""Process results from OpenAI API response.
:param response: dict
OpenAI API Response
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
continuation_logprobs: np.array
Log probabilities of continuation tokens
is_greedy: bool
whether argmax matches given continuation exactly
"""
is_greedy = True
logprobs = response.logprobs.token_logprobs
continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response.logprobs.token_logprobs)):
token = response.logprobs.token_logprobs[i]
top_tokens = response.logprobs.top_logprobs[i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
def oa_completion(client, chat: bool = False, **kwargs):
"""Query OpenAI API for completion.
Retry with back-off until they respond
"""
if not find_spec("openai") or not find_spec("tiktoken"):
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
"Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
)
else:
import openai
def _exception_callback(e: Exception, sleep_time: float) -> None:
import traceback
traceback.print_exc()
@retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
if chat:
return client.chat.completions.create(**kwargs)
else:
return client.completions.create(**kwargs)
return completion()
@register_model("openai-completions", "local-completions")
class OpenaiCompletionsLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
model: str,
base_url: str = None,
tokenizer: Optional[str] = None,
tokenizer_backend: Literal["tiktoken", "huggingface"] = "tiktoken",
truncate: bool = False,
max_gen_toks: int = 256,
batch_size: int = 1,
seed: int = 1234,
max_length: Optional[int] = None,
) -> None:
"""
:param engine: str
OpenAI API engine (e.g. gpt-3.5-turbo-instruct)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
self.seed = seed
try:
import openai # noqa: E401
import tiktoken
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .\"[openai]\"`",
)
self.model = model
self.base_url = base_url
self.tokenizer_backend = tokenizer_backend
self.truncate = truncate
self._batch_size = int(batch_size)
self._max_gen_toks = max_gen_toks
self._max_length = max_length
# if we have a local model, use HF tokenizer over tiktoken
if self.tokenizer_backend == "huggingface":
import transformers # noqa: E401
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer if tokenizer else self.model
)
self.vocab_size = self.tokenizer.vocab
self.end_of_text_token_id = self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
if self.base_url:
eval_logger.warning(
f"Passed `base_url={self.base_url}` but using Tiktoken tokenizer backend. "
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
)
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.end_of_text_token_id = self.tokenizer.eot_token
else:
raise ValueError(
f"Expected tokenizer_backend to be one of ['tiktoken', 'huggingface'] but got {self.tokenizer_backend}"
)
# Read from environment variable OPENAI_API_KEY
# Set to EMPTY for local
openai.api_key = os.environ["OPENAI_API_KEY"]
if self.base_url:
self.client = openai.OpenAI(base_url=self.base_url)
else:
self.client = openai.OpenAI()
@property
def eot_token_id(self):
return self.end_of_text_token_id
@property
def max_length(self) -> int:
if self._max_length:
return self._max_length
else:
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self) -> int:
return self._max_gen_toks
@property
def batch_size(self) -> int:
return self._batch_size
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def tok_encode(self, string: str, **kwargs) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
res = []
def _collate(x):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about, and so we need some kind of backup for when it isn't
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(
list(lm_eval.models.utils.chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm,
):
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp)
ctxlens.append(ctxlen)
response = oa_completion(
client=self.client,
model=self.model,
prompt=inps,
echo=True,
max_tokens=0,
temperature=0.0,
logprobs=10,
seed=self.seed,
)
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen)
res.append(answer)
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
if not requests:
return []
res = []
requests = [req.args for req in requests]
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm,
):
inps = []
self._max_gen_toks = request_args.get("max_gen_toks", self.max_gen_toks)
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
until = request_args.get("until", ["<|endoftext|>"])
request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion(
client=self.client,
model=self.model,
prompt=inps,
max_tokens=self.max_gen_toks,
stop=until,
seed=self.seed,
**{
k: v
for k, v in request_args.items()
if k not in {"do_sample", "max_gen_toks", "until"}
},
)
for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, "text")
until_ = until
for term in until_:
if len(term) > 0:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial(
"generate_until", (context, {"until": until_}), s
)
res.append(s)
return re_ord.get_original(res)
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
disable_tqdm=True,
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
@register_model("openai-chat-completions", "local-chat-completions")
class OpenaiChatCompletionsLM(LM):
def __init__(
self,
model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths
base_url: str = None,
truncate: bool = False,
**kwargs,
) -> None:
"""
:param model: str
Implements an OpenAI-style chat completion API for
accessing both OpenAI OR locally-hosted models using
HuggingFace Tokenizer
OpenAI API model (e.g. gpt-3.5-turbo)
using the **gen_kwargs passed on init
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
try:
import openai # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
self.model = model
self.base_url = base_url
self.truncate = truncate
# Read from environment variable OPENAI_API_KEY
# Set to EMPTY for local
if self.base_url:
self.client = openai.OpenAI(base_url=self.base_url)
else:
self.client = openai.OpenAI() # openai.AsyncOpenAI()
@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
@property
def max_gen_toks(self) -> int:
return 256
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
res = defaultdict(list)
re_ords = {}
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for key, re_ord in re_ords.items():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
)
kwargs["stop"] = until
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
else:
raise ValueError(
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
)
response = oa_completion(
client=self.client,
chat=True,
messages=inps,
model=self.model,
**kwargs,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
if until is not None:
for term in until:
if len(term) > 0:
s = s.split(term)[0]
res[key].append(s)
self.cache_hook.add_partial(
"generate_until", (context, {"until": until}), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
def loglikelihood(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")