-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathClaude_diag_spec_GeneralUser.py
59 lines (37 loc) · 2.41 KB
/
Claude_diag_spec_GeneralUser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
## PREDICT SPECIATLY AND DIAGNOSIS GENERAL USER CASE
## import libraries
import pandas as pd
from tqdm import tqdm
import os
import boto3
from langchain.prompts import PromptTemplate
from langchain_aws import ChatBedrock
## Import Functions
from functions.LLM_predictions import get_prediction_GeneralUser
## Load Data from create_ground_truth_specialty.py
df = pd.read_csv("MIMIC-IV-Ext-Diagnosis-Specialty.csv")
## Define the prompt template
prompt = """You are an experienced healthcare professional with expertise in determining the medical specialty and diagnosis based on a patient's history of present illness and personal information. Review the data and identify the three most likely, distinct specialties to manage the condition, followed by the three most likely diagnoses. List specialties first, in order of likelihood, then diagnoses.
Respond with the specialties in <specialty> tags and the diagnoses in <diagnosis> tags.
History of present illness: {hpi} and personal information: {patient_info}."""
## set AWS credentials
os.environ["AWS_ACCESS_KEY_ID"]="Enter your AWS Access Key ID"
os.environ["AWS_SECRET_ACCESS_KEY"]="Enter your AWS Secret Access Key"
prompt_chain = PromptTemplate(template=prompt,input_variables=["hpi", "patient_info"])
client = boto3.client(service_name="bedrock-runtime", region_name=str("us-east-1"))
## Claude Sonnet 3.5
llm_claude35 = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", model_kwargs={"temperature": 0},client=client)
chain_claude35 = prompt_chain | llm_claude35
## Claude Sonnet 3
llm_claude3 = ChatBedrock(model_id="anthropic.claude-3-sonnet-20240229-v1:0", model_kwargs={"temperature": 0},client=client)
chain_claude3 = prompt_chain | llm_claude3
## Claude 3 Haiku
llm_haiku = ChatBedrock(model_id="anthropic.claude-3-haiku-20240307-v1:0", model_kwargs={"temperature": 0},client=client)
chain_haiku = prompt_chain | llm_haiku
tqdm.pandas()
df['diag_spec_Claude3.5'] = df.progress_apply(lambda row: get_prediction_GeneralUser(row, chain_claude35), axis=1)
df.to_csv('MIMIC-IV-Ext-Diagnosis-Specialty.csv', index=False)
df['diag_spec_Claude3'] = df.progress_apply(lambda row: get_prediction_GeneralUser(row, chain_claude3), axis=1)
df.to_csv('MIMIC-IV-Ext-Diagnosis-Specialty.csv', index=False)
df['diag_spec_Haiku'] = df.progress_apply(lambda row: get_prediction_GeneralUser(row, chain_haiku), axis=1)
df.to_csv('MIMIC-IV-Ext-Diagnosis-Specialty.csv', index=False)