Skip to content

Commit

Permalink
Add support for Writer models
Browse files Browse the repository at this point in the history
  • Loading branch information
samjulien committed Oct 23, 2024
1 parent 7ca0d43 commit 1e471dd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
7 changes: 7 additions & 0 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,3 +1000,10 @@ def get_model_info(name: str) -> ModelInfo:
"https://huggingface.co/cllm",
"consistency-llm is a new generation of parallel decoder LLMs with fast generation speed.",
)

register_model_info(
["palmyra-x-004"],
"Palmyra X 004",
"https://dev.writer.com/home/models",
"Palmyra by Writer",
)
48 changes: 48 additions & 0 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,16 @@ def get_api_provider_stream_iter(
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "writer":
prompt = conv.to_openai_api_messages()
stream_iter = writer_api_stream_iter(
model_name,
prompt,
temperature,
top_p,
max_tokens=max_new_tokens,
api_key=model_api_dict["api_key"],
)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -1261,3 +1271,41 @@ def metagen_api_stream_iter(
"text": f"**API REQUEST ERROR** Reason: Unknown.",
"error_code": 1,
}


def writer_api_stream_iter(
model_name, messages, temperature, top_p, max_tokens, api_key
):
from writerai import Writer

api_key = api_key or os.environ["WRITER_API_KEY"]

client = Writer(api_key=api_key)

# Make requests
gen_params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
}
logger.info(f"==== request ====\n{gen_params}")

res = client.chat.chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
model=model_name,
stream=True,
)
text = ""
for chunk in res:
if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
text += chunk.choices[0].delta.content
data = {
"text": text,
"error_code": 0,
}
yield data

0 comments on commit 1e471dd

Please sign in to comment.