Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Change stored_chunks return type #512

Merged
merged 3 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ hex = "0.4.3"
leb128 = "0.2.5"
ring = "0.16.20"
serde = "1.0.162"
serde_bytes = "0.11.9"
serde_bytes = "0.11.13"
serde_cbor = "0.11.2"
serde_json = "1.0.96"
serde_repr = "0.1.12"
Expand Down
13 changes: 11 additions & 2 deletions ic-utils/src/interfaces/management_canister.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
use crate::{call::AsyncCall, Canister};
use candid::{CandidType, Deserialize, Nat};
use ic_agent::{export::Principal, Agent};
use std::{convert::AsRef, fmt::Debug, ops::Deref};
use std::{convert::AsRef, ops::Deref};
use strum_macros::{AsRefStr, EnumString};

pub mod attributes;
pub mod builders;
mod serde_impls;
#[doc(inline)]
pub use builders::{
CreateCanisterBuilder, InstallBuilder, InstallChunkedCodeBuilder, InstallCodeBuilder,
Expand Down Expand Up @@ -139,6 +140,14 @@ pub struct DefiniteCanisterSettings {
#[derive(Clone, Debug, Deserialize, CandidType)]
pub struct UploadChunkResult {
/// The hash of the uploaded chunk.
#[serde(with = "serde_bytes")]
pub hash: ChunkHash,
}

/// The result of a [`ManagementCanister::stored_chunks`] call.
#[derive(Clone, Debug)]
pub struct ChunkInfo {
/// The hash of the stored chunk.
pub hash: ChunkHash,
}

Expand Down Expand Up @@ -367,7 +376,7 @@ impl<'agent> ManagementCanister<'agent> {
pub fn stored_chunks(
&self,
canister_id: &Principal,
) -> impl 'agent + AsyncCall<(Vec<ChunkHash>,)> {
) -> impl 'agent + AsyncCall<(Vec<ChunkInfo>,)> {
#[derive(CandidType)]
struct Argument<'a> {
canister_id: &'a Principal,
Expand Down
2 changes: 1 addition & 1 deletion ic-utils/src/interfaces/management_canister/builders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ impl<'agent: 'canister, 'canister: 'builder, 'builder> InstallBuilder<'agent, 'c
)
} else {
let (existing_chunks,) = self.canister.stored_chunks(&self.canister_id).call_and_wait().await?;
let existing_chunks = existing_chunks.into_iter().collect::<BTreeSet<_>>();
let existing_chunks = existing_chunks.into_iter().map(|c| c.hash).collect::<BTreeSet<_>>();
let to_upload_chunks_ordered = self.wasm.chunks(1024 * 1024).map(|x| (<[u8; 32]>::from(Sha256::digest(x)), x)).collect::<Vec<_>>();
let to_upload_chunks = to_upload_chunks_ordered.iter().map(|&(k, v)| (k, v)).collect::<BTreeMap<_, _>>();
let (new_chunks, setup) = if existing_chunks.iter().all(|hash| to_upload_chunks.contains_key(hash)) {
Expand Down
125 changes: 125 additions & 0 deletions ic-utils/src/interfaces/management_canister/serde_impls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::fmt::Formatter;

use super::ChunkInfo;
use candid::types::{CandidType, Type, TypeInner};
use serde::de::{Deserialize, Deserializer, Error, IgnoredAny, MapAccess, SeqAccess, Visitor};
use serde_bytes::ByteArray;
// ChunkInfo can be deserialized from both `blob` and `record { hash: blob }`.
// This impl can be removed when both mainnet and dfx no longer return `blob`.
impl CandidType for ChunkInfo {
fn _ty() -> Type {
Type(<_>::from(TypeInner::Unknown))
}
fn idl_serialize<S>(&self, _serializer: S) -> Result<(), S::Error>
where
S: candid::types::Serializer,
{
unimplemented!()
}
}
impl<'de> Deserialize<'de> for ChunkInfo {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(ChunkInfoVisitor)
}
}
struct ChunkInfoVisitor;
impl<'de> Visitor<'de> for ChunkInfoVisitor {
type Value = ChunkInfo;
fn expecting(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
formatter.write_str("blob or record {hash: blob}")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: Error,
{
// deserialize_any combined with visit_bytes produces an extra 6 byte for difficult reasons
let v = if v.len() == 33 && v[0] == 6 {
&v[1..]
} else {
v
};
Ok(ChunkInfo {
hash: v
.try_into()
.map_err(|_| E::invalid_length(v.len(), &"32 bytes"))?,
})
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut hash = [0; 32];
for (i, n) in hash.iter_mut().enumerate() {
*n = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(i, &"32 bytes"))?;
}
if seq.next_element::<IgnoredAny>()?.is_some() {
Err(A::Error::invalid_length(
seq.size_hint().unwrap_or(33),
&"32 bytes",
))
} else {
Ok(ChunkInfo { hash })
}
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
while let Some(k) = map.next_key::<Field>()? {
eprintln!("here");
if matches!(k, Field::Hash) {
return Ok(ChunkInfo {
hash: map.next_value::<ByteArray<32>>()?.into_array(),
});
} else {
map.next_value::<IgnoredAny>()?;
}
}
Err(A::Error::missing_field("hash"))
}
}
// Needed because candid cannot infer field names without specifying them in _ty()
enum Field {
Hash,
Other,
}
impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct FieldVisitor;
impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a field name")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
if v == "hash" {
Ok(Field::Hash)
} else {
Ok(Field::Other)
}
}
fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
where
E: Error,
{
if v == 1158164430 {
Ok(Field::Hash)
} else {
Ok(Field::Other)
}
}
}
Loading