diff --git a/src/kio/serial/readers.py b/src/kio/serial/readers.py index a1c3d1dd..2e27c6c9 100644 --- a/src/kio/serial/readers.py +++ b/src/kio/serial/readers.py @@ -23,6 +23,7 @@ from kio.static.primitive import u32 from kio.static.primitive import u64 +from .errors import OutOfBoundValue from .errors import UnexpectedNull T = TypeVar("T") @@ -192,15 +193,20 @@ def read_timedelta_i64(buffer: IO[bytes]) -> i64Timedelta: return datetime.timedelta(milliseconds=read_int64(buffer)) # type: ignore[return-value] +def _tz_aware_from_i64(timestamp: i64) -> TZAware: + dt = datetime.datetime.fromtimestamp(timestamp / 1000, datetime.UTC) + try: + return TZAware.truncate(dt) + except TypeError as exception: + raise OutOfBoundValue("Read invalid value for datetime") from exception + + def read_datetime_i64(buffer: IO[bytes]) -> TZAware: - return datetime.datetime.fromtimestamp( # type: ignore[return-value] - read_int64(buffer) / 1000, - datetime.UTC, - ) + return _tz_aware_from_i64(read_int64(buffer)) def read_nullable_datetime_i64(buffer: IO[bytes]) -> TZAware | None: timestamp = read_int64(buffer) if timestamp == -1: return None - return TZAware.fromtimestamp(timestamp / 1000) + return _tz_aware_from_i64(timestamp) diff --git a/src/kio/static/primitive.py b/src/kio/static/primitive.py index 2b971ae5..37206321 100644 --- a/src/kio/static/primitive.py +++ b/src/kio/static/primitive.py @@ -4,6 +4,7 @@ import math from collections.abc import Callable from typing import Final +from typing import Self from ._phantom import Phantom from ._phantom import Predicate @@ -177,7 +178,12 @@ def i64_timedeltas( def is_tz_aware(dt: datetime.datetime) -> bool: - return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None + return ( + dt.tzinfo is not None + and dt.tzinfo.utcoffset(dt) is not None + and dt.microsecond == 0 + and dt.timestamp() >= 0 + ) class TZAware( @@ -186,15 +192,59 @@ class TZAware( bound=datetime.datetime, predicate=is_tz_aware, ): + """ + Type describing all datetime.datetime instances that are timezone aware, + have millisecond precision, and have a non-negative unix timestamp + representation. + + - Timezone awareness means any object lacking timezone data is excluded. + - Millisecond precision means any object with microsecond != 0 is excluded. + - Kafka uses -1 to represent NULL, so negative unix timestamps are not + supported. + """ + tzinfo: datetime.tzinfo @classmethod def __hypothesis_hook__(cls) -> None: - from hypothesis.strategies import datetimes + from hypothesis import assume + from hypothesis.strategies import SearchStrategy + from hypothesis.strategies import composite + from hypothesis.strategies import integers from hypothesis.strategies import register_type_strategy from hypothesis.strategies import timezones - register_type_strategy( - cls, - datetimes(timezones=timezones()), # type: ignore[arg-type] - ) + min_ts = 0 + max_ts = int(datetime.datetime.max.replace(tzinfo=datetime.UTC).timestamp()) - 1 + + @composite + def milli_second_precision_tz_aware_datetimes( + draw: Callable, + timestamp_strategy: SearchStrategy[int] = integers(min_ts, max_ts), + timezone_strategy: SearchStrategy[datetime.tzinfo] = timezones(), + ) -> TZAware: + """ + Generate millisecond precision datetime objects that are representable both + within the legal boundaries of UTC timestamps, and within the boundaries of + Python datetime objects (i.e. with 0 < year 10_000. + """ + try: + return TZAware.parse( + datetime.datetime.fromtimestamp( + draw(timestamp_strategy), + tz=datetime.UTC, + ).astimezone(draw(timezone_strategy)) + ) + except OverflowError: + # Both timestamps and dates have an upper limit. This means that the + # upper boundary for timestamps cannot be represented in all timezones. + # For timezones where the date wraps around to the year 10_000, an + # OverflowError occurs. + assume(False) + raise # make mypy aware this branch always raises + + register_type_strategy(cls, milli_second_precision_tz_aware_datetimes()) + + @classmethod + def truncate(cls, value: datetime.datetime) -> Self: + return cls.parse(value.replace(microsecond=0)) diff --git a/tests/serial/test_readers.py b/tests/serial/test_readers.py index 32c22ae4..f563c765 100644 --- a/tests/serial/test_readers.py +++ b/tests/serial/test_readers.py @@ -1,3 +1,4 @@ +import datetime import io import struct import sys @@ -6,12 +7,14 @@ import pytest +from kio.serial.errors import OutOfBoundValue from kio.serial.errors import UnexpectedNull from kio.serial.readers import Reader from kio.serial.readers import read_compact_string from kio.serial.readers import read_compact_string_as_bytes from kio.serial.readers import read_compact_string_as_bytes_nullable from kio.serial.readers import read_compact_string_nullable +from kio.serial.readers import read_datetime_i64 from kio.serial.readers import read_float64 from kio.serial.readers import read_int8 from kio.serial.readers import read_int16 @@ -19,6 +22,7 @@ from kio.serial.readers import read_int64 from kio.serial.readers import read_legacy_bytes from kio.serial.readers import read_legacy_string +from kio.serial.readers import read_nullable_datetime_i64 from kio.serial.readers import read_nullable_legacy_bytes from kio.serial.readers import read_nullable_legacy_string from kio.serial.readers import read_uint8 @@ -28,6 +32,7 @@ from kio.serial.readers import read_unsigned_varint from kio.serial.readers import read_uuid from kio.static.constants import uuid_zero +from kio.static.primitive import TZAware class IntReaderContract: @@ -371,3 +376,74 @@ def test_can_read_uuid4(self, buffer: io.BytesIO) -> None: buffer.write(value.bytes) buffer.seek(0) assert read_uuid(buffer) == value + + +class TestReadDatetimeI64: + reader = read_datetime_i64 + lower_limit = datetime.datetime.fromtimestamp(0, tz=datetime.UTC) + lower_limit_as_bytes = struct.pack(">q", 0) + upper_limit = datetime.datetime.fromtimestamp(253402300799, datetime.UTC) + upper_limit_as_bytes = struct.pack(">q", int(upper_limit.timestamp() * 1000)) + + @classmethod + def read(cls, buffer: IO[bytes]) -> TZAware: + return cls.reader(buffer) + + def test_can_read_lower_limit(self, buffer: io.BytesIO) -> None: + buffer.write(self.lower_limit_as_bytes) + buffer.seek(0) + assert self.lower_limit == self.read(buffer) + + def test_can_read_upper_limit(self, buffer: io.BytesIO) -> None: + buffer.write(self.upper_limit_as_bytes) + buffer.seek(0) + assert self.upper_limit == self.read(buffer) + + # As -1 is special null marker, also test with -2. + @pytest.mark.parametrize("value", (-1, -2)) + def test_raises_out_of_bound_value_for_negative_values( + self, + value: int, + buffer: io.BytesIO, + ) -> None: + buffer.write(struct.pack(">q", value)) + buffer.seek(0) + with pytest.raises(OutOfBoundValue): + self.read(buffer) + + +class TestReadNullableDatetimeI64: + reader = read_nullable_datetime_i64 + null_as_bytes = struct.pack(">q", -1) + lower_limit = datetime.datetime.fromtimestamp(0, tz=datetime.UTC) + lower_limit_as_bytes = struct.pack(">q", 0) + upper_limit = datetime.datetime.fromtimestamp(253402300799, datetime.UTC) + upper_limit_as_bytes = struct.pack(">q", int(upper_limit.timestamp() * 1000)) + + @classmethod + def read(cls, buffer: IO[bytes]) -> TZAware | None: + return cls.reader(buffer) + + def test_can_read_null(self, buffer: io.BytesIO) -> None: + buffer.write(self.null_as_bytes) + buffer.seek(0) + assert self.read(buffer) is None + + def test_can_read_lower_limit(self, buffer: io.BytesIO) -> None: + buffer.write(self.lower_limit_as_bytes) + buffer.seek(0) + assert self.lower_limit == self.read(buffer) + + def test_can_read_upper_limit(self, buffer: io.BytesIO) -> None: + buffer.write(self.upper_limit_as_bytes) + buffer.seek(0) + assert self.upper_limit == self.read(buffer) + + def test_raises_out_of_bound_value_for_negative_values( + self, + buffer: io.BytesIO, + ) -> None: + buffer.write(struct.pack(">q", -2)) + buffer.seek(0) + with pytest.raises(OutOfBoundValue): + self.read(buffer)