From 82f73f9e1a70c7ff9848c19811136af384d3ce82 Mon Sep 17 00:00:00 2001 From: Reinder Vos de Wael Date: Fri, 27 Dec 2024 15:36:29 -0500 Subject: [PATCH] refactor: Move LLM functionality to cloai (#160) * refactor: Move LLM functionality to cloai * refactor: Move LLM functionality to cloai # Conflicts: # src/ctk_functions/microservices/aws.py # src/ctk_functions/microservices/azure.py # src/ctk_functions/microservices/llm.py * remove qodana.yaml * various fixes for cloai * fix: mypy issues * chore: Cleanup --- .funcignore | 6 - .gitignore | 1 + function_app.py | 7 - host.json | 20 -- local.settings.json | 7 - pyproject.toml | 8 +- qodana.yaml | 34 -- src/ctk_functions/core/config.py | 1 + src/ctk_functions/microservices/aws.py | 70 ----- src/ctk_functions/microservices/azure.py | 70 ----- .../microservices/language_models.py | 49 +++ src/ctk_functions/microservices/llm.py | 290 ------------------ src/ctk_functions/microservices/utils.py | 22 -- .../routers/file_conversion/schemas.py | 4 - .../routers/intake/controller.py | 4 +- .../intake/intake_processing/writer.py | 4 +- .../intake/intake_processing/writer_llm.py | 9 +- src/ctk_functions/routers/intake/views.py | 4 +- src/ctk_functions/routers/llm/controller.py | 10 +- src/ctk_functions/routers/llm/schemas.py | 4 +- tests/conftest.py | 4 +- tests/endpoint/test_intake.py | 6 +- tests/endpoint/test_llm.py | 2 +- tests/integration/azure_test_intakes.py | 2 +- tests/smoke/test_function_app.py | 37 --- tests/smoke/test_intake_writer_smoke.py | 8 +- uv.lock | 36 +-- 27 files changed, 98 insertions(+), 621 deletions(-) delete mode 100644 .funcignore delete mode 100644 function_app.py delete mode 100644 host.json delete mode 100644 local.settings.json delete mode 100644 qodana.yaml delete mode 100644 src/ctk_functions/microservices/aws.py delete mode 100644 src/ctk_functions/microservices/azure.py create mode 100644 src/ctk_functions/microservices/language_models.py delete mode 100644 src/ctk_functions/microservices/llm.py delete mode 100644 src/ctk_functions/microservices/utils.py delete mode 100644 tests/smoke/test_function_app.py diff --git a/.funcignore b/.funcignore deleted file mode 100644 index 2829bb8..0000000 --- a/.funcignore +++ /dev/null @@ -1,6 +0,0 @@ -.git* -.vscode -__azurite_db*__.json -__blobstorage__ -__queuestorage__ -local.settings.json diff --git a/.gitignore b/.gitignore index fb49b25..3bffcb5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +qodana.yaml .idea ~* .ruff_cache diff --git a/function_app.py b/function_app.py deleted file mode 100644 index 6028636..0000000 --- a/function_app.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Entrypoint for Azure Functions.""" - -import azure.functions as func - -from ctk_functions import app as fastapi_app - -app = func.AsgiFunctionApp(app=fastapi_app.app, http_auth_level=func.AuthLevel.FUNCTION) diff --git a/host.json b/host.json deleted file mode 100644 index c3d1609..0000000 --- a/host.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "version": "2.0", - "extensions": { - "http": { - "routePrefix": "" - } - }, - "logging": { - "applicationInsights": { - "samplingSettings": { - "isEnabled": true, - "excludedTypes": "Request" - } - } - }, - "extensionBundle": { - "id": "Microsoft.Azure.Functions.ExtensionBundle", - "version": "[3.*, 4.0.0)" - } -} diff --git a/local.settings.json b/local.settings.json deleted file mode 100644 index 3a10b64..0000000 --- a/local.settings.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "IsEncrypted": false, - "Values": { - "AzureWebJobsStorage": "UseDevelopmentStorage=true", - "FUNCTIONS_WORKER_RUNTIME": "python" - } -} diff --git a/pyproject.toml b/pyproject.toml index c11d9b3..7210173 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,23 +7,19 @@ requires-python = ">=3.11, <3.12" dependencies = [ "aiofiles>=24.1.0", "aiohttp>=3.11.11", - "azure-functions>=1.21.3", "cmi-docx>=0.3.7", "fastapi[standard]>=0.115.6", - "jsonpickle>=4.0.1", "language-tool-python>=2.8.1", - "openai>=0.27.10", "pycap>=2.6.0", "pydantic>=2.10.4", "pydantic-settings>=2.7.0", "pypandoc-binary>=1.14", "python-dateutil>=2.9.0.post0", "pytz>=2024.2", - "instructor[anthropic]>1.6", - "anthropic[bedrock]>=0.37.1", "spacy>=3.8.3", "en-core-web-sm", - "httpx==0.27" + "cloai>=1.0.0", + "jsonpickle>=4.0.1" ] [tool.uv] diff --git a/qodana.yaml b/qodana.yaml deleted file mode 100644 index f8021a3..0000000 --- a/qodana.yaml +++ /dev/null @@ -1,34 +0,0 @@ -#-------------------------------------------------------------------------------# -# Qodana analysis is configured by qodana.yaml file # -# https://www.jetbrains.com/help/qodana/qodana-yaml.html # -#-------------------------------------------------------------------------------# -version: '1.0' - -#Specify inspection profile for code analysis -profile: - name: qodana.starter - -#Enable inspections -#include: -# - name: - -#Disable inspections -#exclude: -# - name: -# paths: -# - - -#Execute shell command before Qodana execution (Applied in CI/CD pipeline) -#bootstrap: sh ./prepare-qodana.sh - -#Install IDE plugins before Qodana execution (Applied in CI/CD pipeline) -#plugins: -# - id: #(plugin id can be found at https://plugins.jetbrains.com) - -#Specify Qodana linter for analysis (Applied in CI/CD pipeline) -linter: jetbrains/qodana-python:2024.3 -exclude: -- name: PyArgumentListInspection - paths: - - src/ctk_functions/core/config.py - - src/ctk_functions/microservices/redcap.py diff --git a/src/ctk_functions/core/config.py b/src/ctk_functions/core/config.py index b1e553a..4ca121f 100644 --- a/src/ctk_functions/core/config.py +++ b/src/ctk_functions/core/config.py @@ -20,6 +20,7 @@ class Settings(pydantic_settings.BaseSettings): AWS_ACCESS_KEY_ID: pydantic.SecretStr AWS_SECRET_ACCESS_KEY: pydantic.SecretStr + AWS_REGION: str = "us-west-2" AZURE_OPENAI_API_KEY: pydantic.SecretStr AZURE_OPENAI_LLM_DEPLOYMENT: pydantic.SecretStr diff --git a/src/ctk_functions/microservices/aws.py b/src/ctk_functions/microservices/aws.py deleted file mode 100644 index 72b43cb..0000000 --- a/src/ctk_functions/microservices/aws.py +++ /dev/null @@ -1,70 +0,0 @@ -"""This module contains interactions with AWS microservices.""" - -from typing import Literal - -import anthropic - -from ctk_functions.core import config -from ctk_functions.microservices import utils - -settings = config.get_settings() - -AWS_ACCESS_KEY_ID = settings.AWS_ACCESS_KEY_ID -AWS_SECRET_ACCESS_KEY = settings.AWS_SECRET_ACCESS_KEY -ANTHROPIC_MODELS = Literal[ - "anthropic.claude-3-5-sonnet-20241022-v2:0", - "anthropic.claude-3-5-sonnet-20240620-v1:0", - "anthropic.claude-3-opus-20240229-v1:0", -] - - -class ClaudeLlm(utils.LlmAbstractBaseClass): - """Caller for Claude Large Language models. - - Attributes: - client: The BedRock client. - model: The model that is invoked. - - """ - - def __init__( - self, - model: ANTHROPIC_MODELS, - ) -> None: - """Initializes the BedRock client.""" - if model in ( - "anthropic.claude-3-opus-20240229-v1:0", - "anthropic.claude-3-5-sonnet-20241022-v2:0", - ): - region = "us-west-2" - elif model == "anthropic.claude-3-5-sonnet-20240620-v1:0": - region = "us-east-1" - else: - msg = "Unknown model." - raise ValueError(msg) - - self.client = anthropic.AsyncAnthropicBedrock( - aws_access_key=AWS_ACCESS_KEY_ID.get_secret_value(), - aws_secret_key=AWS_SECRET_ACCESS_KEY.get_secret_value(), - aws_region=region, - ) - self.model = model - - async def run(self, system_prompt: str, user_prompt: str) -> str: - """Runs the model with the given prompts. - - Args: - system_prompt: The system prompt. - user_prompt: The user prompt. - - Returns: - The output text. - """ - message = await self.client.messages.create( - model=self.model, - max_tokens=5000, - system=system_prompt, - messages=[{"role": "user", "content": user_prompt}], - ) - - return message.content[0].text # type: ignore[union-attr] diff --git a/src/ctk_functions/microservices/azure.py b/src/ctk_functions/microservices/azure.py deleted file mode 100644 index 9d371b6..0000000 --- a/src/ctk_functions/microservices/azure.py +++ /dev/null @@ -1,70 +0,0 @@ -"""A module to interact with Azure Blob Storage.""" - -from typing import Literal - -import openai - -from ctk_functions.core import config -from ctk_functions.microservices import utils - -logger = config.get_logger() - -settings = config.get_settings() -AZURE_OPENAI_API_KEY = settings.AZURE_OPENAI_API_KEY -AZURE_OPENAI_LLM_DEPLOYMENT = settings.AZURE_OPENAI_LLM_DEPLOYMENT -AZURE_OPENAI_ENDPOINT = settings.AZURE_OPENAI_ENDPOINT - -GPT_MODELS = Literal["gpt-4o"] - - -class AzureLlm(utils.LlmAbstractBaseClass): - """A class to interact with the Azure Language Model service.""" - - def __init__( - self, - model: GPT_MODELS = "gpt-4o", - ) -> None: - """Initialize the Azure Language Model client. - - Args: - model: The model to use for the language model. - """ - self.client = openai.AsyncAzureOpenAI( - api_key=AZURE_OPENAI_API_KEY.get_secret_value(), - azure_endpoint=AZURE_OPENAI_ENDPOINT.get_secret_value(), - api_version="2024-02-01", - ) - self.model = model - - async def run(self, system_prompt: str, user_prompt: str) -> str: - """Runs the model with the given prompts. - - Args: - system_prompt: The system prompt. - user_prompt: The user prompt. - - Returns: - The output text. - """ - system_message = { - "role": "system", - "content": system_prompt, - } - user_message = { - "role": "user", - "content": user_prompt, - } - try: - response = await self.client.chat.completions.create( - messages=[system_message, user_message], # type: ignore[list-item] - model=AZURE_OPENAI_LLM_DEPLOYMENT.get_secret_value(), - ) - message = response.choices[0].message.content - except openai.BadRequestError: - # Fallback: Return a message to the user even on remote server failure. - # Example of this being necessary is content management policy. - message = "Failure in LLM processing. Please let the development team know." - - if message is None: - message = "Failure in LLM processing. Please let the development team know." - return message diff --git a/src/ctk_functions/microservices/language_models.py b/src/ctk_functions/microservices/language_models.py new file mode 100644 index 0000000..9124a37 --- /dev/null +++ b/src/ctk_functions/microservices/language_models.py @@ -0,0 +1,49 @@ +"""Large Language Model client creation.""" + +from typing import Literal, TypeGuard, get_args + +import cloai +from cloai.llm import bedrock + +from ctk_functions.core import config + +settings = config.get_settings() + +VALID_MODELS = Literal[ + "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", + "gpt-4o", +] + + +def get_llm(model: str) -> cloai.LargeLanguageModel: + """Gets the LLM client. + + Args: + model: Model name to use. + + Returns: + The client for the large language model. + """ + if _is_anthropic_bedrock_model(model): + client = cloai.AnthropicBedrockLlm( + model=model, + aws_access_key=settings.AWS_ACCESS_KEY_ID.get_secret_value(), + aws_secret_key=settings.AWS_SECRET_ACCESS_KEY.get_secret_value(), + region=settings.AWS_REGION, + ) + else: + client = cloai.AzureLlm( + deployment=settings.AZURE_OPENAI_LLM_DEPLOYMENT.get_secret_value(), + endpoint=settings.AZURE_OPENAI_ENDPOINT.get_secret_value(), + api_key=settings.AZURE_OPENAI_API_KEY.get_secret_value(), + api_version="2024-02-01", + ) + return cloai.LargeLanguageModel(client=client) + + +def _is_anthropic_bedrock_model( + model: str, +) -> TypeGuard[bedrock.ANTHROPIC_BEDROCK_MODELS]: + return model in get_args(bedrock.ANTHROPIC_BEDROCK_MODELS) diff --git a/src/ctk_functions/microservices/llm.py b/src/ctk_functions/microservices/llm.py deleted file mode 100644 index c1f674e..0000000 --- a/src/ctk_functions/microservices/llm.py +++ /dev/null @@ -1,290 +0,0 @@ -"""This module coalesces all large language models from different microservices.""" - -import asyncio -import typing -from collections.abc import Iterable -from typing import Any, Literal, TypeGuard, TypeVar, overload - -import instructor -import pydantic - -from ctk_functions.core import config -from ctk_functions.microservices import aws, azure, utils - -settings = config.get_settings() -logger = config.get_logger() - -LOGGER_PHI_LOGGING_LEVEL = settings.LOGGER_PHI_LOGGING_LEVEL -VALID_LLM_MODELS = typing.Literal[aws.ANTHROPIC_MODELS, azure.GPT_MODELS] - -T = TypeVar("T") - - -class GeneratedStatement(pydantic.BaseModel): - """A class for a statement about the correctness of an LLM result.""" - - statement: str = pydantic.Field( - ..., - description="A True or False statement about the text.", - ) - - @pydantic.field_validator("statement") - @classmethod - def statement_validation(cls, value: str) -> str: - """Check whether the phrase is actually a statement.""" - if value[0].isnumeric(): - msg = "statements should not be numbered." - raise ValueError(msg) - return value - - -class VerificationStatement(pydantic.BaseModel): - """A class for a statement verifying the correctness of an LLM result.""" - - statement: GeneratedStatement = pydantic.Field( - ..., - description="A True or False statement about the text.", - ) - correct: bool = pydantic.Field( - ..., - description="True if the answer to the statement is true, False otherwise.", - ) - - -class RewrittenText(pydantic.BaseModel): - """Class for rewriting text based on verification statements.""" - - text: str = pydantic.Field(..., description="The editted text.") - statements: tuple[VerificationStatement] = pydantic.Field( - ..., - description=( - "The statements along with whether they are True or False about the " - "editted text." - ), - ) - - -class LargeLanguageModel(pydantic.BaseModel, utils.LlmAbstractBaseClass): - """Llm class that provides access to all available LLMs. - - Attributes: - client: The client for the large language model. - """ - - model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True) - - model: VALID_LLM_MODELS - _client: azure.AzureLlm | aws.ClaudeLlm = pydantic.PrivateAttr() - _instructor_client: instructor.client.AsyncInstructor = pydantic.PrivateAttr() - - def model_post_init(self, __context: Any) -> None: # noqa: ANN401 - """Initializes the language model. - - Args: - model: The model to use for the language model. - """ - logger.info("Using LLM model: %s", self.model) - if self._is_azure_model(self.model): - self._client = azure.AzureLlm(self.model) - self._instructor_client = instructor.from_openai(self._client.client) - elif self._is_aws_model(self.model): - self._client = aws.ClaudeLlm(self.model) - self._instructor_client = instructor.from_anthropic(self._client.client) - else: - # As the model name can be supplied by the user, this case might be reached. - msg = f"Invalid LLM model: {self.model}" - raise ValueError(msg) - - async def run(self, system_prompt: str, user_prompt: str) -> str: - """Runs the model with the given prompts. - - Args: - system_prompt: The system prompt. - user_prompt: The user prompt. - - Returns: - The output text. - """ - return await self._client.run(system_prompt, user_prompt) - - @overload - async def chain_of_verification( - self, - system_prompt: str, - user_prompt: str, - statements: list[str] = ..., - max_verifications: int = ..., - *, - create_new_statements: bool, - ) -> str: - pass - - @overload - async def chain_of_verification( - self, - system_prompt: str, - user_prompt: str, - statements: None = None, - max_verifications: int = ..., - *, - create_new_statements: Literal[True], - ) -> str: - pass - - async def chain_of_verification( - self, - system_prompt: str, - user_prompt: str, - statements: list[str] | None = None, - max_verifications: int = 3, - *, - create_new_statements: bool = False, - ) -> str: - """Runs an LLM prompt that is self-assessed by the LLM. - - Args: - system_prompt: The system prompt for the initial prompt. - user_prompt: The user prompt for the initial prompt. - statements: Statements to verify the results. Defaults to None. - max_verifications: The maximum number of times to verify the results. - Defaults to 3. - create_new_statements: If True, generate new statements from the system - prompt. Defaults to False. - - Returns: - The edited text result. - """ - if statements is None and not create_new_statements: - msg = ( - "Either statements must be provided, or new statements need to be " - "generated, or both." - ) - raise ValueError(msg) - statements = statements or [] - - text_promise = self.run(system_prompt, user_prompt) - if create_new_statements: - statements_promise = self._create_statements(system_prompt) - text, new_statements = await asyncio.gather( - text_promise, - statements_promise, - ) - statements += [statement.statement for statement in new_statements] - else: - text = await text_promise - - logger.log( - LOGGER_PHI_LOGGING_LEVEL, - "Running with statements: %s", - statements, - ) - for _ in range(max_verifications): - logger.log(LOGGER_PHI_LOGGING_LEVEL, text) - rewrite = await self._verify( - text, - statements, - user_prompt, - ) - if all(statement.correct for statement in rewrite.statements): - break - logger.log( - LOGGER_PHI_LOGGING_LEVEL, - [q for q in rewrite.statements if not q.correct], - ) - text = rewrite.text - else: - logger.warning("Reached verification limit.") - - return text - - async def _create_statements(self, instructions: str) -> list[GeneratedStatement]: - """Creates statements for prompt result validation. - - Args: - instructions: The instructions provided to the model, commonly - the system prompt. - - Returns: - List of verification statements as strings. - """ - system_prompt = """ -Based on the following instructions, write a set of statements that can be -answered with True or False to determine whether a piece of text adheres to -these instructions. True should denote adherence to the structure whereas -False should denote a lack of adherence. - """ - - return await self.call_instructor( - list[GeneratedStatement], - system_prompt=system_prompt, - user_prompt=instructions, - max_tokens=4096, - ) - - async def call_instructor( - self, - response_model: type[T], - system_prompt: str, - user_prompt: str, - max_tokens: int, - ) -> T: - """Generic interface for Anthropic/OpenAI instructor.""" - if self._is_aws_model(self.model): - return await self._instructor_client.chat.completions.create( # type: ignore[type-var] - response_model=response_model, - messages=[ - { - "role": "user", - "content": user_prompt, - }, - ], - system=system_prompt, - model=self.model, - max_tokens=max_tokens, - ) - if self._is_azure_model(self.model): - return await self._instructor_client.chat.completions.create( # type: ignore[type-var] - response_model=response_model, - messages=[ - { - "role": "system", - "content": user_prompt, - }, - { - "role": "user", - "content": user_prompt, - }, - ], - model=self.model, - max_tokens=max_tokens, - ) - msg = "Invalid model." - raise ValueError(msg) - - async def _verify( - self, - text: str, - statements: Iterable[str], - source: str, - ) -> RewrittenText: - statement_string = "\n".join(statements) - system_prompt = ( - "Based on the following statements, edit the text to comply" - f"with all statements. The statements are as follows: " - f"{statement_string}. Furthermore, ensure that all edits are reflective " - f"of the source material: {source}" - ) - return await self.call_instructor( - response_model=RewrittenText, - system_prompt=system_prompt, - user_prompt=text, - max_tokens=4096, - ) - - @staticmethod - def _is_azure_model(model: VALID_LLM_MODELS) -> TypeGuard[azure.GPT_MODELS]: - return model in typing.get_args(azure.GPT_MODELS) - - @staticmethod - def _is_aws_model(model: VALID_LLM_MODELS) -> TypeGuard[aws.ANTHROPIC_MODELS]: - return model in typing.get_args(aws.ANTHROPIC_MODELS) diff --git a/src/ctk_functions/microservices/utils.py b/src/ctk_functions/microservices/utils.py deleted file mode 100644 index 2382ecd..0000000 --- a/src/ctk_functions/microservices/utils.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Utilities for the microservices.""" - -import abc - - -class LlmAbstractBaseClass(abc.ABC): - """An abstract class for large language model interfaces.""" - - model: str - - @abc.abstractmethod - def __init__(self, model: str) -> None: - """Initialize the Language Model client.""" - - @abc.abstractmethod - async def run(self, system_prompt: str, user_prompt: str) -> str: - """Runs the model with the given prompts. - - Args: - system_prompt: The system prompt. - user_prompt: The user prompt. - """ diff --git a/src/ctk_functions/routers/file_conversion/schemas.py b/src/ctk_functions/routers/file_conversion/schemas.py index 6c0a2c9..a53e2c9 100644 --- a/src/ctk_functions/routers/file_conversion/schemas.py +++ b/src/ctk_functions/routers/file_conversion/schemas.py @@ -1,12 +1,8 @@ """Schemas for the file conversion router.""" -from typing import TypeVar - import cmi_docx import pydantic -ParagraphStyleType = TypeVar("ParagraphStyleType", bound=cmi_docx.ParagraphStyle) - class PostMarkdown2DocxRequest(pydantic.BaseModel): """Definition of the markdown2docx request. diff --git a/src/ctk_functions/routers/intake/controller.py b/src/ctk_functions/routers/intake/controller.py index 2603aa7..d026bce 100644 --- a/src/ctk_functions/routers/intake/controller.py +++ b/src/ctk_functions/routers/intake/controller.py @@ -7,7 +7,7 @@ from fastapi import status from ctk_functions.core import config, exceptions -from ctk_functions.microservices import llm, redcap +from ctk_functions.microservices import redcap from ctk_functions.routers.intake.intake_processing import parser, writer logger = config.get_logger() @@ -15,7 +15,7 @@ async def get_intake_report( survey_id: str, - model: llm.VALID_LLM_MODELS, + model: str, enabled_tasks: writer.EnabledTasks | None = None, ) -> bytes: """Generates an intake report for a survey. diff --git a/src/ctk_functions/routers/intake/intake_processing/writer.py b/src/ctk_functions/routers/intake/intake_processing/writer.py index d507239..6e24eb0 100644 --- a/src/ctk_functions/routers/intake/intake_processing/writer.py +++ b/src/ctk_functions/routers/intake/intake_processing/writer.py @@ -14,7 +14,7 @@ from docx.text import paragraph as docx_paragraph from ctk_functions.core import config -from ctk_functions.microservices import llm, redcap +from ctk_functions.microservices import redcap from ctk_functions.routers.intake.intake_processing import ( parser, transformers, @@ -78,7 +78,7 @@ class ReportWriter: def __init__( self, intake: parser.IntakeInformation, - model: llm.VALID_LLM_MODELS, + model: str, enabled_tasks: EnabledTasks | None = None, ) -> None: """Initializes the report writer. diff --git a/src/ctk_functions/routers/intake/intake_processing/writer_llm.py b/src/ctk_functions/routers/intake/intake_processing/writer_llm.py index 9a25a38..402a170 100644 --- a/src/ctk_functions/routers/intake/intake_processing/writer_llm.py +++ b/src/ctk_functions/routers/intake/intake_processing/writer_llm.py @@ -9,7 +9,7 @@ import pydantic from ctk_functions.core import config -from ctk_functions.microservices import llm +from ctk_functions.microservices import language_models from ctk_functions.routers.intake.intake_processing.utils import string_utils logger = config.get_logger() @@ -105,7 +105,7 @@ class WriterLlm: def __init__( self, - model: llm.VALID_LLM_MODELS, + model: str, child_name: str, child_pronouns: Sequence[str], ) -> None: @@ -116,7 +116,7 @@ def __init__( child_name: The name of the child in the report. child_pronouns: The pronouns of the child in the report. """ - self.client = llm.LargeLanguageModel(model=model) + self.client = language_models.get_llm(model) self.child_name = child_name self.child_pronouns = child_pronouns self.placeholders: list[LlmPlaceholder] = [] @@ -249,7 +249,7 @@ def run_with_object_input( return self._run(system_prompt, user_prompt, verify=verify, comment=comment) def run_for_adjectives(self, description: str, comment: str | None = None) -> str: - """Extraces adjectives based on a description of a child. + """Extracts adjectives based on a description of a child. Args: description: The description of the child's strengths. @@ -317,6 +317,7 @@ def _run( replacement = self.client.chain_of_verification( system_prompt, user_prompt, + response_model=str, create_new_statements=True, ) else: diff --git a/src/ctk_functions/routers/intake/views.py b/src/ctk_functions/routers/intake/views.py index 48b9b1e..cef0fef 100644 --- a/src/ctk_functions/routers/intake/views.py +++ b/src/ctk_functions/routers/intake/views.py @@ -5,7 +5,7 @@ import fastapi from ctk_functions.core import config -from ctk_functions.microservices import llm +from ctk_functions.microservices import language_models from ctk_functions.routers.intake import controller logger = config.get_logger() @@ -15,7 +15,7 @@ @router.get("/intake-report/{mrn}") async def post_language_tool( mrn: str, - x_model: Annotated[llm.VALID_LLM_MODELS, fastapi.Header()], + x_model: Annotated[language_models.VALID_MODELS, fastapi.Header()], ) -> fastapi.Response: """POST endpoint for markdown2docx. diff --git a/src/ctk_functions/routers/llm/controller.py b/src/ctk_functions/routers/llm/controller.py index 558d6e0..09d1a3f 100644 --- a/src/ctk_functions/routers/llm/controller.py +++ b/src/ctk_functions/routers/llm/controller.py @@ -1,13 +1,13 @@ """Controller for the LLM model.""" from ctk_functions.core import config -from ctk_functions.microservices import llm +from ctk_functions.microservices import language_models settings = config.get_settings() async def run_llm( - model: llm.VALID_LLM_MODELS, + model: str, system_prompt: str, user_prompt: str, ) -> str: @@ -21,7 +21,5 @@ async def run_llm( Returns: The output text. """ - return await llm.LargeLanguageModel(model=model).run( - system_prompt, - user_prompt, - ) + client = language_models.get_llm(model) + return await client.run(system_prompt, user_prompt) # type: ignore[no-any-return] # I can't figure out why mypy believes this returns Any type. diff --git a/src/ctk_functions/routers/llm/schemas.py b/src/ctk_functions/routers/llm/schemas.py index 5976764..dd1ba2a 100644 --- a/src/ctk_functions/routers/llm/schemas.py +++ b/src/ctk_functions/routers/llm/schemas.py @@ -2,7 +2,7 @@ import pydantic -from ctk_functions.microservices import llm +from ctk_functions.microservices import language_models class PostLlmRequest(pydantic.BaseModel): @@ -14,6 +14,6 @@ class PostLlmRequest(pydantic.BaseModel): user_prompt: The user's message. """ - model: llm.VALID_LLM_MODELS + model: language_models.VALID_MODELS system_prompt: str user_prompt: str diff --git a/tests/conftest.py b/tests/conftest.py index 1c8ecbb..f1f2b2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,11 +9,11 @@ @pytest.fixture(scope="session") def test_redcap_data() -> redcap.RedCapData: - """Returns a dictionary of test data.""" + """Returns a dictionary of test data compliant with RedCAP.""" return redcap.get_intake_data("mock") @pytest.fixture(scope="session") def client() -> testclient.TestClient: - """Returns a test client.""" + """Returns a test client for the FastAPI server.""" return testclient.TestClient(app.app) diff --git a/tests/endpoint/test_intake.py b/tests/endpoint/test_intake.py index 698ec0c..8ba0836 100644 --- a/tests/endpoint/test_intake.py +++ b/tests/endpoint/test_intake.py @@ -21,15 +21,15 @@ def test_intake_with_model( ) -> None: """Tests whether the GET intake endpoint works with a model.""" mocker.patch( - "ctk_functions.microservices.llm.LargeLanguageModel.chain_of_verification", + "cloai.LargeLanguageModel.chain_of_verification", return_value="cov", ) mocker.patch( - "ctk_functions.microservices.llm.LargeLanguageModel.call_instructor", + "cloai.LargeLanguageModel.call_instructor", return_value="instructor", ) mocker.patch( - "ctk_functions.microservices.llm.LargeLanguageModel.run", + "cloai.LargeLanguageModel.run", return_value="run", ) diff --git a/tests/endpoint/test_llm.py b/tests/endpoint/test_llm.py index b2573a1..068f583 100644 --- a/tests/endpoint/test_llm.py +++ b/tests/endpoint/test_llm.py @@ -9,7 +9,7 @@ def test_llm(client: testclient.TestClient, mocker: pytest_mock.MockerFixture) -> None: """Test the LLM endpoint.""" spy = mocker.patch( - "ctk_functions.microservices.aws.ClaudeLlm.run", + "cloai.LargeLanguageModel.run", return_value="output", ) body = schemas.PostLlmRequest( diff --git a/tests/integration/azure_test_intakes.py b/tests/integration/azure_test_intakes.py index 3f69a13..bf1ed6c 100644 --- a/tests/integration/azure_test_intakes.py +++ b/tests/integration/azure_test_intakes.py @@ -1,4 +1,4 @@ -"""Test a few survey IDs.""" +"""Test a few survey IDs. This is run only on Azure Pipelines.""" import asyncio import os diff --git a/tests/smoke/test_function_app.py b/tests/smoke/test_function_app.py deleted file mode 100644 index d7b9928..0000000 --- a/tests/smoke/test_function_app.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Smoke tests for the function app module.""" - -import inspect -import json -import pathlib -import sys - -from azure.functions.decorators import function_app as azure_function_app - -import ctk_functions - -repository_root = pathlib.Path(ctk_functions.__file__).parent.parent.parent -sys.path.insert(0, str(repository_root)) - -import function_app # noqa: E402 - - -def get_function_auth_level( - function: azure_function_app.FunctionBuilder, -) -> str: - """Gets the auth level of a function.""" - settings = json.loads(function._function.get_function_json()) - return settings["bindings"][0]["authLevel"] # type: ignore[no-any-return] - - -def test_function_auth_level() -> None: - """Tests that no function has the anonymous auth level.""" - endpoints = inspect.getmembers( - function_app, - lambda attribute: isinstance(attribute, azure_function_app.FunctionBuilder), - ) - allowed_auth_levels = ["FUNCTION", "ADMIN"] - - for name, endpoint in endpoints: - auth_level = get_function_auth_level(endpoint) - - assert auth_level in allowed_auth_levels, f"{name} has the wrong auth level." diff --git a/tests/smoke/test_intake_writer_smoke.py b/tests/smoke/test_intake_writer_smoke.py index 67bb099..1d56135 100644 --- a/tests/smoke/test_intake_writer_smoke.py +++ b/tests/smoke/test_intake_writer_smoke.py @@ -26,15 +26,15 @@ async def intake_document( ) -> document.Document: """Returns a file-like object for the intake_writer.py module.""" mocker.patch( - "ctk_functions.microservices.azure.AzureLlm.run", + "cloai.LargeLanguageModel.run", return_value="llm", ) mocker.patch( - "ctk_functions.microservices.llm.LargeLanguageModel.chain_of_verification", + "cloai.LargeLanguageModel.chain_of_verification", return_value="cov", ) mocker.patch( - "ctk_functions.microservices.llm.LargeLanguageModel.call_instructor", + "cloai.LargeLanguageModel.call_instructor", return_value="instructor", ) intake_info = parser.IntakeInformation(test_redcap_data) @@ -63,7 +63,7 @@ async def test_no_printed_objects( assert "<" not in text assert ">" not in text assert "none" not in text - assert "ctk_api" not in text + assert "ctk_functions" not in text assert "object at 0x" not in text assert "replacementtags" not in text assert re.match(regex_scientific_notation, text) is None diff --git a/uv.lock b/uv.lock index 5bf716e..77120fe 100644 --- a/uv.lock +++ b/uv.lock @@ -120,15 +120,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/aa/ab0f7891a01eeb2d2e338ae8fecbe57fcebea1a24dbb64d45801bfab481d/attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308", size = 63397 }, ] -[[package]] -name = "azure-functions" -version = "1.21.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ee/6a/614f670cd8d28c68f744d849fafb654e6ab456954eb7160ddee25bfaa373/azure-functions-1.21.3.tar.gz", hash = "sha256:c359b9dbd2998c84d8595e31a28ffad4e8ca8dcfb4a3798327f776a67f964351", size = 205342 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/85/9f/2ab4ee66cb803d48b4a685a2c0a1a20b8f54f9c8a4ac474ffe1ebc0f9466/azure_functions-1.21.3-py3-none-any.whl", hash = "sha256:6c45f5e61fe59328c81928a428d43c838e9ead52f1cde1628fcabb372fa10cc8", size = 185657 }, -] - [[package]] name = "blis" version = "1.1.0" @@ -238,6 +229,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, ] +[[package]] +name = "cloai" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anthropic", extra = ["bedrock"] }, + { name = "httpx" }, + { name = "instructor", extra = ["anthropic"] }, + { name = "openai" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/e7/d8e7d0f56eb045ae528e70602a54f5ed2947d2403728e187625ababe6759/cloai-1.0.0.tar.gz", hash = "sha256:b0cd28827385cb385ab71dc6ea875a277eee5a52f4016d2f0a6e89f50ed6702d", size = 29715 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/c3/e339910bbc8f62b4f989a4c8f0e5d2df6958c414d8722fd868bc719c425d/cloai-1.0.0-py3-none-any.whl", hash = "sha256:68eea48c336d85546cc11723138ad06f4bbd814d6497872cb528c663bfe0f802", size = 20463 }, +] + [[package]] name = "cloudpathlib" version = "0.20.0" @@ -311,16 +317,12 @@ source = { editable = "." } dependencies = [ { name = "aiofiles" }, { name = "aiohttp" }, - { name = "anthropic", extra = ["bedrock"] }, - { name = "azure-functions" }, + { name = "cloai" }, { name = "cmi-docx" }, { name = "en-core-web-sm" }, { name = "fastapi", extra = ["standard"] }, - { name = "httpx" }, - { name = "instructor", extra = ["anthropic"] }, { name = "jsonpickle" }, { name = "language-tool-python" }, - { name = "openai" }, { name = "pycap" }, { name = "pydantic" }, { name = "pydantic-settings" }, @@ -350,16 +352,12 @@ dev = [ requires-dist = [ { name = "aiofiles", specifier = ">=24.1.0" }, { name = "aiohttp", specifier = ">=3.11.11" }, - { name = "anthropic", extras = ["bedrock"], specifier = ">=0.37.1" }, - { name = "azure-functions", specifier = ">=1.21.3" }, + { name = "cloai", specifier = ">=1.0.0" }, { name = "cmi-docx", specifier = ">=0.3.7" }, { name = "en-core-web-sm", url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0.tar.gz" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" }, - { name = "httpx", specifier = "==0.27" }, - { name = "instructor", extras = ["anthropic"], specifier = ">1.6" }, { name = "jsonpickle", specifier = ">=4.0.1" }, { name = "language-tool-python", specifier = ">=2.8.1" }, - { name = "openai", specifier = ">=0.27.10" }, { name = "pycap", specifier = ">=2.6.0" }, { name = "pydantic", specifier = ">=2.10.4" }, { name = "pydantic-settings", specifier = ">=2.7.0" },