Skip to content

Commit

Permalink
Added EntityResolver
Browse files Browse the repository at this point in the history
  • Loading branch information
pigeonhands committed Mar 30, 2023
1 parent bbf536f commit d9f2446
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 39 deletions.
2 changes: 2 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

### New Features

[#581] - Added `EntityResolver` for resolving unknown entities that would otherwise cause the parser to reurn an [`Error::UnrecognizedSymbol`] error.

### Bug Fixes

### Misc Changes
Expand Down
44 changes: 28 additions & 16 deletions src/de/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
events::attributes::IterState,
events::BytesStart,
name::QName,
resolver::EntityResolver,
};
use serde::de::value::BorrowedStrDeserializer;
use serde::de::{self, DeserializeSeed, SeqAccess, Visitor};
Expand Down Expand Up @@ -165,13 +166,14 @@ enum ValueSource {
///
/// - `'a` lifetime represents a parent deserializer, which could own the data
/// buffer.
pub(crate) struct MapAccess<'de, 'a, R>
pub(crate) struct MapAccess<'de, 'a, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Tag -- owner of attributes
start: BytesStart<'de>,
de: &'a mut Deserializer<'de, R>,
de: &'a mut Deserializer<'de, R, E>,
/// State of the iterator over attributes. Contains the next position in the
/// inner `start` slice, from which next attribute should be parsed.
iter: IterState,
Expand All @@ -190,13 +192,14 @@ where
has_value_field: bool,
}

impl<'de, 'a, R> MapAccess<'de, 'a, R>
impl<'de, 'a, R, E> MapAccess<'de, 'a, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
/// Create a new MapAccess
pub fn new(
de: &'a mut Deserializer<'de, R>,
de: &'a mut Deserializer<'de, R, E>,
start: BytesStart<'de>,
fields: &'static [&'static str],
) -> Result<Self, DeError> {
Expand All @@ -211,9 +214,10 @@ where
}
}

impl<'de, 'a, R> de::MapAccess<'de> for MapAccess<'de, 'a, R>
impl<'de, 'a, R, E> de::MapAccess<'de> for MapAccess<'de, 'a, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
type Error = DeError;

Expand Down Expand Up @@ -369,13 +373,14 @@ macro_rules! forward {
/// A deserializer for a value of map or struct. That deserializer slightly
/// differently processes events for a primitive types and sequences than
/// a [`Deserializer`].
struct MapValueDeserializer<'de, 'a, 'm, R>
struct MapValueDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
/// Access to the map that created this deserializer. Gives access to the
/// context, such as list of fields, that current map known about.
map: &'m mut MapAccess<'de, 'a, R>,
map: &'m mut MapAccess<'de, 'a, R, E>,
/// Determines, should [`Deserializer::read_string_impl()`] expand the second
/// level of tags or not.
///
Expand Down Expand Up @@ -453,9 +458,10 @@ where
allow_start: bool,
}

impl<'de, 'a, 'm, R> MapValueDeserializer<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> MapValueDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
/// Returns a next string as concatenated content of consequent [`Text`] and
/// [`CData`] events, used inside [`deserialize_primitives!()`].
Expand All @@ -468,9 +474,10 @@ where
}
}

impl<'de, 'a, 'm, R> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
type Error = DeError;

Expand Down Expand Up @@ -629,13 +636,14 @@ impl<'de> TagFilter<'de> {
///
/// [`Text`]: crate::events::Event::Text
/// [`CData`]: crate::events::Event::CData
struct MapValueSeqAccess<'de, 'a, 'm, R>
struct MapValueSeqAccess<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
/// Accessor to a map that creates this accessor and to a deserializer for
/// a sequence items.
map: &'m mut MapAccess<'de, 'a, R>,
map: &'m mut MapAccess<'de, 'a, R, E>,
/// Filter that determines whether a tag is a part of this sequence.
///
/// When feature `overlapped-lists` is not activated, iteration will stop
Expand All @@ -662,9 +670,10 @@ where
}
}

impl<'de, 'a, 'm, R> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
type Error = DeError;

Expand Down Expand Up @@ -705,18 +714,20 @@ where
////////////////////////////////////////////////////////////////////////////////////////////////////

/// A deserializer for a single item of a sequence.
struct SeqItemDeserializer<'de, 'a, 'm, R>
struct SeqItemDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
/// Access to the map that created this deserializer. Gives access to the
/// context, such as list of fields, that current map known about.
map: &'m mut MapAccess<'de, 'a, R>,
map: &'m mut MapAccess<'de, 'a, R, E>,
}

impl<'de, 'a, 'm, R> SeqItemDeserializer<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> SeqItemDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
/// Returns a next string as concatenated content of consequent [`Text`] and
/// [`CData`] events, used inside [`deserialize_primitives!()`].
Expand All @@ -729,9 +740,10 @@ where
}
}

impl<'de, 'a, 'm, R> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
type Error = DeError;

Expand Down
76 changes: 63 additions & 13 deletions src/de/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,7 @@ use crate::{
events::{BytesCData, BytesEnd, BytesStart, BytesText, Event},
name::QName,
reader::Reader,
resolver::{EntityResolver, NoEntityResolver},
};
use serde::de::{self, Deserialize, DeserializeOwned, DeserializeSeed, SeqAccess, Visitor};
use std::borrow::Cow;
Expand Down Expand Up @@ -1960,23 +1961,32 @@ impl<'a> PayloadEvent<'a> {
/// An intermediate reader that consumes [`PayloadEvent`]s and produces final [`DeEvent`]s.
/// [`PayloadEvent::Text`] events, that followed by any event except
/// [`PayloadEvent::Text`] or [`PayloadEvent::CData`], are trimmed from the end.
struct XmlReader<'i, R: XmlRead<'i>> {
struct XmlReader<'i, R: XmlRead<'i>, E: EntityResolver = NoEntityResolver> {
/// A source of low-level XML events
reader: R,
/// Intermediate event, that could be returned by the next call to `next()`.
/// If that is the `Text` event then leading spaces already trimmed, but
/// trailing spaces is not. Before the event will be returned, trimming of
/// the spaces could be necessary
lookahead: Result<PayloadEvent<'i>, DeError>,

/// Used to resolve unknown entities that would otherwise cause the parser
/// to return an [`Error::UnrecognizedSymbol`] error.
entity_resolver: E
}

impl<'i, R: XmlRead<'i>> XmlReader<'i, R> {
fn new(mut reader: R) -> Self {
impl<'i, R: XmlRead<'i>, E: EntityResolver> XmlReader<'i, R, E> {
fn new(reader: R) -> Self
where E: Default {
Self::with_resolver(reader, E::default())
}

fn with_resolver(mut reader: R, entity_resolver: E) -> Self {
// Lookahead by one event immediately, so we do not need to check in the
// loop if we need lookahead or not
let lookahead = reader.next();

Self { reader, lookahead }
Self { reader, lookahead, entity_resolver }
}

/// Read next event and put it in lookahead, return the current lookahead
Expand Down Expand Up @@ -2032,7 +2042,7 @@ impl<'i, R: XmlRead<'i>> XmlReader<'i, R> {
if self.need_trim_end() {
e.inplace_trim_end();
}
Ok(e.unescape()?)
Ok(e.unescape_with(|e| self.entity_resolver.resolve(e))?)
}
PayloadEvent::CData(e) => Ok(e.decode()?),

Expand All @@ -2051,10 +2061,13 @@ impl<'i, R: XmlRead<'i>> XmlReader<'i, R> {
if self.need_trim_end() && e.inplace_trim_end() {
continue;
}
self.drain_text(e.unescape()?)
self.drain_text(e.unescape_with(|ent| self.entity_resolver.resolve(ent))?)
}
PayloadEvent::CData(e) => self.drain_text(e.decode()?),
PayloadEvent::DocType(_) => continue,
PayloadEvent::DocType(e) => {
self.entity_resolver.capture(e);
continue
},
PayloadEvent::Eof => Ok(DeEvent::Eof),
};
}
Expand Down Expand Up @@ -2171,12 +2184,12 @@ where
////////////////////////////////////////////////////////////////////////////////////////////////////

/// A structure that deserializes XML into Rust values.
pub struct Deserializer<'de, R>
pub struct Deserializer<'de, R, E: EntityResolver = NoEntityResolver>
where
R: XmlRead<'de>,
{
/// An XML reader that streams events into this deserializer
reader: XmlReader<'de, R>,
reader: XmlReader<'de, R, E>,

/// When deserializing sequences sometimes we have to skip unwanted events.
/// That events should be stored and then replayed. This is a replay buffer,
Expand Down Expand Up @@ -2231,7 +2244,13 @@ where
peek: None,
}
}
}

impl<'de, R, E> Deserializer<'de, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
/// Set the maximum number of events that could be skipped during deserialization
/// of sequences.
///
Expand Down Expand Up @@ -2561,20 +2580,50 @@ where
/// instead, because it will borrow instead of copy. If you have `&[u8]` which
/// is known to represent UTF-8, you can decode it first before using [`from_str`].
pub fn from_reader(reader: R) -> Self {
Self::with_resolver(reader, NoEntityResolver)
}
}


impl<'de, R, E> Deserializer<'de, IoReader<R>, E>
where
R: BufRead,
E: EntityResolver,
{
/// Create new deserializer that will copy data from the specified reader
/// into internal buffer. If you already have a string use [`Self::from_str`]
/// instead, because it will borrow instead of copy. If you have `&[u8]` which
/// is known to represent UTF-8, you can decode it first before using [`from_str`].
pub fn with_resolver(reader: R, entity_resolver: E) -> Self {
let mut reader = Reader::from_reader(reader);
reader.expand_empty_elements(true).check_end_names(true);

Self::new(IoReader {
let io_reader = IoReader {
reader,
start_trimmer: StartTrimmer::default(),
buf: Vec::new(),
})
};

Self {
reader: XmlReader::with_resolver(io_reader, entity_resolver),

#[cfg(feature = "overlapped-lists")]
read: VecDeque::new(),
#[cfg(feature = "overlapped-lists")]
write: VecDeque::new(),
#[cfg(feature = "overlapped-lists")]
limit: None,

#[cfg(not(feature = "overlapped-lists"))]
peek: None,
}
}
}

impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<'de, R>
impl<'de, 'a, R, E> de::Deserializer<'de> for &'a mut Deserializer<'de, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
type Error = DeError;

Expand Down Expand Up @@ -2710,9 +2759,10 @@ where
///
/// Technically, multiple top-level elements violates XML rule of only one top-level
/// element, but we consider this as several concatenated XML documents.
impl<'de, 'a, R> SeqAccess<'de> for &'a mut Deserializer<'de, R>
impl<'de, 'a, R, E> SeqAccess<'de> for &'a mut Deserializer<'de, R, E>
where
R: XmlRead<'de>,
E: EntityResolver
{
type Error = DeError;

Expand Down
Loading

0 comments on commit d9f2446

Please sign in to comment.