forked from cbizon/Bagel
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from RENCI-NER/langchain
Initial langchain implementation
- Loading branch information
Showing
9 changed files
with
423 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
name: 'Release a new version to Github Packages' | ||
|
||
on: | ||
release: | ||
types: [published] | ||
|
||
env: | ||
REGISTRY: ghcr.io | ||
|
||
jobs: | ||
push_to_registry: | ||
name: Push Docker image to GitHub Packages tagged with "latest" and version number. | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Check out the repo | ||
uses: actions/checkout@v2 | ||
- name: Get the version | ||
id: get_version | ||
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//} | ||
- name: Login to ghcr | ||
uses: docker/login-action@v1 | ||
with: | ||
registry: ${{ env.REGISTRY }} | ||
username: ${{ github.actor }} | ||
password: ${{ secrets.GITHUB_TOKEN }} | ||
- name: Build and push base image | ||
uses: docker/build-push-action@v5 | ||
with: | ||
context: . | ||
push: true | ||
tags: ghcr.io/renci-ner/bagel:${{ steps.get_version.outputs.VERSION }} | ||
labels: ${{ steps.meta_base_image.outputs.labels }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
FROM ghcr.io/translatorsri/renci-python-image:3.11 | ||
ARG BRANCH=main | ||
|
||
RUN pip install --upgrade pip | ||
|
||
ENV USER nru | ||
ENV HOME /home/$USER | ||
|
||
USER $USER | ||
WORKDIR $HOME | ||
|
||
ENV PATH=$HOME/.local/bin:$PATH | ||
|
||
COPY --chown=$USER . bagel/ | ||
WORKDIR $HOME/bagel | ||
ENV PYTHONPATH=$HOME/bagel/src | ||
RUN pip install -r requirements.txt | ||
ENTRYPOINT python src/server.py | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
prompts: | ||
- name: "ask_classes" | ||
version: "" | ||
|
||
openai_config: | ||
llm_model_name: "gpt-4o-2024-05-13" | ||
organization: "" | ||
access_key: "" | ||
llm_model_args: {} | ||
|
||
ollama_config: | ||
llm_model_name: "llama3" | ||
ollama_base_url: "https://ollama.apps.renci.org" | ||
llm_model_args: {} | ||
|
||
langServe: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import prompt | ||
from prompt import load_prompts | ||
from config import settings, Settings, OLLAMAConfig, OpenAIConfig | ||
from langchain_openai import ChatOpenAI | ||
from langchain_community.llms import Ollama | ||
from langchain.llms import BaseLLM | ||
from langchain.prompts import Prompt | ||
from models import SynonymListContext, SynonymClassesResponse, Entity | ||
from langchain_core.output_parsers import JsonOutputParser | ||
from typing import List, Dict | ||
|
||
|
||
def get_ollama_llm(ollama_config: OLLAMAConfig): | ||
""" | ||
Get an instance of ollama class | ||
:param ollama_config: configuration for ollama class | ||
:return: | ||
""" | ||
return Ollama( | ||
base_url=ollama_config.ollama_base_url, | ||
model=ollama_config.llm_model_name, | ||
**ollama_config.llm_model_args | ||
) | ||
|
||
|
||
def get_openai_llm(openai_config: OpenAIConfig): | ||
""" | ||
Get an instance of ChatOpenAI class | ||
:param openai_config: Configuration for openai class | ||
:return: | ||
""" | ||
return ChatOpenAI( | ||
api_key=openai_config.access_key, | ||
organization=openai_config.organization, | ||
model=openai_config.llm_model_name, | ||
**openai_config.llm_model_args, | ||
) | ||
|
||
|
||
class LLMHelper: | ||
@classmethod | ||
async def ask(cls, llm: BaseLLM, prompt: Prompt, synonym_context: SynonymListContext): | ||
chain = cls.get_chain(prompt=prompt, llm=llm) | ||
response = await chain.ainvoke({ | ||
'text': synonym_context.text, | ||
'term': synonym_context.entity, | ||
'synonym_list': synonym_context.pretty_print_synonyms() | ||
}, verbose=True) | ||
return LLMHelper.re_map_responses(synonym_context.synonyms, response) | ||
|
||
|
||
@classmethod | ||
def re_map_responses(cls, synonym_list: List[Entity], llm_response: Dict[str, any]) -> any: | ||
# Map back ids to llm responses | ||
by_color = { | ||
entity.color_code: entity for entity in synonym_list | ||
} | ||
entities = [] | ||
for item in llm_response: | ||
synonym_type = item['synonym_type'] | ||
entity = by_color[item['color_code']] | ||
final = entity.dict() | ||
final.update({ | ||
"synonym_type": synonym_type | ||
}) | ||
del final['color_code'] | ||
entities.append(final) | ||
return entities | ||
|
||
@classmethod | ||
async def ask_batch(cls, llm: BaseLLM, prompt: Prompt, synonym_contexts: List[SynonymListContext]): | ||
chain = cls.get_chain(prompt=prompt, llm=llm) | ||
responses = await chain.abatch([{ | ||
'text': synonym_context.text, | ||
'term': synonym_context.entity, | ||
'synonym_list': synonym_context.pretty_print_synonyms() | ||
} for synonym_context in synonym_contexts]) | ||
results = [] | ||
for synonym_context, response in zip(synonym_contexts,responses): | ||
results.append(LLMHelper.re_map_responses(synonym_context.synonyms, response)) | ||
return results | ||
|
||
|
||
@classmethod | ||
def get_chain(cls, prompt: Prompt, llm: BaseLLM, model_name: str = ""): | ||
chain = (prompt | llm | JsonOutputParser(pydantic_object=SynonymClassesResponse)) | ||
chain.name = prompt.metadata['lc_hub_repo'] + '_' + model_name | ||
return chain | ||
|
||
|
||
class ChainFactory: | ||
chains = {} | ||
|
||
@classmethod | ||
def get_llms(cls): | ||
# Add additional LLMS here | ||
return [ | ||
(settings.openai_config.llm_model_name, get_openai_llm(settings.openai_config)), | ||
(settings.ollama_config.llm_model_name, get_ollama_llm(settings.ollama_config)) | ||
] | ||
|
||
@classmethod | ||
def init_chains(cls): | ||
prompts = load_prompts(settings.prompts) | ||
ollama = get_ollama_llm(settings.ollama_config) | ||
for key, value in prompts.items(): | ||
ChainFactory.chains[key] = [LLMHelper.get_chain(value, llm=llms[1], model_name=llms[0]) | ||
for llms in ChainFactory.get_llms()] | ||
|
||
@classmethod | ||
def get_chain(cls, prompt_name): | ||
if not cls.chains: | ||
cls.init_chains() | ||
if not prompt_name in ChainFactory.chains: | ||
raise ValueError(f"Prompt {prompt_name} not found locally or in hub.") | ||
else: | ||
return cls.chains[prompt_name] | ||
|
||
|
||
@classmethod | ||
def get_all_chains(cls): | ||
if not cls.chains: | ||
cls.init_chains() | ||
all_chains = [] | ||
for prompt_name, chain in ChainFactory.chains.items(): | ||
all_chains += chain | ||
return all_chains | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from pydantic import BaseModel, Field | ||
from typing import List | ||
import yaml,pathlib, os | ||
|
||
|
||
|
||
class PromptSettings(BaseModel): | ||
""" Basic class for prompt location and version""" | ||
name: str = Field(..., description="Name of the prompt") | ||
version: str = Field("", description="Version of the prompt") | ||
# The langchainhub owner address where to pull prompts from for eg "bagel" | ||
# see : https://smith.langchain.com/hub/bagel | ||
hub_base: str = Field('bagel', description="Langchain hub org name") | ||
|
||
@classmethod | ||
def from_str(cls, prompt_address: str) -> "PromptSettings": | ||
hub_base, name = prompt_address.split('/') | ||
name, version = prompt_address.split(':') | ||
return PromptSettings(name=name, version=version, hub_base=hub_base) | ||
|
||
def __str__(self) -> str: | ||
name = self.name + ':' + self.version if self.version else self.name | ||
return self.hub_base + '/' + name | ||
|
||
|
||
class OpenAIConfig(BaseModel): | ||
llm_model_name: str = Field(default="gpt4o", description="Name of the model") | ||
organization: str = Field(default="org id", description="OPENAI organization") | ||
access_key: str = Field(default="access key", description="OPENAI access key") | ||
llm_model_args: dict = Field(default_factory=dict, description="Arguments to pass to the model") | ||
|
||
|
||
class OLLAMAConfig(BaseModel): | ||
llm_model_name: str = Field(default="llama3", description="Name of the model") | ||
ollama_base_url: str = Field(default="https://ollama.apps.renci.org", description="URL of the OLLAMA instance") | ||
llm_model_args: dict = Field(default_factory=dict, description="Arguments to pass to the model") | ||
|
||
|
||
class Settings(BaseModel): | ||
prompts: List[PromptSettings] = Field(default=[], description="Prompts to be used in the applicaton") | ||
openai_config: OpenAIConfig = Field(default=None, description="") | ||
ollama_config: OLLAMAConfig = Field(default=None, description="") | ||
langServe: bool = True | ||
|
||
|
||
|
||
def load_settings(): | ||
yaml_path = pathlib.Path(os.path.dirname(__file__), '..', 'settings.yaml') | ||
with open(str(yaml_path), 'r') as stream: | ||
_settings = Settings(**yaml.load(stream, yaml.FullLoader)) | ||
return _settings | ||
|
||
|
||
# app settings | ||
settings = load_settings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from pydantic import BaseModel, Field | ||
from typing import List | ||
|
||
|
||
colors = [ | ||
"red", | ||
"blue", | ||
"green", | ||
"yellow", | ||
"purple", | ||
"orange", | ||
"pink", | ||
"brown", | ||
"black", | ||
"white", | ||
"gray", | ||
"cyan", | ||
"magenta", | ||
"lime", | ||
"maroon", | ||
"navy", | ||
"olive", | ||
"teal", | ||
"aqua", | ||
"gold" | ||
] | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class Entity(BaseModel): | ||
""" Entity class: Represents a normalized entity (Synonym).""" | ||
label: str = Field(..., description="Label of the Entity.") | ||
identifier: str = Field("", description="Curie identifier of the Entity.") | ||
description: str = Field("", description="Formal description(definition) of the Entity.") | ||
entity_type: str = Field("", description="Type of the Entity.") | ||
color_code: str = Field("", description="Color coding for mapping back items.") | ||
|
||
class SynonymListContext(BaseModel): | ||
text: str = Field(..., description="Body of text containing entity.") | ||
entity: str = Field(..., description="Entity identified in text.") | ||
synonyms: List[Entity] = Field(..., description="Entities linked to the target entity, to be re-ranked.") | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
for index, synonym in enumerate(self.synonyms): | ||
synonym.color_code = colors[index] | ||
|
||
def pretty_print_synonyms(self): | ||
string = "\n" | ||
for synonym in self.synonyms: | ||
string += f"\t- {synonym.label}" | ||
string += f" ({synonym.entity_type}) ({synonym.color_code})" if synonym.entity_type else "" | ||
string += f" : {synonym.description}" if synonym.description else "" | ||
string += "\n" | ||
return string | ||
|
||
|
||
class SynonymClassResponse(BaseModel): | ||
synonym: str = Field(..., description="Synonym") | ||
vocabulary_class: str = Field(..., description="Vocabulary class") | ||
synonym_type: str = Field(..., description="Synonym type") | ||
|
||
|
||
class SynonymClassesResponse(BaseModel): | ||
synonyms: List[SynonymClassResponse] = Field(..., description="Synonyms") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# set the LANGCHAIN_API_KEY environment variable (create key in settings) | ||
from langchain import hub | ||
from config import PromptSettings, List | ||
|
||
|
||
def load_prompt_from_hub(prompt_name): | ||
""" Loads prompts from langchain hub and returns a prompt object.""" | ||
return hub.pull(str(prompt_name)) | ||
|
||
|
||
def load_prompts(prompts: List[PromptSettings]): | ||
# loads all prompts | ||
prompt_mapping = {} | ||
for prompt in prompts: | ||
prompt_mapping[prompt.name] = load_prompt_from_hub(str(prompt)) | ||
return prompt_mapping | ||
|
Oops, something went wrong.