Skip to content

Commit

Permalink
Merge pull request #3 from RENCI-NER/langchain
Browse files Browse the repository at this point in the history
Initial langchain implementation
  • Loading branch information
YaphetKG authored Jul 3, 2024
2 parents 2ec70d3 + 11d740c commit d840c29
Show file tree
Hide file tree
Showing 9 changed files with 423 additions and 0 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/release-docker.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: 'Release a new version to Github Packages'

on:
release:
types: [published]

env:
REGISTRY: ghcr.io

jobs:
push_to_registry:
name: Push Docker image to GitHub Packages tagged with "latest" and version number.
runs-on: ubuntu-latest
steps:
- name: Check out the repo
uses: actions/checkout@v2
- name: Get the version
id: get_version
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
- name: Login to ghcr
uses: docker/login-action@v1
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push base image
uses: docker/build-push-action@v5
with:
context: .
push: true
tags: ghcr.io/renci-ner/bagel:${{ steps.get_version.outputs.VERSION }}
labels: ${{ steps.meta_base_image.outputs.labels }}
20 changes: 20 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM ghcr.io/translatorsri/renci-python-image:3.11
ARG BRANCH=main

RUN pip install --upgrade pip

ENV USER nru
ENV HOME /home/$USER

USER $USER
WORKDIR $HOME

ENV PATH=$HOME/.local/bin:$PATH

COPY --chown=$USER . bagel/
WORKDIR $HOME/bagel
ENV PYTHONPATH=$HOME/bagel/src
RUN pip install -r requirements.txt
ENTRYPOINT python src/server.py


5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,9 @@ langchain==0.2.1
langchain-community==0.2.1
langchain-core==0.2.3
langchain-text-splitters==0.2.0
pydantic==1.10.13
langchainhub==0.1.20
langchain-openai==0.1.8
langserve[all]==0.2.2
fastapi==0.111.0
git+https://github.com/RENCI-NER/benchmarks.git
16 changes: 16 additions & 0 deletions settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
prompts:
- name: "ask_classes"
version: ""

openai_config:
llm_model_name: "gpt-4o-2024-05-13"
organization: ""
access_key: ""
llm_model_args: {}

ollama_config:
llm_model_name: "llama3"
ollama_base_url: "https://ollama.apps.renci.org"
llm_model_args: {}

langServe: true
130 changes: 130 additions & 0 deletions src/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import prompt
from prompt import load_prompts
from config import settings, Settings, OLLAMAConfig, OpenAIConfig
from langchain_openai import ChatOpenAI
from langchain_community.llms import Ollama
from langchain.llms import BaseLLM
from langchain.prompts import Prompt
from models import SynonymListContext, SynonymClassesResponse, Entity
from langchain_core.output_parsers import JsonOutputParser
from typing import List, Dict


def get_ollama_llm(ollama_config: OLLAMAConfig):
"""
Get an instance of ollama class
:param ollama_config: configuration for ollama class
:return:
"""
return Ollama(
base_url=ollama_config.ollama_base_url,
model=ollama_config.llm_model_name,
**ollama_config.llm_model_args
)


def get_openai_llm(openai_config: OpenAIConfig):
"""
Get an instance of ChatOpenAI class
:param openai_config: Configuration for openai class
:return:
"""
return ChatOpenAI(
api_key=openai_config.access_key,
organization=openai_config.organization,
model=openai_config.llm_model_name,
**openai_config.llm_model_args,
)


class LLMHelper:
@classmethod
async def ask(cls, llm: BaseLLM, prompt: Prompt, synonym_context: SynonymListContext):
chain = cls.get_chain(prompt=prompt, llm=llm)
response = await chain.ainvoke({
'text': synonym_context.text,
'term': synonym_context.entity,
'synonym_list': synonym_context.pretty_print_synonyms()
}, verbose=True)
return LLMHelper.re_map_responses(synonym_context.synonyms, response)


@classmethod
def re_map_responses(cls, synonym_list: List[Entity], llm_response: Dict[str, any]) -> any:
# Map back ids to llm responses
by_color = {
entity.color_code: entity for entity in synonym_list
}
entities = []
for item in llm_response:
synonym_type = item['synonym_type']
entity = by_color[item['color_code']]
final = entity.dict()
final.update({
"synonym_type": synonym_type
})
del final['color_code']
entities.append(final)
return entities

@classmethod
async def ask_batch(cls, llm: BaseLLM, prompt: Prompt, synonym_contexts: List[SynonymListContext]):
chain = cls.get_chain(prompt=prompt, llm=llm)
responses = await chain.abatch([{
'text': synonym_context.text,
'term': synonym_context.entity,
'synonym_list': synonym_context.pretty_print_synonyms()
} for synonym_context in synonym_contexts])
results = []
for synonym_context, response in zip(synonym_contexts,responses):
results.append(LLMHelper.re_map_responses(synonym_context.synonyms, response))
return results


@classmethod
def get_chain(cls, prompt: Prompt, llm: BaseLLM, model_name: str = ""):
chain = (prompt | llm | JsonOutputParser(pydantic_object=SynonymClassesResponse))
chain.name = prompt.metadata['lc_hub_repo'] + '_' + model_name
return chain


class ChainFactory:
chains = {}

@classmethod
def get_llms(cls):
# Add additional LLMS here
return [
(settings.openai_config.llm_model_name, get_openai_llm(settings.openai_config)),
(settings.ollama_config.llm_model_name, get_ollama_llm(settings.ollama_config))
]

@classmethod
def init_chains(cls):
prompts = load_prompts(settings.prompts)
ollama = get_ollama_llm(settings.ollama_config)
for key, value in prompts.items():
ChainFactory.chains[key] = [LLMHelper.get_chain(value, llm=llms[1], model_name=llms[0])
for llms in ChainFactory.get_llms()]

@classmethod
def get_chain(cls, prompt_name):
if not cls.chains:
cls.init_chains()
if not prompt_name in ChainFactory.chains:
raise ValueError(f"Prompt {prompt_name} not found locally or in hub.")
else:
return cls.chains[prompt_name]


@classmethod
def get_all_chains(cls):
if not cls.chains:
cls.init_chains()
all_chains = []
for prompt_name, chain in ChainFactory.chains.items():
all_chains += chain
return all_chains



55 changes: 55 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pydantic import BaseModel, Field
from typing import List
import yaml,pathlib, os



class PromptSettings(BaseModel):
""" Basic class for prompt location and version"""
name: str = Field(..., description="Name of the prompt")
version: str = Field("", description="Version of the prompt")
# The langchainhub owner address where to pull prompts from for eg "bagel"
# see : https://smith.langchain.com/hub/bagel
hub_base: str = Field('bagel', description="Langchain hub org name")

@classmethod
def from_str(cls, prompt_address: str) -> "PromptSettings":
hub_base, name = prompt_address.split('/')
name, version = prompt_address.split(':')
return PromptSettings(name=name, version=version, hub_base=hub_base)

def __str__(self) -> str:
name = self.name + ':' + self.version if self.version else self.name
return self.hub_base + '/' + name


class OpenAIConfig(BaseModel):
llm_model_name: str = Field(default="gpt4o", description="Name of the model")
organization: str = Field(default="org id", description="OPENAI organization")
access_key: str = Field(default="access key", description="OPENAI access key")
llm_model_args: dict = Field(default_factory=dict, description="Arguments to pass to the model")


class OLLAMAConfig(BaseModel):
llm_model_name: str = Field(default="llama3", description="Name of the model")
ollama_base_url: str = Field(default="https://ollama.apps.renci.org", description="URL of the OLLAMA instance")
llm_model_args: dict = Field(default_factory=dict, description="Arguments to pass to the model")


class Settings(BaseModel):
prompts: List[PromptSettings] = Field(default=[], description="Prompts to be used in the applicaton")
openai_config: OpenAIConfig = Field(default=None, description="")
ollama_config: OLLAMAConfig = Field(default=None, description="")
langServe: bool = True



def load_settings():
yaml_path = pathlib.Path(os.path.dirname(__file__), '..', 'settings.yaml')
with open(str(yaml_path), 'r') as stream:
_settings = Settings(**yaml.load(stream, yaml.FullLoader))
return _settings


# app settings
settings = load_settings()
69 changes: 69 additions & 0 deletions src/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from pydantic import BaseModel, Field
from typing import List


colors = [
"red",
"blue",
"green",
"yellow",
"purple",
"orange",
"pink",
"brown",
"black",
"white",
"gray",
"cyan",
"magenta",
"lime",
"maroon",
"navy",
"olive",
"teal",
"aqua",
"gold"
]







class Entity(BaseModel):
""" Entity class: Represents a normalized entity (Synonym)."""
label: str = Field(..., description="Label of the Entity.")
identifier: str = Field("", description="Curie identifier of the Entity.")
description: str = Field("", description="Formal description(definition) of the Entity.")
entity_type: str = Field("", description="Type of the Entity.")
color_code: str = Field("", description="Color coding for mapping back items.")

class SynonymListContext(BaseModel):
text: str = Field(..., description="Body of text containing entity.")
entity: str = Field(..., description="Entity identified in text.")
synonyms: List[Entity] = Field(..., description="Entities linked to the target entity, to be re-ranked.")

def __init__(self, **kwargs):
super().__init__(**kwargs)
for index, synonym in enumerate(self.synonyms):
synonym.color_code = colors[index]

def pretty_print_synonyms(self):
string = "\n"
for synonym in self.synonyms:
string += f"\t- {synonym.label}"
string += f" ({synonym.entity_type}) ({synonym.color_code})" if synonym.entity_type else ""
string += f" : {synonym.description}" if synonym.description else ""
string += "\n"
return string


class SynonymClassResponse(BaseModel):
synonym: str = Field(..., description="Synonym")
vocabulary_class: str = Field(..., description="Vocabulary class")
synonym_type: str = Field(..., description="Synonym type")


class SynonymClassesResponse(BaseModel):
synonyms: List[SynonymClassResponse] = Field(..., description="Synonyms")
17 changes: 17 additions & 0 deletions src/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# set the LANGCHAIN_API_KEY environment variable (create key in settings)
from langchain import hub
from config import PromptSettings, List


def load_prompt_from_hub(prompt_name):
""" Loads prompts from langchain hub and returns a prompt object."""
return hub.pull(str(prompt_name))


def load_prompts(prompts: List[PromptSettings]):
# loads all prompts
prompt_mapping = {}
for prompt in prompts:
prompt_mapping[prompt.name] = load_prompt_from_hub(str(prompt))
return prompt_mapping

Loading

0 comments on commit d840c29

Please sign in to comment.