Skip to content

Commit

Permalink
Add protoc plugin handling to the prost-build Config and basic plugin
Browse files Browse the repository at this point in the history
This refactors some of the private logic within prost-build's lib.rs so
that it is compatible with generating the logic necessary to run a
plugin for protoc or a build.rs file.  In both cases, the output can be
changed based on how the Config is built.  This adds the capability for
paramters passed in via protoc to a plugin to allow adjust this
configuration, but the mechanics of that are left as a TODO until the
design is solidified.
  • Loading branch information
dfreese committed May 17, 2020
1 parent 8025627 commit d306145
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 13 deletions.
6 changes: 6 additions & 0 deletions prost-build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ which = { version = "3", default-features = false }

[dev-dependencies]
env_logger = { version = "0.7", default-features = false }

[[bin]]
name = "protoc-gen-rust"
path = "src/bin/protoc-gen-rust.rs"
test = false
bench = false
32 changes: 32 additions & 0 deletions prost-build/src/bin/protoc-gen-rust.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
extern crate prost;
extern crate prost_types;

use crate::prost::Message;
use prost_types::compiler::{CodeGeneratorRequest, CodeGeneratorResponse};
use std::io::{Error, ErrorKind, Result};
use std::io::{Read, Write};

fn main() -> Result<()> {
let mut buf = Vec::new();
std::io::stdin().read_to_end(&mut buf)?;

let request = CodeGeneratorRequest::decode(&*buf).map_err(|error| {
Error::new(
ErrorKind::InvalidInput,
format!("invalid FileDescriptorSet: {}", error.to_string()),
)
})?;

let response: CodeGeneratorResponse = prost_build::Config::new().run_plugin(request);

let mut out = Vec::new();
response.encode(&mut out).map_err(|error| {
Error::new(
ErrorKind::InvalidInput,
format!("invalid FileDescriptorSet: {}", error.to_string()),
)
})?;
std::io::stdout().write_all(&out)?;

Ok(())
}
263 changes: 250 additions & 13 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ use std::process::Command;

use log::trace;
use prost::Message;
use prost_types::compiler::{CodeGeneratorRequest, CodeGeneratorResponse};
use prost_types::{FileDescriptorProto, FileDescriptorSet};

pub use crate::ast::{Comments, Method, Service};
Expand Down Expand Up @@ -176,6 +177,25 @@ pub trait ServiceGenerator {
///
/// The default implementation is empty and does nothing.
fn finalize_package(&mut self, _package: &str, _buf: &mut String) {}

/// Handles parameters passed into a plugin for the generator as a customization point
///
/// This function will be called once for every parameter that is passed into the plugin. The
/// service generator is not expected to handle any of the paramters. Parameters may be
/// intended to modify `Config`, but are provided here in case they are of use.
///
/// The generator should only return an error when parsing the value of a parameter name it
/// expected to handle is not parsable. If it has handled a value, it should return `Ok(true)`
/// and `Ok(false)` otherwise.
///
/// The default implementation ignores all parameters and returns false.
fn handle_plugin_parameter(
&mut self,
_name: &str,
_value: &str,
) -> std::result::Result<bool, String> {
Ok(false)
}
}

/// Configuration options for Protobuf code generation.
Expand Down Expand Up @@ -557,12 +577,29 @@ impl Config {
format!("invalid FileDescriptorSet: {}", error.to_string()),
)
})?;
let request = CodeGeneratorRequest {
file_to_generate: descriptor_set
.file
.iter()
.filter_map(|file| file.name.as_ref())
.map(|name| name.to_owned())
.collect(),
parameter: None,
proto_file: descriptor_set.file,
compiler_version: None,
};

let modules = self.generate(descriptor_set.file)?;
for (module, content) in modules {
let mut filename = module.join(".");
filename.push_str(".rs");
let response = self
.codegen(request)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;

for file in response
.file
.iter()
.filter(|file| file.name.is_some() && file.content.is_some())
{
let filename = file.name.as_ref().unwrap();
let content = file.content.as_ref().unwrap();
let output_path = target.join(&filename);

let previous_content = fs::read(&output_path);
Expand All @@ -577,20 +614,45 @@ impl Config {
fs::write(output_path, content)?;
}
}

Ok(())
}

fn generate(&mut self, files: Vec<FileDescriptorProto>) -> Result<HashMap<Module, String>> {
/// Handle the request protoc provides to plugins and generate an appropriate response.
///
/// The protoc plugin mechanism provides a way for an arbitrary executable to specify the code
/// generation for a set of protofiles. The plugin reads `CodeGeneratorRequest` and reports a
/// `CodeGeneratorResponse`. Errors for plugins are reported via the error field within the
/// `CodeGeneratorResponse`. Errors that are the fault of protoc are reported using a non-zero
/// exit code in the plugin. An example of using this can be seen in `protoc-gen-rust.rs`.
pub fn run_plugin(&mut self, request: CodeGeneratorRequest) -> CodeGeneratorResponse {
self.codegen(request)
.unwrap_or_else(|message| CodeGeneratorResponse {
error: Some(message),
file: Vec::new(),
})
}

fn codegen(
&mut self,
request: CodeGeneratorRequest,
) -> std::result::Result<CodeGeneratorResponse, String> {
self.handle_plugin_parameters(&request)?;

let mut modules = HashMap::new();
let mut packages = HashMap::new();
let files: HashMap<&String, &FileDescriptorProto> = request
.proto_file
.iter()
.filter_map(|x| x.name.as_ref().map(|name| (name, x)))
.collect();

let message_graph = MessageGraph::new(&files)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;
let message_graph = MessageGraph::new(&request.proto_file)?;
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)?;

for file in files {
for filename in request.file_to_generate {
let file: &FileDescriptorProto = files
.get(&filename)
.ok_or_else(|| "filename to generate not found in protofile field".to_string())?;
let module = self.module(&file);

// Only record packages that have services
Expand All @@ -599,7 +661,13 @@ impl Config {
}

let mut buf = modules.entry(module).or_insert_with(String::new);
CodeGenerator::generate(self, &message_graph, &extern_paths, file, &mut buf);
CodeGenerator::generate(
self,
&message_graph,
&extern_paths,
file.to_owned(),
&mut buf,
);
}

if let Some(ref mut service_generator) = self.service_generator {
Expand All @@ -609,7 +677,65 @@ impl Config {
}
}

Ok(modules)
Ok(CodeGeneratorResponse {
error: None,
file: modules
.iter()
.map(
|(module, content)| prost_types::compiler::code_generator_response::File {
name: {
let mut filename = module.join(".");
filename.push_str(".rs");
Some(filename)
},
insertion_point: None,
content: Some(content.to_owned()),
},
)
.collect(),
})
}

fn handle_plugin_parameters(
&mut self,
request: &CodeGeneratorRequest,
) -> std::result::Result<(), String> {
if request.parameter.is_none() {
return Ok(());
}
for param in request.parameter.as_ref().unwrap().split_whitespace() {
match param.find('=') {
Some(eq) => {
let name = &param[..eq];
let value = &param[eq + 1..];
let used_by_prost = self.handle_single_parameter(name, value)?;
let used_by_service = self.service_generator.as_mut().map_or_else(
|| Ok(false),
|ref mut gen| gen.handle_plugin_parameter(name, value),
)?;

if !used_by_prost && !used_by_service {
return Err(format!("Unrecognized parameter name \"{}\"", name));
}
}
None => {
return Err(format!(
"Invalid parameter \"{}\". Expected param_name=value",
param
));
}
}
}
Ok(())
}

fn handle_single_parameter(
&mut self,
_name: &str,
_value: &str,
) -> std::result::Result<bool, String> {
// TODO: provide a way to adjust Config via options parameters
Ok(false)
}

fn module(&self, file: &FileDescriptorProto) -> Module {
Expand Down Expand Up @@ -753,12 +879,19 @@ mod tests {
state: Rc<RefCell<MockState>>,
}

#[derive(Debug, Default, PartialEq)]
struct Parameter {
name: String,
value: String,
}

/// Holds state for `MockServiceGenerator`
#[derive(Default)]
struct MockState {
service_names: Vec<String>,
package_names: Vec<String>,
finalized: u32,
parameters: Vec<Parameter>,
}

impl MockServiceGenerator {
Expand All @@ -782,6 +915,19 @@ mod tests {
let mut state = self.state.borrow_mut();
state.package_names.push(package.to_string());
}

fn handle_plugin_parameter(
&mut self,
name: &str,
value: &str,
) -> std::result::Result<bool, String> {
let mut state = self.state.borrow_mut();
state.parameters.push(Parameter {
name: name.to_owned(),
value: value.to_owned(),
});
Ok(true)
}
}

#[test]
Expand Down Expand Up @@ -809,5 +955,96 @@ mod tests {
assert_eq!(&state.service_names, &["Greeting", "Farewell"]);
assert_eq!(&state.package_names, &["helloworld"]);
assert_eq!(state.finalized, 3);
assert!(state.parameters.is_empty());
}

#[test]
fn no_parameters() {
let _ = env_logger::try_init();

let state = Rc::new(RefCell::new(MockState::default()));
let gen = MockServiceGenerator::new(Rc::clone(&state));

let response =
Config::new()
.service_generator(Box::new(gen))
.run_plugin(CodeGeneratorRequest {
file_to_generate: Vec::new(),
parameter: None,
proto_file: Vec::new(),
compiler_version: None,
});

let state = state.borrow();
assert!(&state.service_names.is_empty());
assert!(&state.package_names.is_empty());
assert_eq!(state.finalized, 0);
assert!(state.parameters.is_empty());

assert_eq!(response.error, None);
assert!(response.file.is_empty());
}

#[test]
fn valid_parameter() {
let _ = env_logger::try_init();

let state = Rc::new(RefCell::new(MockState::default()));
let gen = MockServiceGenerator::new(Rc::clone(&state));

let response =
Config::new()
.service_generator(Box::new(gen))
.run_plugin(CodeGeneratorRequest {
file_to_generate: Vec::new(),
parameter: Some("service_quirk=0".to_string()),
proto_file: Vec::new(),
compiler_version: None,
});

let state = state.borrow();
assert!(&state.service_names.is_empty());
assert!(&state.package_names.is_empty());
assert_eq!(state.finalized, 0);
assert_eq!(
state.parameters,
&[Parameter {
name: "service_quirk".to_string(),
value: "0".to_string(),
}]
);

assert_eq!(response.error, None);
assert!(response.file.is_empty());
}

#[test]
fn invalid_parameter() {
let _ = env_logger::try_init();

let state = Rc::new(RefCell::new(MockState::default()));
let gen = MockServiceGenerator::new(Rc::clone(&state));

let response =
Config::new()
.service_generator(Box::new(gen))
.run_plugin(CodeGeneratorRequest {
file_to_generate: Vec::new(),
parameter: Some("service_quirk".to_string()),
proto_file: Vec::new(),
compiler_version: None,
});

let state = state.borrow();
assert!(&state.service_names.is_empty());
assert!(&state.package_names.is_empty());
assert_eq!(state.finalized, 0);
assert!(&state.parameters.is_empty());

assert_eq!(
response.error,
Some("Invalid parameter \"service_quirk\". Expected param_name=value".to_string())
);
assert!(response.file.is_empty());
}
}

0 comments on commit d306145

Please sign in to comment.