diff --git a/rust/src/router.rs b/rust/src/router.rs index 615a9550ef3..cbaa4673b5e 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -106,28 +106,6 @@ pub enum PolicyConfig { }, } -fn get_text_from_request(body: &Bytes, route: &str) -> String { - // convert body to json - let json = serde_json::from_slice::(body).unwrap(); - - if route == "generate" { - // get the "text" field - let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); - return text.to_string(); - } else if route == "v1/chat/completions" { - // get the messages field as raw text - if let Some(messages) = json.get("messages") { - // Convert messages back to a string, preserving all JSON formatting - return serde_json::to_string(messages).unwrap_or_default(); - } - } else if route == "v1/completions" { - let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or(""); - return prompt.to_string(); - } - - return "".to_string(); -} - impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { // Wait until all workers are healthy @@ -204,20 +182,6 @@ impl Router { }) } - pub fn get_first(&self) -> Option { - match self { - Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } - | Router::CacheAware { worker_urls, .. } => { - if worker_urls.read().unwrap().is_empty() { - None - } else { - Some(worker_urls.read().unwrap()[0].clone()) - } - } - } - } - fn wait_for_healthy_workers( worker_urls: &[String], timeout_secs: u64, @@ -271,14 +235,76 @@ impl Router { } } - pub async fn dispatch( + fn select_first_worker(&self) -> Result { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls } + | Router::CacheAware { worker_urls, .. } => { + if worker_urls.read().unwrap().is_empty() { + Err("No workers are available".to_string()) + } else { + Ok(worker_urls.read().unwrap()[0].clone()) + } + } + } + } + + async fn send_request( &self, client: &reqwest::Client, - req: HttpRequest, - body: Bytes, + worker_url: String, route: &str, ) -> HttpResponse { - let text = get_text_from_request(&body, route); + match client.get(format!("{}{}", worker_url, route)).send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError().body(format!( + "Failed to send request to worker {}: {}", + worker_url, e + )), + } + } + + pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse { + match self.select_first_worker() { + Ok(worker_url) => self.send_request(client, worker_url, route).await, + Err(e) => HttpResponse::InternalServerError().body(e), + } + } + + fn get_text_from_request(&self, body: &Bytes, route: &str) -> String { + // convert body to json + let json = serde_json::from_slice::(body).unwrap(); + + if route == "generate" { + // get the "text" field + let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); + return text.to_string(); + } else if route == "v1/chat/completions" { + // get the messages field as raw text + if let Some(messages) = json.get("messages") { + // Convert messages back to a string, preserving all JSON formatting + return serde_json::to_string(messages).unwrap_or_default(); + } + } else if route == "v1/completions" { + let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or(""); + return prompt.to_string(); + } + + return "".to_string(); + } + + // TODO: return Result instead of panicking + fn select_generate_worker(&self, body: &Bytes, route: &str) -> String { + let text = self.get_text_from_request(&body, route); let worker_url = match self { Router::RoundRobin { @@ -366,12 +392,23 @@ impl Router { } }; + worker_url + } + + async fn send_generate_request( + &self, + client: &reqwest::Client, + req: HttpRequest, + body: Bytes, + route: &str, + worker_url: &str, + ) -> HttpResponse { let is_stream = serde_json::from_slice::(&body) .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .unwrap_or(false); let res = match client - .post(format!("{}/{}", worker_url.clone(), route)) + .post(format!("{}{}", worker_url, route)) .header( "Content-Type", req.headers() @@ -403,7 +440,7 @@ impl Router { // Then decrement running queue counter if using CacheAware if let Router::CacheAware { running_queue, .. } = self { if let Ok(mut queue) = running_queue.lock() { - if let Some(count) = queue.get_mut(&worker_url) { + if let Some(count) = queue.get_mut(worker_url) { *count = count.saturating_sub(1); } } @@ -412,7 +449,7 @@ impl Router { response } else if let Router::CacheAware { running_queue, .. } = self { let running_queue = Arc::clone(running_queue); - let worker_url = worker_url.clone(); + let worker_url = worker_url.to_string(); HttpResponse::build(status) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) @@ -431,7 +468,7 @@ impl Router { let mut locked_queue = running_queue.lock().unwrap(); let count = locked_queue.get_mut(&worker_url).unwrap(); *count = count.saturating_sub(1); - debug!("streaming is done!!") + debug!("Streaming is done!!") } }), ) @@ -444,6 +481,18 @@ impl Router { } } + pub async fn route_generate_request( + &self, + client: &reqwest::Client, + req: HttpRequest, + body: Bytes, + route: &str, + ) -> HttpResponse { + let worker_url = self.select_generate_worker(&body, route); + self.send_generate_request(client, req, body, route, &worker_url) + .await + } + pub async fn add_worker(&self, worker_url: String) -> Result { let interval_secs = 10; // check every 10 seconds let timeout_secs = 300; // 5 minutes diff --git a/rust/src/server.rs b/rust/src/server.rs index 8a0eb1547d6..4b83283976d 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -29,84 +29,41 @@ impl AppState { } } -async fn forward_request( - client: &reqwest::Client, - worker_url: String, - route: String, -) -> HttpResponse { - match client.get(format!("{}{}", worker_url, route)).send().await { - Ok(res) => { - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - // print the status - println!( - "Forwarding Request Worker URL: {}, Route: {}, Status: {}", - worker_url, route, status - ); - match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(_) => HttpResponse::InternalServerError().finish(), - } - } - Err(_) => HttpResponse::InternalServerError().finish(), - } -} - #[get("/health")] async fn health(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/health".to_string()).await + data.router.route_to_first(&data.client, "/health").await } #[get("/health_generate")] async fn health_generate(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/health_generate".to_string()).await + data.router + .route_to_first(&data.client, "/health_generate") + .await } #[get("/get_server_info")] async fn get_server_info(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/get_server_info".to_string()).await + data.router + .route_to_first(&data.client, "/get_server_info") + .await } #[get("/v1/models")] async fn v1_models(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/v1/models".to_string()).await + data.router.route_to_first(&data.client, "/v1/models").await } #[get("/get_model_info")] async fn get_model_info(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/get_model_info".to_string()).await + data.router + .route_to_first(&data.client, "/get_model_info") + .await } #[post("/generate")] async fn generate(req: HttpRequest, body: Bytes, data: web::Data) -> impl Responder { data.router - .dispatch(&data.client, req, body, "generate") + .route_generate_request(&data.client, req, body, "/generate") .await } @@ -117,7 +74,7 @@ async fn v1_chat_completions( data: web::Data, ) -> impl Responder { data.router - .dispatch(&data.client, req, body, "v1/chat/completions") + .route_generate_request(&data.client, req, body, "/v1/chat/completions") .await } @@ -128,7 +85,7 @@ async fn v1_completions( data: web::Data, ) -> impl Responder { data.router - .dispatch(&data.client, req, body, "v1/completions") + .route_generate_request(&data.client, req, body, "/v1/completions") .await }