Skip to content

Commit

Permalink
feat: allocate stack with best price (#30)
Browse files Browse the repository at this point in the history
* feat: allocate stack with best price

* rename cheapest task to node

* rename task to node
  • Loading branch information
Cifko authored Nov 19, 2024
1 parent a82f82a commit be02541
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 23 deletions.
34 changes: 12 additions & 22 deletions atoma-proxy/src/server/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ use super::streamer::Streamer;
/// and is used to process chat-based requests for AI model inference.
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";

// The default value when creating a new stack entry.
// TODO: Make this configurable or compute from the available stacks subscriptions.
const STACK_ENTRY_COMPUTE_UNITS: u64 = 1000;

// The default price for a new stack entry.
// TODO: Make this configurable or compute from the available stacks subscriptions.
const STACK_ENTRY_PRICE: u64 = 100;

/// The interval for the keep-alive message in the SSE stream.
const STREAM_KEEP_ALIVE_INTERVAL_IN_SECONDS: u64 = 15;

Expand Down Expand Up @@ -665,15 +657,15 @@ async fn get_selected_node(
if stacks.is_empty() {
let (result_sender, result_receiver) = oneshot::channel();
state_manager_sender
.send(AtomaAtomaStateManagerEvent::GetTasksForModel {
.send(AtomaAtomaStateManagerEvent::GetCheapestNodeForModel {
model: model.to_string(),
result_sender,
})
.map_err(|err| {
error!("Failed to send GetTasksForModel event: {:?}", err);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let tasks = result_receiver
let node = result_receiver
.await
.map_err(|err| {
error!("Failed to receive GetTasksForModel result: {:?}", err);
Expand All @@ -683,22 +675,20 @@ async fn get_selected_node(
error!("Failed to get GetTasksForModel result: {:?}", err);
StatusCode::INTERNAL_SERVER_ERROR
})?;
if tasks.is_empty() {
error!("No tasks found for model {}", model);
return Err(StatusCode::NOT_FOUND);
}
// TODO: What should be the default values for the stack entry/price?
if total_tokens > STACK_ENTRY_COMPUTE_UNITS {
error!("Total tokens exceeds maximum limit of {STACK_ENTRY_COMPUTE_UNITS}");
return Err(StatusCode::BAD_REQUEST);
}
let node: atoma_state::types::CheapestNode = match node {
Some(node) => node,
None => {
error!("No tasks found for model {}", model);
return Err(StatusCode::NOT_FOUND);
}
};
let event = sui
.write()
.await
.acquire_new_stack_entry(
tasks[0].task_small_id as u64,
STACK_ENTRY_COMPUTE_UNITS,
STACK_ENTRY_PRICE,
node.task_small_id as u64,
node.max_num_compute_units as u64,
node.price_per_compute_unit as u64,
)
.await
.map_err(|err| {
Expand Down
18 changes: 18 additions & 0 deletions atoma-state/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,24 @@ pub(crate) async fn handle_state_manager_event(
.send(tasks)
.map_err(|_| AtomaStateManagerError::ChannelSendError)?;
}
AtomaAtomaStateManagerEvent::GetCheapestNodeForModel {
model,
result_sender,
} => {
trace!(
target = "atoma-state-handlers",
event = "handle-state-manager-event",
"Getting cheapest node for model: {}",
model
);
let node = state_manager
.state
.get_cheapest_node_for_model(&model)
.await;
result_sender
.send(node)
.map_err(|_| AtomaStateManagerError::ChannelSendError)?;
}
AtomaAtomaStateManagerEvent::UpsertNodePublicAddress {
node_small_id,
public_address,
Expand Down
46 changes: 45 additions & 1 deletion atoma-state/src/state_manager.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::build_query_with_in;
use crate::handlers::{handle_atoma_event, handle_state_manager_event};
use crate::types::{
AtomaAtomaStateManagerEvent, NodeSubscription, Stack, StackAttestationDispute,
AtomaAtomaStateManagerEvent, CheapestNode, NodeSubscription, Stack, StackAttestationDispute,
StackSettlementTicket, Task,
};

Expand Down Expand Up @@ -284,6 +284,50 @@ impl AtomaState {
.collect()
}

/// Get node settings for model with the cheapest price (based on the current node subscription).
///
/// This method fetches the task from the database that is associated with
/// the given model through the `tasks` table and has the cheapest price per compute unit.
/// The price is determined based on the node subscription for the task.
///
/// # Arguments
///
/// * `model` - The model name for the task.
///
/// # Returns
///
/// - `Result<Option<CheapestNode>>>`: A result containing either:
/// - `Ok(Some<CheapestNode>)`: A `CheapestNode` object representing the node setting with the cheapest price.
/// - `Ok(None)`: If no task is found for the given model.
/// - `Err(AtomaStateManagerError)`: An error if the database query fails or if there's an issue parsing the results.
///
/// # Errors
///
/// This function will return an error if the database query fails.
#[instrument(level = "trace", skip_all, fields(%model))]
pub async fn get_cheapest_node_for_model(&self, model: &str) -> Result<Option<CheapestNode>> {
let node_settings = sqlx::query(
"SELECT tasks.task_small_id, price_per_compute_unit, max_num_compute_units
FROM (SELECT *
FROM tasks
WHERE is_deprecated=false
AND model_name = $1) AS tasks
JOIN (SELECT *
FROM node_subscriptions
WHERE valid = true) AS node_subscriptions
ON tasks.task_small_id=node_subscriptions.task_small_id
ORDER BY node_subscriptions.price_per_compute_unit
LIMIT 1",
)
.bind(model)
.bind(false)
.fetch_optional(&self.db)
.await?;
Ok(node_settings
.map(|node_settings| CheapestNode::from_row(&node_settings))
.transpose()?)
}

/// Retrieves all tasks from the database.
///
/// This method fetches all task records from the `tasks` table in the database.
Expand Down
15 changes: 15 additions & 0 deletions atoma-state/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ impl From<TaskRegisteredEvent> for Task {
}
}

/// Represents the cheapest node settings for a specific model
#[derive(FromRow)]
pub struct CheapestNode {
/// Unique small integer identifier for the task
pub task_small_id: i64,
/// Price per compute unit for the task that is offered by some node
pub price_per_compute_unit: i64,
/// Maximum number of compute units for the task that is offered by the cheapest node
pub max_num_compute_units: i64,
}

/// Represents a stack of compute units for a specific task
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromRow)]
pub struct Stack {
Expand Down Expand Up @@ -232,6 +243,10 @@ pub enum AtomaAtomaStateManagerEvent {
model: String,
result_sender: oneshot::Sender<Result<Vec<Task>>>,
},
GetCheapestNodeForModel {
model: String,
result_sender: oneshot::Sender<Result<Option<CheapestNode>>>,
},
UpsertNodePublicAddress {
node_small_id: i64,
public_address: String,
Expand Down

0 comments on commit be02541

Please sign in to comment.