diff --git a/bindings/python/docs/source/changelog.rst b/bindings/python/docs/source/changelog.rst index 8f3a12f5..95ff8b5e 100644 --- a/bindings/python/docs/source/changelog.rst +++ b/bindings/python/docs/source/changelog.rst @@ -4,7 +4,7 @@ Changelog Changes in Version 1.3.0 ------------------------ - Support for Polars -- Support for PyArrow.DataTypes: large_list, large_string +- Support for PyArrow.DataTypes: large_list, large_string, date32, date64 Changes in Version 1.2.0 ------------------------ diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index b850c628..4a59ebb9 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -28,7 +28,8 @@ from bson.raw_bson import RawBSONDocument from numpy import ndarray from pyarrow import Schema as ArrowSchema -from pyarrow import Table +from pyarrow import Table, timestamp +from pyarrow.types import is_date32, is_date64 from pymongo.bulk import BulkWriteError from pymongo.common import MAX_WRITE_BATCH_SIZE @@ -430,6 +431,17 @@ def write(collection, tabular): } tab_size = len(tabular) if isinstance(tabular, Table): + # Convert date objects to datetime objects. + changed = False + new_types = [] + for dtype in tabular.schema.types: + if is_date32(dtype) or is_date64(dtype): + changed = True + dtype = timestamp("ms") # noqa: PLW2901 + new_types.append(dtype) + if changed: + cols = [tabular.column(i).cast(new_types[i]) for i in range(tabular.num_columns)] + tabular = Table.from_arrays(cols, names=tabular.column_names) _validate_schema(tabular.schema.types) elif isinstance(tabular, pd.DataFrame): _validate_schema(ArrowSchema.from_pandas(tabular).types) diff --git a/bindings/python/pymongoarrow/context.py b/bindings/python/pymongoarrow/context.py index 2b006d22..0f08b058 100644 --- a/bindings/python/pymongoarrow/context.py +++ b/bindings/python/pymongoarrow/context.py @@ -21,6 +21,8 @@ BinaryBuilder, BoolBuilder, CodeBuilder, + Date32Builder, + Date64Builder, DatetimeBuilder, Decimal128Builder, DocumentBuilder, @@ -45,6 +47,8 @@ _BsonArrowTypes.array: ListBuilder, _BsonArrowTypes.binary: BinaryBuilder, _BsonArrowTypes.code: CodeBuilder, + _BsonArrowTypes.date32: Date32Builder, + _BsonArrowTypes.date64: Date64Builder, } except ImportError: pass diff --git a/bindings/python/pymongoarrow/lib.pyx b/bindings/python/pymongoarrow/lib.pyx index ac4c7716..0f29d214 100644 --- a/bindings/python/pymongoarrow/lib.pyx +++ b/bindings/python/pymongoarrow/lib.pyx @@ -60,6 +60,11 @@ cdef const bson_t* bson_reader_read_safe(bson_reader_t* stream_reader) except? N raise InvalidBSON("Could not read BSON document stream") return doc + +# Placeholder numbers for the date types. +cdef uint8_t ARROW_TYPE_DATE32 = 100 +cdef uint8_t ARROW_TYPE_DATE64 = 101 + _builder_type_map = { BSON_TYPE_INT32: Int32Builder, BSON_TYPE_INT64: Int64Builder, @@ -73,6 +78,8 @@ _builder_type_map = { BSON_TYPE_ARRAY: ListBuilder, BSON_TYPE_BINARY: BinaryBuilder, BSON_TYPE_CODE: CodeBuilder, + ARROW_TYPE_DATE32: Date32Builder, + ARROW_TYPE_DATE64: Date64Builder, } _field_type_map = { @@ -188,6 +195,8 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje cdef Decimal128Builder dec128_builder cdef ListBuilder list_builder cdef DocumentBuilder doc_builder + cdef Date32Builder date32_builder + cdef Date64Builder date64_builder # Build up a map of the builders. for key, value in context.builder_map.items(): @@ -249,6 +258,10 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje bson_iter_binary (&doc_iter, &subtype, &val_buf_len, &val_buf) builder = BinaryBuilder(subtype) + elif builder_type == Date32Builder: + builder = Date32Builder() + elif builder_type == Date64Builder: + builder = Date64Builder() else: builder = builder_type() if arr_value_builder is None: @@ -346,6 +359,18 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje double_builder.append_raw(bson_iter_as_double(&doc_iter)) else: double_builder.append_null() + elif ftype == ARROW_TYPE_DATE32: + date32_builder = builder + if value_t == BSON_TYPE_DATE_TIME: + date32_builder.append_raw(bson_iter_date_time(&doc_iter)) + else: + date32_builder.append_null() + elif ftype == ARROW_TYPE_DATE64: + date64_builder = builder + if value_t == BSON_TYPE_DATE_TIME: + date64_builder.append_raw(bson_iter_date_time(&doc_iter)) + else: + date64_builder.append_null() elif ftype == BSON_TYPE_DATE_TIME: datetime_builder = builder if value_t == BSON_TYPE_DATE_TIME: @@ -626,6 +651,78 @@ cdef class DatetimeBuilder(_ArrayBuilderBase): cdef shared_ptr[CTimestampBuilder] unwrap(self): return self.builder +cdef class Date64Builder(_ArrayBuilderBase): + cdef: + shared_ptr[CDate64Builder] builder + DataType dtype + + def __cinit__(self, MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + self.builder.reset(new CDate64Builder(pool)) + self.type_marker = ARROW_TYPE_DATE64 + + cdef append_raw(self, int64_t value): + self.builder.get().Append(value) + + cpdef append(self, value): + self.builder.get().Append(value) + + cpdef append_null(self): + self.builder.get().AppendNull() + + def __len__(self): + return self.builder.get().length() + + @property + def unit(self): + return self.dtype + + cpdef finish(self): + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + cdef shared_ptr[CDate64Builder] unwrap(self): + return self.builder + +cdef class Date32Builder(_ArrayBuilderBase): + cdef: + shared_ptr[CDate32Builder] builder + DataType dtype + + def __cinit__(self, MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + self.builder.reset(new CDate32Builder(pool)) + self.type_marker = ARROW_TYPE_DATE32 + + cdef append_raw(self, int64_t value): + # Convert from milliseconds to days (1000*60*60*24) + cdef int32_t seconds_val = value // 86400000 + self.builder.get().Append(seconds_val) + + cpdef append(self, value): + self.builder.get().Append(value) + + cpdef append_null(self): + self.builder.get().AppendNull() + + def __len__(self): + return self.builder.get().length() + + @property + def unit(self): + return self.dtype + + cpdef finish(self): + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + cdef shared_ptr[CDate32Builder] unwrap(self): + return self.builder + cdef class BoolBuilder(_ArrayBuilderBase): cdef: diff --git a/bindings/python/pymongoarrow/types.py b/bindings/python/pymongoarrow/types.py index 3a692753..915de2cf 100644 --- a/bindings/python/pymongoarrow/types.py +++ b/bindings/python/pymongoarrow/types.py @@ -53,6 +53,8 @@ class _BsonArrowTypes(enum.Enum): array = 10 binary = 11 code = 12 + date32 = 13 + date64 = 14 # Custom Extension Types. @@ -266,6 +268,8 @@ def get_numpy_type(type): _atypes.is_boolean: _BsonArrowTypes.bool, _atypes.is_struct: _BsonArrowTypes.document, _atypes.is_list: _BsonArrowTypes.array, + _atypes.is_date32: _BsonArrowTypes.date32, + _atypes.is_date64: _BsonArrowTypes.date64, _atypes.is_large_string: _BsonArrowTypes.string, _atypes.is_large_list: _BsonArrowTypes.array, } diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 76c30989..42961619 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -95,7 +95,7 @@ archs = "x86_64 arm64" [tool.pytest.ini_options] minversion = "7" addopts = ["-ra", "--strict-config", "--strict-markers", "--durations=5", "--junitxml=xunit-results/TEST-results.xml"] -testpaths = ["test"] +testpaths = ["test", "test/pandas_types"] log_cli_level = "INFO" norecursedirs = ["test/*"] faulthandler_timeout = 1500 diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index 88973041..92af4be4 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -15,7 +15,7 @@ import tempfile import unittest import unittest.mock as mock -from datetime import datetime +from datetime import date, datetime from test import client_context from test.utils import AllowListEventListener, NullsTestMixin @@ -28,6 +28,8 @@ Table, bool_, csv, + date32, + date64, decimal256, field, int32, @@ -316,6 +318,21 @@ def test_write_schema_validation(self): with self.assertRaises(ValueError): self.round_trip(data, Schema(schema)) + def test_date_types(self): + schema, data = self._create_data() + self.round_trip(data, Schema(schema)) + + schema = {"_id": int32(), "date32": date32(), "date64": date64()} + data = Table.from_pydict( + { + "_id": [i for i in range(2)], + "date32": [date(2012, 1, 1) for _ in range(2)], + "date64": [datetime(2012, 1, 1) for _ in range(2)], + }, + ArrowSchema(schema), + ) + self.round_trip(data, Schema(schema)) + @mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True) def test_write_batching(self, mock): schema = { @@ -349,6 +366,8 @@ def _create_nested_data(self, nested_elem=None): "ObjectId": [ObjectId().binary for i in range(3)], "Decimal128": [Decimal128(str(i)).bid for i in range(3)], "Code": [str(i) for i in range(3)], + "date32": [date(2012, 1, 1) for i in range(3)], + "date64": [date(2012, 1, 1) for i in range(3)], } def inner(i): @@ -363,6 +382,8 @@ def inner(i): Binary=Binary(bytes(i), 10), ObjectId=ObjectId().binary, Code=str(i), + date32=date(2012, 1, 1), + date64=date(2014, 1, 1), ) if nested_elem: inner_dict["list"] = [nested_elem] diff --git a/bindings/python/test/test_builders.py b/bindings/python/test/test_builders.py index acf32f54..2a19e103 100644 --- a/bindings/python/test/test_builders.py +++ b/bindings/python/test/test_builders.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import calendar -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from unittest import TestCase from bson import Binary, Code, Decimal128, ObjectId @@ -22,6 +22,8 @@ BinaryBuilder, BoolBuilder, CodeBuilder, + Date32Builder, + Date64Builder, DatetimeBuilder, Decimal128Builder, DocumentBuilder, @@ -62,7 +64,7 @@ def setUp(self): self.data_type = int64() -class TestDate64Builder(TestCase): +class TestDatetimeBuilder(TestCase): def test_default_unit(self): # Check default unit builder = DatetimeBuilder() @@ -281,3 +283,40 @@ def test_simple(self): self.assertEqual(arr.null_count, 1) self.assertEqual(len(arr), 5) self.assertEqual(arr.to_pylist(), codes + [None]) + + +class TestDate32Builder(TestCase): + def test_simple(self): + epoch = date(1970, 1, 1) + values = [date(2012, 1, 1), date(2012, 1, 2), date(2014, 4, 5)] + builder = Date32Builder() + builder.append(values[0].toordinal() - epoch.toordinal()) + builder.append_values([v.toordinal() - epoch.toordinal() for v in values[1:]]) + builder.append_null() + arr = builder.finish() + + self.assertIsInstance(arr, Array) + self.assertEqual(arr.null_count, 1) + self.assertEqual(len(arr), 4) + self.assertEqual(arr.to_pylist(), values + [None]) + + +class TestDate64Builder(TestCase): + def test_simple(self): + def msec_since_epoch(d): + epoch = datetime(1970, 1, 1) + d = datetime.fromordinal(d.toordinal()) + diff = d - epoch + return diff.total_seconds() * 1000 + + values = [date(2012, 1, 1), date(2012, 1, 2), date(2014, 4, 5)] + builder = Date64Builder() + builder.append(msec_since_epoch(values[0])) + builder.append_values([msec_since_epoch(v) for v in values[1:]]) + builder.append_null() + arr = builder.finish() + + self.assertIsInstance(arr, Array) + self.assertEqual(arr.null_count, 1) + self.assertEqual(len(arr), 4) + self.assertEqual(arr.to_pylist(), values + [None])