diff --git a/examples/http_headers.rs b/examples/http_headers.rs index e50d0a69..aa071b39 100644 --- a/examples/http_headers.rs +++ b/examples/http_headers.rs @@ -35,11 +35,11 @@ impl Context for HttpHeaders {} impl HttpContext for HttpHeaders { fn on_http_request_headers(&mut self, _: usize) -> Action { for (name, value) in &self.get_http_request_headers() { - trace!("#{} -> {}: {:?}", self.context_id, name, value); + trace!("#{} -> {}: {}", self.context_id, name, value); } match self.get_http_request_header(":path") { - Some(path) if path == b"/hello" => { + Some(path) if path == "/hello" => { self.send_http_response( 200, vec![("Hello", "World"), ("Powered-By", "proxy-wasm")], @@ -53,7 +53,7 @@ impl HttpContext for HttpHeaders { fn on_http_response_headers(&mut self, _: usize) -> Action { for (name, value) in &self.get_http_response_headers() { - trace!("#{} <- {}: {:?}", self.context_id, name, value); + trace!("#{} <- {}: {}", self.context_id, name, value); } Action::Continue } diff --git a/src/hostcalls.rs b/src/hostcalls.rs index 198cc060..e1bc08da 100644 --- a/src/hostcalls.rs +++ b/src/hostcalls.rs @@ -181,7 +181,7 @@ extern "C" { } /// Returns all key-value pairs from a given map. -pub fn get_map(map_type: MapType) -> Result> { +pub fn get_map(map_type: MapType) -> Result> { unsafe { let mut return_data: *mut u8 = null_mut(); let mut return_size: usize = 0; @@ -264,7 +264,7 @@ extern "C" { /// # Ok(()) /// # } /// ``` -pub fn get_map_value(map_type: MapType, key: K) -> Result> +pub fn get_map_value(map_type: MapType, key: K) -> Result> where K: AsRef, { @@ -280,7 +280,9 @@ where ) { Status::Ok => { if !return_data.is_null() { - Ok(Vec::from_raw_parts(return_data, return_size, return_size)).map(Option::from) + Ok(Vec::from_raw_parts(return_data, return_size, return_size)) + .map(HeaderValue::from) + .map(Option::from) } else { Ok(None) } @@ -923,7 +925,7 @@ pub fn done() -> Result<()> { mod utils { use crate::error::Result; - use crate::types::Bytes; + use crate::types::{Bytes, HeaderValue}; use std::convert::TryFrom; pub(super) fn serialize_property_path

(path: &[P]) -> Bytes @@ -970,7 +972,7 @@ mod utils { bytes } - pub(super) fn deserialize_map(bytes: &[u8]) -> Result> { + pub(super) fn deserialize_map(bytes: &[u8]) -> Result> { let mut map = Vec::new(); if bytes.is_empty() { return Ok(map); @@ -985,7 +987,7 @@ mod utils { let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s + 4..s + 8])?) as usize; let value = bytes[p..p + size].to_vec(); p += size + 1; - map.push((String::from_utf8(key)?, value)); + map.push((String::from_utf8(key)?, value.into())); } Ok(map) } diff --git a/src/traits.rs b/src/traits.rs index 42ecb1ef..58d0d55f 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -75,7 +75,7 @@ pub trait Context { ) { } - fn get_http_call_response_headers(&self) -> Vec<(String, Bytes)> { + fn get_http_call_response_headers(&self) -> Vec<(String, HeaderValue)> { hostcalls::get_map(MapType::HttpCallResponseHeaders).unwrap() } @@ -83,7 +83,7 @@ pub trait Context { hostcalls::get_buffer(BufferType::HttpCallResponseBody, start, max_size).unwrap() } - fn get_http_call_response_trailers(&self) -> Vec<(String, Bytes)> { + fn get_http_call_response_trailers(&self) -> Vec<(String, HeaderValue)> { hostcalls::get_map(MapType::HttpCallResponseTrailers).unwrap() } @@ -165,7 +165,7 @@ pub trait HttpContext: Context { Action::Continue } - fn get_http_request_headers(&self) -> Vec<(String, Bytes)> { + fn get_http_request_headers(&self) -> Vec<(String, HeaderValue)> { hostcalls::get_map(MapType::HttpRequestHeaders).unwrap() } @@ -173,7 +173,7 @@ pub trait HttpContext: Context { hostcalls::set_map(MapType::HttpRequestHeaders, &headers).unwrap() } - fn get_http_request_header(&self, name: &str) -> Option { + fn get_http_request_header(&self, name: &str) -> Option { hostcalls::get_map_value(MapType::HttpRequestHeaders, &name).unwrap() } @@ -197,7 +197,7 @@ pub trait HttpContext: Context { Action::Continue } - fn get_http_request_trailers(&self) -> Vec<(String, Bytes)> { + fn get_http_request_trailers(&self) -> Vec<(String, HeaderValue)> { hostcalls::get_map(MapType::HttpRequestTrailers).unwrap() } @@ -205,7 +205,7 @@ pub trait HttpContext: Context { hostcalls::set_map(MapType::HttpRequestTrailers, &trailers).unwrap() } - fn get_http_request_trailer(&self, name: &str) -> Option { + fn get_http_request_trailer(&self, name: &str) -> Option { hostcalls::get_map_value(MapType::HttpRequestTrailers, &name).unwrap() } @@ -225,7 +225,7 @@ pub trait HttpContext: Context { Action::Continue } - fn get_http_response_headers(&self) -> Vec<(String, Bytes)> { + fn get_http_response_headers(&self) -> Vec<(String, HeaderValue)> { hostcalls::get_map(MapType::HttpResponseHeaders).unwrap() } @@ -233,7 +233,7 @@ pub trait HttpContext: Context { hostcalls::set_map(MapType::HttpResponseHeaders, &headers).unwrap() } - fn get_http_response_header(&self, name: &str) -> Option { + fn get_http_response_header(&self, name: &str) -> Option { hostcalls::get_map_value(MapType::HttpResponseHeaders, &name).unwrap() } @@ -257,7 +257,7 @@ pub trait HttpContext: Context { Action::Continue } - fn get_http_response_trailers(&self) -> Vec<(String, Bytes)> { + fn get_http_response_trailers(&self) -> Vec<(String, HeaderValue)> { hostcalls::get_map(MapType::HttpResponseTrailers).unwrap() } @@ -265,7 +265,7 @@ pub trait HttpContext: Context { hostcalls::set_map(MapType::HttpResponseTrailers, &headers).unwrap() } - fn get_http_response_trailer(&self, name: &str) -> Option { + fn get_http_response_trailer(&self, name: &str) -> Option { hostcalls::get_map_value(MapType::HttpResponseTrailers, &name).unwrap() } diff --git a/src/types.rs b/src/types.rs index a287a2c1..1041c847 100644 --- a/src/types.rs +++ b/src/types.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::fmt; +use std::hash::{Hash, Hasher}; + use crate::traits::*; pub type NewRootContext = fn(context_id: u32) -> Box; @@ -77,3 +80,249 @@ pub enum PeerType { } pub type Bytes = Vec; + +/// Represents an HTTP header value that is not necessarily a UTF-8 encoded string. +#[derive(Eq)] +pub struct HeaderValue { + inner: Result, +} + +impl HeaderValue { + fn new(inner: Result) -> Self { + HeaderValue { inner } + } + + pub fn into_vec(self) -> Vec { + match self.inner { + Ok(string) => string.into_bytes(), + Err(bytes) => bytes, + } + } + + pub fn into_string_or_vec(self) -> Result> { + self.inner + } +} + +impl From> for HeaderValue { + #[inline] + fn from(data: Vec) -> Self { + Self::new(match String::from_utf8(data) { + Ok(string) => Ok(string), + Err(err) => Err(err.into_bytes()), + }) + } +} + +impl From for HeaderValue { + #[inline] + fn from(string: String) -> Self { + Self::new(Ok(string)) + } +} + +impl fmt::Display for HeaderValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.inner { + Ok(ref string) => fmt::Display::fmt(string, f), + Err(ref bytes) => fmt::Debug::fmt(bytes, f), + } + } +} + +impl fmt::Debug for HeaderValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.inner { + Ok(ref string) => fmt::Debug::fmt(string, f), + Err(ref bytes) => fmt::Debug::fmt(bytes, f), + } + } +} + +impl AsRef<[u8]> for HeaderValue { + #[inline] + fn as_ref(&self) -> &[u8] { + match self.inner { + Ok(ref string) => string.as_bytes(), + Err(ref bytes) => bytes.as_slice(), + } + } +} + +impl PartialEq for HeaderValue { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + self.inner == other.inner + } +} + +impl PartialEq for HeaderValue { + #[inline] + fn eq(&self, other: &String) -> bool { + self.as_ref() == other.as_bytes() + } +} + +impl PartialEq> for HeaderValue { + #[inline] + fn eq(&self, other: &Vec) -> bool { + self.as_ref() == other.as_slice() + } +} + +impl PartialEq<&str> for HeaderValue { + #[inline] + fn eq(&self, other: &&str) -> bool { + self.as_ref() == other.as_bytes() + } +} + +impl PartialEq<&[u8]> for HeaderValue { + #[inline] + fn eq(&self, other: &&[u8]) -> bool { + self.as_ref() == *other + } +} + +impl PartialEq for String { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + *other == *self + } +} + +impl PartialEq for Vec { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + *other == *self + } +} + +impl PartialEq for &str { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + *other == *self + } +} + +impl PartialEq for &[u8] { + #[inline] + fn eq(&self, other: &HeaderValue) -> bool { + *other == *self + } +} + +impl Hash for HeaderValue { + #[inline] + fn hash(&self, state: &mut H) { + match self.inner { + Ok(ref string) => Hash::hash(string, state), + Err(ref bytes) => Hash::hash(bytes, state), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + #[test] + fn test_header_value_display_string() { + let string: HeaderValue = String::from("utf-8 encoded string").into(); + + assert_eq!(format!("{}", string), "utf-8 encoded string"); + } + + #[test] + fn test_header_value_display_bytes() { + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + + assert_eq!(format!("{}", bytes), "[144, 145, 146]"); + } + + #[test] + fn test_header_value_debug_string() { + let string: HeaderValue = String::from("utf-8 encoded string").into(); + + assert_eq!(format!("{:?}", string), "\"utf-8 encoded string\""); + } + + #[test] + fn test_header_value_debug_bytes() { + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + + assert_eq!(format!("{:?}", bytes), "[144, 145, 146]"); + } + + #[test] + fn test_header_value_as_ref() { + fn receive(value: T) + where + T: AsRef<[u8]>, + { + value.as_ref(); + } + + let string: HeaderValue = String::from("utf-8 encoded string").into(); + receive(string); + + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + receive(bytes); + } + + #[test] + fn test_header_value_eq_string() { + let string: HeaderValue = String::from("utf-8 encoded string").into(); + + assert_eq!(string, "utf-8 encoded string"); + assert_eq!(string, b"utf-8 encoded string" as &[u8]); + + assert_eq!("utf-8 encoded string", string); + assert_eq!(b"utf-8 encoded string" as &[u8], string); + + assert_eq!(string, string); + } + + #[test] + fn test_header_value_eq_bytes() { + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + + assert_eq!(bytes, vec![144u8, 145u8, 146u8]); + assert_eq!(bytes, b"\x90\x91\x92" as &[u8]); + + assert_eq!(vec![144u8, 145u8, 146u8], bytes); + assert_eq!(b"\x90\x91\x92" as &[u8], bytes); + + assert_eq!(bytes, bytes); + } + + fn hash(t: &T) -> u64 { + let mut h = DefaultHasher::new(); + t.hash(&mut h); + h.finish() + } + + #[test] + fn test_header_value_hash_string() { + let string: HeaderValue = String::from("utf-8 encoded string").into(); + + assert_eq!(hash(&string), hash(&"utf-8 encoded string")); + assert_ne!(hash(&string), hash(&b"utf-8 encoded string")); + + assert_eq!(hash(&"utf-8 encoded string"), hash(&string)); + assert_ne!(hash(&b"utf-8 encoded string"), hash(&string)); + } + + #[test] + fn test_header_value_hash_bytes() { + let bytes: HeaderValue = vec![144u8, 145u8, 146u8].into(); + + assert_eq!(hash(&bytes), hash(&vec![144u8, 145u8, 146u8])); + assert_eq!(hash(&bytes), hash(&[144u8, 145u8, 146u8])); + + assert_eq!(hash(&vec![144u8, 145u8, 146u8]), hash(&bytes)); + assert_eq!(hash(&[144u8, 145u8, 146u8]), hash(&bytes)); + } +}