Skip to content

Commit

Permalink
[router] Refactor: decouple select and send stage (#2440)
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Dec 11, 2024
1 parent 7310aed commit d4de9a6
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 101 deletions.
137 changes: 93 additions & 44 deletions rust/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,6 @@ pub enum PolicyConfig {
},
}

fn get_text_from_request(body: &Bytes, route: &str) -> String {
// convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();

if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}

return "".to_string();
}

impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
// Wait until all workers are healthy
Expand Down Expand Up @@ -204,20 +182,6 @@ impl Router {
})
}

pub fn get_first(&self) -> Option<String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
None
} else {
Some(worker_urls.read().unwrap()[0].clone())
}
}
}
}

fn wait_for_healthy_workers(
worker_urls: &[String],
timeout_secs: u64,
Expand Down Expand Up @@ -271,14 +235,76 @@ impl Router {
}
}

pub async fn dispatch(
fn select_first_worker(&self) -> Result<String, String> {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
Err("No workers are available".to_string())
} else {
Ok(worker_urls.read().unwrap()[0].clone())
}
}
}
}

async fn send_request(
&self,
client: &reqwest::Client,
req: HttpRequest,
body: Bytes,
worker_url: String,
route: &str,
) -> HttpResponse {
let text = get_text_from_request(&body, route);
match client.get(format!("{}{}", worker_url, route)).send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e)),
}
}
Err(e) => HttpResponse::InternalServerError().body(format!(
"Failed to send request to worker {}: {}",
worker_url, e
)),
}
}

pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse {
match self.select_first_worker() {
Ok(worker_url) => self.send_request(client, worker_url, route).await,
Err(e) => HttpResponse::InternalServerError().body(e),
}
}

fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
// convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();

if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}

return "".to_string();
}

// TODO: return Result<String, String> instead of panicking
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
let text = self.get_text_from_request(&body, route);

let worker_url = match self {
Router::RoundRobin {
Expand Down Expand Up @@ -366,12 +392,23 @@ impl Router {
}
};

worker_url
}

async fn send_generate_request(
&self,
client: &reqwest::Client,
req: HttpRequest,
body: Bytes,
route: &str,
worker_url: &str,
) -> HttpResponse {
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false);

let res = match client
.post(format!("{}/{}", worker_url.clone(), route))
.post(format!("{}{}", worker_url, route))
.header(
"Content-Type",
req.headers()
Expand Down Expand Up @@ -403,7 +440,7 @@ impl Router {
// Then decrement running queue counter if using CacheAware
if let Router::CacheAware { running_queue, .. } = self {
if let Ok(mut queue) = running_queue.lock() {
if let Some(count) = queue.get_mut(&worker_url) {
if let Some(count) = queue.get_mut(worker_url) {
*count = count.saturating_sub(1);
}
}
Expand All @@ -412,7 +449,7 @@ impl Router {
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
let worker_url = worker_url.clone();
let worker_url = worker_url.to_string();

HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
Expand All @@ -431,7 +468,7 @@ impl Router {
let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1);
debug!("streaming is done!!")
debug!("Streaming is done!!")
}
}),
)
Expand All @@ -444,6 +481,18 @@ impl Router {
}
}

pub async fn route_generate_request(
&self,
client: &reqwest::Client,
req: HttpRequest,
body: Bytes,
route: &str,
) -> HttpResponse {
let worker_url = self.select_generate_worker(&body, route);
self.send_generate_request(client, req, body, route, &worker_url)
.await
}

pub async fn add_worker(&self, worker_url: String) -> Result<String, String> {
let interval_secs = 10; // check every 10 seconds
let timeout_secs = 300; // 5 minutes
Expand Down
71 changes: 14 additions & 57 deletions rust/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,84 +29,41 @@ impl AppState {
}
}

async fn forward_request(
client: &reqwest::Client,
worker_url: String,
route: String,
) -> HttpResponse {
match client.get(format!("{}{}", worker_url, route)).send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

// print the status
println!(
"Forwarding Request Worker URL: {}, Route: {}, Status: {}",
worker_url, route, status
);
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(),
}
}
Err(_) => HttpResponse::InternalServerError().finish(),
}
}

#[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};

forward_request(&data.client, worker_url, "/health".to_string()).await
data.router.route_to_first(&data.client, "/health").await
}

#[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};

forward_request(&data.client, worker_url, "/health_generate".to_string()).await
data.router
.route_to_first(&data.client, "/health_generate")
.await
}

#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};

forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
data.router
.route_to_first(&data.client, "/get_server_info")
.await
}

#[get("/v1/models")]
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};

forward_request(&data.client, worker_url, "/v1/models".to_string()).await
data.router.route_to_first(&data.client, "/v1/models").await
}

#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};

forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
data.router
.route_to_first(&data.client, "/get_model_info")
.await
}

#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
data.router
.dispatch(&data.client, req, body, "generate")
.route_generate_request(&data.client, req, body, "/generate")
.await
}

Expand All @@ -117,7 +74,7 @@ async fn v1_chat_completions(
data: web::Data<AppState>,
) -> impl Responder {
data.router
.dispatch(&data.client, req, body, "v1/chat/completions")
.route_generate_request(&data.client, req, body, "/v1/chat/completions")
.await
}

Expand All @@ -128,7 +85,7 @@ async fn v1_completions(
data: web::Data<AppState>,
) -> impl Responder {
data.router
.dispatch(&data.client, req, body, "v1/completions")
.route_generate_request(&data.client, req, body, "/v1/completions")
.await
}

Expand Down

0 comments on commit d4de9a6

Please sign in to comment.