Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allocate stack with best price #30

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 12 additions & 22 deletions atoma-proxy/src/server/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ use super::streamer::Streamer;
/// and is used to process chat-based requests for AI model inference.
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";

// The default value when creating a new stack entry.
// TODO: Make this configurable or compute from the available stacks subscriptions.
const STACK_ENTRY_COMPUTE_UNITS: u64 = 1000;

// The default price for a new stack entry.
// TODO: Make this configurable or compute from the available stacks subscriptions.
const STACK_ENTRY_PRICE: u64 = 100;

/// The interval for the keep-alive message in the SSE stream.
const STREAM_KEEP_ALIVE_INTERVAL_IN_SECONDS: u64 = 15;

Expand Down Expand Up @@ -619,15 +611,15 @@ async fn get_selected_node(
if stacks.is_empty() {
let (result_sender, result_receiver) = oneshot::channel();
state_manager_sender
.send(AtomaAtomaStateManagerEvent::GetTasksForModel {
.send(AtomaAtomaStateManagerEvent::GetCheapestNodeForModel {
model: model.to_string(),
result_sender,
})
.map_err(|err| {
error!("Failed to send GetTasksForModel event: {:?}", err);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let tasks = result_receiver
let node = result_receiver
.await
.map_err(|err| {
error!("Failed to receive GetTasksForModel result: {:?}", err);
Expand All @@ -637,22 +629,20 @@ async fn get_selected_node(
error!("Failed to get GetTasksForModel result: {:?}", err);
StatusCode::INTERNAL_SERVER_ERROR
})?;
if tasks.is_empty() {
error!("No tasks found for model {}", model);
return Err(StatusCode::NOT_FOUND);
}
// TODO: What should be the default values for the stack entry/price?
if total_tokens > STACK_ENTRY_COMPUTE_UNITS {
error!("Total tokens exceeds maximum limit of {STACK_ENTRY_COMPUTE_UNITS}");
return Err(StatusCode::BAD_REQUEST);
}
let node: atoma_state::types::CheapestNode = match node {
Some(node) => node,
None => {
error!("No tasks found for model {}", model);
return Err(StatusCode::NOT_FOUND);
}
};
let event = sui
.write()
.await
.acquire_new_stack_entry(
tasks[0].task_small_id as u64,
STACK_ENTRY_COMPUTE_UNITS,
STACK_ENTRY_PRICE,
node.task_small_id as u64,
node.max_num_compute_units as u64,
node.price_per_compute_unit as u64,
)
.await
.map_err(|err| {
Expand Down
18 changes: 18 additions & 0 deletions atoma-state/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,24 @@ pub(crate) async fn handle_state_manager_event(
.send(tasks)
.map_err(|_| AtomaStateManagerError::ChannelSendError)?;
}
AtomaAtomaStateManagerEvent::GetCheapestNodeForModel {
model,
result_sender,
} => {
trace!(
target = "atoma-state-handlers",
event = "handle-state-manager-event",
"Getting cheapest node for model: {}",
model
);
let node = state_manager
.state
.get_cheapest_node_for_model(&model)
.await;
result_sender
.send(node)
.map_err(|_| AtomaStateManagerError::ChannelSendError)?;
}
AtomaAtomaStateManagerEvent::UpsertNodePublicAddress {
node_small_id,
public_address,
Expand Down
46 changes: 45 additions & 1 deletion atoma-state/src/state_manager.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::build_query_with_in;
use crate::handlers::{handle_atoma_event, handle_state_manager_event};
use crate::types::{
AtomaAtomaStateManagerEvent, NodeSubscription, Stack, StackAttestationDispute,
AtomaAtomaStateManagerEvent, CheapestNode, NodeSubscription, Stack, StackAttestationDispute,
StackSettlementTicket, Task,
};

Expand Down Expand Up @@ -284,6 +284,50 @@ impl AtomaState {
.collect()
}

/// Get node settings for model with the cheapest price (based on the current node subscription).
///
/// This method fetches the task from the database that is associated with
/// the given model through the `tasks` table and has the cheapest price per compute unit.
/// The price is determined based on the node subscription for the task.
///
/// # Arguments
///
/// * `model` - The model name for the task.
///
/// # Returns
///
/// - `Result<Option<CheapestNode>>>`: A result containing either:
/// - `Ok(Some<CheapestNode>)`: A `CheapestNode` object representing the node setting with the cheapest price.
/// - `Ok(None)`: If no task is found for the given model.
/// - `Err(AtomaStateManagerError)`: An error if the database query fails or if there's an issue parsing the results.
///
/// # Errors
///
/// This function will return an error if the database query fails.
#[instrument(level = "trace", skip_all, fields(%model))]
pub async fn get_cheapest_node_for_model(&self, model: &str) -> Result<Option<CheapestNode>> {
let node_settings = sqlx::query(
"SELECT tasks.task_small_id, price_per_compute_unit, max_num_compute_units
FROM (SELECT *
FROM tasks
WHERE is_deprecated=false
AND model_name = $1) AS tasks
JOIN (SELECT *
FROM node_subscriptions
WHERE valid = true) AS node_subscriptions
ON tasks.task_small_id=node_subscriptions.task_small_id
ORDER BY node_subscriptions.price_per_compute_unit
LIMIT 1",
)
.bind(model)
Cifko marked this conversation as resolved.
Show resolved Hide resolved
.bind(false)
.fetch_optional(&self.db)
.await?;
Ok(node_settings
.map(|node_settings| CheapestNode::from_row(&node_settings))
.transpose()?)
}

/// Retrieves all tasks from the database.
///
/// This method fetches all task records from the `tasks` table in the database.
Expand Down
15 changes: 15 additions & 0 deletions atoma-state/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ impl From<TaskRegisteredEvent> for Task {
}
}

/// Represents the cheapest node settings for a specific model
#[derive(FromRow)]
pub struct CheapestNode {
/// Unique small integer identifier for the task
pub task_small_id: i64,
/// Price per compute unit for the task that is offered by some node
pub price_per_compute_unit: i64,
/// Maximum number of compute units for the task that is offered by the cheapest node
pub max_num_compute_units: i64,
}

/// Represents a stack of compute units for a specific task
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromRow)]
pub struct Stack {
Expand Down Expand Up @@ -232,6 +243,10 @@ pub enum AtomaAtomaStateManagerEvent {
model: String,
result_sender: oneshot::Sender<Result<Vec<Task>>>,
},
GetCheapestNodeForModel {
model: String,
result_sender: oneshot::Sender<Result<Option<CheapestNode>>>,
},
UpsertNodePublicAddress {
node_small_id: i64,
public_address: String,
Expand Down