Skip to content

Commit

Permalink
Merge pull request #64 from KruxAI/chat
Browse files Browse the repository at this point in the history
Add chat playground
  • Loading branch information
aravind10x authored Oct 12, 2024
2 parents f36cb2c + daebed9 commit 1f760e9
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 23 deletions.
84 changes: 62 additions & 22 deletions src/ragbuilder/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import openai
import optuna
import math
import sqlite3
from optuna.storages import RDBStorage
from optuna.trial import TrialState
from tenacity import retry, stop_after_attempt, wait_random_exponential, before_sleep_log, retry_if_exception_type, retry_if_result
Expand Down Expand Up @@ -219,7 +220,7 @@ def rag_builder_bayes_optimization_optuna(**kwargs):
# result=0
logger.info(f"Evaluating RAG Config #{progress_state.get_progress()['current_run']}... (this may take a while)")
rageval=eval.RagEvaluator(
rag_builder, # code for rag function
rag_builder,
test_ds,
llm = llm,
embeddings = embeddings,
Expand All @@ -229,6 +230,7 @@ def rag_builder_bayes_optimization_optuna(**kwargs):
)
result=rageval.evaluate()
logger.debug(f'progress_state={progress_state.get_progress()}')
rag_manager.cache_rag(rageval.id, rag_builder.rag)

if kwargs['include_granular_combos']:
# Objective function for Bayesian optimization on the custom RAG configurations
Expand Down Expand Up @@ -289,8 +291,11 @@ def objective(trial):
if percent_none > 0.2:
logger.warning(f"More than 20% of the records have 'answer_correctness' as None. Skipping this config...")
return float('NaN')

if not progress_state.get_progress()['first_eval_complete']:
progress_state.set_first_eval_complete()

rag_manager.cache_rag(rageval.id, rag_builder.rag)
return result['answer_correctness']
return float('NaN')

Expand Down Expand Up @@ -400,6 +405,7 @@ def rag_builder(**kwargs):
)
result=rageval.evaluate()
logger.debug(f'progress_state={progress_state.get_progress()}')
rag_manager.cache_rag(rageval.id, rag_builder.rag)
if configs['type'] == 'CUSTOM':
for val in configs['configs'].values():
progress_state.increment_progress()
Expand All @@ -422,6 +428,7 @@ def rag_builder(**kwargs):
is_async=RUN_CONFIG_IS_ASYNC
)
result=rageval.evaluate()
rag_manager.cache_rag(rageval.id, rag_builder.rag)
return result

import importlib
Expand Down Expand Up @@ -465,6 +472,19 @@ def __repr__(self):
class RagBuilderException(Exception):
pass

@retry(stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, min=4, max=60),
retry=(retry_if_result(lambda result: result is None)
| retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)),
before_sleep=before_sleep_log(logger, LOG_LEVEL))
def _exec(code):
locals_dict={}
exec(code, None, locals_dict)
ragchain=locals_dict['rag_pipeline']()
return ragchain

class RagBuilder:
def __init__(self, val):
self.config = val
Expand All @@ -480,12 +500,6 @@ def __init__(self, val):
self.retriever_kwargs=val['retriever_kwargs']
# self.prompt_text = val['prompt_text']

# self.router(Configs) # Calls appropriate code generator calls codeGen Within returns Code string
# namespace={}
# exec(rag_func_str, namespace) # executes code
# ragchain=namespace['ragchain'] catch the func object
# self.runCode=ragchain()

# output of router is genrated code as string
self.router=rag.codeGen(
framework=self.framework,
Expand All @@ -503,24 +517,11 @@ def __init__(self, val):
try:
#execution os string
logger.debug(f"Generated Code\n{self.router}")
self.rag = self._exec()
self.rag = _exec(self.router)
# print(f"self.rag = {self.rag}")
except Exception as e:
logger.error(f"Error creating RAG object from generated code. ERROR: {e}")
raise RagBuilderException(f"Error creating RAG object from generated code. ERROR: {e}")

@retry(stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, min=4, max=60),
retry=(retry_if_result(lambda result: result is None)
| retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)),
before_sleep=before_sleep_log(logger, LOG_LEVEL))
def _exec(self):
locals_dict={}
exec(self.router, None, locals_dict)
ragchain=locals_dict['rag_pipeline']()
return ragchain

def __repr__(self):
try:
Expand Down Expand Up @@ -572,4 +573,43 @@ def byor_ragbuilder(test_ds,eval_llm,eval_embedding):
)
result=rageval.evaluate()
logger.debug(f'progress_state={progress_state.get_progress()}')

rag_manager.cache_rag(rageval.id, rag_builder.rag)

class RagManager:
def __init__(self):
self.cache: Dict[int, Any] = {}
self.last_accessed: Dict[int, float] = {}

def get_rag(self, eval_id: int, db: sqlite3.Connection):
if eval_id in self.cache:
self.last_accessed[eval_id] = time.time()
return self.cache[eval_id]

# Fetch configuration from database
cur = db.execute("SELECT code_snippet FROM rag_eval_summary WHERE eval_id = ?", (eval_id,))
result = cur.fetchone()
if not result:
raise ValueError(f"No configuration found for eval_id {eval_id}")

code_snippet = result['code_snippet']

# Reconstruct RAG pipeline
rag = _exec(code_snippet)

self.cache[eval_id] = rag
self.last_accessed[eval_id] = time.time()
return rag

def cache_rag(self, eval_id: int, rag_builder: Any):
self.cache[eval_id] = rag_builder
self.last_accessed[eval_id] = time.time()

def clear_cache(self, max_age: int = 86400):
current_time = time.time()
to_remove = [eval_id for eval_id, last_access in self.last_accessed.items()
if current_time - last_access > max_age]
for eval_id in to_remove:
del self.cache[eval_id]
del self.last_accessed[eval_id]

rag_manager = RagManager()
33 changes: 32 additions & 1 deletion src/ragbuilder/ragbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import optuna
from pathlib import Path
from urllib.parse import urlparse
from ragbuilder.executor import rag_builder, rag_builder_bayes_optimization_optuna, get_model_obj
from ragbuilder.executor import rag_builder, rag_builder_bayes_optimization_optuna, rag_manager, get_model_obj
from ragbuilder.langchain_module.loader import loader as l
from ragbuilder.langchain_module.common import setup_logging, progress_state
from ragbuilder import generate_data
Expand Down Expand Up @@ -94,6 +94,37 @@ async def startup():
# def shutdown_db():
# db.close()

class ChatSessionCreate(BaseModel):
eval_id: int

class ChatMessage(BaseModel):
message: str

@app.post("/create_chat_session")
async def create_chat_session(session_data: ChatSessionCreate, db: sqlite3.Connection = Depends(get_db)):
try:
rag = rag_manager.get_rag(session_data.eval_id, db)
session_id = f"{session_data.eval_id}_{int(time.time())}"
return {"session_id": session_id}
except Exception as e:
logger.error(f"Error creating chat session: {e}")
raise HTTPException(status_code=500, detail="Failed to create chat session")

@app.get("/chat/{eval_id}", response_class=HTMLResponse)
async def chat_page(request: Request, eval_id: int):
return templates.TemplateResponse("chat.html", {"request": request, "eval_id": eval_id})

@app.post("/chat/{session_id}")
async def chat(session_id: str, chat_message: ChatMessage, db: sqlite3.Connection = Depends(get_db)):
try:
eval_id = int(session_id.split('_')[0])
rag = rag_manager.get_rag(eval_id, db)
response = rag.invoke(chat_message.message)
return {"answer": response["answer"]}
except Exception as e:
logger.error(f"Error in chat: {e}")
raise HTTPException(status_code=500, detail="Failed to process chat message")

@app.get("/", response_class=HTMLResponse)
def index(request: Request, db: sqlite3.Connection = Depends(get_db)):
cur = db.execute("""
Expand Down
Binary file added src/ragbuilder/static/ai-avatar.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/ragbuilder/static/user-avatar.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
122 changes: 122 additions & 0 deletions src/ragbuilder/templates/chat.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{% extends "layouts.html" %}

{% block content %}
<div class="container mt-5" id="chat-container" data-eval-id="{{ eval_id }}">
<h4>Chat Playground for Eval ID: {{ eval_id }}</h4>
<div class="chat-messages" style="height: 400px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; margin-bottom: 20px;">
<!-- Chat messages will be dynamically added here -->
</div>
<form id="chatForm" class="mt-3">
<div class="input-group mb-3">
<input type="text" id="userMessage" class="form-control" placeholder="Type your message..." required>
<button class="btn btn-primary" type="submit" id="sendMessage" disabled>Send</button>
</div>
</form>
</div>

<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js" integrity="sha384-YvpcrYf0tY3lHB60NNkmXc5s9fDVZLESaAA55NDzOxhy9GkcIdslK1eN7N6jIeHz" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@popperjs/[email protected]/dist/umd/popper.min.js" integrity="sha384-I7E8VVD/ismYTF4hNIPjVp/Zjvgyol6VFvRkX/vR+Vc4jQkC+hVqc2pM8ODewa9r" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.min.js" integrity="sha384-0pUGZvbkm6XF6gxjEnlmuGrJXVbNuzT9qBBavbLwCsOGabYfZo0T0to5eqruptLy" crossorigin="anonymous"></script>

<script>
let currentSessionId = null;
var evalId = document.getElementById('chat-container').getAttribute('data-eval-id');

$(document).ready(function() {
// const evalId = $('.container').data('eval-id');
$('.chat-messages').append('<div class="text-center"><div class="spinner-border" role="status"><span class="visually-hidden">Loading...</span></div><p>Initializing chat session...</p></div>');
console.log(evalId)
$.ajax({
url: `/create_chat_session`,
method: 'POST',
contentType: 'application/json',
data: JSON.stringify({ eval_id: evalId }),
success: function(response) {
currentSessionId = response.session_id;
console.log(currentSessionId)
$('.chat-messages').empty();
$('#sendMessage').prop('disabled', false);
},
error: function(error) {
console.error('Error creating chat session:', error);
$('.chat-messages').append('<div class="text-center text-danger"><p>Failed to create chat session. Please try again.</p></div>');
}
});
});

$('#chatForm').submit(function(e) {
e.preventDefault();
const message = $('#userMessage').val();
if (!message.trim() || !currentSessionId) return;

// Add user message to chat
$('.chat-messages').append(`
<div class="d-flex justify-content-end mb-4">
<div class="msg_cotainer_send">
${message}
</div>
<div class="img_cont_send">
&ensp;<img src="/static/user-avatar.png" class="rounded-circle user_img_send">
</div>
</div>
`);

$('#userMessage').val('');
$('#sendMessage').prop('disabled', true);

// Add AI response placeholder
$('.chat-messages').append(`
<div class="d-flex justify-content-start mb-4 ai-response-spinner">
<div class="img_cont_msg">
<img src="/static/ai-avatar.png" class="rounded-circle user_img_msg">&ensp;
</div>
<div class="msg_cotainer">
<div class="spinner-border spinner-border-sm" role="status">
<span class="visually-hidden">Loading...</span>
</div>
Generating response...
</div>
</div>
`);

$('.chat-messages').scrollTop($('.chat-messages')[0].scrollHeight);

// Send message to backend
$.ajax({
url: `/chat/${currentSessionId}`,
method: 'POST',
contentType: 'application/json',
data: JSON.stringify({ message: message }),
success: function(response) {
$('.ai-response-spinner').remove();
// Add AI response to chat
$('.chat-messages').append(`
<div class="d-flex justify-content-start mb-4">
<div class="img_cont_msg">
<img src="/static/ai-avatar.png" class="rounded-circle user_img_msg">&ensp;
</div>
<div class="msg_cotainer">
${response.answer}
</div>
</div>
`);
$('.chat-messages').scrollTop($('.chat-messages')[0].scrollHeight);
},
error: function(error) {
$('.ai-response-spinner').remove();
console.error('Error sending message:', error);
$('.chat-messages').append(`<div class="text-danger">Error: Failed to get response</div>`);
},
complete: function() {
$('#sendMessage').prop('disabled', false);
}
});
});

$('#userMessage').on('input', function() {
$('#sendMessage').prop('disabled', $(this).val().trim() === '');
});

</script>
{% endblock content %}
31 changes: 31 additions & 0 deletions src/ragbuilder/templates/summary.html
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ <h5>EST. LATENCY</h5>
<th class="timestamp">Timestamp</th>
<th>Details</th>
<th>Code Snippet</th>
<th>Chat</th>
</tr>
</thead>
<tbody>
Expand All @@ -109,6 +110,7 @@ <h5>EST. LATENCY</h5>
<td>{{ eval.eval_ts }}</td>
<td><a href="/details/{{ eval.eval_id }}" class="btn btn-outline-primary btn-sm"><span style="font-size:smaller;">View Details</span></a></td>
<td><a href="#" class="btn btn-outline-success btn-sm view-code-snippet" data-code="{{ eval.code_snippet }}"><span style="font-size:smaller;">View Code snippet</span></a></td>
<td><a href="#" class="btn btn-outline-info btn-sm open-chat" data-eval-id="{{ eval.eval_id }}"><span style="font-size:smaller;">Chat</span></a></td>
</tr>
{% endfor %}
</tbody>
Expand Down Expand Up @@ -146,6 +148,29 @@ <h1 class="modal-title fs-5" id="codeSnippetModalLabel">Code Snippet</h1>
</div>
</div>

<!-- Chat Modal -->
<div class="modal fade" id="chatModal" tabindex="-1" aria-labelledby="chatModalLabel" aria-hidden="true">
<div class="modal-dialog modal-lg">
<div class="modal-content">
<div class="modal-header">
<h5 class="modal-title" id="chatModalLabel">Chat Playground</h5>
<button type="button" class="btn-close" data-bs-dismiss="modal" aria-label="Close"></button>
</div>
<div class="modal-body">
<div class="chat-messages" style="height: 300px; overflow-y: auto;">
<!-- Chat messages will be dynamically added here -->
</div>
<form id="chatForm" class="mt-3">
<div class="input-group mb-3">
<input type="text" id="userMessage" class="form-control" placeholder="Type your message..." required>
<button class="btn btn-primary" type="submit" id="sendMessage" disabled>Send</button>
</div>
</form>
</div>
</div>
</div>
</div>

{{ modal }}

<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
Expand Down Expand Up @@ -518,6 +543,12 @@ <h1 class="modal-title fs-5" id="codeSnippetModalLabel">Code Snippet</h1>
});
});
});

$(document).on('click', '.open-chat', function(e) {
e.preventDefault();
const evalId = $(this).data('eval-id');
window.open(`/chat/${evalId}`, '_blank');
});
</script>

{% endblock content %}

0 comments on commit 1f760e9

Please sign in to comment.