Skip to content

Commit

Permalink
Refactor Store Api into client side and driver side (#935)
Browse files Browse the repository at this point in the history
Refactors the Store api into the driver (backend) implementation
and a client Store/StoreLike api. Store & StoreLike have their
sizes known at compile-time. This enables us to add templates
to the client-side to make it easier to work with, for example,
we no longer need to Pin the store and we'll be able to add
things like `digest: impl Into<DigestInfo>` to make the callers
life much easier.

towards: #934
  • Loading branch information
allada authored Jun 4, 2024
1 parent da2c4a7 commit 04beafd
Show file tree
Hide file tree
Showing 47 changed files with 1,095 additions and 1,327 deletions.
22 changes: 7 additions & 15 deletions nativelink-scheduler/src/cache_lookup_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;

use async_trait::async_trait;
Expand All @@ -30,7 +29,7 @@ use nativelink_util::action_messages::{
use nativelink_util::background_spawn;
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::DigestHasherFunc;
use nativelink_util::store_trait::Store;
use nativelink_util::store_trait::{Store, StoreLike};
use parking_lot::{Mutex, MutexGuard};
use scopeguard::guard;
use tokio::select;
Expand All @@ -50,7 +49,7 @@ type CheckActions = HashMap<ActionInfoHashKey, Arc<watch::Sender<Arc<ActionState
pub struct CacheLookupScheduler {
/// A reference to the AC to find existing actions in.
/// To prevent unintended issues, this store should probably be a CompletenessCheckingStore.
ac_store: Arc<dyn Store>,
ac_store: Store,
/// The "real" scheduler to use to perform actions if they were not found
/// in the action cache.
action_scheduler: Arc<dyn ActionScheduler>,
Expand All @@ -59,14 +58,13 @@ pub struct CacheLookupScheduler {
}

async fn get_action_from_store(
ac_store: Pin<&dyn Store>,
ac_store: &Store,
action_digest: DigestInfo,
instance_name: String,
digest_function: DigestHasherFunc,
) -> Option<ProtoActionResult> {
// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = ac_store.inner_store(Some(action_digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = ac_store.downcast_ref::<GrpcStore>(Some(action_digest)) {
let action_result_request = GetActionResultRequest {
instance_name,
action_digest: Some(action_digest.into()),
Expand Down Expand Up @@ -103,10 +101,7 @@ fn subscribe_to_existing_action(
}

impl CacheLookupScheduler {
pub fn new(
ac_store: Arc<dyn Store>,
action_scheduler: Arc<dyn ActionScheduler>,
) -> Result<Self, Error> {
pub fn new(ac_store: Store, action_scheduler: Arc<dyn ActionScheduler>) -> Result<Self, Error> {
Ok(Self {
ac_store,
action_scheduler,
Expand Down Expand Up @@ -170,17 +165,14 @@ impl ActionScheduler for CacheLookupScheduler {
let action_digest = current_state.action_digest();
let instance_name = action_info.instance_name().clone();
if let Some(action_result) = get_action_from_store(
Pin::new(ac_store.as_ref()),
&ac_store,
*action_digest,
instance_name,
current_state.id.unique_qualifier.digest_function,
)
.await
{
match Pin::new(ac_store.clone().as_ref())
.has(*action_digest)
.await
{
match ac_store.has(*action_digest).await {
Ok(Some(_)) => {
Arc::make_mut(&mut current_state).stage =
ActionStage::CompletedFromCache(action_result);
Expand Down
13 changes: 6 additions & 7 deletions nativelink-scheduler/tests/cache_lookup_scheduler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::UNIX_EPOCH;

Expand All @@ -35,7 +34,7 @@ use nativelink_util::action_messages::{
};
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::DigestHasherFunc;
use nativelink_util::store_trait::Store;
use nativelink_util::store_trait::{Store, StoreLike};
use prost::Message;
use tokio::sync::watch;
use tokio::{self};
Expand All @@ -44,15 +43,15 @@ use utils::scheduler_utils::{make_base_action_info, INSTANCE_NAME};

struct TestContext {
mock_scheduler: Arc<MockActionScheduler>,
ac_store: Arc<dyn Store>,
ac_store: Store,
cache_scheduler: CacheLookupScheduler,
}

fn make_cache_scheduler() -> Result<TestContext, Error> {
let mock_scheduler = Arc::new(MockActionScheduler::new());
let ac_store = Arc::new(MemoryStore::new(
let ac_store = Store::new(Arc::new(MemoryStore::new(
&nativelink_config::stores::MemoryStore::default(),
));
)));
let cache_scheduler = CacheLookupScheduler::new(ac_store.clone(), mock_scheduler.clone())?;
Ok(TestContext {
mock_scheduler,
Expand Down Expand Up @@ -93,8 +92,8 @@ mod cache_lookup_scheduler_tests {
let context = make_cache_scheduler()?;
let action_info = make_base_action_info(UNIX_EPOCH);
let action_result = ProtoActionResult::from(ActionResult::default());
let store_pin = Pin::new(context.ac_store.as_ref());
store_pin
context
.ac_store
.update_oneshot(*action_info.digest(), action_result.encode_to_vec().into())
.await?;
let (_forward_watch_channel_tx, forward_watch_channel_rx) =
Expand Down
18 changes: 7 additions & 11 deletions nativelink-service/src/ac_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

use std::collections::HashMap;
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;

use bytes::BytesMut;
use nativelink_config::cas_server::{AcStoreConfig, InstanceName};
Expand All @@ -31,14 +29,14 @@ use nativelink_store::grpc_store::GrpcStore;
use nativelink_store::store_manager::StoreManager;
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::make_ctx_for_hash_func;
use nativelink_util::store_trait::Store;
use nativelink_util::store_trait::{Store, StoreLike};
use prost::Message;
use tonic::{Request, Response, Status};
use tracing::{error_span, event, instrument, Level};

#[derive(Clone)]
pub struct AcStoreInfo {
store: Arc<dyn Store>,
store: Store,
read_only: bool,
}

Expand Down Expand Up @@ -97,14 +95,12 @@ impl AcServer {
.try_into()?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store_info.store.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store_info.store.downcast_ref::<GrpcStore>(Some(digest)) {
return grpc_store.get_action_result(Request::new(request)).await;
}

Ok(Response::new(
get_and_decode_digest::<ActionResult>(Pin::new(store_info.store.as_ref()), &digest)
.await?,
get_and_decode_digest::<ActionResult>(&store_info.store, &digest).await?,
))
}

Expand Down Expand Up @@ -132,8 +128,7 @@ impl AcServer {
.try_into()?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store_info.store.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store_info.store.downcast_ref::<GrpcStore>(Some(digest)) {
return grpc_store.update_action_result(Request::new(request)).await;
}

Expand All @@ -146,7 +141,8 @@ impl AcServer {
.encode(&mut store_data)
.err_tip(|| "Provided ActionResult could not be serialized")?;

Pin::new(store_info.store.as_ref())
store_info
.store
.update_oneshot(digest, store_data.freeze())
.await
.err_tip(|| "Failed to update in action cache")?;
Expand Down
25 changes: 11 additions & 14 deletions nativelink-service/src/bytestream_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use nativelink_util::digest_hasher::{
use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper;
use nativelink_util::resource_info::ResourceInfo;
use nativelink_util::spawn;
use nativelink_util::store_trait::{Store, UploadSizeInfo};
use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo};
use nativelink_util::task::JoinHandleDropGuard;
use parking_lot::Mutex;
use tokio::time::sleep;
Expand Down Expand Up @@ -153,7 +153,7 @@ type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>);
type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;

pub struct ByteStreamServer {
stores: HashMap<String, Arc<dyn Store>>,
stores: HashMap<String, Store>,
// Max number of bytes to send on each grpc stream chunk.
max_bytes_per_stream: usize,
active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>,
Expand Down Expand Up @@ -206,7 +206,7 @@ impl ByteStreamServer {
fn create_or_join_upload_stream(
&self,
uuid: String,
store: Arc<dyn Store>,
store: Store,
digest: DigestInfo,
) -> Result<ActiveStreamGuard<'_>, Error> {
let (uuid, bytes_received) = match self.active_uploads.lock().entry(uuid) {
Expand Down Expand Up @@ -236,7 +236,7 @@ impl ByteStreamServer {
let store_update_fut = Box::pin(async move {
// We need to wrap `Store::update()` in a another future because we need to capture
// `store` to ensure it's lifetime follows the future and not the caller.
Pin::new(store.as_ref())
store
// Bytestream always uses digest size as the actual byte size.
.update(
digest,
Expand All @@ -260,7 +260,7 @@ impl ByteStreamServer {

async fn inner_read(
&self,
store: Arc<dyn Store>,
store: Store,
digest: DigestInfo,
read_request: ReadRequest,
) -> Result<Response<ReadStream>, Error> {
Expand Down Expand Up @@ -289,7 +289,7 @@ impl ByteStreamServer {
maybe_get_part_result: None,
get_part_fut: Box::pin(async move {
store
.get_part_arc(digest, tx, read_request.read_offset as usize, read_limit)
.get_part(digest, tx, read_request.read_offset as usize, read_limit)
.await
}),
});
Expand Down Expand Up @@ -383,7 +383,7 @@ impl ByteStreamServer {
)]
async fn inner_write(
&self,
store: Arc<dyn Store>,
store: Store,
digest: DigestInfo,
stream: WriteRequestStreamWrapper<Streaming<WriteRequest>, Status>,
) -> Result<Response<WriteResponse>, Error> {
Expand Down Expand Up @@ -520,8 +520,7 @@ impl ByteStreamServer {
let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store_clone.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store_clone.downcast_ref::<GrpcStore>(Some(digest)) {
return grpc_store
.query_write_status(Request::new(query_request.clone()))
.await;
Expand All @@ -544,7 +543,7 @@ impl ByteStreamServer {
}
}

let has_fut = Pin::new(store_clone.as_ref()).has(digest);
let has_fut = store_clone.has(digest);
let Some(item_size) = has_fut.await.err_tip(|| "Failed to call .has() on store")? else {
return Err(make_err!(Code::NotFound, "{}", "not found"));
};
Expand Down Expand Up @@ -583,8 +582,7 @@ impl ByteStream for ByteStreamServer {
let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(Some(digest)) {
let stream = grpc_store.read(Request::new(read_request)).await?;
return Ok(Response::new(Box::pin(stream)));
}
Expand Down Expand Up @@ -640,8 +638,7 @@ impl ByteStream for ByteStreamServer {
.err_tip(|| "Invalid digest input in ByteStream::write")?;

// If we are a GrpcStore we shortcut here, as this is a special store.
let any_store = store.inner_store(Some(digest)).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(Some(digest)) {
return grpc_store.write(stream).await.map_err(|e| e.into());
}

Expand Down
27 changes: 11 additions & 16 deletions nativelink-service/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use std::collections::{HashMap, VecDeque};
use std::pin::Pin;
use std::sync::Arc;

use bytes::Bytes;
use futures::stream::{FuturesUnordered, Stream};
Expand All @@ -35,12 +34,12 @@ use nativelink_store::grpc_store::GrpcStore;
use nativelink_store::store_manager::StoreManager;
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::make_ctx_for_hash_func;
use nativelink_util::store_trait::Store;
use nativelink_util::store_trait::{Store, StoreLike};
use tonic::{Request, Response, Status};
use tracing::{error_span, event, instrument, Level};

pub struct CasServer {
stores: HashMap<String, Arc<dyn Store>>,
stores: HashMap<String, Store>,
}

type GetTreeStream = Pin<Box<dyn Stream<Item = Result<GetTreeResponse, Status>> + Send + 'static>>;
Expand Down Expand Up @@ -79,7 +78,7 @@ impl CasServer {
for digest in request.blob_digests.iter() {
requested_blobs.push(DigestInfo::try_from(digest.clone())?);
}
let sizes = Pin::new(store.as_ref())
let sizes = store
.has_many(&requested_blobs)
.await
.err_tip(|| "In find_missing_blobs")?;
Expand Down Expand Up @@ -109,12 +108,11 @@ impl CasServer {
// If we are a GrpcStore we shortcut here, as this is a special store.
// Note: We don't know the digests here, so we try perform a very shallow
// check to see if it's a grpc store.
let any_store = store.inner_store(None).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(None) {
return grpc_store.batch_update_blobs(Request::new(request)).await;
}

let store_pin = Pin::new(store.as_ref());
let store_ref = &store;
let update_futures: FuturesUnordered<_> = request
.requests
.into_iter()
Expand All @@ -133,7 +131,7 @@ impl CasServer {
size_bytes,
request_data.len()
);
let result = store_pin
let result = store_ref
.update_oneshot(digest_info, request_data)
.await
.err_tip(|| "Error writing to store");
Expand Down Expand Up @@ -165,19 +163,18 @@ impl CasServer {
// If we are a GrpcStore we shortcut here, as this is a special store.
// Note: We don't know the digests here, so we try perform a very shallow
// check to see if it's a grpc store.
let any_store = store.inner_store(None).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(None) {
return grpc_store.batch_read_blobs(Request::new(request)).await;
}

let store_pin = Pin::new(store.as_ref());
let store_ref = &store;
let read_futures: FuturesUnordered<_> = request
.digests
.into_iter()
.map(|digest| async move {
let digest_copy = DigestInfo::try_from(digest.clone())?;
// TODO(allada) There is a security risk here of someone taking all the memory on the instance.
let result = store_pin
let result = store_ref
.get_part_unchunked(digest_copy, 0, None)
.await
.err_tip(|| "Error reading from store");
Expand Down Expand Up @@ -223,15 +220,13 @@ impl CasServer {
// If we are a GrpcStore we shortcut here, as this is a special store.
// Note: We don't know the digests here, so we try perform a very shallow
// check to see if it's a grpc store.
let any_store = store.inner_store(None).as_any();
if let Some(grpc_store) = any_store.downcast_ref::<GrpcStore>() {
if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(None) {
let stream = grpc_store
.get_tree(Request::new(request))
.await?
.into_inner();
return Ok(Response::new(Box::pin(stream)));
}
let store_pin = Pin::new(store.as_ref());
let root_digest: DigestInfo = request
.root_digest
.err_tip(|| "Expected root_digest to exist in GetTreeRequest")?
Expand Down Expand Up @@ -260,7 +255,7 @@ impl CasServer {

while !deque.is_empty() {
let digest: DigestInfo = deque.pop_front().err_tip(|| "In VecDeque::pop_front")?;
let directory = get_and_decode_digest::<Directory>(store_pin, &digest)
let directory = get_and_decode_digest::<Directory>(&store, &digest)
.await
.err_tip(|| "Converting digest to Directory")?;
if digest == page_token_digest {
Expand Down
Loading

0 comments on commit 04beafd

Please sign in to comment.