Skip to content

Commit

Permalink
fix: protobuf oneof and nested messages (cloudwego#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
LYF1999 committed Apr 27, 2023
1 parent 5c6c635 commit 77f08e8
Show file tree
Hide file tree
Showing 12 changed files with 910 additions and 671 deletions.
171 changes: 110 additions & 61 deletions pilota-build/src/codegen/protobuf/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use faststr::FastStr;
use itertools::Itertools;
use proc_macro2::{Ident, Span, TokenStream};
use proc_macro2::{Ident, Span};
use quote::quote;

use crate::{
Expand Down Expand Up @@ -62,12 +62,22 @@ impl ProtobufBackend {
)
.into()
} else {
let encoded_len: FastStr = if self.is_one_of(ty) {
"msg.encoded_len()".into()
} else {
let ident: FastStr = match kind {
FieldKind::Required => format!("&{ident}").into(),
FieldKind::Optional => "msg".into(),
};
format!("::pilota::prost::encoding::message::encoded_len({tag}, {ident})")
.into()
};

match kind {
FieldKind::Required => format!(
"::pilota::prost::encoding::message::encoded_len({tag}, &{ident})"
)
.into(),
FieldKind::Optional => format!("{ident}.as_ref().map_or(0, |msg| ::pilota::prost::encoding::message::encoded_len({tag}, msg))").into(),
FieldKind::Required => format!("{encoded_len}").into(),
FieldKind::Optional => {
format!("{ident}.as_ref().map_or(0, |msg| {encoded_len})").into()
}
}
}
}
Expand Down Expand Up @@ -102,6 +112,26 @@ impl ProtobufBackend {
false
}

fn is_one_of_item(&self, def_id: DefId) -> bool {
let node = self.cx.node(def_id).unwrap();
if let NodeKind::Item(item) = node.kind {
if let Item::Enum(_) = &*item {
if self.cx.contains_tag::<OneOf>(node.tags) {
return true;
}
}
}
false
}

fn is_one_of(&self, ty: &Ty) -> bool {
let mut ty = ty;
if let ty::TyKind::Vec(inner) = &ty.kind {
ty = inner;
}
matches!(&ty.kind, ty::TyKind::Path(p) if self.is_one_of_item(p.did))
}

fn ty_category(&self, ty: &Ty) -> Category {
let mut ty = ty;
if let ty::TyKind::Vec(inner) = &ty.kind {
Expand Down Expand Up @@ -199,15 +229,23 @@ impl ProtobufBackend {
)
.into()
} else {
let encode: FastStr = if self.is_one_of(ty) {
"pilota_inner_value.encode(buf);".into()
} else {
let ident: FastStr = match kind {
FieldKind::Required => format!("(&{ident})").into(),
FieldKind::Optional => "_pilota_inner_value".into(),
};
format!("::pilota::prost::encoding::message::encode({tag}, {ident}, buf);")
.into()
};

match kind {
FieldKind::Required => format!(
"::pilota::prost::encoding::message::encode({tag}, &{ident}, buf);"
).into(),
FieldKind::Optional => format!(r#"
if let Some(_pilota_inner_value) = {ident}.as_ref() {{
::pilota::prost::encoding::message::encode({tag}, _pilota_inner_value, buf);
}}
"#).into(),
FieldKind::Required => encode,
FieldKind::Optional => format!(
r#"if let Some(_pilota_inner_value) = {ident}.as_ref() {{ {encode} }}"#
)
.into(),
}
}
}
Expand Down Expand Up @@ -255,23 +293,35 @@ impl ProtobufBackend {
Some(field.id as u32).into_iter().chain(vec![])
}

fn codegen_merge_field(&self, ident: TokenStream, ty: &Ty, kind: FieldKind) -> TokenStream {
fn codegen_merge_field(&self, ident: FastStr, ty: &Ty, kind: FieldKind) -> FastStr {
match self.ty_category(ty) {
Category::Scalar | Category::Message => {
let merge_fn = match ty.kind {
ty::TyKind::Vec(_) => quote!(merge_repeated),
_ => quote!(merge),
};

let module = self.ty_module(ty);
let merge_fn = quote!(::pilota::prost::encoding::#module::#merge_fn);
if self.is_one_of(ty) {
let did = match &ty.kind {
ty::TyKind::Path(p) => p.did,
_ => unreachable!(),
};

match kind {
FieldKind::Required => quote!(#merge_fn(wire_type, #ident, buf, ctx)),
FieldKind::Optional => quote!(#merge_fn(wire_type,
#ident.get_or_insert_with(::core::default::Default::default),
buf,
ctx)),
let path = self.cx.cur_related_item_path(did);
format!("{path}::merge(&mut {ident}, tag, wire_type, buf, ctx)").into()
} else {
let module = self.ty_module(ty);
let merge_fn = format!("::pilota::prost::encoding::{module}::{merge_fn}");

match kind {
FieldKind::Required => {
format!("{merge_fn}(wire_type, {ident}, buf, ctx)").into()
}
FieldKind::Optional => format!(
r#"{merge_fn}(wire_type, {ident}.get_or_insert_with(::core::default::Default::default), buf, ctx)"#
)
.into(),
}
}
}
Category::Map => {
Expand All @@ -283,12 +333,10 @@ impl ProtobufBackend {
let key_mod = self.ty_module(key_ty);
let value_mod = self.ty_module(value_ty);

let key_merge_fn = quote!(::pilota::prost::encoding::#key_mod::merge);
let value_merge_fn = quote!(::pilota::prost::encoding::#value_mod::merge);
let key_merge_fn = format!("::pilota::prost::encoding::{key_mod}::merge");
let value_merge_fn = format!("::pilota::prost::encoding::{value_mod}::merge");

quote! {
::pilota::prost::encoding::hash_map::merge(#key_merge_fn, #value_merge_fn, &mut #ident, buf, ctx)
}
format!("::pilota::prost::encoding::hash_map::merge({key_merge_fn}, {value_merge_fn}, &mut {ident}, buf, ctx)").into()
}
}
}
Expand Down Expand Up @@ -330,7 +378,7 @@ impl CodegenBackend for ProtobufBackend {
.map(|field| {
let field_ident = self.cx.rust_name(field.did);
let merge =
self.codegen_merge_field(quote!(_inner_pilota_value), &field.ty, field.kind);
self.codegen_merge_field("_inner_pilota_value".into(), &field.ty, field.kind);
let mut tags = self.field_tags(field).map(|tag| tag.to_string());
let tags = tags.join("|");

Expand Down Expand Up @@ -429,49 +477,50 @@ impl CodegenBackend for ProtobufBackend {
})
.join(",");

let merge = e
.variants
.iter()
.map(|variant| {
let tag = variant.id.unwrap() as u32;
let variant_name = self.cx.rust_name(variant.did);
let merge = self.codegen_merge_field(
quote! {value},
variant.fields.first().unwrap(),
FieldKind::Required,
);
format! {
r#"{tag} => {{
let mut owned_value = ::core::default::Default::default();
let value = &mut owned_value;
{merge}?;
*self = {name}::{variant_name}(owned_value);
Ok(())
}}"#
}
})
.join(",");
let merge = e.variants.iter().map(|variant| {
let tag = variant.id.unwrap() as u32;
let variant_name = self.cx.rust_name(variant.did);
let merge = self.codegen_merge_field(
"value".into(),
variant.fields.first().unwrap(),
FieldKind::Required,
);
format! {
r#"{tag} => {{
match field {{
::core::option::Option::Some({name}::{variant_name}(ref mut value)) => {{
{merge}?;
}},
_ => {{
let mut owned_value = ::core::default::Default::default();
let value = &mut owned_value;
{merge}?;
*field = ::core::option::Option::Some({name}::{variant_name}(owned_value));
}},
}}
}},"#
}
}).join("");

stream.push_str(&format! {
r#"
impl ::pilota::prost::Message for {name} {{
fn encode_raw<B>(&self, buf: &mut B) where B: ::pilota::prost::bytes::BufMut {{
r#"impl {name} {{
pub fn encode<B>(&self, buf: &mut B) where B: ::pilota::prost::bytes::BufMut {{
match self {{
{encode}
}}
}}
/// Returns the encoded length of the message without a length delimiter.
#[inline]
fn encoded_len(&self) -> usize {{
pub fn encoded_len(&self) -> usize {{
match self {{
{encoded_len}
}}
}}
/// Decodes an instance of the message from a buffer, and merges it into self.
fn merge_field<B>(
&mut self,
pub fn merge<B>(
field: &mut ::core::option::Option<Self>,
tag: u32,
wire_type: ::pilota::prost::encoding::WireType,
buf: &mut B,
Expand All @@ -480,11 +529,11 @@ impl CodegenBackend for ProtobufBackend {
where B: ::pilota::prost::bytes::Buf {{
match tag {{
{merge}
_ => unreachable!(concat!("invalid ", "{name}", " tag: {{}}"), tag),
}}
_ => unreachable!(concat!("invalid ", stringify!({name}), " tag: {{}}"), tag),
}};
Ok(())
}}
}}
"#
}}"#
});
}

Expand Down
3 changes: 2 additions & 1 deletion pilota-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use std::{

// mod dedup;
pub mod plugin;
mod test;

pub use codegen::{
protobuf::ProtobufBackend, thrift::ThriftBackend, traits::CodegenBackend, Codegen,
Expand Down Expand Up @@ -343,3 +342,5 @@ where
.unwrap();
}
}

mod test;
6 changes: 2 additions & 4 deletions pilota-build/src/middle/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,11 @@ impl Context {
.try_collect::<_, Vec<_>, _>()?
.join("");
anyhow::Ok(
format! {
r#"{{
format! {r#"{{
let mut map = ::std::collections::HashMap::with_capacity({len});
{kvs}
map
}}
"#}
}}"#}
.into(),
)
};
Expand Down
Loading

0 comments on commit 77f08e8

Please sign in to comment.