Skip to content

Commit

Permalink
Merge pull request #3 from phelps-sg/grit/b6799721-042e-4810-8c7e-5e6…
Browse files Browse the repository at this point in the history
…9e6c9b484

[bot] Run grit migration: Apply a GritQL pattern
  • Loading branch information
phelps-sg authored Nov 17, 2023
2 parents 22588ae + 5b24ae2 commit 28a4228
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 206 deletions.
8 changes: 8 additions & 0 deletions .grit/grit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
version: 0.0.1
patterns:
- name: github.com/getgrit/js#*
- name: github.com/getgrit/python#*
- name: github.com/getgrit/json#*
- name: github.com/getgrit/hcl#*
- name: github.com/getgrit/python#openai
level: info
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,22 @@ repos:

# mypy
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.2.0' # Use the sha / tag you want to point at
rev: 'v1.7.0' # Use the sha / tag you want to point at
hooks:
- id: mypy
exclude: 'jupyter-book/.*|tests/.*'
args: [--config=mypy.ini, --explicit-package-bases]
additional_dependencies: [ pandas-stubs ]
additional_dependencies: [ openai ]

# mypy for tests
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.2.0' # Use the sha / tag you want to point at
rev: 'v1.7.0' # Use the sha / tag you want to point at
hooks:
- id: mypy
name: mypy-tests
files: ^tests/
args: [--config=mypy-tests.ini]
additional_dependencies: [ pandas-stubs ]
additional_dependencies: [ openai ]

# pylint
- repo: local
Expand All @@ -91,4 +91,4 @@ repos:
"-rn", # Only display messages
"-sn", # Don't display the score
"--rcfile=.pylintrc",
]
]
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
![GitHub release (latest by date)](https://img.shields.io/github/v/release/phelps-sg/openai-pygenerator)
![GitHub](https://img.shields.io/github/license/phelps-sg/openai-pygenerator?color=blue)

## Import Note

The openai Python API versions 1.0.0 and later now incorporate retry functionality and type annotations. This package has been migrated to the new API, but should only be used as a temporary measure to ensure backward compatibility. If you are using this package in production, consider rewriting your code so that it uses the new openai API directly.

## Overview

This is a simple type-annotated wrapper around the OpenAI Python API which:
- provides configurable retry functionality,
- reduces the default timeout from 10 minutes to 20 seconds (configurable),
Expand Down Expand Up @@ -122,8 +128,6 @@ export GPT_MODEL=gpt-3.5-turbo
export GPT_TEMPERATURE=0.2
export GPT_MAX_TOKENS=500
export GPT_MAX_RETRIES=5
export GPT_RETRY_EXPONENT_SECONDS=2
export GPT_RETRY_BASE_SECONDS=20
export GPT_REQUEST_TIMEOUT_SECONDS=20
export OPENAI_API_KEY=<key>
python src/openai_pygenerator/example.py
Expand Down
314 changes: 263 additions & 51 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ packages = [{include="openai_pygenerator", from="src"}]

[tool.poetry.dependencies]
python = "^3.8.1"
openai = "^0.28.1"
openai = "^1.0.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.2"
Expand Down
2 changes: 1 addition & 1 deletion src/openai_pygenerator/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
user_message,
)

high_temp_completions = completer(temperature=0.8)
high_temp_completions = completer(temperature=1.0)


def heading(message: str, margin: int = 80) -> None:
Expand Down
158 changes: 86 additions & 72 deletions src/openai_pygenerator/openai_pygenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,22 @@

import logging
import os
import time
from enum import Enum, auto
from typing import Callable, Dict, Iterable, Iterator, List, NewType, Optional, TypeVar

import openai
import urllib3.exceptions
from openai.error import (
APIConnectionError,
APIError,
RateLimitError,
ServiceUnavailableError,
from functools import lru_cache
from typing import Callable, Iterator, List, NewType, Optional, TypeVar

from openai import OpenAI
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionUserMessageParam,
)

Completion = Dict[str, str]
Completion = ChatCompletionMessageParam
Seconds = NewType("Seconds", int)
Completions = Iterator[Completion]
History = Iterable[Completion]
History = List[Completion]
Completer = Callable[[History, int], Completions]
T = TypeVar("T")

Expand All @@ -54,38 +53,46 @@ def var(name: str, to_type: Callable[[str], T], default: T) -> T:
return to_type(result)


OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GPT_MODEL = var("GPT_MODEL", str, "gpt-3.5-turbo")
GPT_TEMPERATURE = var("GPT_TEMPERATURE", float, 0.2)
GPT_MAX_TOKENS = var("GPT_MAX_TOKENS", int, 500)
GPT_MAX_RETRIES = var("GPT_MAX_RETRIES", int, 5)
GPT_RETRY_EXPONENT_SECONDS = Seconds(var("GPT_RETRY_EXPONENT_SECONDS", int, 2))
GPT_RETRY_BASE_SECONDS = Seconds(var("GPT_RETRY_BASE_SECONDS", int, 20))
GPT_REQUEST_TIMEOUT_SECONDS = Seconds(var("GPT_REQUEST_TIMEOUT_SECONDS", int, 20))

logger = logging.getLogger(__name__)
openai.api_key = os.getenv("OPENAI_API_KEY")


def completer(
model: str = GPT_MODEL,
max_tokens: int = GPT_MAX_TOKENS,
temperature: float = GPT_TEMPERATURE,
max_retries: int = GPT_MAX_RETRIES,
retry_base: Seconds = GPT_RETRY_BASE_SECONDS,
retry_exponent: Seconds = GPT_RETRY_EXPONENT_SECONDS,
# retry_base: Seconds = GPT_RETRY_BASE_SECONDS,
# retry_exponent: Seconds = GPT_RETRY_EXPONENT_SECONDS,
request_timeout: Seconds = GPT_REQUEST_TIMEOUT_SECONDS,
) -> Completer:
@lru_cache()
def get_client() -> OpenAI:
return OpenAI(
api_key=OPENAI_API_KEY,
timeout=request_timeout,
max_retries=max_retries,
)

def f(messages: History, n: int = 1) -> Completions:
return generate_completions(
get_client,
messages,
model,
max_tokens,
temperature,
max_retries,
retry_base,
retry_exponent,
request_timeout,
n,
n
# max_retries,
# retry_base,
# retry_exponent,
# request_timeout,
# n,
)

return f
Expand All @@ -95,61 +102,68 @@ def f(messages: History, n: int = 1) -> Completions:


def generate_completions(
client: Callable[[], OpenAI],
messages: History,
model: str,
max_tokens: int,
temperature: float,
max_retries: int,
retry_base: Seconds,
retry_exponent: Seconds,
request_timeout: Seconds,
# max_retries: int,
# retry_base: Seconds,
# retry_exponent: Seconds,
# request_timeout: Seconds,
n: int = 1,
retries: int = 0,
# retries: int = 0,
) -> Completions:
logger.debug("messages = %s", messages)
try:
result = openai.ChatCompletion.create(
model=model,
messages=messages,
max_tokens=max_tokens,
n=n,
temperature=temperature,
request_timeout=request_timeout,
)
logger.debug("response = %s", result)
for choice in result.choices: # type: ignore
yield choice.message
except (
openai.error.Timeout, # type: ignore
urllib3.exceptions.TimeoutError,
RateLimitError,
APIConnectionError,
APIError,
ServiceUnavailableError,
) as err:
if isinstance(err, APIError) and not (err.http_status in [524, 502, 500]):
raise
logger.warning("Error returned from openai API: %s", err)
logger.debug("retries = %d", retries)
if retries < max_retries:
logger.info("Retrying... ")
time.sleep(retry_base + retry_exponent**retries)
for completion in generate_completions(
messages,
model,
max_tokens,
temperature,
max_retries,
retry_base,
retry_exponent,
request_timeout,
n,
retries + 1,
):
yield completion
else:
logger.error("Maximum retries reached, aborting.")
raise
# try:
result = client().chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
n=n,
temperature=temperature,
)
# request_timeout=request_timeout)
logger.debug("response = %s", result)
for choice in result.choices: # type: ignore
yield to_message_param(choice.message)
# except (
# openai.Timeout, # type: ignore
# urllib3.exceptions.TimeoutError,
# RateLimitError,
# APIConnectionError,
# APIError,
# ServiceUnavailableError,
# ) as err:
# if isinstance(err, APIError) and not (err.http_status in [524, 502, 500]):
# raise
# logger.warning("Error returned from openai API: %s", err)
# logger.debug("retries = %d", retries)
# if retries < max_retries:
# logger.info("Retrying... ")
# time.sleep(retry_base + retry_exponent**retries)
# for completion in generate_completions(
# messages,
# model,
# max_tokens,
# temperature,
# max_retries,
# retry_base,
# retry_exponent,
# request_timeout,
# n,
# retries + 1,
# ):
# yield completion
# else:
# logger.error("Maximum retries reached, aborting.")
# raise


def to_message_param(message: ChatCompletionMessage) -> Completion:
return ChatCompletionAssistantMessageParam(
{"role": message.role, "content": message.content}
)


def next_completion(completions: Completions) -> Optional[Completion]:
Expand All @@ -160,11 +174,11 @@ def next_completion(completions: Completions) -> Optional[Completion]:


def user_message(text: str) -> Completion:
return {"role": "user", "content": text}
return ChatCompletionUserMessageParam({"role": "user", "content": text})


def content(completion: Completion) -> str:
return completion["content"]
return str(completion["content"])


def role(completion: Completion) -> Role:
Expand Down
Loading

0 comments on commit 28a4228

Please sign in to comment.