Skip to content

Commit

Permalink
Refactor Store Api into client side and driver side
Browse files Browse the repository at this point in the history
Refactors the Store api into the driver (backend) implementation
and a client Store/StoreLike api. Store & StoreLike have their
sizes known at compile-time. This enables us to add templates
to the client-side to make it easier to work with, for example,
we no longer need to Pin the store and we'll be able to add
things like `digest: impl Into<DigestInfo>` to make the callers
life much easier.

towards: TraceMachina#934
  • Loading branch information
allada committed May 28, 2024
1 parent 9c45e86 commit f3c1657
Show file tree
Hide file tree
Showing 47 changed files with 1,040 additions and 1,210 deletions.
22 changes: 7 additions & 15 deletions nativelink-scheduler/src/cache_lookup_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;

use async_trait::async_trait;
Expand All @@ -30,7 +29,7 @@ use nativelink_util::action_messages::{
use nativelink_util::background_spawn;
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::DigestHasherFunc;
use nativelink_util::store_trait::Store;
use nativelink_util::store_trait::{Store, StoreLike};
use parking_lot::{Mutex, MutexGuard};
use scopeguard::guard;
use tokio::select;
Expand All @@ -50,7 +49,7 @@ type CheckActions = HashMap<ActionInfoHashKey, Arc<watch::Sender<Arc<ActionState
pub struct CacheLookupScheduler {
/// A reference to the AC to find existing actions in.
/// To prevent unintended issues, this store should probably be a CompletenessCheckingStore.
ac_store: Arc<dyn Store>,
ac_store: Store,
/// The "real" scheduler to use to perform actions if they were not found
/// in the action cache.
action_scheduler: Arc<dyn ActionScheduler>,
Expand All @@ -59,14 +58,13 @@ pub struct CacheLookupScheduler {
}

async fn get_action_from_store(
ac_store: Pin<&dyn Store>,
ac_store: &Store,
action_digest: DigestInfo,
instance_name: String,
digest_function: DigestHasherFunc,
) -> Option<ProtoActionResult> {
// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = ac_store.inner_store(Some(action_digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = ac_store.downcast_ref::<GrpcStore>(Some(action_digest)) {
let action_result_request = GetActionResultRequest {
instance_name,
action_digest: Some(action_digest.into()),
Expand Down Expand Up @@ -103,10 +101,7 @@ fn subscribe_to_existing_action(
}

impl CacheLookupScheduler {
pub fn new(
ac_store: Arc<dyn Store>,
action_scheduler: Arc<dyn ActionScheduler>,
) -> Result<Self, Error> {
pub fn new(ac_store: Store, action_scheduler: Arc<dyn ActionScheduler>) -> Result<Self, Error> {
Ok(Self {
ac_store,
action_scheduler,
Expand Down Expand Up @@ -169,17 +164,14 @@ impl ActionScheduler for CacheLookupScheduler {
let action_digest = current_state.action_digest();
let instance_name = action_info.instance_name().clone();
if let Some(action_result) = get_action_from_store(
Pin::new(ac_store.as_ref()),
&ac_store,
*action_digest,
instance_name,
current_state.unique_qualifier.digest_function,
)
.await
{
match Pin::new(ac_store.clone().as_ref())
.has(*action_digest)
.await
{
match ac_store.has(*action_digest).await {
Ok(Some(_)) => {
Arc::make_mut(&mut current_state).stage =
ActionStage::CompletedFromCache(action_result);
Expand Down
13 changes: 6 additions & 7 deletions nativelink-scheduler/tests/cache_lookup_scheduler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::UNIX_EPOCH;

Expand All @@ -33,7 +32,7 @@ use nativelink_store::memory_store::MemoryStore;
use nativelink_util::action_messages::{ActionInfoHashKey, ActionResult, ActionStage, ActionState};
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::DigestHasherFunc;
use nativelink_util::store_trait::Store;
use nativelink_util::store_trait::{Store, StoreLike};
use prost::Message;
use tokio::sync::watch;
use tokio::{self};
Expand All @@ -42,15 +41,15 @@ use utils::scheduler_utils::{make_base_action_info, INSTANCE_NAME};

struct TestContext {
mock_scheduler: Arc<MockActionScheduler>,
ac_store: Arc<dyn Store>,
ac_store: Store,
cache_scheduler: CacheLookupScheduler,
}

fn make_cache_scheduler() -> Result<TestContext, Error> {
let mock_scheduler = Arc::new(MockActionScheduler::new());
let ac_store = Arc::new(MemoryStore::new(
let ac_store = Store::new(Arc::new(MemoryStore::new(
&nativelink_config::stores::MemoryStore::default(),
));
)));
let cache_scheduler = CacheLookupScheduler::new(ac_store.clone(), mock_scheduler.clone())?;
Ok(TestContext {
mock_scheduler,
Expand Down Expand Up @@ -91,8 +90,8 @@ mod cache_lookup_scheduler_tests {
let context = make_cache_scheduler()?;
let action_info = make_base_action_info(UNIX_EPOCH);
let action_result = ProtoActionResult::from(ActionResult::default());
let store_pin = Pin::new(context.ac_store.as_ref());
store_pin
context
.ac_store
.update_oneshot(*action_info.digest(), action_result.encode_to_vec().into())
.await?;
let (_forward_watch_channel_tx, forward_watch_channel_rx) =
Expand Down
18 changes: 7 additions & 11 deletions nativelink-service/src/ac_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

use std::collections::HashMap;
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;

use bytes::BytesMut;
use nativelink_config::cas_server::{AcStoreConfig, InstanceName};
Expand All @@ -31,14 +29,14 @@ use nativelink_store::grpc_store::GrpcStore;
use nativelink_store::store_manager::StoreManager;
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::make_ctx_for_hash_func;
use nativelink_util::store_trait::Store;
use nativelink_util::store_trait::{Store, StoreLike};
use prost::Message;
use tonic::{Request, Response, Status};
use tracing::{error_span, event, instrument, Level};

#[derive(Clone)]
pub struct AcStoreInfo {
store: Arc<dyn Store>,
store: Store,
read_only: bool,
}

Expand Down Expand Up @@ -97,14 +95,12 @@ impl AcServer {
.try_into()?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store_info.store.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store_info.store.downcast_ref::<GrpcStore>(Some(digest)) {
return grpc_store.get_action_result(Request::new(request)).await;
}

Ok(Response::new(
get_and_decode_digest::<ActionResult>(Pin::new(store_info.store.as_ref()), &digest)
.await?,
get_and_decode_digest::<ActionResult>(&store_info.store, &digest).await?,
))
}

Expand Down Expand Up @@ -132,8 +128,7 @@ impl AcServer {
.try_into()?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store_info.store.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store_info.store.downcast_ref::<GrpcStore>(Some(digest)) {
return grpc_store.update_action_result(Request::new(request)).await;
}

Expand All @@ -146,7 +141,8 @@ impl AcServer {
.encode(&mut store_data)
.err_tip(|| "Provided ActionResult could not be serialized")?;

Pin::new(store_info.store.as_ref())
store_info
.store
.update_oneshot(digest, store_data.freeze())
.await
.err_tip(|| "Failed to update in action cache")?;
Expand Down
25 changes: 11 additions & 14 deletions nativelink-service/src/bytestream_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use nativelink_util::digest_hasher::{
use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper;
use nativelink_util::resource_info::ResourceInfo;
use nativelink_util::spawn;
use nativelink_util::store_trait::{Store, UploadSizeInfo};
use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo};
use nativelink_util::task::JoinHandleDropGuard;
use parking_lot::Mutex;
use tokio::time::sleep;
Expand Down Expand Up @@ -153,7 +153,7 @@ type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>);
type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;

pub struct ByteStreamServer {
stores: HashMap<String, Arc<dyn Store>>,
stores: HashMap<String, Store>,
// Max number of bytes to send on each grpc stream chunk.
max_bytes_per_stream: usize,
active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>,
Expand Down Expand Up @@ -206,7 +206,7 @@ impl ByteStreamServer {
fn create_or_join_upload_stream(
&self,
uuid: String,
store: Arc<dyn Store>,
store: Store,
digest: DigestInfo,
) -> Result<ActiveStreamGuard<'_>, Error> {
let (uuid, bytes_received) = match self.active_uploads.lock().entry(uuid) {
Expand Down Expand Up @@ -236,7 +236,7 @@ impl ByteStreamServer {
let store_update_fut = Box::pin(async move {
// We need to wrap `Store::update()` in a another future because we need to capture
// `store` to ensure it's lifetime follows the future and not the caller.
Pin::new(store.as_ref())
store
// Bytestream always uses digest size as the actual byte size.
.update(
digest,
Expand All @@ -260,7 +260,7 @@ impl ByteStreamServer {

async fn inner_read(
&self,
store: Arc<dyn Store>,
store: Store,
digest: DigestInfo,
read_request: ReadRequest,
) -> Result<Response<ReadStream>, Error> {
Expand Down Expand Up @@ -289,7 +289,7 @@ impl ByteStreamServer {
maybe_get_part_result: None,
get_part_fut: Box::pin(async move {
store
.get_part_arc(digest, tx, read_request.read_offset as usize, read_limit)
.get_part(digest, tx, read_request.read_offset as usize, read_limit)
.await
}),
});
Expand Down Expand Up @@ -383,7 +383,7 @@ impl ByteStreamServer {
)]
async fn inner_write(
&self,
store: Arc<dyn Store>,
store: Store,
digest: DigestInfo,
stream: WriteRequestStreamWrapper<Streaming<WriteRequest>, Status>,
) -> Result<Response<WriteResponse>, Error> {
Expand Down Expand Up @@ -520,8 +520,7 @@ impl ByteStreamServer {
let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store_clone.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store_clone.downcast_ref::<GrpcStore>(Some(digest)) {
return grpc_store
.query_write_status(Request::new(query_request.clone()))
.await;
Expand All @@ -544,7 +543,7 @@ impl ByteStreamServer {
}
}

let has_fut = Pin::new(store_clone.as_ref()).has(digest);
let has_fut = store_clone.has(digest);
let Some(item_size) = has_fut.await.err_tip(|| "Failed to call .has() on store")? else {
return Err(make_err!(Code::NotFound, "{}", "not found"));
};
Expand Down Expand Up @@ -583,8 +582,7 @@ impl ByteStream for ByteStreamServer {
let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(Some(digest)) {
let stream = grpc_store.read(Request::new(read_request)).await?;
return Ok(Response::new(Box::pin(stream)));
}
Expand Down Expand Up @@ -640,8 +638,7 @@ impl ByteStream for ByteStreamServer {
.err_tip(|| "Invalid digest input in ByteStream::write")?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(Some(digest)) {
return grpc_store.write(stream).await.map_err(|e| e.into());
}

Expand Down
27 changes: 11 additions & 16 deletions nativelink-service/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use std::collections::{HashMap, VecDeque};
use std::pin::Pin;
use std::sync::Arc;

use bytes::Bytes;
use futures::stream::{FuturesUnordered, Stream};
Expand All @@ -35,12 +34,12 @@ use nativelink_store::grpc_store::GrpcStore;
use nativelink_store::store_manager::StoreManager;
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::make_ctx_for_hash_func;
use nativelink_util::store_trait::Store;
use nativelink_util::store_trait::{Store, StoreLike};
use tonic::{Request, Response, Status};
use tracing::{error_span, event, instrument, Level};

pub struct CasServer {
stores: HashMap<String, Arc<dyn Store>>,
stores: HashMap<String, Store>,
}

type GetTreeStream = Pin<Box<dyn Stream<Item = Result<GetTreeResponse, Status>> + Send + 'static>>;
Expand Down Expand Up @@ -79,7 +78,7 @@ impl CasServer {
for digest in request.blob_digests.iter() {
requested_blobs.push(DigestInfo::try_from(digest.clone())?);
}
let sizes = Pin::new(store.as_ref())
let sizes = store
.has_many(&requested_blobs)
.await
.err_tip(|| "In find_missing_blobs")?;
Expand Down Expand Up @@ -109,12 +108,11 @@ impl CasServer {
// If we are a GrpcStore we shortcut here, as this is a special store.
// Note: We don't know the digests here, so we try perform a very shallow
// check to see if it's a grpc store.
let any_store = store.inner_store(None).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(None) {
return grpc_store.batch_update_blobs(Request::new(request)).await;
}

let store_pin = Pin::new(store.as_ref());
let store_ref = &store;
let update_futures: FuturesUnordered<_> = request
.requests
.into_iter()
Expand All @@ -133,7 +131,7 @@ impl CasServer {
size_bytes,
request_data.len()
);
let result = store_pin
let result = store_ref
.update_oneshot(digest_info, request_data)
.await
.err_tip(|| "Error writing to store");
Expand Down Expand Up @@ -165,19 +163,18 @@ impl CasServer {
// If we are a GrpcStore we shortcut here, as this is a special store.
// Note: We don't know the digests here, so we try perform a very shallow
// check to see if it's a grpc store.
let any_store = store.inner_store(None).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(None) {
return grpc_store.batch_read_blobs(Request::new(request)).await;
}

let store_pin = Pin::new(store.as_ref());
let store_ref = &store;
let read_futures: FuturesUnordered<_> = request
.digests
.into_iter()
.map(|digest| async move {
let digest_copy = DigestInfo::try_from(digest.clone())?;
// TODO(allada) There is a security risk here of someone taking all the memory on the instance.
let result = store_pin
let result = store_ref
.get_part_unchunked(digest_copy, 0, None)
.await
.err_tip(|| "Error reading from store");
Expand Down Expand Up @@ -223,15 +220,13 @@ impl CasServer {
// If we are a GrpcStore we shortcut here, as this is a special store.
// Note: We don't know the digests here, so we try perform a very shallow
// check to see if it's a grpc store.
let any_store = store.inner_store(None).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(None) {
let stream = grpc_store
.get_tree(Request::new(request))
.await?
.into_inner();
return Ok(Response::new(Box::pin(stream)));
}
let store_pin = Pin::new(store.as_ref());
let root_digest: DigestInfo = request
.root_digest
.err_tip(|| "Expected root_digest to exist in GetTreeRequest")?
Expand Down Expand Up @@ -260,7 +255,7 @@ impl CasServer {

while !deque.is_empty() {
let digest: DigestInfo = deque.pop_front().err_tip(|| "In VecDeque::pop_front")?;
let directory = get_and_decode_digest::<Directory>(store_pin, &digest)
let directory = get_and_decode_digest::<Directory>(&store, &digest)
.await
.err_tip(|| "Converting digest to Directory")?;
if digest == page_token_digest {
Expand Down
Loading

0 comments on commit f3c1657

Please sign in to comment.