Skip to content

Commit

Permalink
pass the wallet ID
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg committed Oct 26, 2023
1 parent ebed946 commit 2d6e0fb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
OnDemandCluster,
)
from models_library.users import UserID
from models_library.wallets import WalletID
from servicelib.rabbitmq import (
RabbitMQRPCClient,
RemoteMethodNotRegisteredError,
Expand All @@ -24,15 +25,18 @@


async def get_or_create_on_demand_cluster(
user_id: UserID, rabbitmq_rpc_client: RabbitMQRPCClient
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
user_id: UserID,
wallet_id: WalletID | None,
) -> BaseCluster:
try:
returned_cluster: OnDemandCluster = await rabbitmq_rpc_client.request(
RPCNamespace("clusters-keeper"),
RPCMethodName("get_or_create_cluster"),
timeout_s=300,
user_id=user_id,
wallet_id=None, # NOTE: --> MD this will need to be replaced by the real walletID
wallet_id=wallet_id,
)
_logger.info("received cluster: %s", returned_cluster)
if returned_cluster.state is not ClusterState.RUNNING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ async def _cluster_dask_client(
cluster: BaseCluster = scheduler.settings.default_cluster
if pipeline_params.use_on_demand_clusters:
cluster = await get_or_create_on_demand_cluster(
user_id, scheduler.rabbitmq_rpc_client
scheduler.rabbitmq_rpc_client,
user_id=user_id,
wallet_id=pipeline_params.run_metadata.get("wallet_id"),
)
if pipeline_params.cluster_id != DEFAULT_CLUSTER_ID:
clusters_repo = ClustersRepository.instance(scheduler.db_engine)
Expand Down

0 comments on commit 2d6e0fb

Please sign in to comment.