Skip to content

Commit

Permalink
Update interface (#16)
Browse files Browse the repository at this point in the history
* CLI cleanup

* Also support uploading files as a convenience to the user

* Events in the CLI (#23)

* Events in the CLI

* Update message about ctrl-c

* Version

* Forgot to use the api_base arg (#20)

* Forgot to use the api_base arg

* Bump version

* newline

Co-authored-by: hallacy <[email protected]>
  • Loading branch information
rachellim and hallacy authored May 21, 2021
1 parent 5f8c4a8 commit 7b0f97e
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 23 deletions.
2 changes: 1 addition & 1 deletion openai/api_resources/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def create(
):
requestor = api_requestor.APIRequestor(
api_key,
api_base=openai.file_api_base or openai.api_base,
api_base=api_base or openai.file_api_base or openai.api_base,
api_version=api_version,
organization=organization,
)
Expand Down
40 changes: 38 additions & 2 deletions openai/api_resources/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
nested_resource_class_methods,
)
from openai.six.moves.urllib.parse import quote_plus
from openai import util
from openai import api_requestor, util


@nested_resource_class_methods("event", operations=["list"])
Expand All @@ -18,4 +18,40 @@ def cancel(cls, id, api_key=None, request_id=None, **params):
url = "%s/%s/cancel" % (base, extn)
instance = cls(id, api_key, **params)
headers = util.populate_headers(request_id=request_id)
return instance.request("post", url, headers=headers)
return instance.request("post", url, headers=headers)

@classmethod
def stream_events(
cls,
id,
api_key=None,
api_base=None,
request_id=None,
api_version=None,
organization=None,
**params
):
base = cls.class_url()
extn = quote_plus(id)

requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base,
api_version=api_version,
organization=organization,
)
url = "%s/%s/events?stream=true" % (base, extn)
headers = util.populate_headers(request_id=request_id)
response, _, api_key = requestor.request(
"get", url, params, headers=headers, stream=True
)

return (
util.convert_to_openai_object(
line,
api_key,
api_version,
organization,
)
for line in response
)
138 changes: 119 additions & 19 deletions openai/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import datetime
import json
import os
import signal
import sys
import warnings

Expand Down Expand Up @@ -205,21 +208,34 @@ def list(cls, args):
print(file)


class FineTuneCLI:
class FineTune:
@classmethod
def list(cls, args):
resp = openai.FineTune.list()
print(resp)

@classmethod
def _get_or_upload(cls, file):
try:
openai.File.retrieve(file)
except openai.error.InvalidRequestError as e:
if e.http_status == 404 and os.path.isfile(file):
resp = openai.File.create(file=open(file), purpose="fine-tune")
sys.stdout.write(
"Uploaded file from {file}: {id}\n".format(file=file, id=resp["id"])
)
return resp["id"]
return file

@classmethod
def create(cls, args):
create_args = {
"train_file": args.train_file,
"training_file": cls._get_or_upload(args.training_file),
}
if args.test_file:
create_args["test_file"] = args.test_file
if args.base_model:
create_args["base_model"] = args.base_model
if args.validation_file:
create_args["validation_file"] = cls._get_or_upload(args.validation_file)
if args.model:
create_args["model"] = args.model
if args.hparams:
try:
hparams = json.loads(args.hparams)
Expand All @@ -231,7 +247,35 @@ def create(cls, args):
create_args.update(hparams)

resp = openai.FineTune.create(**create_args)
print(resp)

if args.no_wait:
print(resp)
return

sys.stdout.write(
"Created job: {job_id}\n"
"Streaming events until the job is complete...\n\n"
"(Ctrl-C will interrupt the stream, but not cancel the job)\n".format(
job_id=resp["id"]
)
)
cls._stream_events(resp["id"])

resp = openai.FineTune.retrieve(id=resp["id"])
status = resp["status"]
sys.stdout.write("\nJob complete! Status: {status}".format(status=status))
if status == "succeeded":
sys.stdout.write(" 🎉")
sys.stdout.write(
"\nTry out your fine-tuned model: {model}\n"
"(Pass this as the model parameter to a completion request)".format(
model=resp["fine_tuned_model"]
)
)
# TODO(rachel): Print instructions on how to use the model here.
elif status == "failed":
sys.stdout.write("\nPlease contact [email protected] for assistance.")
sys.stdout.write("\n")

@classmethod
def get(cls, args):
Expand All @@ -240,8 +284,39 @@ def get(cls, args):

@classmethod
def events(cls, args):
resp = openai.FineTune.list_events(id=args.id)
print(resp)
if not args.stream:
resp = openai.FineTune.list_events(id=args.id)
print(resp)
return
cls._stream_events(args.id)

@classmethod
def _stream_events(cls, job_id):
def signal_handler(sig, frame):
status = openai.FineTune.retrieve(job_id).status
sys.stdout.write(
"\nStream interrupted. Job is still {status}. "
"To cancel your job, run:\n"
"`openai api fine_tunes.cancel -i {job_id}`\n".format(
status=status, job_id=job_id
)
)
sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

events = openai.FineTune.stream_events(job_id)
# TODO(rachel): Add a nifty spinner here.
for event in events:
sys.stdout.write(
"[%s] %s"
% (
datetime.datetime.fromtimestamp(event["created_at"]),
event["message"],
)
)
sys.stdout.write("\n")
sys.stdout.flush()

@classmethod
def cancel(cls, args):
Expand Down Expand Up @@ -436,27 +511,52 @@ def help(args):

# Finetune
sub = subparsers.add_parser("fine_tunes.list")
sub.set_defaults(func=FineTuneCLI.list)
sub.set_defaults(func=FineTune.list)

sub = subparsers.add_parser("fine_tunes.create")
sub.add_argument("-t", "--train_file", required=True, help="File to train")
sub.add_argument("--test_file", help="File to test")
sub.add_argument(
"-b",
"--base_model",
help="The model name to start the run from",
"-t",
"--training_file",
required=True,
help="JSONL file containing prompt-completion examples for training. This can "
"be the ID of a file uploaded through the OpenAI API (e.g. file-abcde12345) "
"or a local file path.",
)
sub.add_argument(
"-v",
"--validation_file",
help="JSONL file containing prompt-completion examples for validation. This can "
"be the ID of a file uploaded through the OpenAI API (e.g. file-abcde12345) "
"or a local file path.",
)
sub.add_argument(
"-m",
"--model",
help="The model to start fine-tuning from",
)
sub.add_argument(
"--no_wait",
action="store_true",
help="If set, returns immediately after creating the job. Otherwise, waits for the job to complete.",
)
sub.add_argument("-p", "--hparams", help="Hyperparameter JSON")
sub.set_defaults(func=FineTuneCLI.create)
sub.set_defaults(func=FineTune.create)

sub = subparsers.add_parser("fine_tunes.get")
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
sub.set_defaults(func=FineTuneCLI.get)
sub.set_defaults(func=FineTune.get)

sub = subparsers.add_parser("fine_tunes.events")
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
sub.set_defaults(func=FineTuneCLI.events)
sub.add_argument(
"-s",
"--stream",
action="store_true",
help="If set, events will be streamed until the job is done. Otherwise, "
"displays the event history to date.",
)
sub.set_defaults(func=FineTune.events)

sub = subparsers.add_parser("fine_tunes.cancel")
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
sub.set_defaults(func=FineTuneCLI.cancel)
sub.set_defaults(func=FineTune.cancel)
1 change: 1 addition & 0 deletions openai/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def log_info(message, **params):
print(msg, file=sys.stderr)
logger.info(msg)


def log_warn(message, **params):
msg = logfmt(dict(message=message, **params))
print(msg, file=sys.stderr)
Expand Down
2 changes: 1 addition & 1 deletion openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "0.6.3"
VERSION = "0.6.4"

0 comments on commit 7b0f97e

Please sign in to comment.