From d7cf72d8d591632889967b37eb4367f9f9f889e7 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 12:23:05 +0000 Subject: [PATCH] Int extraction (#1155) --- .gitignore | 3 +++ src/errors/types.rs | 2 +- src/errors/value_exception.rs | 2 +- src/input/input_python.rs | 10 ++++---- src/input/return_enums.rs | 6 ++--- src/lookup_key.rs | 2 +- src/serializers/infer.rs | 5 +++- src/serializers/type_serializers/literal.rs | 4 +-- src/tools.rs | 28 +++++++++++++++------ tests/benchmarks/test_micro_benchmarks.py | 12 +++++++++ tests/validators/test_int.py | 15 ++++++++--- 11 files changed, 64 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 6c9ace5f4..efffcbf69 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ node_modules/ /foobar.py /python/pydantic_core/*.so /src/self_schema.py + +# samply +/profile.json diff --git a/src/errors/types.rs b/src/errors/types.rs index cfa96221e..eddd7dbaa 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -786,7 +786,7 @@ impl From for Number { impl FromPyObject<'_> for Number { fn extract(obj: &PyAny) -> PyResult { - if let Ok(int) = extract_i64(obj) { + if let Some(int) = extract_i64(obj) { Ok(Number::Int(int)) } else if let Ok(float) = obj.extract::() { Ok(Number::Float(float)) diff --git a/src/errors/value_exception.rs b/src/errors/value_exception.rs index a88610eef..68f93d463 100644 --- a/src/errors/value_exception.rs +++ b/src/errors/value_exception.rs @@ -122,7 +122,7 @@ impl PydanticCustomError { let key: &PyString = key.downcast()?; if let Ok(py_str) = value.downcast::() { message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?); - } else if let Ok(value_int) = extract_i64(value) { + } else if let Some(value_int) = extract_i64(value) { message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string()); } else { // fallback for anything else just in case diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 5d4d4826c..2eaaebe10 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -96,7 +96,7 @@ impl AsLocItem for PyAny { fn as_loc_item(&self) -> LocItem { if let Ok(py_str) = self.downcast::() { py_str.to_string_lossy().as_ref().into() - } else if let Ok(key_int) = extract_i64(self) { + } else if let Some(key_int) = extract_i64(self) { key_int.into() } else { safe_repr(self).to_string().into() @@ -292,7 +292,7 @@ impl<'a> Input<'a> for PyAny { if !strict { if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? { return str_as_bool(self, &cow_str).map(ValidationMatch::lax); - } else if let Ok(int) = extract_i64(self) { + } else if let Some(int) = extract_i64(self) { return int_as_bool(self, int).map(ValidationMatch::lax); } else if let Ok(float) = self.extract::() { if let Ok(int) = float_as_int(self, float) { @@ -635,7 +635,7 @@ impl<'a> Input<'a> for PyAny { bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior) } else if PyBool::is_exact_type_of(self) { Err(ValError::new(ErrorTypeDefaults::TimeType, self)) - } else if let Ok(int) = extract_i64(self) { + } else if let Some(int) = extract_i64(self) { int_as_time(self, int, 0) } else if let Ok(float) = self.extract::() { float_as_time(self, float) @@ -669,7 +669,7 @@ impl<'a> Input<'a> for PyAny { bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior) } else if PyBool::is_exact_type_of(self) { Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) - } else if let Ok(int) = extract_i64(self) { + } else if let Some(int) = extract_i64(self) { int_as_datetime(self, int, 0) } else if let Ok(float) = self.extract::() { float_as_datetime(self, float) @@ -706,7 +706,7 @@ impl<'a> Input<'a> for PyAny { bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior) } else if let Ok(py_bytes) = self.downcast::() { bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(int) = extract_i64(self) { + } else if let Some(int) = extract_i64(self) { Ok(int_as_duration(self, int)?.into()) } else if let Ok(float) = self.extract::() { Ok(float_as_duration(self, float)?.into()) diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index fa70880ca..905b895f9 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -23,7 +23,7 @@ use pyo3::PyTypeInfo; use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult}; -use crate::tools::py_err; +use crate::tools::{extract_i64, py_err}; use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator}; use super::input_string::StringMapping; @@ -863,7 +863,7 @@ pub enum EitherInt<'a> { impl<'a> EitherInt<'a> { pub fn upcast(py_any: &'a PyAny) -> ValResult { // Safety: we know that py_any is a python int - if let Ok(int_64) = py_any.extract::() { + if let Some(int_64) = extract_i64(py_any) { Ok(Self::I64(int_64)) } else { let big_int: BigInt = py_any.extract()?; @@ -1021,7 +1021,7 @@ impl<'a> Rem for &'a Int { impl<'a> FromPyObject<'a> for Int { fn extract(obj: &'a PyAny) -> PyResult { - if let Ok(i) = obj.extract::() { + if let Some(i) = extract_i64(obj) { Ok(Int::I64(i)) } else if let Ok(b) = obj.extract::() { Ok(Int::Big(b)) diff --git a/src/lookup_key.rs b/src/lookup_key.rs index e145c1f41..e4d2dcce7 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -429,7 +429,7 @@ impl PathItem { } else { Ok(Self::Pos(usize_key)) } - } else if let Ok(int_key) = extract_i64(obj) { + } else if let Some(int_key) = extract_i64(obj) { if index == 0 { py_err!(PyTypeError; "The first item in an alias path should be a string") } else { diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 13c20062b..e39ca38f8 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -123,7 +123,10 @@ pub(crate) fn infer_to_python_known( // `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py), // have to do this to make sure subclasses of for example str are upcast to `str` - ObType::IntSubclass => extract_i64(value)?.into_py(py), + ObType::IntSubclass => match extract_i64(value) { + Some(v) => v.into_py(py), + None => return py_err!(PyTypeError; "expected int, got {}", safe_repr(value)), + }, ObType::Float | ObType::FloatSubclass => { let v = value.extract::()?; if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null { diff --git a/src/serializers/type_serializers/literal.rs b/src/serializers/type_serializers/literal.rs index 846b8843f..d6b08afa5 100644 --- a/src/serializers/type_serializers/literal.rs +++ b/src/serializers/type_serializers/literal.rs @@ -46,7 +46,7 @@ impl BuildSerializer for LiteralSerializer { repr_args.push(item.repr()?.extract()?); if let Ok(bool) = item.downcast::() { expected_py.append(bool)?; - } else if let Ok(int) = extract_i64(item) { + } else if let Some(int) = extract_i64(item) { expected_int.insert(int); } else if let Ok(py_str) = item.downcast::() { expected_str.insert(py_str.to_str()?.to_string()); @@ -79,7 +79,7 @@ impl LiteralSerializer { fn check<'a>(&self, value: &'a PyAny, extra: &Extra) -> PyResult> { if extra.check.enabled() { if !self.expected_int.is_empty() && !PyBool::is_type_of(value) { - if let Ok(int) = extract_i64(value) { + if let Some(int) = extract_i64(value) { if self.expected_int.contains(&int) { return Ok(OutputValue::OkInt(int)); } diff --git a/src/tools.rs b/src/tools.rs index af58131f5..bdc41583c 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,9 +1,9 @@ use std::borrow::Cow; -use pyo3::exceptions::{PyKeyError, PyTypeError}; +use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyInt, PyString}; -use pyo3::{intern, FromPyObject, PyTypeInfo}; +use pyo3::types::{PyDict, PyString}; +use pyo3::{ffi, intern, FromPyObject}; pub trait SchemaDict<'py> { fn get_as(&'py self, key: &PyString) -> PyResult> @@ -99,10 +99,24 @@ pub fn safe_repr(v: &PyAny) -> Cow { } } -pub fn extract_i64(v: &PyAny) -> PyResult { - if PyInt::is_type_of(v) { - v.extract() +/// Extract an i64 from a python object more quickly, see +/// https://github.com/PyO3/pyo3/pull/3742#discussion_r1451763928 +#[cfg(not(any(target_pointer_width = "32", windows, PyPy)))] +pub fn extract_i64(obj: &PyAny) -> Option { + let val = unsafe { ffi::PyLong_AsLong(obj.as_ptr()) }; + if val == -1 && PyErr::occurred(obj.py()) { + unsafe { ffi::PyErr_Clear() }; + None } else { - py_err!(PyTypeError; "expected int, got {}", safe_repr(v)) + Some(val) + } +} + +#[cfg(any(target_pointer_width = "32", windows, PyPy))] +pub fn extract_i64(v: &PyAny) -> Option { + if v.is_instance_of::() { + v.extract().ok() + } else { + None } } diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index f1ec32eef..c2320427c 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -1232,6 +1232,18 @@ def test_strict_int(benchmark): benchmark(v.validate_python, 42) +@pytest.mark.benchmark(group='strict_int') +def test_strict_int_fails(benchmark): + v = SchemaValidator(core_schema.int_schema(strict=True)) + + @benchmark + def t(): + try: + v.validate_python(()) + except ValidationError: + pass + + @pytest.mark.benchmark(group='int_range') def test_int_range(benchmark): v = SchemaValidator(core_schema.int_schema(gt=0, lt=100)) diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 80dd1cf73..35a13f6a7 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -29,6 +29,8 @@ ('123456789123456.00001', Err('Input should be a valid integer, unable to parse string as an integer')), (int(1e10), int(1e10)), (i64_max, i64_max), + (i64_max + 1, i64_max + 1), + (i64_max * 2, i64_max * 2), pytest.param( 12.5, Err('Input should be a valid integer, got a number with a fractional part [type=int_from_float'), @@ -106,10 +108,15 @@ def test_int(input_value, expected): @pytest.mark.parametrize( 'input_value,expected', [ - (Decimal('1'), 1), - (Decimal('1.0'), 1), - (i64_max, i64_max), - (i64_max + 1, i64_max + 1), + pytest.param(Decimal('1'), 1), + pytest.param(Decimal('1.0'), 1), + pytest.param(i64_max, i64_max, id='i64_max'), + pytest.param(i64_max + 1, i64_max + 1, id='i64_max+1'), + pytest.param( + -1, + Err('Input should be greater than 0 [type=greater_than, input_value=-1, input_type=int]'), + id='-1', + ), ( -i64_max + 1, Err('Input should be greater than 0 [type=greater_than, input_value=-9223372036854775806, input_type=int]'),