Skip to content

Commit

Permalink
Merge pull request #92 from firstbatchxyz/erhant/rpc-links-some-rfks
Browse files Browse the repository at this point in the history
`AvailableNodes` logic, handler refactors, workflows re-export
  • Loading branch information
erhant authored Aug 16, 2024
2 parents 4c71e9a + a27c6ad commit af4a67d
Show file tree
Hide file tree
Showing 17 changed files with 328 additions and 171 deletions.
12 changes: 6 additions & 6 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
## DRIA (required) ##
# Secret key of your compute node (32 byte, hexadecimal, without 0x prefix).
# e.g.: DKN_WALLET_SECRET_KEY=ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80
# Secret key of your compute node, 32 byte in hexadecimal.
# e.g.: DKN_WALLET_SECRET_KEY=0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80
DKN_WALLET_SECRET_KEY=
# model1,model2,model3,... (comma separated, case-insensitive)
DKN_MODELS=phi3:3.8b
# Public key of Dria Admin node (33-byte compressed, hexadecimal, without 0x prefix).
# Public key of Dria Admin node, 33-byte (compressed) in hexadecimal.
# You don't need to change this, simply copy and paste it.
DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110be6ae658
# model1,model2,model3,... (comma separated, case-insensitive)
DKN_MODELS=phi3:3.8b

## DRIA (optional) ##
# info | debug | error | none,dkn_compute=debug
Expand All @@ -22,7 +22,7 @@ DKN_BOOTSTRAP_NODES=
OPENAI_API_KEY=

## Ollama (if used, optional) ##
# do not change the host, it is used by Docker
# do not change this, it is used by Docker
OLLAMA_HOST=http://host.docker.internal
# you can change the port if you would like
OLLAMA_PORT=11434
Expand Down
35 changes: 16 additions & 19 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dkn-compute"
version = "0.1.3"
version = "0.1.4"
edition = "2021"
license = "Apache-2.0"
readme = "README.md"
Expand All @@ -12,6 +12,7 @@ parking_lot = "0.12.2"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
async-trait = "0.1.81"
reqwest = "0.12.5"

# utilities
base64 = "0.22.0"
Expand All @@ -34,8 +35,7 @@ sha3 = "0.10.8"
fastbloom-rs = "0.5.9"

# workflows
ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows", rev = "274b26e" }
ollama-rs = "0.2.0"
ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows", rev = "25467d2" }

# peer-to-peer
libp2p = { version = "0.53", features = [
Expand Down
2 changes: 1 addition & 1 deletion compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ services:
JINA_API_KEY: ${JINA_API_KEY}
OLLAMA_HOST: ${OLLAMA_HOST}
OLLAMA_PORT: ${OLLAMA_PORT}
OLLAMA_AUTO_PULL: ${OLLAMA_AUTO_PULL}
OLLAMA_AUTO_PULL: ${OLLAMA_AUTO_PULL:-true}
network_mode: "host"
extra_hosts:
# for Linux, we need to add this line manually
Expand Down
2 changes: 1 addition & 1 deletion examples/common/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::SystemTime;

use ollama_rs::{
use ollama_workflows::ollama_rs::{
generation::completion::{request::GenerationRequest, GenerationResponse},
Ollama,
};
Expand Down
6 changes: 3 additions & 3 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ impl DriaComputeNodeConfig {
pub fn new() -> Self {
let secret_key = match env::var("DKN_WALLET_SECRET_KEY") {
Ok(secret_env) => {
let secret_dec =
hex::decode(secret_env).expect("Secret key should be 32-bytes hex encoded.");
let secret_dec = hex::decode(secret_env.trim_start_matches("0x"))
.expect("Secret key should be 32-bytes hex encoded.");
SecretKey::parse_slice(&secret_dec).expect("Secret key should be parseable.")
}
Err(err) => {
Expand All @@ -60,7 +60,7 @@ impl DriaComputeNodeConfig {

let admin_public_key = match env::var("DKN_ADMIN_PUBLIC_KEY") {
Ok(admin_public_key) => {
let pubkey_dec = hex::decode(admin_public_key)
let pubkey_dec = hex::decode(admin_public_key.trim_start_matches("0x"))
.expect("Admin public key should be 33-bytes hex encoded.");
PublicKey::parse_slice(&pubkey_dec, None)
.expect("Admin public key should be parseable.")
Expand Down
2 changes: 1 addition & 1 deletion src/config/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ollama_rs::Ollama;
use ollama_workflows::ollama_rs::Ollama;

const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1";
const DEFAULT_OLLAMA_PORT: u16 = 11434;
Expand Down
11 changes: 10 additions & 1 deletion src/errors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ollama_rs::error::OllamaError;
use ollama_workflows::ollama_rs::error::OllamaError;

/// Alias for `Result<T, NodeError>`.
pub type NodeResult<T> = std::result::Result<T, NodeError>;
Expand Down Expand Up @@ -96,6 +96,15 @@ impl From<libp2p::gossipsub::SubscriptionError> for NodeError {
}
}

impl From<reqwest::Error> for NodeError {
fn from(value: reqwest::Error) -> Self {
Self {
message: value.to_string(),
source: "reqwest".to_string(),
}
}
}

impl From<libp2p::gossipsub::PublishError> for NodeError {
fn from(value: libp2p::gossipsub::PublishError) -> Self {
Self {
Expand Down
18 changes: 16 additions & 2 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
use async_trait::async_trait;
use libp2p::gossipsub::MessageAcceptance;

mod pingpong;
pub use pingpong::HandlesPingpong;
pub use pingpong::PingpongHandler;

mod workflow;
pub use workflow::HandlesWorkflow;
pub use workflow::WorkflowHandler;

use crate::{errors::NodeResult, p2p::P2PMessage, DriaComputeNode};

#[async_trait]
pub trait ComputeHandler {
async fn handle_compute(
node: &mut DriaComputeNode,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance>;
}
28 changes: 12 additions & 16 deletions src/handlers/pingpong.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use crate::{
errors::NodeResult, node::DriaComputeNode, p2p::P2PMessage, utils::get_current_time_nanos,
};
use async_trait::async_trait;
use libp2p::gossipsub::MessageAcceptance;
use ollama_workflows::{Model, ModelProvider};
use serde::{Deserialize, Serialize};

use super::ComputeHandler;

pub struct PingpongHandler;

#[derive(Serialize, Deserialize, Debug, Clone)]
struct PingpongPayload {
uuid: String,
Expand All @@ -18,19 +23,10 @@ struct PingpongResponse {
pub(crate) timestamp: u128,
}

/// A ping-pong is a message sent by a node to indicate that it is alive.
/// Compute nodes listen to `pong` topic, and respond to `ping` topic.
pub trait HandlesPingpong {
fn handle_heartbeat(
&mut self,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance>;
}

impl HandlesPingpong for DriaComputeNode {
fn handle_heartbeat(
&mut self,
#[async_trait]
impl ComputeHandler for PingpongHandler {
async fn handle_compute(
node: &mut DriaComputeNode,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance> {
Expand All @@ -53,15 +49,15 @@ impl HandlesPingpong for DriaComputeNode {
// respond
let response_body = PingpongResponse {
uuid: pingpong.uuid.clone(),
models: self.config.model_config.models.clone(),
models: node.config.model_config.models.clone(),
timestamp: get_current_time_nanos(),
};
let response = P2PMessage::new_signed(
serde_json::json!(response_body).to_string(),
result_topic,
&self.config.secret_key,
&node.config.secret_key,
);
self.publish(response)?;
node.publish(response)?;

// accept message, someone else may be included in the filter
Ok(MessageAcceptance::Accept)
Expand Down
31 changes: 13 additions & 18 deletions src/handlers/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ use crate::p2p::P2PMessage;
use crate::utils::get_current_time_nanos;
use crate::utils::payload::{TaskRequest, TaskRequestPayload};

use super::ComputeHandler;

pub struct WorkflowHandler;

#[derive(Debug, Deserialize)]
struct WorkflowPayload {
/// Workflow object to be parsed.
Expand All @@ -23,18 +27,9 @@ struct WorkflowPayload {
}

#[async_trait]
pub trait HandlesWorkflow {
async fn handle_workflow(
&mut self,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance>;
}

#[async_trait]
impl HandlesWorkflow for DriaComputeNode {
async fn handle_workflow(
&mut self,
impl ComputeHandler for WorkflowHandler {
async fn handle_compute(
node: &mut DriaComputeNode,
message: P2PMessage,
result_topic: &str,
) -> NodeResult<MessageAcceptance> {
Expand All @@ -55,7 +50,7 @@ impl HandlesWorkflow for DriaComputeNode {
}

// check task inclusion via the bloom filter
if !task.filter.contains(&self.config.address)? {
if !task.filter.contains(&node.config.address)? {
log::info!(
"Task {} does not include this node within the filter.",
task.task_id
Expand All @@ -75,7 +70,7 @@ impl HandlesWorkflow for DriaComputeNode {
};

// read model / provider from the task
let (model_provider, model) = self
let (model_provider, model) = node
.config
.model_config
.get_any_matching_model(task.input.model)?;
Expand All @@ -85,8 +80,8 @@ impl HandlesWorkflow for DriaComputeNode {
let executor = if model_provider == ModelProvider::Ollama {
Executor::new_at(
model,
&self.config.ollama_config.host,
self.config.ollama_config.port,
&node.config.ollama_config.host,
node.config.ollama_config.port,
)
} else {
Executor::new(model)
Expand All @@ -98,7 +93,7 @@ impl HandlesWorkflow for DriaComputeNode {
.map(|prompt| Entry::try_value_or_str(&prompt));
let result: Option<String>;
tokio::select! {
_ = self.cancellation.cancelled() => {
_ = node.cancellation.cancelled() => {
log::info!("Received cancellation, quitting all tasks.");
return Ok(MessageAcceptance::Accept)
},
Expand All @@ -113,7 +108,7 @@ impl HandlesWorkflow for DriaComputeNode {
let result = result.ok_or::<String>(format!("No result for task {}", task.task_id))?;

// publish the result
self.send_result(result_topic, &task.public_key, &task.task_id, result)?;
node.send_result(result_topic, &task.public_key, &task.task_id, result)?;

// accept message, someone else may be included in the filter
Ok(MessageAcceptance::Accept)
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
config.check_services().await?;

// launch the node
let mut node = DriaComputeNode::new(config, CancellationToken::new())?;
let mut node = DriaComputeNode::new(config, CancellationToken::new()).await?;
node.launch().await?;

Ok(())
Expand Down
Loading

0 comments on commit af4a67d

Please sign in to comment.