diff --git a/src/global_context.rs b/src/global_context.rs index c872da5..abb8539 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::sync::RwLock as StdRwLock; +use tokio::sync::Mutex as AMutex; use tokio::sync::RwLock as ARwLock; use tokenizers::Tokenizer; use structopt::StructOpt; @@ -34,7 +35,7 @@ pub struct CommandLine { } -#[derive(Debug)] +// #[derive(Debug)] pub struct GlobalContext { pub http_client: reqwest::Client, pub ask_shutdown_sender: Arc>>, @@ -45,7 +46,7 @@ pub struct GlobalContext { pub cmdline: CommandLine, pub completions_cache: Arc>, pub telemetry: Arc>, - pub vecdb_search: Arc>, + pub vecdb_search: Arc>>, } @@ -124,6 +125,7 @@ pub async fn create_global_context( cmdline: cmdline.clone(), completions_cache: Arc::new(StdRwLock::new(CompletionCache::new())), telemetry: Arc::new(StdRwLock::new(telemetry_storage::Storage::new())), + vecdb_search: Arc::new(AMutex::new(Box::new(crate::vecdb_search::VecdbSearchTest::new()))), }; (Arc::new(ARwLock::new(cx)), ask_shutdown_receiver, cmdline) } diff --git a/src/http_server.rs b/src/http_server.rs index 20bc77e..2f950dd 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -21,6 +21,7 @@ use crate::custom_error::ScratchError; use crate::telemetry_basic; use crate::telemetry_snippets; use crate::completion_cache; +// use crate::vecdb_search::VecdbSearch; async fn _get_caps_and_tokenizer( @@ -155,7 +156,7 @@ async fn handle_v1_code_completion( let prompt = scratchpad.prompt( 2048, &mut code_completion_post.parameters, - ).map_err(|e| + ).await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Prompt: {}", e)) )?; // info!("prompt {:?}\n{}", t1.elapsed(), prompt); @@ -193,7 +194,7 @@ async fn handle_v1_chat( ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR,format!("Tokenizer: {}", e)) )?; - let vecdb_search = ; + let vecdb_search = global_context.read().await.vecdb_search.clone(); let mut scratchpad = scratchpads::create_chat_scratchpad( chat_post.clone(), &scratchpad_name, @@ -207,7 +208,7 @@ async fn handle_v1_chat( let prompt = scratchpad.prompt( 2048, &mut chat_post.parameters, - ).map_err(|e| + ).await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Prompt: {}", e)) )?; // info!("chat prompt {:?}\n{}", t1.elapsed(), prompt); diff --git a/src/lsp.rs b/src/lsp.rs index 96cc0d1..f5acf11 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -61,7 +61,7 @@ impl Document { } } -#[derive(Debug)] +// #[derive(Debug)] GlobalContext does not implement Debug pub struct Backend { pub gcx: Arc>, pub client: tower_lsp::Client, diff --git a/src/main.rs b/src/main.rs index 28a5256..fd261cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,15 +20,6 @@ mod telemetry_snippets; mod telemetry_storage; mod vecdb_search; mod lsp; -use crate::vecdb_search::VecdbSearch; - - -async fn test_vecdb() -{ - let mut v = vecdb_search::VecdbSearchTest::new(); - let res = v.search("ParallelTasksV3").await; - info!("{:?}", res); -} #[tokio::main] @@ -54,8 +45,6 @@ async fn main() { .init(); info!("started"); info!("cache dir: {}", cache_dir.display()); - test_vecdb().await; - return; let gcx2 = gcx.clone(); let gcx3 = gcx.clone(); diff --git a/src/scratchpad_abstract.rs b/src/scratchpad_abstract.rs index 227cfc8..563ee60 100644 --- a/src/scratchpad_abstract.rs +++ b/src/scratchpad_abstract.rs @@ -3,15 +3,17 @@ use std::sync::Arc; use std::sync::RwLock; use tokenizers::Tokenizer; use crate::call_validation::SamplingParameters; +use async_trait::async_trait; +#[async_trait] pub trait ScratchpadAbstract: Send { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, ) -> Result<(), String>; - fn prompt( + async fn prompt( &mut self, context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 1bb39c7..335b82e 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -5,8 +5,12 @@ use crate::call_validation::ChatPost; use crate::call_validation::ChatMessage; use crate::call_validation::SamplingParameters; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; +use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; + use std::sync::Arc; use std::sync::RwLock; +use async_trait::async_trait; +use tokio::sync::Mutex as AMutex; use tokenizers::Tokenizer; use tracing::info; @@ -14,7 +18,6 @@ use tracing::info; const DEBUG: bool = true; -#[derive(Debug)] pub struct GenericChatScratchpad { pub t: HasTokenizerAndEot, pub dd: DeltaDeltaChatStreamer, @@ -24,12 +27,14 @@ pub struct GenericChatScratchpad { pub keyword_user: String, pub keyword_asst: String, pub default_system_message: String, + pub vecdb_search: Arc>>, } impl GenericChatScratchpad { pub fn new( tokenizer: Arc>, post: ChatPost, + vecdb_search: Arc>>, ) -> Self { GenericChatScratchpad { t: HasTokenizerAndEot::new(tokenizer), @@ -40,10 +45,12 @@ impl GenericChatScratchpad { keyword_user: "".to_string(), keyword_asst: "".to_string(), default_system_message: "".to_string(), + vecdb_search } } } +#[async_trait] impl ScratchpadAbstract for GenericChatScratchpad { fn apply_model_adaptation_patch( &mut self, @@ -68,11 +75,12 @@ impl ScratchpadAbstract for GenericChatScratchpad { Ok(()) } - fn prompt( + async fn prompt( &mut self, context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { + embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await; let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 9aa95fe..dfaaeb7 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -1,3 +1,10 @@ +use tracing::info; +use std::sync::Arc; +use std::sync::RwLock as StdRwLock; +use tokio::sync::Mutex as AMutex; +use tokenizers::Tokenizer; +use async_trait::async_trait; + use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; @@ -5,13 +12,8 @@ use crate::call_validation::ChatPost; use crate::call_validation::ChatMessage; use crate::call_validation::SamplingParameters; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::vecdb_search; -use std::sync::Arc; -use std::sync::RwLock as StdRwLock; -use std::sync::Mutex; +use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; -use tokenizers::Tokenizer; -use tracing::info; const DEBUG: bool = true; @@ -24,14 +26,15 @@ pub struct ChatLlama2 { pub keyword_s: String, // "SYSTEM:" keyword means it's not one token pub keyword_slash_s: String, pub default_system_message: String, - pub vecdb_search: Arc>>, + pub vecdb_search: Arc>>, } + impl ChatLlama2 { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>>, + vecdb_search: Arc>>, ) -> Self { ChatLlama2 { t: HasTokenizerAndEot::new(tokenizer), @@ -40,11 +43,12 @@ impl ChatLlama2 { keyword_s: "".to_string(), keyword_slash_s: "".to_string(), default_system_message: "".to_string(), - vecdb_search: vecdb_search + vecdb_search } } } +#[async_trait] impl ScratchpadAbstract for ChatLlama2 { fn apply_model_adaptation_patch( &mut self, @@ -62,11 +66,12 @@ impl ScratchpadAbstract for ChatLlama2 { Ok(()) } - fn prompt( + async fn prompt( &mut self, context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { + embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await; let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 @@ -101,6 +106,7 @@ impl ScratchpadAbstract for ChatLlama2 { // This only supports assistant, not suggestions for user self.dd.role = "assistant".to_string(); if DEBUG { + // info!("llama2 chat vdb_suggestion {:?}", vdb_suggestion); info!("llama2 chat prompt\n{}", prompt); info!("llama2 chat re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); } diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index 5536c07..0e63585 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -8,6 +8,8 @@ use std::sync::RwLock as StdRwLock; use tokenizers::Tokenizer; use ropey::Rope; use tracing::info; +use async_trait::async_trait; + use crate::completion_cache; use crate::telemetry_storage; use crate::telemetry_snippets; @@ -42,6 +44,7 @@ impl SingleFileFIM { } +#[async_trait] impl ScratchpadAbstract for SingleFileFIM { fn apply_model_adaptation_patch( &mut self, @@ -59,7 +62,7 @@ impl ScratchpadAbstract for SingleFileFIM { Ok(()) } - fn prompt( + async fn prompt( &mut self, context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index f9076ab..ad7e670 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::sync::RwLock as StdRwLock; -use std::sync::Mutex; +use tokio::sync::Mutex as AMutex; + use tokenizers::Tokenizer; pub mod completion_single_file_fim; @@ -46,11 +47,11 @@ pub fn create_chat_scratchpad( scratchpad_name: &str, scratchpad_patch: &serde_json::Value, tokenizer_arc: Arc>, - vecdb_search: Arc>>, + vecdb_search: Arc>>, ) -> Result, String> { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { - result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post)); + result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post, vecdb_search)); } else if scratchpad_name == "CHAT-LLAMA2" { result = Box::new(chat_llama2::ChatLlama2::new(tokenizer_arc, post, vecdb_search)); } else { diff --git a/src/vecdb_search.rs b/src/vecdb_search.rs index 8c2b0c8..e13a067 100644 --- a/src/vecdb_search.rs +++ b/src/vecdb_search.rs @@ -1,10 +1,13 @@ +use crate::call_validation::{ChatMessage, ChatPost}; // use reqwest::header::AUTHORIZATION; use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use serde::{Deserialize, Serialize}; use serde_json::json; -use tracing::info; + +use std::sync::Arc; +use tokio::sync::Mutex as AMutex; use async_trait::async_trait; @@ -20,6 +23,57 @@ pub struct VecdbResult { pub results: Vec, } +pub async fn embed_vecdb_results( + vecdb_search: Arc>>, + post: &mut ChatPost, + limit_examples_cnt: usize, +) { + let my_vdb = vecdb_search.clone(); + let latest_msg_cont = &post.messages.last().unwrap().content; + let mut vecdb_locked = my_vdb.lock().await; + let vdb_resp = vecdb_locked.search(&latest_msg_cont).await; + let vdb_cont = vecdb_resp_to_prompt(&vdb_resp, limit_examples_cnt); + if vdb_cont.len() > 0 { + post.messages = [ + &post.messages[..post.messages.len() -1], + &[ChatMessage { + role: "user".to_string(), + content: vdb_cont, + }], + &post.messages[post.messages.len() -1..], + ].concat(); + } +} + + +fn vecdb_resp_to_prompt( + resp: &Result, + limit_examples_cnt: usize, +) -> String { + let mut cont = "".to_string(); + match resp { + Ok(resp) => { + cont.push_str("CONTEXT:\n"); + for i in 0..limit_examples_cnt { + if i >= resp.results.len() { + break; + } + cont.push_str("FILENAME:\n"); + cont.push_str(resp.results[i].file_name.clone().as_str()); + cont.push_str("\nTEXT:"); + cont.push_str(resp.results[i].text.clone().as_str()); + cont.push_str("\n"); + } + cont.push_str("\nRefer to the context to answer my next question.\n"); + cont + } + Err(e) => { + format!("Vecdb error: {}", e); + cont + } + } +} + #[async_trait] pub trait VecdbSearch: Send { async fn search( @@ -39,6 +93,8 @@ impl VecdbSearchTest { } } +// unsafe impl Send for VecdbSearchTest {} + #[async_trait] impl VecdbSearch for VecdbSearchTest { async fn search( @@ -51,7 +107,7 @@ impl VecdbSearch for VecdbSearchTest { headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); let body = json!({ "texts": [query], - "account": "smc", + "account": "XXX", "top_k": 3, }); let res = reqwest::Client::new() @@ -60,17 +116,6 @@ impl VecdbSearch for VecdbSearchTest { .body(body.to_string()) .send() .await.map_err(|e| format!("Vecdb search HTTP error (1): {}", e))?; - // print Allow header - // println!("{:?}", res.headers().get("allow")); - - // let x = VecdbResult { - // results: vec![VecdbResultRec { - // file_name: "test.txt".to_string(), - // text: "test".to_string(), - // score: "0.0".to_string(), - // }], - // }; - // info!("example: {:?}", serde_json::to_string(&x).unwrap()); let body = res.text().await.map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; // info!("Vecdb search result: {:?}", &body); @@ -85,4 +130,3 @@ impl VecdbSearch for VecdbSearchTest { Ok(result0) } } -