-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
198 lines (167 loc) · 6.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# main.py
import logging
import random
import uuid
import streamlit as st
import config
from ui import render_file_upload
from utils import get_claude_client
from utils.file_handler import FileProcessingError, process_files
from utils.message_handler import (
add_message_to_history,
clear_conversation,
process_message,
)
from utils.session import (
check_session_expiry,
clear_file_data,
initialize_session_state,
update_last_activity,
)
logger = logging.getLogger(__name__)
st.set_page_config(page_title=config.APP_TITLE, layout="wide")
def load_css() -> None:
"""Loads the CSS styles for the application."""
# Load Font Awesome
st.markdown(
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">',
unsafe_allow_html=True,
)
# Load custom CSS
with open("styles/main.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
def main() -> None:
"""Main function to run the Streamlit application."""
initialize_session_state()
load_css()
st.title(config.APP_TITLE)
# Check session expiry
if check_session_expiry():
clear_conversation()
st.info("Your session has expired. Starting a new conversation.")
st.rerun()
try:
client = get_claude_client()
except Exception as e:
logger.exception("Error initializing Claude client: %s", e)
st.error(
"Failed to initialize Claude client. Please check your Google Cloud credentials and configuration."
)
return
system_prompt = st.text_area(
"System Prompt (optional)",
value=st.session_state.get("system_prompt", ""),
help="Enter a system prompt to guide Claude's behavior.",
)
# File upload with dynamic key
uploaded_files = st.file_uploader(
"Attach a file",
type=[
"txt",
"py",
"js",
"html",
"css",
"json",
"jpg",
"jpeg",
"png",
"md",
"pdf",
"xml",
],
accept_multiple_files=True,
key=st.session_state.file_uploader_key,
)
# Display conversation history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input():
if not prompt.strip():
st.info("Please enter a message.")
return
update_last_activity() # Update last activity time
# Process newly uploaded files
attached_files = []
message_id = str(uuid.uuid4()) # Generate a unique ID for the message
if uploaded_files:
try:
attached_files = process_files(uploaded_files)
# Store files in session state with the message ID
st.session_state.files[message_id] = attached_files
# Reset the file uploader key
st.session_state.file_uploader_key = random.randint(0, 1000000)
except FileProcessingError as e:
st.error(str(e))
return
# Prepare the message content
display_content = prompt
if attached_files:
display_content += "\n\nAttached Files:\n"
for file in attached_files:
display_content += f"\n- {file['name']} ({file['type']})"
# Display the new user message
with st.chat_message("user"):
st.markdown(display_content)
# Add user message to conversation history (without file contents)
add_message_to_history("user", display_content, message_id)
# Log the current state before processing
logger.debug(
f"Current message count before processing: {len(st.session_state.messages)}"
)
# Process the message and get Claude's response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
for response in process_message(
st.session_state.messages,
client,
user_prompt=prompt, # Only send the prompt, not file contents
system_prompt=system_prompt,
message_id=message_id, # Pass the message ID
):
full_response = response
message_placeholder.markdown(response)
# Add Claude's response to the conversation history
add_message_to_history("assistant", full_response)
# Log the state after processing
logger.debug(
f"Current message count after processing: {len(st.session_state.messages)}"
)
# Clear the files after they have been processed
clear_file_data()
# Clear conversation button
if st.session_state.messages:
if st.button("Clear Conversation"):
clear_conversation()
st.rerun()
# Continue response button (only shown when max tokens were reached)
if st.session_state.get("max_tokens_reached", False):
if st.button("Continue Response"):
update_last_activity() # Update last activity time
st.session_state.max_tokens_reached = False
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
for response in process_message(
st.session_state.messages,
client,
continue_last=True,
system_prompt=system_prompt,
):
full_response = response
message_placeholder.markdown(response)
# Update the last assistant message in the conversation history
if (
st.session_state.messages
and st.session_state.messages[-1]["role"] == "assistant"
):
st.session_state.messages[-1]["content"] += full_response
else:
add_message_to_history("assistant", full_response)
# Update last activity time
update_last_activity()
if __name__ == "__main__":
main()