Skip to content

Commit

Permalink
Updating error messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Aug 6, 2024
1 parent ad25509 commit eada410
Showing 1 changed file with 115 additions and 8 deletions.
123 changes: 115 additions & 8 deletions tokenizers/src/decoders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod wordpiece;
pub use super::pre_tokenizers::byte_level;
pub use super::pre_tokenizers::metaspace;

use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};

use crate::decoders::bpe::BPEDecoder;
use crate::decoders::byte_fallback::ByteFallback;
Expand All @@ -24,7 +24,7 @@ use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::pre_tokenizers::metaspace::Metaspace;
use crate::{Decoder, Result};

#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Clone, Debug)]
#[serde(untagged)]
pub enum DecoderWrapper {
BPE(BPEDecoder),
Expand All @@ -39,6 +39,116 @@ pub enum DecoderWrapper {
ByteFallback(ByteFallback),
}

impl<'de> Deserialize<'de> for DecoderWrapper {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
pub struct Tagged {
#[serde(rename = "type")]
variant: EnumType,
#[serde(flatten)]
rest: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
pub enum EnumType {
BPEDecoder,
ByteLevel,
WordPiece,
Metaspace,
CTC,
Sequence,
Replace,
Fuse,
Strip,
ByteFallback,
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum DecoderHelper {
Tagged(Tagged),
Legacy(serde_json::Value),
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum DecoderUntagged {
BPE(BPEDecoder),
ByteLevel(ByteLevel),
WordPiece(WordPiece),
Metaspace(Metaspace),
CTC(CTC),
Sequence(Sequence),
Replace(Replace),
Fuse(Fuse),
Strip(Strip),
ByteFallback(ByteFallback),
}

let helper = DecoderHelper::deserialize(deserializer).expect("Helper");
Ok(match helper {
DecoderHelper::Tagged(model) => {
let mut values: serde_json::Map<String, serde_json::Value> =
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?;
values.insert(
"type".to_string(),
serde_json::to_value(&model.variant).map_err(serde::de::Error::custom)?,
);
let values = serde_json::Value::Object(values);
match model.variant {
EnumType::BPEDecoder => DecoderWrapper::BPE(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::ByteLevel => DecoderWrapper::ByteLevel(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::WordPiece => DecoderWrapper::WordPiece(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Metaspace => DecoderWrapper::Metaspace(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::CTC => DecoderWrapper::CTC(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Sequence => DecoderWrapper::Sequence(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Replace => DecoderWrapper::Replace(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Fuse => DecoderWrapper::Fuse(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Strip => DecoderWrapper::Strip(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::ByteFallback => DecoderWrapper::ByteFallback(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
}
}
DecoderHelper::Legacy(value) => {
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
match untagged {
DecoderUntagged::BPE(dec) => DecoderWrapper::BPE(dec),
DecoderUntagged::ByteLevel(dec) => DecoderWrapper::ByteLevel(dec),
DecoderUntagged::WordPiece(dec) => DecoderWrapper::WordPiece(dec),
DecoderUntagged::Metaspace(dec) => DecoderWrapper::Metaspace(dec),
DecoderUntagged::CTC(dec) => DecoderWrapper::CTC(dec),
DecoderUntagged::Sequence(dec) => DecoderWrapper::Sequence(dec),
DecoderUntagged::Replace(dec) => DecoderWrapper::Replace(dec),
DecoderUntagged::Fuse(dec) => DecoderWrapper::Fuse(dec),
DecoderUntagged::Strip(dec) => DecoderWrapper::Strip(dec),
DecoderUntagged::ByteFallback(dec) => DecoderWrapper::ByteFallback(dec),
}
}
})
}
}

impl Decoder for DecoderWrapper {
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
match self {
Expand Down Expand Up @@ -98,7 +208,7 @@ mod tests {
match parse {
Err(err) => assert_eq!(
format!("{err}"),
"data did not match any variant of untagged enum DecoderWrapper"
"data did not match any variant of untagged enum DecoderUntagged"
),
_ => panic!("Expected error"),
}
Expand All @@ -108,18 +218,15 @@ mod tests {
match parse {
Err(err) => assert_eq!(
format!("{err}"),
"data did not match any variant of untagged enum DecoderWrapper"
"data did not match any variant of untagged enum DecoderUntagged"
),
_ => panic!("Expected error"),
}

let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
let parse = serde_json::from_str::<DecoderWrapper>(json);
match parse {
Err(err) => assert_eq!(
format!("{err}"),
"data did not match any variant of untagged enum DecoderWrapper"
),
Err(err) => assert_eq!(format!("{err}"), "missing field `decoders`"),
_ => panic!("Expected error"),
}
}
Expand Down

0 comments on commit eada410

Please sign in to comment.