diff --git a/Changelog.md b/Changelog.md index 8d5c064a..97a3bf24 100644 --- a/Changelog.md +++ b/Changelog.md @@ -12,6 +12,8 @@ ### New Features +[#581] - Added `EntityResolver` for resolving unknown entities that would otherwise cause the parser to reurn an [`Error::UnrecognizedSymbol`] error. + ### Bug Fixes ### Misc Changes diff --git a/src/de/map.rs b/src/de/map.rs index 8c13554f..960f2d37 100644 --- a/src/de/map.rs +++ b/src/de/map.rs @@ -9,6 +9,7 @@ use crate::{ events::attributes::IterState, events::BytesStart, name::QName, + resolver::EntityResolver, }; use serde::de::value::BorrowedStrDeserializer; use serde::de::{self, DeserializeSeed, SeqAccess, Visitor}; @@ -165,13 +166,14 @@ enum ValueSource { /// /// - `'a` lifetime represents a parent deserializer, which could own the data /// buffer. -pub(crate) struct MapAccess<'de, 'a, R> +pub(crate) struct MapAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { /// Tag -- owner of attributes start: BytesStart<'de>, - de: &'a mut Deserializer<'de, R>, + de: &'a mut Deserializer<'de, R, E>, /// State of the iterator over attributes. Contains the next position in the /// inner `start` slice, from which next attribute should be parsed. iter: IterState, @@ -190,13 +192,14 @@ where has_value_field: bool, } -impl<'de, 'a, R> MapAccess<'de, 'a, R> +impl<'de, 'a, R, E> MapAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver { /// Create a new MapAccess pub fn new( - de: &'a mut Deserializer<'de, R>, + de: &'a mut Deserializer<'de, R, E>, start: BytesStart<'de>, fields: &'static [&'static str], ) -> Result { @@ -211,9 +214,10 @@ where } } -impl<'de, 'a, R> de::MapAccess<'de> for MapAccess<'de, 'a, R> +impl<'de, 'a, R, E> de::MapAccess<'de> for MapAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver { type Error = DeError; @@ -369,13 +373,14 @@ macro_rules! forward { /// A deserializer for a value of map or struct. That deserializer slightly /// differently processes events for a primitive types and sequences than /// a [`Deserializer`]. -struct MapValueDeserializer<'de, 'a, 'm, R> +struct MapValueDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver { /// Access to the map that created this deserializer. Gives access to the /// context, such as list of fields, that current map known about. - map: &'m mut MapAccess<'de, 'a, R>, + map: &'m mut MapAccess<'de, 'a, R, E>, /// Determines, should [`Deserializer::read_string_impl()`] expand the second /// level of tags or not. /// @@ -453,9 +458,10 @@ where allow_start: bool, } -impl<'de, 'a, 'm, R> MapValueDeserializer<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> MapValueDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver { /// Returns a next string as concatenated content of consequent [`Text`] and /// [`CData`] events, used inside [`deserialize_primitives!()`]. @@ -468,9 +474,10 @@ where } } -impl<'de, 'a, 'm, R> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver { type Error = DeError; @@ -629,13 +636,14 @@ impl<'de> TagFilter<'de> { /// /// [`Text`]: crate::events::Event::Text /// [`CData`]: crate::events::Event::CData -struct MapValueSeqAccess<'de, 'a, 'm, R> +struct MapValueSeqAccess<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver { /// Accessor to a map that creates this accessor and to a deserializer for /// a sequence items. - map: &'m mut MapAccess<'de, 'a, R>, + map: &'m mut MapAccess<'de, 'a, R, E>, /// Filter that determines whether a tag is a part of this sequence. /// /// When feature `overlapped-lists` is not activated, iteration will stop @@ -662,9 +670,10 @@ where } } -impl<'de, 'a, 'm, R> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver { type Error = DeError; @@ -705,18 +714,20 @@ where //////////////////////////////////////////////////////////////////////////////////////////////////// /// A deserializer for a single item of a sequence. -struct SeqItemDeserializer<'de, 'a, 'm, R> +struct SeqItemDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver { /// Access to the map that created this deserializer. Gives access to the /// context, such as list of fields, that current map known about. - map: &'m mut MapAccess<'de, 'a, R>, + map: &'m mut MapAccess<'de, 'a, R, E>, } -impl<'de, 'a, 'm, R> SeqItemDeserializer<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> SeqItemDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver { /// Returns a next string as concatenated content of consequent [`Text`] and /// [`CData`] events, used inside [`deserialize_primitives!()`]. @@ -729,9 +740,10 @@ where } } -impl<'de, 'a, 'm, R> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver { type Error = DeError; diff --git a/src/de/mod.rs b/src/de/mod.rs index 5f41431b..7aaf39d3 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -1843,6 +1843,7 @@ use crate::{ events::{BytesCData, BytesEnd, BytesStart, BytesText, Event}, name::QName, reader::Reader, + resolver::{EntityResolver, NoEntityResolver}, }; use serde::de::{self, Deserialize, DeserializeOwned, DeserializeSeed, SeqAccess, Visitor}; use std::borrow::Cow; @@ -1960,7 +1961,7 @@ impl<'a> PayloadEvent<'a> { /// An intermediate reader that consumes [`PayloadEvent`]s and produces final [`DeEvent`]s. /// [`PayloadEvent::Text`] events, that followed by any event except /// [`PayloadEvent::Text`] or [`PayloadEvent::CData`], are trimmed from the end. -struct XmlReader<'i, R: XmlRead<'i>> { +struct XmlReader<'i, R: XmlRead<'i>, E: EntityResolver = NoEntityResolver> { /// A source of low-level XML events reader: R, /// Intermediate event, that could be returned by the next call to `next()`. @@ -1968,15 +1969,24 @@ struct XmlReader<'i, R: XmlRead<'i>> { /// trailing spaces is not. Before the event will be returned, trimming of /// the spaces could be necessary lookahead: Result, DeError>, + + /// Used to resolve unknown entities that would otherwise cause the parser + /// to return an [`Error::UnrecognizedSymbol`] error. + entity_resolver: E } -impl<'i, R: XmlRead<'i>> XmlReader<'i, R> { - fn new(mut reader: R) -> Self { +impl<'i, R: XmlRead<'i>, E: EntityResolver> XmlReader<'i, R, E> { + fn new(reader: R) -> Self + where E: Default { + Self::with_resolver(reader, E::default()) + } + + fn with_resolver(mut reader: R, entity_resolver: E) -> Self { // Lookahead by one event immediately, so we do not need to check in the // loop if we need lookahead or not let lookahead = reader.next(); - Self { reader, lookahead } + Self { reader, lookahead, entity_resolver } } /// Read next event and put it in lookahead, return the current lookahead @@ -2032,7 +2042,7 @@ impl<'i, R: XmlRead<'i>> XmlReader<'i, R> { if self.need_trim_end() { e.inplace_trim_end(); } - Ok(e.unescape()?) + Ok(e.unescape_with(|e| self.entity_resolver.resolve(e))?) } PayloadEvent::CData(e) => Ok(e.decode()?), @@ -2051,10 +2061,13 @@ impl<'i, R: XmlRead<'i>> XmlReader<'i, R> { if self.need_trim_end() && e.inplace_trim_end() { continue; } - self.drain_text(e.unescape()?) + self.drain_text(e.unescape_with(|ent| self.entity_resolver.resolve(ent))?) } PayloadEvent::CData(e) => self.drain_text(e.decode()?), - PayloadEvent::DocType(_) => continue, + PayloadEvent::DocType(e) => { + self.entity_resolver.capture(e); + continue + }, PayloadEvent::Eof => Ok(DeEvent::Eof), }; } @@ -2171,12 +2184,12 @@ where //////////////////////////////////////////////////////////////////////////////////////////////////// /// A structure that deserializes XML into Rust values. -pub struct Deserializer<'de, R> +pub struct Deserializer<'de, R, E: EntityResolver = NoEntityResolver> where R: XmlRead<'de>, { /// An XML reader that streams events into this deserializer - reader: XmlReader<'de, R>, + reader: XmlReader<'de, R, E>, /// When deserializing sequences sometimes we have to skip unwanted events. /// That events should be stored and then replayed. This is a replay buffer, @@ -2231,7 +2244,13 @@ where peek: None, } } +} +impl<'de, R, E> Deserializer<'de, R, E> +where + R: XmlRead<'de>, + E: EntityResolver +{ /// Set the maximum number of events that could be skipped during deserialization /// of sequences. /// @@ -2561,20 +2580,50 @@ where /// instead, because it will borrow instead of copy. If you have `&[u8]` which /// is known to represent UTF-8, you can decode it first before using [`from_str`]. pub fn from_reader(reader: R) -> Self { + Self::with_resolver(reader, NoEntityResolver) + } +} + + +impl<'de, R, E> Deserializer<'de, IoReader, E> +where + R: BufRead, + E: EntityResolver, +{ + /// Create new deserializer that will copy data from the specified reader + /// into internal buffer. If you already have a string use [`Self::from_str`] + /// instead, because it will borrow instead of copy. If you have `&[u8]` which + /// is known to represent UTF-8, you can decode it first before using [`from_str`]. + pub fn with_resolver(reader: R, entity_resolver: E) -> Self { let mut reader = Reader::from_reader(reader); reader.expand_empty_elements(true).check_end_names(true); - Self::new(IoReader { + let io_reader = IoReader { reader, start_trimmer: StartTrimmer::default(), buf: Vec::new(), - }) + }; + + Self { + reader: XmlReader::with_resolver(io_reader, entity_resolver), + + #[cfg(feature = "overlapped-lists")] + read: VecDeque::new(), + #[cfg(feature = "overlapped-lists")] + write: VecDeque::new(), + #[cfg(feature = "overlapped-lists")] + limit: None, + + #[cfg(not(feature = "overlapped-lists"))] + peek: None, + } } } -impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R, E> de::Deserializer<'de> for &'a mut Deserializer<'de, R, E> where R: XmlRead<'de>, + E: EntityResolver { type Error = DeError; @@ -2710,9 +2759,10 @@ where /// /// Technically, multiple top-level elements violates XML rule of only one top-level /// element, but we consider this as several concatenated XML documents. -impl<'de, 'a, R> SeqAccess<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R, E> SeqAccess<'de> for &'a mut Deserializer<'de, R, E> where R: XmlRead<'de>, + E: EntityResolver { type Error = DeError; diff --git a/src/de/var.rs b/src/de/var.rs index 32295258..ed26dac8 100644 --- a/src/de/var.rs +++ b/src/de/var.rs @@ -3,35 +3,39 @@ use crate::{ de::simple_type::SimpleTypeDeserializer, de::{DeEvent, Deserializer, XmlRead, TEXT_KEY}, errors::serialize::DeError, + resolver::EntityResolver, }; use serde::de::value::BorrowedStrDeserializer; use serde::de::{self, DeserializeSeed, Deserializer as _, Visitor}; /// An enum access -pub struct EnumAccess<'de, 'a, R> +pub struct EnumAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver { - de: &'a mut Deserializer<'de, R>, + de: &'a mut Deserializer<'de, R, E>, } -impl<'de, 'a, R> EnumAccess<'de, 'a, R> +impl<'de, 'a, R, E> EnumAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver { - pub fn new(de: &'a mut Deserializer<'de, R>) -> Self { + pub fn new(de: &'a mut Deserializer<'de, R, E>) -> Self { EnumAccess { de } } } -impl<'de, 'a, R> de::EnumAccess<'de> for EnumAccess<'de, 'a, R> +impl<'de, 'a, R, E> de::EnumAccess<'de> for EnumAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver { type Error = DeError; - type Variant = VariantAccess<'de, 'a, R>; + type Variant = VariantAccess<'de, 'a, R, E>; - fn variant_seed(self, seed: V) -> Result<(V::Value, VariantAccess<'de, 'a, R>), DeError> + fn variant_seed(self, seed: V) -> Result<(V::Value, VariantAccess<'de, 'a, R, E>), DeError> where V: DeserializeSeed<'de>, { @@ -58,19 +62,21 @@ where } } -pub struct VariantAccess<'de, 'a, R> +pub struct VariantAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver { - de: &'a mut Deserializer<'de, R>, + de: &'a mut Deserializer<'de, R, E>, /// `true` if variant should be deserialized from a textual content /// and `false` if from tag is_text: bool, } -impl<'de, 'a, R> de::VariantAccess<'de> for VariantAccess<'de, 'a, R> +impl<'de, 'a, R, E> de::VariantAccess<'de> for VariantAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver { type Error = DeError; diff --git a/src/lib.rs b/src/lib.rs index 5d5d51c4..6cf97cc6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,6 +61,7 @@ pub mod escape { pub mod events; pub mod name; pub mod reader; +pub mod resolver; #[cfg(feature = "serialize")] pub mod se; /// Not an official API, public for integration tests diff --git a/src/resolver.rs b/src/resolver.rs new file mode 100644 index 00000000..977ebcc9 --- /dev/null +++ b/src/resolver.rs @@ -0,0 +1,78 @@ +//! Entity resolver module + +use crate::events::BytesText; + +/// Used to resolve unknown entities while parsing +/// ``` +/// # use std::collections::BTreeMap; +/// # use std::borrow::Cow; +/// # use std::io::{BufReader, Cursor}; +/// # use serde::Deserialize; +/// use quick_xml::{ +/// de::Deserializer, +/// resolver::EntityResolver, +/// events::BytesText +/// }; +/// +/// #[derive(Default)] +/// struct DocTypeEntityResolver(BTreeMap); +/// +/// impl EntityResolver for DocTypeEntityResolver { +/// fn capture(&mut self, doctype: BytesText) { +/// fn extract_entities(doctype: BytesText) -> impl IntoIterator { +/// // "Impliment doctype parsing to extract entities" +/// [(String::from("e1"), String::from("entity 1"))] +/// } +/// // Doctype parsing is not included in this library +/// let entities = extract_entities(doctype); +/// for (name, value) in entities.into_iter() { +/// self.0.insert(name, value); +/// } +/// } +/// +/// fn resolve(&self, entity: &str) -> Option<&str> { +/// self.0.get(entity).map(|s| s.as_str()) +/// } +/// } +/// +/// let xml_reader = BufReader::new(Cursor::new( +/// r#" +/// ]> +/// +/// &e1; +/// "#.as_bytes() +/// )); +/// +/// let mut de = Deserializer::with_resolver( +/// xml_reader, +/// DocTypeEntityResolver::default() +/// ); +/// let data: BTreeMap> = BTreeMap::deserialize(&mut de).unwrap(); +/// +/// assert_eq!(data.get("entity_one").as_ref(), "entity 1") +/// +/// ``` +pub trait EntityResolver { + /// Called on contents of [`Event::DocType`] to capture declared entities. + /// Can be called multiple times, for each parsed `` declaration. + fn capture(&mut self, doctype: BytesText); + + /// Called when an entity needs to be resolved. + /// + /// `None` is returned if a suitable value can not be found. + /// In that case an [`Error::UnrecognizedSymbol`] will be returned. + fn resolve(&self, entity: &str) -> Option<&str>; +} + +/// An EntityResolver that always returns None. +#[derive(Default, Copy, Clone)] +pub struct NoEntityResolver; + +impl EntityResolver for NoEntityResolver{ + fn capture(&mut self, _: BytesText){ + } + + fn resolve<'entity>(&self, _: &str) -> Option<&str> { + None + } +} diff --git a/tests/serde-de.rs b/tests/serde-de.rs index 907a2cc6..a90c2bbb 100644 --- a/tests/serde-de.rs +++ b/tests/serde-de.rs @@ -6427,3 +6427,71 @@ mod borrow { ); } } + + +/// Test for entity resolver +mod resolve { + use std::collections::BTreeMap; + use std::iter::FromIterator; + + use super::{Deserialize, Deserializer}; + use quick_xml::{ + resolver::EntityResolver, + }; + struct TestEntityResolver { + capture_called: bool + } + + impl EntityResolver for TestEntityResolver { + fn capture(&mut self, doctype: quick_xml::events::BytesText) { + self.capture_called = true; + let as_cow = doctype.into_inner(); + let str_doc = String::from_utf8_lossy(as_cow.as_ref()); + assert_eq!(str_doc.as_ref(), r#"dict[ ]"#); + } + + fn resolve(&self, entity: &str) -> Option<&str> { + assert!(self.capture_called); + match entity { + "t1" => Some("test_one"), + "t2" => Some("test_two"), + _ => None + } + } + } + + #[test] + fn resolve_custom_entity() { + use std::io::{Cursor, BufReader}; + use std::borrow::Cow; + + let reader = BufReader::new(Cursor::new( + r#" + ]> + + + &t1; + &t2; + non-entity + "#.as_bytes() + )); + let resolver = TestEntityResolver{ + capture_called: false + }; + let mut de = Deserializer::with_resolver( + reader, + resolver + ); + + let data: BTreeMap> = BTreeMap::deserialize(&mut de).unwrap(); + assert_eq!( + data, + BTreeMap::from_iter([ + // Comment to prevent formatting in one line + (String::from("entity_one"), Cow::Owned("test_one".into())), + (String::from("entity_two"), Cow::Owned("test_two".into())), + (String::from("entity_three"), Cow::Borrowed("non-entity")), + ]) + ); + } +} \ No newline at end of file