From db5ba93791edfd3cafc26a5cd265900c87754167 Mon Sep 17 00:00:00 2001 From: Richard Dodd Date: Tue, 19 Feb 2019 13:06:50 +0000 Subject: [PATCH 1/2] First commit for review. --- contrib/codegen/build.rs | 63 +- contrib/codegen/src/database.rs | 21 +- contrib/codegen/src/lib.rs | 7 +- contrib/codegen/tests/compile-test.rs | 66 +- contrib/lib/Cargo.toml | 1 + contrib/lib/src/cors.rs | 442 ++++++++++ contrib/lib/src/databases.rs | 164 ++-- contrib/lib/src/helmet/helmet.rs | 4 +- contrib/lib/src/helmet/policy.rs | 18 +- contrib/lib/src/json.rs | 26 +- contrib/lib/src/lib.rs | 40 +- contrib/lib/src/msgpack.rs | 30 +- contrib/lib/src/serve.rs | 16 +- contrib/lib/src/templates/context.rs | 38 +- contrib/lib/src/templates/engine.rs | 27 +- contrib/lib/src/templates/fairing.rs | 15 +- contrib/lib/src/templates/metadata.rs | 5 +- contrib/lib/src/templates/mod.rs | 91 +- contrib/lib/src/templates/tera_templates.rs | 12 +- contrib/lib/src/uuid.rs | 6 +- contrib/lib/tests/databases.rs | 12 +- contrib/lib/tests/helmet.rs | 37 +- contrib/lib/tests/static_files.rs | 28 +- contrib/lib/tests/templates.rs | 43 +- core/codegen/build.rs | 63 +- core/codegen/src/attribute/catch.rs | 46 +- core/codegen/src/attribute/route.rs | 114 ++- core/codegen/src/attribute/segments.rs | 68 +- core/codegen/src/bang/mod.rs | 24 +- core/codegen/src/bang/uri.rs | 73 +- core/codegen/src/bang/uri_parsing.rs | 81 +- core/codegen/src/derive/from_form.rs | 111 ++- core/codegen/src/derive/from_form_value.rs | 71 +- core/codegen/src/derive/responder.rs | 97 ++- core/codegen/src/derive/uri_display.rs | 76 +- core/codegen/src/http_codegen.rs | 84 +- core/codegen/src/lib.rs | 26 +- core/codegen/src/proc_macro_ext.rs | 6 +- core/codegen/tests/compile-test.rs | 64 +- core/codegen/tests/expansion.rs | 8 +- core/codegen/tests/from_form.rs | 288 +++--- core/codegen/tests/from_form_value.rs | 22 +- core/codegen/tests/responder.rs | 47 +- core/codegen/tests/route-data.rs | 22 +- core/codegen/tests/route-format.rs | 72 +- core/codegen/tests/route-params.rs | 1 - core/codegen/tests/route-ranking.rs | 25 +- core/codegen/tests/route.rs | 86 +- core/codegen/tests/typed-uris.rs | 65 +- core/codegen/tests/uri_display.rs | 137 ++- core/http/src/accept.rs | 25 +- core/http/src/content_type.rs | 23 +- core/http/src/cookies.rs | 33 +- core/http/src/ext.rs | 12 +- core/http/src/header.rs | 77 +- core/http/src/hyper.rs | 36 +- core/http/src/known_media_types.rs | 202 ++--- core/http/src/lib.rs | 37 +- core/http/src/media_type.rs | 54 +- core/http/src/method.rs | 2 +- core/http/src/parse/accept.rs | 38 +- core/http/src/parse/checkers.rs | 5 +- core/http/src/parse/indexed.rs | 55 +- core/http/src/parse/media_type.rs | 30 +- core/http/src/parse/mod.rs | 7 +- core/http/src/parse/uri/error.rs | 46 +- core/http/src/parse/uri/mod.rs | 17 +- core/http/src/parse/uri/parser.rs | 20 +- core/http/src/parse/uri/tables.rs | 106 +-- core/http/src/parse/uri/tests.rs | 2 +- core/http/src/raw_str.rs | 10 +- core/http/src/route.rs | 35 +- core/http/src/status.rs | 8 +- core/http/src/tls.rs | 4 +- core/http/src/uncased.rs | 47 +- core/http/src/uri/absolute.rs | 12 +- core/http/src/uri/authority.rs | 29 +- core/http/src/uri/encoding.rs | 25 +- core/http/src/uri/formatter.rs | 11 +- core/http/src/uri/from_uri_param.rs | 7 +- core/http/src/uri/mod.rs | 20 +- core/http/src/uri/origin.rs | 45 +- core/http/src/uri/segments.rs | 17 +- core/http/src/uri/uri.rs | 24 +- core/http/src/uri/uri_display.rs | 24 +- core/lib/benches/format-routing.rs | 33 +- core/lib/benches/ranked-routing.rs | 33 +- core/lib/benches/simple-routing.rs | 54 +- core/lib/build.rs | 63 +- core/lib/src/catcher.rs | 54 +- core/lib/src/codegen.rs | 6 +- core/lib/src/config/builder.rs | 6 +- core/lib/src/config/config.rs | 168 ++-- core/lib/src/config/custom_values.rs | 59 +- core/lib/src/config/environment.rs | 2 +- core/lib/src/config/error.rs | 84 +- core/lib/src/config/mod.rs | 818 ++++++++++++------ core/lib/src/config/toml_ext.rs | 94 +- core/lib/src/data/data.rs | 30 +- core/lib/src/data/data_stream.rs | 10 +- core/lib/src/data/from_data.rs | 25 +- core/lib/src/data/mod.rs | 2 +- core/lib/src/data/net_stream.rs | 29 +- core/lib/src/error.rs | 33 +- core/lib/src/ext.rs | 7 +- core/lib/src/fairing/ad_hoc.rs | 51 +- core/lib/src/fairing/fairings.rs | 31 +- core/lib/src/fairing/info_kind.rs | 2 +- core/lib/src/fairing/mod.rs | 10 +- core/lib/src/handler.rs | 9 +- core/lib/src/lib.rs | 75 +- core/lib/src/local/client.rs | 30 +- core/lib/src/local/mod.rs | 4 +- core/lib/src/local/request.rs | 33 +- core/lib/src/logger.rs | 81 +- core/lib/src/outcome.rs | 26 +- core/lib/src/request/form/error.rs | 4 +- core/lib/src/request/form/form.rs | 22 +- core/lib/src/request/form/form_items.rs | 58 +- core/lib/src/request/form/from_form_value.rs | 31 +- core/lib/src/request/form/lenient.rs | 9 +- core/lib/src/request/form/mod.rs | 10 +- core/lib/src/request/from_request.rs | 21 +- core/lib/src/request/mod.rs | 16 +- core/lib/src/request/param.rs | 22 +- core/lib/src/request/query.rs | 2 +- core/lib/src/request/request.rs | 148 ++-- core/lib/src/request/state.rs | 6 +- core/lib/src/request/tests.rs | 4 +- core/lib/src/response/content.rs | 4 +- core/lib/src/response/flash.rs | 36 +- core/lib/src/response/mod.rs | 15 +- core/lib/src/response/named_file.rs | 4 +- core/lib/src/response/redirect.rs | 150 ++-- core/lib/src/response/responder.rs | 29 +- core/lib/src/response/response.rs | 100 ++- core/lib/src/response/status.rs | 16 +- core/lib/src/response/stream.rs | 6 +- core/lib/src/rocket.rs | 126 +-- core/lib/src/router/collider.rs | 80 +- core/lib/src/router/mod.rs | 60 +- core/lib/src/router/route.rs | 62 +- .../lib/tests/absolute-uris-okay-issue-443.rs | 3 +- .../fairing_before_head_strip-issue-546.rs | 5 +- .../lib/tests/flash-lazy-removes-issue-466.rs | 5 +- core/lib/tests/form_method-issue-45.rs | 11 +- .../lib/tests/form_value_decoding-issue-82.rs | 8 +- core/lib/tests/head_handling.rs | 9 +- core/lib/tests/limits.rs | 21 +- .../local-request-content-type-issue-505.rs | 56 +- .../local_request_private_cookie-issue-368.rs | 10 +- core/lib/tests/nested-fairing-attaches.rs | 8 +- .../tests/precise-content-type-matching.rs | 13 +- .../tests/redirect_from_catcher-issue-113.rs | 5 +- .../lib/tests/responder_lifetime-issue-345.rs | 15 +- core/lib/tests/route_guard.rs | 10 +- core/lib/tests/segments-issues-41-86.rs | 18 +- core/lib/tests/strict_and_lenient_forms.rs | 19 +- .../tests/uri-percent-encoding-issue-808.rs | 8 +- examples/config/tests/common/mod.rs | 6 +- examples/config/tests/development.rs | 3 +- examples/config/tests/production.rs | 3 +- examples/config/tests/staging.rs | 3 +- examples/content_types/src/main.rs | 29 +- examples/content_types/src/tests.rs | 69 +- examples/cookies/src/main.rs | 9 +- examples/cookies/src/tests.rs | 10 +- examples/errors/src/main.rs | 12 +- examples/errors/src/tests.rs | 14 +- examples/fairings/src/main.rs | 14 +- examples/form_kitchen_sink/src/main.rs | 14 +- examples/form_kitchen_sink/src/tests.rs | 59 +- examples/form_validation/src/main.rs | 13 +- examples/form_validation/src/tests.rs | 51 +- examples/handlebars_templates/src/main.rs | 51 +- examples/handlebars_templates/src/tests.rs | 68 +- examples/hello_2018/src/main.rs | 3 +- examples/hello_2018/src/tests.rs | 18 +- examples/hello_person/src/main.rs | 6 +- examples/hello_person/src/tests.rs | 8 +- examples/hello_world/src/main.rs | 6 +- examples/json/src/main.rs | 18 +- examples/json/src/tests.rs | 41 +- examples/managed_queue/src/main.rs | 8 +- examples/managed_queue/src/tests.rs | 2 +- examples/manual_routes/src/main.rs | 21 +- examples/manual_routes/src/tests.rs | 15 +- examples/msgpack/src/main.rs | 16 +- examples/msgpack/src/tests.rs | 22 +- examples/optional_redirect/src/main.rs | 9 +- examples/optional_redirect/src/tests.rs | 12 +- examples/pastebin/src/main.rs | 10 +- examples/pastebin/src/paste_id.rs | 16 +- examples/pastebin/src/tests.rs | 8 +- examples/query_params/src/main.rs | 8 +- examples/query_params/src/tests.rs | 54 +- examples/ranking/src/main.rs | 6 +- examples/ranking/src/tests.rs | 25 +- examples/raw_sqlite/src/main.rs | 38 +- examples/raw_upload/src/main.rs | 11 +- examples/raw_upload/src/tests.rs | 10 +- examples/redirect/src/main.rs | 6 +- examples/redirect/src/tests.rs | 7 +- examples/request_guard/src/main.rs | 7 +- examples/request_local_state/src/main.rs | 8 +- examples/request_local_state/src/tests.rs | 2 +- examples/session/src/main.rs | 29 +- examples/session/src/tests.rs | 10 +- examples/state/src/main.rs | 8 +- examples/state/src/tests.rs | 53 +- examples/static_files/src/main.rs | 3 +- examples/static_files/src/tests.rs | 10 +- examples/stream/src/main.rs | 8 +- examples/stream/src/tests.rs | 3 +- examples/tera_templates/src/main.rs | 20 +- examples/tera_templates/src/tests.rs | 64 +- examples/testing/src/main.rs | 5 +- examples/tls/src/main.rs | 6 +- examples/todo/src/main.rs | 65 +- examples/todo/src/task.rs | 27 +- examples/todo/src/tests.rs | 34 +- examples/uuid/src/main.rs | 14 +- examples/uuid/src/tests.rs | 23 +- 223 files changed, 5548 insertions(+), 3044 deletions(-) create mode 100644 contrib/lib/src/cors.rs diff --git a/contrib/codegen/build.rs b/contrib/codegen/build.rs index 8b79857060..345628b4e0 100644 --- a/contrib/codegen/build.rs +++ b/contrib/codegen/build.rs @@ -1,11 +1,11 @@ //! This tiny build script ensures that rocket is not compiled with an //! incompatible version of rust. -extern crate yansi; extern crate version_check; +extern crate yansi; -use yansi::Color::{Red, Yellow, Blue}; -use version_check::{supports_features, is_min_version, is_min_date}; +use version_check::{is_min_date, is_min_version, supports_features}; +use yansi::Color::{Blue, Red, Yellow}; // Specifies the minimum nightly version needed to compile Rocket. const MIN_DATE: &'static str = "2018-10-05"; @@ -18,38 +18,55 @@ fn main() { let triple = (ok_channel, ok_version, ok_date); let print_version_err = |version: &str, date: &str| { - eprintln!("{} {}. {} {}.", - "Installed version is:", - Yellow.paint(format!("{} ({})", version, date)), - "Minimum required:", - Yellow.paint(format!("{} ({})", MIN_VERSION, MIN_DATE))); + eprintln!( + "{} {}. {} {}.", + "Installed version is:", + Yellow.paint(format!("{} ({})", version, date)), + "Minimum required:", + Yellow.paint(format!("{} ({})", MIN_VERSION, MIN_DATE)) + ); }; if let (Some(ok_channel), Some((ok_version, version)), Some((ok_date, date))) = triple { if !ok_channel { - eprintln!("{} {}", - Red.paint("Error:").bold(), - "Rocket requires a nightly or dev version of Rust."); + eprintln!( + "{} {}", + Red.paint("Error:").bold(), + "Rocket requires a nightly or dev version of Rust." + ); print_version_err(&*version, &*date); - eprintln!("{}{}{}", - Blue.paint("See the getting started guide ("), - "https://rocket.rs/v0.4/guide/getting-started/", - Blue.paint(") for more information.")); + eprintln!( + "{}{}{}", + Blue.paint("See the getting started guide ("), + "https://rocket.rs/v0.4/guide/getting-started/", + Blue.paint(") for more information.") + ); panic!("Aborting compilation due to incompatible compiler.") } if !ok_version || !ok_date { - eprintln!("{} {}", - Red.paint("Error:").bold(), - "Rocket requires a more recent version of rustc."); - eprintln!("{}{}{}", - Blue.paint("Use `"), "rustup update", - Blue.paint("` or your preferred method to update Rust.")); + eprintln!( + "{} {}", + Red.paint("Error:").bold(), + "Rocket requires a more recent version of rustc." + ); + eprintln!( + "{}{}{}", + Blue.paint("Use `"), + "rustup update", + Blue.paint("` or your preferred method to update Rust.") + ); print_version_err(&*version, &*date); panic!("Aborting compilation due to incompatible compiler.") } } else { - println!("cargo:warning={}", "Rocket was unable to check rustc compatibility."); - println!("cargo:warning={}", "Build may fail due to incompatible rustc version."); + println!( + "cargo:warning={}", + "Rocket was unable to check rustc compatibility." + ); + println!( + "cargo:warning={}", + "Build may fail due to incompatible rustc version." + ); } } diff --git a/contrib/codegen/src/database.rs b/contrib/codegen/src/database.rs index 2bec14bfa6..7277c5ca0e 100644 --- a/contrib/codegen/src/database.rs +++ b/contrib/codegen/src/database.rs @@ -1,6 +1,6 @@ +use devise::{Result, Spanned}; use proc_macro::TokenStream; -use devise::{Spanned, Result}; -use syn::{DataStruct, Fields, Data, Type, LitStr, DeriveInput, Ident, Visibility}; +use syn::{Data, DataStruct, DeriveInput, Fields, Ident, LitStr, Type, Visibility}; #[derive(Debug)] struct DatabaseInvocation { @@ -19,9 +19,9 @@ struct DatabaseInvocation { const EXAMPLE: &str = "example: `struct MyDatabase(diesel::SqliteConnection);`"; const ONLY_ON_STRUCTS_MSG: &str = "`database` attribute can only be used on structs"; const ONLY_UNNAMED_FIELDS: &str = "`database` attribute can only be applied to \ - structs with exactly one unnamed field"; + structs with exactly one unnamed field"; const NO_GENERIC_STRUCTS: &str = "`database` attribute cannot be applied to structs \ - with generics"; + with generics"; fn parse_invocation(attr: TokenStream, input: TokenStream) -> Result { let attr_stream2 = ::proc_macro2::TokenStream::from(attr); @@ -36,7 +36,7 @@ fn parse_invocation(attr: TokenStream, input: TokenStream) -> Result s, - _ => return Err(input.span().error(ONLY_ON_STRUCTS_MSG)) + _ => return Err(input.span().error(ONLY_ON_STRUCTS_MSG)), }; let inner_type = match structure.fields { @@ -44,7 +44,13 @@ fn parse_invocation(attr: TokenStream, input: TokenStream) -> Result return Err(structure.fields.span().error(ONLY_UNNAMED_FIELDS).help(EXAMPLE)) + _ => { + return Err(structure + .fields + .span() + .error(ONLY_UNNAMED_FIELDS) + .help(EXAMPLE)); + } }; Ok(DatabaseInvocation { @@ -153,5 +159,6 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result &'static str { match self { - #[cfg(windows)] Kind::Dynamic => ".dll", - #[cfg(all(unix, target_os = "macos"))] Kind::Dynamic => ".dylib", - #[cfg(all(unix, not(target_os = "macos")))] Kind::Dynamic => ".so", - Kind::Static => ".rlib" + #[cfg(windows)] + Kind::Dynamic => ".dll", + #[cfg(all(unix, target_os = "macos"))] + Kind::Dynamic => ".dylib", + #[cfg(all(unix, not(target_os = "macos")))] + Kind::Dynamic => ".so", + Kind::Static => ".rlib", } } } fn target_path() -> PathBuf { - #[cfg(debug_assertions)] const ENVIRONMENT: &str = "debug"; - #[cfg(not(debug_assertions))] const ENVIRONMENT: &str = "release"; + #[cfg(debug_assertions)] + const ENVIRONMENT: &str = "debug"; + #[cfg(not(debug_assertions))] + const ENVIRONMENT: &str = "release"; Path::new(env!("CARGO_MANIFEST_DIR")) - .parent().unwrap().parent().unwrap() + .parent() + .unwrap() + .parent() + .unwrap() .join("target") .join(ENVIRONMENT) } @@ -41,7 +49,8 @@ fn link_flag(flag: &str, lib: &str, rel_path: &[&str]) -> String { } fn best_time_for(metadata: &Metadata) -> SystemTime { - metadata.created() + metadata + .created() .or_else(|_| metadata.modified()) .or_else(|_| metadata.accessed()) .unwrap_or_else(|_| SystemTime::now()) @@ -55,12 +64,18 @@ fn extern_dep(name: &str, kind: Kind) -> io::Result { for entry in deps_root.read_dir().expect("read_dir call failed") { let entry = match entry { Ok(entry) => entry, - Err(_) => continue + Err(_) => continue, }; let filename = entry.file_name(); let filename = filename.to_string_lossy(); - let lib_name = filename.split('.').next().unwrap().split('-').next().unwrap(); + let lib_name = filename + .split('.') + .next() + .unwrap() + .split('-') + .next() + .unwrap(); if lib_name == dep_name && filename.ends_with(kind.extension()) { if let Some(ref mut existing) = dep_path { @@ -74,8 +89,14 @@ fn extern_dep(name: &str, kind: Kind) -> io::Result { } let dep = dep_path.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; - let filename = dep.file_name().ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?; - Ok(link_flag("--extern", name, &["deps", &filename.to_string_lossy()])) + let filename = dep + .file_name() + .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?; + Ok(link_flag( + "--extern", + name, + &["deps", &filename.to_string_lossy()], + )) } fn run_mode(mode: &'static str, path: &'static str) { @@ -84,13 +105,16 @@ fn run_mode(mode: &'static str, path: &'static str) { config.src_base = format!("tests/{}", path).into(); config.clean_rmeta(); - config.target_rustcflags = Some([ - link_flag("-L", "crate", &[]), - link_flag("-L", "dependency", &["deps"]), - extern_dep("rocket_http", Kind::Static).expect("find http dep"), - extern_dep("rocket", Kind::Static).expect("find core dep"), - extern_dep("rocket_contrib", Kind::Static).expect("find contrib dep"), - ].join(" ")); + config.target_rustcflags = Some( + [ + link_flag("-L", "crate", &[]), + link_flag("-L", "dependency", &["deps"]), + extern_dep("rocket_http", Kind::Static).expect("find http dep"), + extern_dep("rocket", Kind::Static).expect("find core dep"), + extern_dep("rocket_contrib", Kind::Static).expect("find contrib dep"), + ] + .join(" "), + ); compiletest::run_tests(&config); } diff --git a/contrib/lib/Cargo.toml b/contrib/lib/Cargo.toml index a7b08f9d56..9cc2a605f0 100644 --- a/contrib/lib/Cargo.toml +++ b/contrib/lib/Cargo.toml @@ -23,6 +23,7 @@ tera_templates = ["tera", "templates"] handlebars_templates = ["handlebars", "templates"] helmet = ["time"] serve = [] +cors = [] # The barage of user-facing database features. diesel_sqlite_pool = ["databases", "diesel/sqlite", "diesel/r2d2"] diff --git a/contrib/lib/src/cors.rs b/contrib/lib/src/cors.rs new file mode 100644 index 0000000000..5d0fa4fe32 --- /dev/null +++ b/contrib/lib/src/cors.rs @@ -0,0 +1,442 @@ +//! Implementation of CORS based on [the fetch whatwg +//! spec](https://fetch.spec.whatwg.org/#http-cors-protocol). +//! +//! This fairing is appropriate when your whole site will follow the same CORS rules. It doesn't +//! yet support custom CORS on individual routes. +use rocket::{ + fairing, + http::{ext::IntoOwned, uncased::Uncased, uri, ContentType, Method, Status}, + response::{Response, ResponseBuilder}, + Request, +}; +use std::{borrow::Cow, collections::HashSet, error::Error as StdError, fmt, io, mem}; + +/// Generate compile-time constant header names. +macro_rules! hdrs { + ($($s:expr),*) => { + [$(Uncased { string: Cow::Borrowed($s) }),*] + }; +} + +/// Headers that are allowed to be accessed in all CORS requests. +const ALLOWED_HEADERS: &'static [Uncased<'static>] = + &hdrs!["Accept", "Accept-Language", "Content-Language"]; + +/// Headers that are allowed to be accessed in all responses to CORS requests. +const EXPOSED_HEADERS: &'static [Uncased<'static>] = &hdrs![ + "Cache-Control", + "Content-Language", + "Content-Type", + "Expires", + "Last-Modified", + "Pragma" +]; + +/// the possibilities for allowed origin. +// This isn't public because the types may change after `http 1.0` is release, breaking backwards +// compat. +enum AllowedOrigin { + /// A whitelist of origins that are allowed. (e.g. `http://my_host.tld:2123`, + /// `https://example.com`). + Some(HashSet<&'static str>), + /// All origins are allowed (corresponds to "*") + Any, +} + +/// Adds Cross-origin resource sharing (CORS) support as a fairing. +pub struct CORS { + /// The origins that will be accepted when responding to requests. + allow_origin: AllowedOrigin, + /// Which headers we allow the client to read in javascript. + expose_headers: HashSet>, + /// Whether cookies should be sent. + allow_credentials: bool, + /// Which headers the client is allowed to send during the actual request (used in preflight) + allow_headers: HashSet>, + /// All methods could be possibly allowed or not. + allow_methods: HashSet, + /// The maximum time between a preflight request and the real request. + max_age: Option, +} + +impl CORS { + /// Helper to create empty CORS object. + fn new(allow_origin: AllowedOrigin) -> CORS { + CORS { + allow_origin, + expose_headers: HashSet::new(), + allow_credentials: false, + allow_headers: HashSet::new(), + allow_methods: HashSet::new(), + max_age: None, + } + } + + /// Create a CORS fairing from a comma-separated list of origins, or `*` to allow for all + /// origins. + /// + /// The CORS spec states that if the origin matches the allowed origin, we set that as the + /// `Access-Control-Allow-Origin` header, or set it to `*` if we support any origin. We extend + /// this slightly by allowing multiple origins, and if a request origin is in the list, we + /// reflect it back on its own, thereby complying with the spec. + pub fn from_origin(origin: &'static str) -> Result { + let allow_origin = match origin.trim() { + "*" => AllowedOrigin::Any, + o => { + AllowedOrigin::Some( + o.split(',') + .map(|o| o.trim()) + .filter(|o| !o.is_empty()) + .map(|o| { + // we parse the origin as a url to check it is valid. + let parsed = uri::Absolute::parse(o) + .map_err(|e| OriginError::from_parts(origin, e))?; + match parsed.scheme() { + "http" | "https" => (), + other => { + return Err(OriginError::from_parts( + origin, + OriginErrorKind::SchemeNotHyperText(other.to_owned()), + )); + } + }; + if let None = parsed.authority() { + return Err(OriginError::from_parts( + origin, + OriginErrorKind::HasNoAuthority, + )); + }; + if let Some(uri_origin) = parsed.origin() { + return Err(OriginError::from_parts( + origin, + OriginErrorKind::HasOrigin(uri_origin.to_owned()), + )); + }; + Ok(o) + }) + .collect::, OriginError>>()?, + ) + } + }; + + Ok(CORS::new(allow_origin)) + } + + /// Allow all origins (`*` in the header). + pub fn any() -> CORS { + CORS::new(AllowedOrigin::Any) + } + + /// Whether credentials are allowed to be present in cross-origin requests. + pub fn allow_credentials(mut self, allow_credentials: bool) -> CORS { + self.allow_credentials = allow_credentials; + self + } + + /// The http methods allowed in cross-origin requests. + pub fn allow_methods(mut self, methods: impl IntoIterator) -> CORS { + for method in methods.into_iter() { + self.allow_methods.insert(method); + } + self + } + + /// These are used for preflight request (OPTIONS) to specify which headers are allowed in the + /// real request. + /// + /// See [the spec](https://fetch.spec.whatwg.org/#http-cors-protocol) for a list of headers + /// that are allowed by default. + pub fn allow_headers( + mut self, + headers: impl IntoIterator>>, + ) -> CORS { + for header in headers.into_iter() { + let header = Uncased::new(header); + if ALLOWED_HEADERS.contains(&header) { + warn!( + "Header \"{}\" is allowed by default and does not need to be included", + header + ); + } else { + self.allow_headers.insert(header); + } + } + self + } + + /// Which headers in the response should be exposed to the client javascript. + pub fn exposed_headers( + mut self, + headers: impl IntoIterator>>, + ) -> CORS { + for header in headers.into_iter() { + let header = Uncased::new(header); + if EXPOSED_HEADERS.contains(&header) { + warn!( + "Header \"{}\" is allowed by default and does not need to be included", + header + ); + } else { + self.expose_headers.insert(header); + } + } + self + } + + /// The maximum amount of time that a preflight request should be valid for. After this, the + /// client should repeat the preflight before the main request. In seconds. + pub fn max_age(mut self, max_age: usize) -> CORS { + self.max_age = Some(max_age); + self + } + + /// Handle a preflight CORS request (method OPTIONS) + fn handle_preflight(&self, request: &Request, response: &mut Response) { + // Only handle requests that weren't handled explicitally. + if response.status() != Status::NotFound { + return; + } + let headers = request.headers(); + // Our response object + let mut cors_response = Response::build().status(Status::Ok).finalize(); + // If the request doesn't match our allowed requests, return a client failure. + if !self.check_origin(request, &mut cors_response) { + return; + } + if !self.check_method(request, &mut cors_response) { + return; + } + if !self.check_headers(request, &mut cors_response) { + return; + } + mem::swap(&mut cors_response, response); + if self.allow_credentials { + response.set_raw_header("Access-Control-Allow-Credentails", "true"); + } + if self.expose_headers.len() > 0 { + set_exposed_headers_header(response, &self.expose_headers); + } + // drop old response. + } + + /// Modify a standard request to add CORS. + fn handle_cors(&self, request: &Request, response: &mut Response) { + if !self.check_origin(request, response) { + return; + } + } + + /// If the origin check passes, add the related header, else replace with an error response. + fn check_origin(&self, request: &Request, response: &mut Response) -> bool { + match self.allow_origin { + AllowedOrigin::Some(ref origins) => match request.headers().get_one("Origin") { + Some(origin) => { + if let Some(cors_origin) = origins.get(origin) { + set_origin_header(response, Some(cors_origin)); + true + } else { + unauthorized(response, "origin not valid"); + false + } + } + _ => { + unauthorized(response, "origin not present"); + false + } + }, + AllowedOrigin::Any => { + set_origin_header(response, None); + true + } + } + } + + /// If the method check passes, add the related header, else replace with an error response. + fn check_method(&self, request: &Request, response: &mut Response) -> bool { + if let Some(method) = request.headers().get_one("Access-Control-Request-Method") { + match ::from_str(&method) { + Ok(method) if self.allow_methods.contains(&method) => { + set_methods_header(response, &self.allow_methods); + true + } + _ => { + unauthorized(response, "requested method not valid"); + false + } + } + } else { + true + } + } + + /// If the allowed headers check passes, add the related header, else replace with an error response. + fn check_headers(&self, request: &Request, response: &mut Response) -> bool { + if let Some(allow_headers) = request.headers().get_one("Access-Control-Request-Headers") { + for header in allow_headers.split(',').map(|s| Uncased::new(s.trim())) { + if !(self.allow_headers.contains(&header) || ALLOWED_HEADERS.contains(&header)) { + unauthorized(response, "a requested header is not supported"); + return false; + } + } + } + if self.allow_headers.len() > 0 { + set_headers_header(response, &self.allow_headers); + } + true + } +} + +impl fairing::Fairing for CORS { + fn info(&self) -> fairing::Info { + use rocket::fairing::{Info, Kind}; + Info { + name: "Cross-origin resource sharing (CORS) support", + kind: Kind::Response, + } + } + + fn on_response(&self, request: &Request, response: &mut Response) { + if request.method() == Method::Options { + self.handle_preflight(request, response) + } else { + self.handle_cors(request, response) + } + } +} + +// Error handling for CORS +// ----------------------- + +/// Ways we can fail to parse a CORS origin. +#[derive(Debug)] +pub enum OriginErrorKind { + ParsingFailed(uri::Error<'static>), + SchemeNotHyperText(String), + HasNoAuthority, + HasOrigin(uri::Origin<'static>), +} + +impl<'a> From> for OriginErrorKind { + fn from(e: uri::Error) -> Self { + OriginErrorKind::ParsingFailed(e.into_owned()) + } +} + +/// A failure in parsing a CORS origin. +#[derive(Debug)] +pub struct OriginError { + uri: String, + kind: OriginErrorKind, +} + +impl OriginError { + fn from_parts(uri: impl Into, kind: impl Into) -> Self { + OriginError { + uri: uri.into(), + kind: kind.into(), + } + } +} + +impl fmt::Display for OriginError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::OriginErrorKind::*; + write!(f, r#"error in uri "{}": "#, self.uri)?; + match &self.kind { + ParsingFailed(inner) => fmt::Display::fmt(inner, f), + SchemeNotHyperText(scheme) => write!( + f, + r#"exepected scheme to be "http" or "https", found "{}""#, + scheme + ), + + HasNoAuthority => f.write_str("the uri should have an authority, none found"), + HasOrigin(origin) => write!(f, r#"expected empty origin, found "{}"#, origin), + } + } +} + +impl StdError for OriginError {} + +// Helper methdos +// -------------- + +/// Replace the current response with one representing unauthorized. +/// +/// It is important that these methods don't leak any information they shouldn't. +fn unauthorized(response: &mut Response, msg: &'static str) { + let mut cors_response = Response::build() + .status(Status::Unauthorized) + .header(ContentType::Plain) + .sized_body(io::Cursor::new(msg)) + .finalize(); + mem::swap(response, &mut cors_response); + // drop original response. +} + +/// Set the origin header of the response to the given origin, or "*" if None. +/// +/// It would be nice to remove the `'static` restriction on the origin, to allow them to be set +/// dynamically, but I'm not sure the typechecker (or me) can check our memory safety rules. +/// +/// Altenratively, they could be interned so there is at most one allocation for each origin, but +/// the current way requires no allocations. +fn set_origin_header(res: &mut Response, origin: Option<&&'static str>) { + res.set_raw_header( + "Access-Control-Allow-Origin", + origin.map(|o| *o).unwrap_or("*"), + ); +} + +/// Set the methods header of the response to the given allowed methods. +fn set_methods_header(res: &mut Response, methods: &HashSet) { + // for now we build the string for every request, but we could do some + // interning. It's not automatic this would be faster, but could bench to + // see. + // + // longest method is 7 bytes. + let mut methods_str = String::with_capacity(methods.len() * 7); + for method in methods { + methods_str.push_str(method.as_str()); + methods_str.push_str(", "); + } + methods_str.pop(); // remove last ' ' + methods_str.pop(); // remove last ',' + res.set_raw_header("Access-Control-Allow-Methods", methods_str); +} + +/// Set the allowed headers header of the response to the given allowed headers. +fn set_headers_header(res: &mut Response, headers: &HashSet>) { + // I guess that most headers are less than 16 bytes + let mut headers_str = String::with_capacity(headers.len() * 18); + for header in headers { + headers_str.push_str(header.as_str()); + headers_str.push_str(", "); + } + headers_str.pop(); // remove last ' ' + headers_str.pop(); // remove last ',' + res.set_raw_header("Access-Control-Allow-Headers", headers_str); +} + +/// Set the allowed headers header of the response to the given allowed headers. +fn set_exposed_headers_header(res: &mut Response, headers: &HashSet>) { + // I guess that most header names are less than 16 bytes + let headers_str = concat_strs(headers.iter(), ", ", headers.len(), 16); + res.set_raw_header("Access-Control-Expose-Headers", headers_str); +} + +fn concat_strs( + strs: impl Iterator>, + join: &'static str, + len: usize, + guess_size: usize, +) -> String { + let mut output = String::with_capacity(len * (guess_size + join.len())); + for (idx, s) in strs.enumerate() { + output.push_str(s.as_ref()); + if idx < len - 1 { + output.push_str(", "); + } + } + output +} diff --git a/contrib/lib/src/databases.rs b/contrib/lib/src/databases.rs index c2f8b48138..120e8a5a85 100644 --- a/contrib/lib/src/databases.rs +++ b/contrib/lib/src/databases.rs @@ -375,9 +375,11 @@ pub extern crate r2d2; -#[cfg(any(feature = "diesel_sqlite_pool", - feature = "diesel_postgres_pool", - feature = "diesel_mysql_pool"))] +#[cfg(any( + feature = "diesel_sqlite_pool", + feature = "diesel_postgres_pool", + feature = "diesel_mysql_pool" +))] pub extern crate diesel; use std::collections::BTreeMap; @@ -388,28 +390,43 @@ use rocket::config::{self, Value}; use self::r2d2::ManageConnection; -#[doc(hidden)] pub use rocket_contrib_codegen::*; +#[doc(hidden)] +pub use rocket_contrib_codegen::*; -#[cfg(feature = "postgres_pool")] pub extern crate postgres; -#[cfg(feature = "postgres_pool")] pub extern crate r2d2_postgres; +#[cfg(feature = "postgres_pool")] +pub extern crate postgres; +#[cfg(feature = "postgres_pool")] +pub extern crate r2d2_postgres; -#[cfg(feature = "mysql_pool")] pub extern crate mysql; -#[cfg(feature = "mysql_pool")] pub extern crate r2d2_mysql; +#[cfg(feature = "mysql_pool")] +pub extern crate mysql; +#[cfg(feature = "mysql_pool")] +pub extern crate r2d2_mysql; -#[cfg(feature = "sqlite_pool")] pub extern crate rusqlite; -#[cfg(feature = "sqlite_pool")] pub extern crate r2d2_sqlite; +#[cfg(feature = "sqlite_pool")] +pub extern crate r2d2_sqlite; +#[cfg(feature = "sqlite_pool")] +pub extern crate rusqlite; -#[cfg(feature = "cypher_pool")] pub extern crate rusted_cypher; -#[cfg(feature = "cypher_pool")] pub extern crate r2d2_cypher; +#[cfg(feature = "cypher_pool")] +pub extern crate r2d2_cypher; +#[cfg(feature = "cypher_pool")] +pub extern crate rusted_cypher; -#[cfg(feature = "redis_pool")] pub extern crate redis; -#[cfg(feature = "redis_pool")] pub extern crate r2d2_redis; +#[cfg(feature = "redis_pool")] +pub extern crate r2d2_redis; +#[cfg(feature = "redis_pool")] +pub extern crate redis; -#[cfg(feature = "mongodb_pool")] pub extern crate mongodb; -#[cfg(feature = "mongodb_pool")] pub extern crate r2d2_mongodb; +#[cfg(feature = "mongodb_pool")] +pub extern crate mongodb; +#[cfg(feature = "mongodb_pool")] +pub extern crate r2d2_mongodb; -#[cfg(feature = "memcache_pool")] pub extern crate memcache; -#[cfg(feature = "memcache_pool")] pub extern crate r2d2_memcache; +#[cfg(feature = "memcache_pool")] +pub extern crate memcache; +#[cfg(feature = "memcache_pool")] +pub extern crate r2d2_memcache; /// A structure representing a particular database configuration. /// @@ -540,23 +557,26 @@ pub enum ConfigError { /// ``` pub fn database_config<'a>( name: &str, - from: &'a config::Config + from: &'a config::Config, ) -> Result, ConfigError> { // Find the first `databases` config that's a table with a key of 'name' // equal to `name`. - let connection_config = from.get_table("databases") + let connection_config = from + .get_table("databases") .map_err(|_| ConfigError::MissingTable)? .get(name) .ok_or(ConfigError::MissingKey)? .as_table() .ok_or(ConfigError::MalformedConfiguration)?; - let maybe_url = connection_config.get("url") + let maybe_url = connection_config + .get("url") .ok_or(ConfigError::MissingUrl)?; let url = maybe_url.as_str().ok_or(ConfigError::MalformedUrl)?; - let pool_size = connection_config.get("pool_size") + let pool_size = connection_config + .get("pool_size") .and_then(Value::as_integer) .unwrap_or(from.workers as i64); @@ -568,30 +588,32 @@ pub fn database_config<'a>( extras.remove("url"); extras.remove("pool_size"); - Ok(DatabaseConfig { url, pool_size: pool_size as u32, extras: extras }) + Ok(DatabaseConfig { + url, + pool_size: pool_size as u32, + extras: extras, + }) } impl<'a> Display for ConfigError { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - ConfigError::MissingTable => { - write!(f, "A table named `databases` was not found for this configuration") - }, - ConfigError::MissingKey => { - write!(f, "An entry in the `databases` table was not found for this key") - }, + ConfigError::MissingTable => write!( + f, + "A table named `databases` was not found for this configuration" + ), + ConfigError::MissingKey => write!( + f, + "An entry in the `databases` table was not found for this key" + ), ConfigError::MalformedConfiguration => { write!(f, "The configuration for this database is malformed") } - ConfigError::MissingUrl => { - write!(f, "The connection URL is missing for this database") - }, - ConfigError::MalformedUrl => { - write!(f, "The specified connection URL is malformed") - }, + ConfigError::MissingUrl => write!(f, "The connection URL is missing for this database"), + ConfigError::MalformedUrl => write!(f, "The specified connection URL is malformed"), ConfigError::InvalidPoolSize(invalid_size) => { write!(f, "'{}' is not a valid value for `pool_size`", invalid_size) - }, + } } } } @@ -688,7 +710,7 @@ impl<'a> Display for ConfigError { /// existing implementations of [`Poolable`]. pub trait Poolable: Send + Sized + 'static { /// The associated connection manager for the given connection type. - type Manager: ManageConnection; + type Manager: ManageConnection; /// The associated error type in the event that constructing the connection /// manager and/or the connection pool fails. type Error; @@ -705,7 +727,9 @@ impl Poolable for diesel::SqliteConnection { fn pool(config: DatabaseConfig) -> Result, Self::Error> { let manager = diesel::r2d2::ConnectionManager::new(config.url); - r2d2::Pool::builder().max_size(config.pool_size).build(manager) + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) } } @@ -716,7 +740,9 @@ impl Poolable for diesel::PgConnection { fn pool(config: DatabaseConfig) -> Result, Self::Error> { let manager = diesel::r2d2::ConnectionManager::new(config.url); - r2d2::Pool::builder().max_size(config.pool_size).build(manager) + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) } } @@ -727,7 +753,9 @@ impl Poolable for diesel::MysqlConnection { fn pool(config: DatabaseConfig) -> Result, Self::Error> { let manager = diesel::r2d2::ConnectionManager::new(config.url); - r2d2::Pool::builder().max_size(config.pool_size).build(manager) + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) } } @@ -738,10 +766,13 @@ impl Poolable for postgres::Connection { type Error = DbError; fn pool(config: DatabaseConfig) -> Result, Self::Error> { - let manager = r2d2_postgres::PostgresConnectionManager::new(config.url, r2d2_postgres::TlsMode::None) - .map_err(DbError::Custom)?; + let manager = + r2d2_postgres::PostgresConnectionManager::new(config.url, r2d2_postgres::TlsMode::None) + .map_err(DbError::Custom)?; - r2d2::Pool::builder().max_size(config.pool_size).build(manager) + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) .map_err(DbError::PoolError) } } @@ -754,7 +785,9 @@ impl Poolable for mysql::Conn { fn pool(config: DatabaseConfig) -> Result, Self::Error> { let opts = mysql::OptsBuilder::from_opts(config.url); let manager = r2d2_mysql::MysqlConnectionManager::new(opts); - r2d2::Pool::builder().max_size(config.pool_size).build(manager) + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) } } @@ -766,7 +799,9 @@ impl Poolable for rusqlite::Connection { fn pool(config: DatabaseConfig) -> Result, Self::Error> { let manager = r2d2_sqlite::SqliteConnectionManager::file(config.url); - r2d2::Pool::builder().max_size(config.pool_size).build(manager) + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) } } @@ -776,8 +811,12 @@ impl Poolable for rusted_cypher::GraphClient { type Error = r2d2::Error; fn pool(config: DatabaseConfig) -> Result, Self::Error> { - let manager = r2d2_cypher::CypherConnectionManager { url: config.url.to_string() }; - r2d2::Pool::builder().max_size(config.pool_size).build(manager) + let manager = r2d2_cypher::CypherConnectionManager { + url: config.url.to_string(), + }; + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) } } @@ -787,8 +826,11 @@ impl Poolable for redis::Connection { type Error = DbError; fn pool(config: DatabaseConfig) -> Result, Self::Error> { - let manager = r2d2_redis::RedisConnectionManager::new(config.url).map_err(DbError::Custom)?; - r2d2::Pool::builder().max_size(config.pool_size).build(manager) + let manager = + r2d2_redis::RedisConnectionManager::new(config.url).map_err(DbError::Custom)?; + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) .map_err(DbError::PoolError) } } @@ -799,8 +841,12 @@ impl Poolable for mongodb::db::Database { type Error = DbError; fn pool(config: DatabaseConfig) -> Result, Self::Error> { - let manager = r2d2_mongodb::MongodbConnectionManager::new_with_uri(config.url).map_err(DbError::Custom)?; - r2d2::Pool::builder().max_size(config.pool_size).build(manager).map_err(DbError::PoolError) + let manager = r2d2_mongodb::MongodbConnectionManager::new_with_uri(config.url) + .map_err(DbError::Custom)?; + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) + .map_err(DbError::PoolError) } } @@ -811,21 +857,25 @@ impl Poolable for memcache::Client { fn pool(config: DatabaseConfig) -> Result, Self::Error> { let manager = r2d2_memcache::MemcacheConnectionManager::new(config.url); - r2d2::Pool::builder().max_size(config.pool_size).build(manager).map_err(DbError::PoolError) + r2d2::Pool::builder() + .max_size(config.pool_size) + .build(manager) + .map_err(DbError::PoolError) } } #[cfg(test)] mod tests { + use super::{database_config, ConfigError::*}; + use rocket::{ + config::{Environment, Value}, + Config, + }; use std::collections::BTreeMap; - use rocket::{Config, config::{Environment, Value}}; - use super::{ConfigError::*, database_config}; #[test] fn no_database_entry_in_config_returns_error() { - let config = Config::build(Environment::Development) - .finalize() - .unwrap(); + let config = Config::build(Environment::Development).finalize().unwrap(); let database_config_result = database_config("dummy_db", &config); assert_eq!(Err(MissingTable), database_config_result); diff --git a/contrib/lib/src/helmet/helmet.rs b/contrib/lib/src/helmet/helmet.rs index 8f68d0b84c..055738ca88 100644 --- a/contrib/lib/src/helmet/helmet.rs +++ b/contrib/lib/src/helmet/helmet.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicBool, Ordering}; -use rocket::http::uncased::UncasedStr; use rocket::fairing::{Fairing, Info, Kind}; +use rocket::http::uncased::UncasedStr; use rocket::{Request, Response, Rocket}; use helmet::*; @@ -173,7 +173,7 @@ impl SpaceHelmet { if response.headers().contains(name.as_str()) { warn!("Space Helmet: response contains a '{}' header.", name); warn_!("Refusing to overwrite existing header."); - continue + continue; } // FIXME: Cache the rendered header. diff --git a/contrib/lib/src/helmet/policy.rs b/contrib/lib/src/helmet/policy.rs index bb40532e96..612055b42f 100644 --- a/contrib/lib/src/helmet/policy.rs +++ b/contrib/lib/src/helmet/policy.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; -use rocket::http::{Header, uri::Uri, uncased::UncasedStr}; +use rocket::http::{uncased::UncasedStr, uri::Uri, Header}; use helmet::time::Duration; @@ -79,7 +79,7 @@ impl SubPolicy for P { } macro_rules! impl_policy { - ($T:ty, $name:expr) => ( + ($T:ty, $name:expr) => { impl Policy for $T { const NAME: &'static str = $name; @@ -87,7 +87,7 @@ macro_rules! impl_policy { self.into() } } - ) + }; } impl_policy!(XssFilter, "X-XSS-Protection"); @@ -141,7 +141,7 @@ pub enum Referrer { /// the full URL of TLS protected resources to insecure origins. Use with /// caution._ UnsafeUrl, - } +} /// Defaults to [`Referrer::NoReferrer`]. Tells the browser to omit the /// `Referer` header. @@ -211,14 +211,16 @@ impl Default for ExpectCt { impl<'a> Into> for &'a ExpectCt { fn into(self) -> Header<'static> { - let policy_string = match self { + let policy_string = match self { ExpectCt::Enforce(age) => format!("max-age={}, enforce", age.num_seconds()), ExpectCt::Report(age, uri) => { format!(r#"max-age={}, report-uri="{}""#, age.num_seconds(), uri) } - ExpectCt::ReportAndEnforce(age, uri) => { - format!("max-age={}, enforce, report-uri=\"{}\"", age.num_seconds(), uri) - } + ExpectCt::ReportAndEnforce(age, uri) => format!( + "max-age={}, enforce, report-uri=\"{}\"", + age.num_seconds(), + uri + ), }; Header::new(ExpectCt::NAME, policy_string) diff --git a/contrib/lib/src/json.rs b/contrib/lib/src/json.rs index b31d764409..39bef08a2f 100644 --- a/contrib/lib/src/json.rs +++ b/contrib/lib/src/json.rs @@ -17,17 +17,17 @@ extern crate serde; extern crate serde_json; -use std::ops::{Deref, DerefMut}; use std::io::{self, Read}; +use std::ops::{Deref, DerefMut}; -use rocket::request::Request; -use rocket::outcome::Outcome::*; -use rocket::data::{Outcome, Transform, Transform::*, Transformed, Data, FromData}; -use rocket::response::{self, Responder, content}; +use rocket::data::{Data, FromData, Outcome, Transform, Transform::*, Transformed}; use rocket::http::Status; +use rocket::outcome::Outcome::*; +use rocket::request::Request; +use rocket::response::{self, content, Responder}; -use self::serde::{Serialize, Serializer}; use self::serde::de::{Deserialize, Deserializer}; +use self::serde::{Serialize, Serializer}; #[doc(hidden)] pub use self::serde_json::{json_internal, json_internal_vec}; @@ -140,7 +140,7 @@ impl<'a, T: Deserialize<'a>> FromData<'a> for Json { let mut s = String::with_capacity(512); match d.open().take(size_limit).read_to_string(&mut s) { Ok(_) => Borrowed(Success(s)), - Err(e) => Borrowed(Failure((Status::BadRequest, JsonError::Io(e)))) + Err(e) => Borrowed(Failure((Status::BadRequest, JsonError::Io(e)))), } } @@ -165,12 +165,12 @@ impl<'a, T: Deserialize<'a>> FromData<'a> for Json { /// fails, an `Err` of `Status::InternalServerError` is returned. impl<'a, T: Serialize> Responder<'a> for Json { fn respond_to(self, req: &Request) -> response::Result<'a> { - serde_json::to_string(&self.0).map(|string| { - content::Json(string).respond_to(req).unwrap() - }).map_err(|e| { - error_!("JSON failed to serialize: {:?}", e); - Status::InternalServerError - }) + serde_json::to_string(&self.0) + .map(|string| content::Json(string).respond_to(req).unwrap()) + .map_err(|e| { + error_!("JSON failed to serialize: {:?}", e); + Status::InternalServerError + }) } } diff --git a/contrib/lib/src/lib.rs b/contrib/lib/src/lib.rs index ef17a8e3a4..0efabe43a5 100644 --- a/contrib/lib/src/lib.rs +++ b/contrib/lib/src/lib.rs @@ -1,7 +1,6 @@ #![feature(crate_visibility_modifier)] #![feature(never_type)] #![feature(doc_cfg)] - #![doc(html_root_url = "https://api.rocket.rs/v0.4")] #![doc(html_favicon_url = "https://rocket.rs/v0.4/images/favicon.ico")] #![doc(html_logo_url = "https://rocket.rs/v0.4/images/logo-boxed.png")] @@ -41,16 +40,33 @@ //! This crate is expected to grow with time, bringing in outside crates to be //! officially supported by Rocket. -#[allow(unused_imports)] #[macro_use] extern crate log; -#[allow(unused_imports)] #[macro_use] extern crate rocket; +#[allow(unused_imports)] +#[macro_use] +extern crate log; +#[allow(unused_imports)] +#[macro_use] +extern crate rocket; -#[cfg(feature="json")] #[macro_use] pub mod json; -#[cfg(feature="serve")] pub mod serve; -#[cfg(feature="msgpack")] pub mod msgpack; -#[cfg(feature="templates")] pub mod templates; -#[cfg(feature="uuid")] pub mod uuid; -#[cfg(feature="databases")] pub mod databases; -#[cfg(feature = "helmet")] pub mod helmet; +#[cfg(feature = "json")] +#[macro_use] +pub mod json; +#[cfg(feature = "cors")] +pub mod cors; +#[cfg(feature = "databases")] +pub mod databases; +#[cfg(feature = "helmet")] +pub mod helmet; +#[cfg(feature = "msgpack")] +pub mod msgpack; +#[cfg(feature = "serve")] +pub mod serve; +#[cfg(feature = "templates")] +pub mod templates; +#[cfg(feature = "uuid")] +pub mod uuid; -#[cfg(feature="databases")] extern crate rocket_contrib_codegen; -#[cfg(feature="databases")] #[doc(hidden)] pub use rocket_contrib_codegen::*; +#[cfg(feature = "databases")] +extern crate rocket_contrib_codegen; +#[cfg(feature = "databases")] +#[doc(hidden)] +pub use rocket_contrib_codegen::*; diff --git a/contrib/lib/src/msgpack.rs b/contrib/lib/src/msgpack.rs index 252eecae49..9d2404be1b 100644 --- a/contrib/lib/src/msgpack.rs +++ b/contrib/lib/src/msgpack.rs @@ -13,20 +13,20 @@ //! default-features = false //! features = ["msgpack"] //! ``` -extern crate serde; extern crate rmp_serde; +extern crate serde; -use std::ops::{Deref, DerefMut}; use std::io::{Cursor, Read}; +use std::ops::{Deref, DerefMut}; -use rocket::request::Request; +use rocket::data::{Data, FromData, Outcome, Transform, Transform::*, Transformed}; +use rocket::http::Status; use rocket::outcome::Outcome::*; -use rocket::data::{Outcome, Transform, Transform::*, Transformed, Data, FromData}; +use rocket::request::Request; use rocket::response::{self, Responder, Response}; -use rocket::http::Status; -use self::serde::Serialize; use self::serde::de::Deserialize; +use self::serde::Serialize; pub use self::rmp_serde::decode::Error; @@ -126,7 +126,7 @@ impl<'a, T: Deserialize<'a>> FromData<'a> for MsgPack { let size_limit = r.limits().get("msgpack").unwrap_or(LIMIT); match d.open().take(size_limit).read_to_end(&mut buf) { Ok(_) => Borrowed(Success(buf)), - Err(e) => Borrowed(Failure((Status::BadRequest, Error::InvalidDataRead(e)))) + Err(e) => Borrowed(Failure((Status::BadRequest, Error::InvalidDataRead(e)))), } } @@ -142,7 +142,7 @@ impl<'a, T: Deserialize<'a>> FromData<'a> for MsgPack { TypeMismatch(_) | OutOfRange | LengthMismatch(_) => { Failure((Status::UnprocessableEntity, e)) } - _ => Failure((Status::BadRequest, e)) + _ => Failure((Status::BadRequest, e)), } } } @@ -154,14 +154,12 @@ impl<'a, T: Deserialize<'a>> FromData<'a> for MsgPack { /// serialization fails, an `Err` of `Status::InternalServerError` is returned. impl Responder<'static> for MsgPack { fn respond_to(self, _: &Request) -> response::Result<'static> { - rmp_serde::to_vec(&self.0).map_err(|e| { - error_!("MsgPack failed to serialize: {:?}", e); - Status::InternalServerError - }).and_then(|buf| { - Response::build() - .sized_body(Cursor::new(buf)) - .ok() - }) + rmp_serde::to_vec(&self.0) + .map_err(|e| { + error_!("MsgPack failed to serialize: {:?}", e); + Status::InternalServerError + }) + .and_then(|buf| Response::build().sized_body(Cursor::new(buf)).ok()) } } diff --git a/contrib/lib/src/serve.rs b/contrib/lib/src/serve.rs index 8c34496332..278d3e4baa 100644 --- a/contrib/lib/src/serve.rs +++ b/contrib/lib/src/serve.rs @@ -14,13 +14,13 @@ //! features = ["serve"] //! ``` -use std::path::{PathBuf, Path}; +use std::path::{Path, PathBuf}; -use rocket::{Request, Data, Route}; -use rocket::http::{Method, Status, uri::Segments}; use rocket::handler::{Handler, Outcome}; -use rocket::response::NamedFile; +use rocket::http::{uri::Segments, Method, Status}; use rocket::outcome::IntoOutcome; +use rocket::response::NamedFile; +use rocket::{Data, Request, Route}; /// A bitset representing configurable options for the [`StaticFiles`] handler. /// @@ -204,7 +204,10 @@ impl StaticFiles { /// } /// ``` pub fn new>(path: P, options: Options) -> Self { - StaticFiles { root: path.as_ref().into(), options } + StaticFiles { + root: path.as_ref().into(), + options, + } } } @@ -241,7 +244,8 @@ impl Handler for StaticFiles { // Otherwise, we're handling segments. Get the segments as a `PathBuf`, // only allowing dotfiles if the user allowed it. let allow_dotfiles = self.options.contains(Options::DotFiles); - let path = req.get_segments::(0) + let path = req + .get_segments::(0) .and_then(|res| res.ok()) .and_then(|segments| segments.into_path_buf(allow_dotfiles).ok()) .map(|path| self.root.join(path)) diff --git a/contrib/lib/src/templates/context.rs b/contrib/lib/src/templates/context.rs index 89f824d60f..910dcf7885 100644 --- a/contrib/lib/src/templates/context.rs +++ b/contrib/lib/src/templates/context.rs @@ -1,5 +1,5 @@ -use std::path::{Path, PathBuf}; use std::collections::HashMap; +use std::path::{Path, PathBuf}; use templates::{glob, Engines, TemplateInfo}; @@ -35,20 +35,27 @@ impl Context { continue; } - let data_type = data_type_str.as_ref() + let data_type = data_type_str + .as_ref() .and_then(|ext| ContentType::from_extension(ext)) .unwrap_or(ContentType::HTML); - templates.insert(name, TemplateInfo { - path: path.to_path_buf(), - extension: ext.to_string(), - data_type, - }); + templates.insert( + name, + TemplateInfo { + path: path.to_path_buf(), + extension: ext.to_string(), + data_type, + }, + ); } } - Engines::init(&templates) - .map(|engines| Context { root, templates, engines } ) + Engines::init(&templates).map(|engines| Context { + root, + templates, + engines, + }) } } @@ -57,12 +64,12 @@ fn remove_extension>(path: P) -> PathBuf { let path = path.as_ref(); let stem = match path.file_stem() { Some(stem) => stem, - None => return path.to_path_buf() + None => return path.to_path_buf(), }; match path.parent() { Some(parent) => parent.join(stem), - None => PathBuf::from(stem) + None => PathBuf::from(stem), } } @@ -72,7 +79,9 @@ fn split_path(root: &Path, path: &Path) -> (String, Option) { let rel_path = path.strip_prefix(root).unwrap().to_path_buf(); let path_no_ext = remove_extension(&rel_path); let data_type = path_no_ext.extension(); - let mut name = remove_extension(&path_no_ext).to_string_lossy().into_owned(); + let mut name = remove_extension(&path_no_ext) + .to_string_lossy() + .into_owned(); // Ensure template name consistency on Windows systems if cfg!(windows) { @@ -127,6 +136,9 @@ mod tests { assert_eq!(name_for("dir/index.hbs"), "dir/index"); assert_eq!(name_for("dir/index.html.tera"), "dir/index"); assert_eq!(name_for("index.template.html.hbs"), "index.template"); - assert_eq!(name_for("subdir/index.template.html.hbs"), "subdir/index.template"); + assert_eq!( + name_for("subdir/index.template.html.hbs"), + "subdir/index.template" + ); } } diff --git a/contrib/lib/src/templates/engine.rs b/contrib/lib/src/templates/engine.rs index a63c19798d..515e30da91 100644 --- a/contrib/lib/src/templates/engine.rs +++ b/contrib/lib/src/templates/engine.rs @@ -1,14 +1,18 @@ use std::collections::HashMap; -use templates::{TemplateInfo, serde::Serialize}; +use templates::{serde::Serialize, TemplateInfo}; -#[cfg(feature = "tera_templates")] use templates::tera::Tera; -#[cfg(feature = "handlebars_templates")] use templates::handlebars::Handlebars; +#[cfg(feature = "handlebars_templates")] +use templates::handlebars::Handlebars; +#[cfg(feature = "tera_templates")] +use templates::tera::Tera; crate trait Engine: Send + Sync + 'static { const EXT: &'static str; - fn init(templates: &[(&str, &TemplateInfo)]) -> Option where Self: Sized; + fn init(templates: &[(&str, &TemplateInfo)]) -> Option + where + Self: Sized; fn render(&self, name: &str, context: C) -> Option; } @@ -59,13 +63,16 @@ pub struct Engines { impl Engines { crate const ENABLED_EXTENSIONS: &'static [&'static str] = &[ - #[cfg(feature = "tera_templates")] Tera::EXT, - #[cfg(feature = "handlebars_templates")] Handlebars::EXT, + #[cfg(feature = "tera_templates")] + Tera::EXT, + #[cfg(feature = "handlebars_templates")] + Handlebars::EXT, ]; crate fn init(templates: &HashMap) -> Option { fn inner(templates: &HashMap) -> Option { - let named_templates = templates.iter() + let named_templates = templates + .iter() .filter(|&(_, i)| i.extension == E::EXT) .map(|(k, i)| (k.as_str(), i)) .collect::>(); @@ -77,12 +84,12 @@ impl Engines { #[cfg(feature = "tera_templates")] tera: match inner::(templates) { Some(tera) => tera, - None => return None + None => return None, }, #[cfg(feature = "handlebars_templates")] handlebars: match inner::(templates) { Some(hb) => hb, - None => return None + None => return None, }, }) } @@ -91,7 +98,7 @@ impl Engines { &self, name: &str, info: &TemplateInfo, - context: C + context: C, ) -> Option { #[cfg(feature = "tera_templates")] { diff --git a/contrib/lib/src/templates/fairing.rs b/contrib/lib/src/templates/fairing.rs index 71e138f191..cb54ba5527 100644 --- a/contrib/lib/src/templates/fairing.rs +++ b/contrib/lib/src/templates/fairing.rs @@ -1,8 +1,8 @@ -use templates::{DEFAULT_TEMPLATE_DIR, Context, Engines}; +use templates::{Context, Engines, DEFAULT_TEMPLATE_DIR}; -use rocket::Rocket; use rocket::config::ConfigError; use rocket::fairing::{Fairing, Info, Kind}; +use rocket::Rocket; crate use self::context::ContextManager; @@ -20,7 +20,7 @@ mod context { ContextManager(ctxt) } - crate fn context<'a>(&'a self) -> impl Deref + 'a { + crate fn context<'a>(&'a self) -> impl Deref + 'a { &self.0 } @@ -35,8 +35,8 @@ mod context { extern crate notify; use std::ops::{Deref, DerefMut}; - use std::sync::{RwLock, Mutex}; use std::sync::mpsc::{channel, Receiver}; + use std::sync::{Mutex, RwLock}; use templates::{Context, Engines}; @@ -75,7 +75,7 @@ mod context { } } - crate fn context<'a>(&'a self) -> impl Deref + 'a { + crate fn context<'a>(&'a self) -> impl Deref + 'a { self.context.read().unwrap() } @@ -83,7 +83,7 @@ mod context { self.watcher.is_some() } - fn context_mut<'a>(&'a self) -> impl DerefMut + 'a { + fn context_mut<'a>(&'a self) -> impl DerefMut + 'a { self.context.write().unwrap() } @@ -166,7 +166,8 @@ impl Fairing for TemplateFairing { #[cfg(debug_assertions)] fn on_request(&self, req: &mut ::rocket::Request, _data: &::rocket::Data) { - let cm = req.guard::<::rocket::State>() + let cm = req + .guard::<::rocket::State>() .expect("Template ContextManager registered in on_attach"); cm.reload_if_needed(&*self.custom_callback); diff --git a/contrib/lib/src/templates/metadata.rs b/contrib/lib/src/templates/metadata.rs index 7e8333850f..7191dc87dd 100644 --- a/contrib/lib/src/templates/metadata.rs +++ b/contrib/lib/src/templates/metadata.rs @@ -1,6 +1,6 @@ -use rocket::{Request, State, Outcome}; use rocket::http::Status; use rocket::request::{self, FromRequest}; +use rocket::{Outcome, Request, State}; use templates::ContextManager; @@ -91,7 +91,8 @@ impl<'a, 'r> FromRequest<'a, 'r> for Metadata<'a> { type Error = (); fn from_request(request: &'a Request) -> request::Outcome { - request.guard::>() + request + .guard::>() .succeeded() .and_then(|cm| Some(Outcome::Success(Metadata(cm.inner())))) .unwrap_or_else(|| { diff --git a/contrib/lib/src/templates/mod.rs b/contrib/lib/src/templates/mod.rs index 1666c6c734..8cb25c3736 100644 --- a/contrib/lib/src/templates/mod.rs +++ b/contrib/lib/src/templates/mod.rs @@ -111,40 +111,44 @@ //! [`Template::custom()`]: templates::Template::custom() //! [`Template::render()`]: templates::Template::render() +extern crate glob; extern crate serde; extern crate serde_json; -extern crate glob; -#[cfg(feature = "tera_templates")] pub extern crate tera; -#[cfg(feature = "tera_templates")] mod tera_templates; +#[cfg(feature = "tera_templates")] +pub extern crate tera; +#[cfg(feature = "tera_templates")] +mod tera_templates; -#[cfg(feature = "handlebars_templates")] pub extern crate handlebars; -#[cfg(feature = "handlebars_templates")] mod handlebars_templates; +#[cfg(feature = "handlebars_templates")] +pub extern crate handlebars; +#[cfg(feature = "handlebars_templates")] +mod handlebars_templates; +mod context; mod engine; mod fairing; -mod context; mod metadata; -pub use self::engine::Engines; -pub use self::metadata::Metadata; crate use self::context::Context; +pub use self::engine::Engines; crate use self::fairing::ContextManager; +pub use self::metadata::Metadata; use self::engine::Engine; use self::fairing::TemplateFairing; -use self::serde::Serialize; -use self::serde_json::{Value, to_value}; use self::glob::glob; +use self::serde::Serialize; +use self::serde_json::{to_value, Value}; use std::borrow::Cow; use std::path::PathBuf; -use rocket::{Rocket, State}; -use rocket::request::Request; use rocket::fairing::Fairing; -use rocket::response::{self, Content, Responder}; use rocket::http::{ContentType, Status}; +use rocket::request::Request; +use rocket::response::{self, Content, Responder}; +use rocket::{Rocket, State}; const DEFAULT_TEMPLATE_DIR: &str = "templates"; @@ -204,7 +208,7 @@ const DEFAULT_TEMPLATE_DIR: &str = "templates"; #[derive(Debug)] pub struct Template { name: Cow<'static, str>, - value: Option + value: Option, } #[derive(Debug)] @@ -214,7 +218,7 @@ crate struct TemplateInfo { /// The extension for the engine of this template. extension: String, /// The extension before the engine extension in the template, if any. - data_type: ContentType + data_type: ContentType, } impl Template { @@ -277,9 +281,12 @@ impl Template { /// } /// ``` pub fn custom(f: F) -> impl Fairing - where F: Fn(&mut Engines) + Send + Sync + 'static + where + F: Fn(&mut Engines) + Send + Sync + 'static, { - TemplateFairing { custom_callback: Box::new(f) } + TemplateFairing { + custom_callback: Box::new(f), + } } /// Render the template named `name` with the context `context`. The @@ -300,9 +307,14 @@ impl Template { /// let template = Template::render("index", context); #[inline] pub fn render(name: S, context: C) -> Template - where S: Into>, C: Serialize + where + S: Into>, + C: Serialize, { - Template { name: name.into(), value: to_value(context).ok() } + Template { + name: name.into(), + value: to_value(context).ok(), + } } /// Render the template named `name` with the context `context` into a @@ -342,16 +354,24 @@ impl Template { /// ``` #[inline] pub fn show(rocket: &Rocket, name: S, context: C) -> Option - where S: Into>, C: Serialize + where + S: Into>, + C: Serialize, { - let ctxt = rocket.state::().map(ContextManager::context).or_else(|| { - warn!("Uninitialized template context: missing fairing."); - info!("To use templates, you must attach `Template::fairing()`."); - info!("See the `Template` documentation for more information."); - None - })?; + let ctxt = rocket + .state::() + .map(ContextManager::context) + .or_else(|| { + warn!("Uninitialized template context: missing fairing."); + info!("To use templates, you must attach `Template::fairing()`."); + info!("See the `Template` documentation for more information."); + None + })?; - Template::render(name, context).finalize(&ctxt).ok().map(|v| v.0) + Template::render(name, context) + .finalize(&ctxt) + .ok() + .map(|v| v.0) } /// Actually render this template given a template context. This method is @@ -387,12 +407,17 @@ impl Template { /// rendering fails, an `Err` of `Status::InternalServerError` is returned. impl Responder<'static> for Template { fn respond_to(self, req: &Request) -> response::Result<'static> { - let ctxt = req.guard::>().succeeded().ok_or_else(|| { - error_!("Uninitialized template context: missing fairing."); - info_!("To use templates, you must attach `Template::fairing()`."); - info_!("See the `Template` documentation for more information."); - Status::InternalServerError - })?.inner().context(); + let ctxt = req + .guard::>() + .succeeded() + .ok_or_else(|| { + error_!("Uninitialized template context: missing fairing."); + info_!("To use templates, you must attach `Template::fairing()`."); + info_!("See the `Template` documentation for more information."); + Status::InternalServerError + })? + .inner() + .context(); let (render, content_type) = self.finalize(&ctxt)?; Content(content_type, render).respond_to(req) diff --git a/contrib/lib/src/templates/tera_templates.rs b/contrib/lib/src/templates/tera_templates.rs index dc07c24b07..8a832086bd 100644 --- a/contrib/lib/src/templates/tera_templates.rs +++ b/contrib/lib/src/templates/tera_templates.rs @@ -9,11 +9,19 @@ impl Engine for Tera { fn init(templates: &[(&str, &TemplateInfo)]) -> Option { // Create the Tera instance. let mut tera = Tera::default(); - let ext = [".html.tera", ".htm.tera", ".xml.tera", ".html", ".htm", ".xml"]; + let ext = [ + ".html.tera", + ".htm.tera", + ".xml.tera", + ".html", + ".htm", + ".xml", + ]; tera.autoescape_on(ext.to_vec()); // Collect into a tuple of (name, path) for Tera. - let tera_templates = templates.iter() + let tera_templates = templates + .iter() .map(|&(name, info)| (&info.path, Some(name))) .collect::>(); diff --git a/contrib/lib/src/uuid.rs b/contrib/lib/src/uuid.rs index 814c18141b..f13317b1a1 100644 --- a/contrib/lib/src/uuid.rs +++ b/contrib/lib/src/uuid.rs @@ -17,11 +17,11 @@ pub extern crate uuid as uuid_crate; use std::fmt; -use std::str::FromStr; use std::ops::Deref; +use std::str::FromStr; -use rocket::request::{FromParam, FromFormValue}; use rocket::http::RawStr; +use rocket::request::{FromFormValue, FromParam}; pub use self::uuid_crate::parser::ParseError; @@ -149,9 +149,9 @@ impl PartialEq for Uuid { #[cfg(test)] mod test { use super::uuid_crate; - use super::Uuid; use super::FromParam; use super::FromStr; + use super::Uuid; #[test] fn test_from_str() { diff --git a/contrib/lib/tests/databases.rs b/contrib/lib/tests/databases.rs index cdffb61e1e..7aea5b9cef 100644 --- a/contrib/lib/tests/databases.rs +++ b/contrib/lib/tests/databases.rs @@ -15,10 +15,10 @@ mod databases_tests { #[cfg(all(feature = "databases", feature = "sqlite_pool"))] #[cfg(test)] mod rusqlite_integration_test { - use std::collections::BTreeMap; use rocket::config::{Config, Environment, Value}; - use rocket_contrib::databases::rusqlite; use rocket_contrib::database; + use rocket_contrib::databases::rusqlite; + use std::collections::BTreeMap; #[database("test_db")] struct SqliteDb(pub rusqlite::Connection); @@ -40,7 +40,9 @@ mod rusqlite_integration_test { // Rusqlite's `transaction()` method takes `&mut self`; this tests the // presence of a `DerefMut` trait on the generated connection type. let tx = conn.transaction().unwrap(); - let _: i32 = tx.query_row("SELECT 1", &[], |row| row.get(0)).expect("get row"); + let _: i32 = tx + .query_row("SELECT 1", &[], |row| row.get(0)) + .expect("get row"); tx.commit().expect("committed transaction"); } @@ -57,6 +59,8 @@ mod rusqlite_integration_test { let rocket = rocket::custom(config).attach(SqliteDb::fairing()); let conn = SqliteDb::get_one(&rocket).expect("unable to get connection"); - let _: i32 = conn.query_row("SELECT 1", &[], |row| row.get(0)).expect("get row"); + let _: i32 = conn + .query_row("SELECT 1", &[], |row| row.get(0)) + .expect("get row"); } } diff --git a/contrib/lib/tests/helmet.rs b/contrib/lib/tests/helmet.rs index c45a21dcc8..58cd2720e2 100644 --- a/contrib/lib/tests/helmet.rs +++ b/contrib/lib/tests/helmet.rs @@ -6,23 +6,24 @@ extern crate rocket; #[cfg(feature = "helmet")] mod helmet_tests { - extern crate time; extern crate rocket_contrib; + extern crate time; use rocket; - use rocket::http::{Status, uri::Uri}; + use rocket::http::{uri::Uri, Status}; use rocket::local::{Client, LocalResponse}; use self::rocket_contrib::helmet::*; use self::time::Duration; - #[get("/")] fn hello() { } + #[get("/")] + fn hello() {} macro_rules! assert_header { ($response:ident, $name:expr, $value:expr) => { match $response.headers().get_one($name) { Some(value) => assert_eq!(value, $value), - None => panic!("missing header '{}' with value '{}'", $name, $value) + None => panic!("missing header '{}' with value '{}'", $name, $value), } }; } @@ -42,7 +43,7 @@ mod helmet_tests { let response = client.get("/").dispatch(); assert_eq!(response.status(), Status::Ok); $closure(response) - }} + }}; } #[test] @@ -121,17 +122,29 @@ mod helmet_tests { let helmet = SpaceHelmet::default() .enable(Frame::AllowFrom(allow_uri)) .enable(XssFilter::EnableReport(report_uri)) - .enable(ExpectCt::ReportAndEnforce(Duration::seconds(30), enforce_uri)); + .enable(ExpectCt::ReportAndEnforce( + Duration::seconds(30), + enforce_uri, + )); dispatch!(helmet, |response: LocalResponse| { - assert_header!(response, "X-Frame-Options", - "ALLOW-FROM https://www.google.com"); + assert_header!( + response, + "X-Frame-Options", + "ALLOW-FROM https://www.google.com" + ); - assert_header!(response, "X-XSS-Protection", - "1; report=https://www.google.com"); + assert_header!( + response, + "X-XSS-Protection", + "1; report=https://www.google.com" + ); - assert_header!(response, "Expect-CT", - "max-age=30, enforce, report-uri=\"https://www.google.com\""); + assert_header!( + response, + "Expect-CT", + "max-age=30, enforce, report-uri=\"https://www.google.com\"" + ); }); } } diff --git a/contrib/lib/tests/static_files.rs b/contrib/lib/tests/static_files.rs index d1c168e87f..6ee0e0fb51 100644 --- a/contrib/lib/tests/static_files.rs +++ b/contrib/lib/tests/static_files.rs @@ -5,13 +5,13 @@ extern crate rocket_contrib; #[cfg(feature = "static")] mod static_tests { - use std::{io::Read, fs::File}; use std::path::{Path, PathBuf}; + use std::{fs::File, io::Read}; - use rocket::{self, Rocket}; - use rocket_contrib::serve::{StaticFiles, Options}; use rocket::http::Status; use rocket::local::Client; + use rocket::{self, Rocket}; + use rocket_contrib::serve::{Options, StaticFiles}; fn static_root() -> PathBuf { Path::new(env!("CARGO_MANIFEST_DIR")) @@ -26,7 +26,10 @@ mod static_tests { .mount("/no_index", StaticFiles::new(&root, Options::None)) .mount("/dots", StaticFiles::new(&root, Options::DotFiles)) .mount("/index", StaticFiles::new(&root, Options::Index)) - .mount("/both", StaticFiles::new(&root, Options::DotFiles | Options::Index)) + .mount( + "/both", + StaticFiles::new(&root, Options::DotFiles | Options::Index), + ) } static REGULAR_FILES: &[&str] = &[ @@ -36,15 +39,9 @@ mod static_tests { "other/hello.txt", ]; - static HIDDEN_FILES: &[&str] = &[ - ".hidden", - "inner/.hideme", - ]; + static HIDDEN_FILES: &[&str] = &[".hidden", "inner/.hideme"]; - static INDEXED_DIRECTORIES: &[&str] = &[ - "", - "inner/", - ]; + static INDEXED_DIRECTORIES: &[&str] = &["", "inner/"]; fn assert_file(client: &Client, prefix: &str, path: &str, exists: bool) { let full_path = format!("/{}", Path::new(prefix).join(path).display()); @@ -59,7 +56,8 @@ mod static_tests { let mut file = File::open(path).expect("open file"); let mut expected_contents = String::new(); - file.read_to_string(&mut expected_contents).expect("read file"); + file.read_to_string(&mut expected_contents) + .expect("read file"); assert_eq!(response.body_string(), Some(expected_contents)); } else { assert_eq!(response.status(), Status::NotFound); @@ -67,7 +65,9 @@ mod static_tests { } fn assert_all(client: &Client, prefix: &str, paths: &[&str], exist: bool) { - paths.iter().for_each(|path| assert_file(client, prefix, path, exist)) + paths + .iter() + .for_each(|path| assert_file(client, prefix, path, exist)) } #[test] diff --git a/contrib/lib/tests/templates.rs b/contrib/lib/tests/templates.rs index 2c267715c3..f2f4c53099 100644 --- a/contrib/lib/tests/templates.rs +++ b/contrib/lib/tests/templates.rs @@ -1,7 +1,8 @@ #![feature(proc_macro_hygiene, decl_macro)] #[cfg(feature = "templates")] -#[macro_use] extern crate rocket; +#[macro_use] +extern crate rocket; #[cfg(feature = "templates")] extern crate rocket_contrib; @@ -10,47 +11,57 @@ extern crate rocket_contrib; mod templates_tests { use std::path::{Path, PathBuf}; - use rocket::{Rocket, http::RawStr}; use rocket::config::{Config, Environment}; - use rocket_contrib::templates::{Template, Metadata}; + use rocket::{http::RawStr, Rocket}; + use rocket_contrib::templates::{Metadata, Template}; #[get("//")] fn template_check(md: Metadata, engine: &RawStr, name: &RawStr) -> Option<()> { match md.contains_template(&format!("{}/{}", engine, name)) { true => Some(()), - false => None + false => None, } } #[get("/is_reloading")] fn is_reloading(md: Metadata) -> Option<()> { - if md.reloading() { Some(()) } else { None } + if md.reloading() { + Some(()) + } else { + None + } } fn template_root() -> PathBuf { - Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("templates") + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("templates") } fn rocket() -> Rocket { let config = Config::build(Environment::Development) - .extra("template_dir", template_root().to_str().expect("template directory")) + .extra( + "template_dir", + template_root().to_str().expect("template directory"), + ) .expect("valid configuration"); - ::rocket::custom(config).attach(Template::fairing()) + ::rocket::custom(config) + .attach(Template::fairing()) .mount("/", routes![template_check, is_reloading]) } #[cfg(feature = "tera_templates")] mod tera_tests { use super::*; - use std::collections::HashMap; use rocket::http::Status; use rocket::local::Client; + use std::collections::HashMap; - const UNESCAPED_EXPECTED: &'static str - = "\nh_start\ntitle: _test_\nh_end\n\n\n