Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-191 Add support for date32 and date64 types #192

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/python/docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
Changes in Version 1.3.0
------------------------
- Support for Polars
- Support for PyArrow.DataTypes: large_list, large_string
- Support for PyArrow.DataTypes: large_list, large_string, date32, date64

Changes in Version 1.2.0
------------------------
Expand Down
14 changes: 13 additions & 1 deletion bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from bson.raw_bson import RawBSONDocument
from numpy import ndarray
from pyarrow import Schema as ArrowSchema
from pyarrow import Table
from pyarrow import Table, timestamp
from pyarrow.types import is_date32, is_date64
from pymongo.bulk import BulkWriteError
from pymongo.common import MAX_WRITE_BATCH_SIZE

Expand Down Expand Up @@ -430,6 +431,17 @@ def write(collection, tabular):
}
tab_size = len(tabular)
if isinstance(tabular, Table):
# Convert date objects to datetime objects.
changed = False
new_types = []
for dtype in tabular.schema.types:
if is_date32(dtype) or is_date64(dtype):
changed = True
dtype = timestamp("ms") # noqa: PLW2901
new_types.append(dtype)
if changed:
cols = [tabular.column(i).cast(new_types[i]) for i in range(tabular.num_columns)]
tabular = Table.from_arrays(cols, names=tabular.column_names)
_validate_schema(tabular.schema.types)
elif isinstance(tabular, pd.DataFrame):
_validate_schema(ArrowSchema.from_pandas(tabular).types)
Expand Down
4 changes: 4 additions & 0 deletions bindings/python/pymongoarrow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
BinaryBuilder,
BoolBuilder,
CodeBuilder,
Date32Builder,
Date64Builder,
DatetimeBuilder,
Decimal128Builder,
DocumentBuilder,
Expand All @@ -45,6 +47,8 @@
_BsonArrowTypes.array: ListBuilder,
_BsonArrowTypes.binary: BinaryBuilder,
_BsonArrowTypes.code: CodeBuilder,
_BsonArrowTypes.date32: Date32Builder,
_BsonArrowTypes.date64: Date64Builder,
}
except ImportError:
pass
Expand Down
97 changes: 97 additions & 0 deletions bindings/python/pymongoarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ cdef const bson_t* bson_reader_read_safe(bson_reader_t* stream_reader) except? N
raise InvalidBSON("Could not read BSON document stream")
return doc


# Placeholder numbers for the date types.
cdef uint8_t ARROW_TYPE_DATE32 = 100
cdef uint8_t ARROW_TYPE_DATE64 = 101

_builder_type_map = {
BSON_TYPE_INT32: Int32Builder,
BSON_TYPE_INT64: Int64Builder,
Expand All @@ -73,6 +78,8 @@ _builder_type_map = {
BSON_TYPE_ARRAY: ListBuilder,
BSON_TYPE_BINARY: BinaryBuilder,
BSON_TYPE_CODE: CodeBuilder,
ARROW_TYPE_DATE32: Date32Builder,
ARROW_TYPE_DATE64: Date64Builder,
}

_field_type_map = {
Expand Down Expand Up @@ -188,6 +195,8 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
cdef Decimal128Builder dec128_builder
cdef ListBuilder list_builder
cdef DocumentBuilder doc_builder
cdef Date32Builder date32_builder
cdef Date64Builder date64_builder

# Build up a map of the builders.
for key, value in context.builder_map.items():
Expand Down Expand Up @@ -249,6 +258,10 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
bson_iter_binary (&doc_iter, &subtype,
&val_buf_len, &val_buf)
builder = BinaryBuilder(subtype)
elif builder_type == Date32Builder:
builder = Date32Builder()
elif builder_type == Date64Builder:
builder = Date64Builder()
else:
builder = builder_type()
if arr_value_builder is None:
Expand Down Expand Up @@ -346,6 +359,18 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
double_builder.append_raw(bson_iter_as_double(&doc_iter))
else:
double_builder.append_null()
elif ftype == ARROW_TYPE_DATE32:
date32_builder = builder
if value_t == BSON_TYPE_DATE_TIME:
date32_builder.append_raw(bson_iter_date_time(&doc_iter))
else:
date32_builder.append_null()
elif ftype == ARROW_TYPE_DATE64:
date64_builder = builder
if value_t == BSON_TYPE_DATE_TIME:
date64_builder.append_raw(bson_iter_date_time(&doc_iter))
else:
date64_builder.append_null()
elif ftype == BSON_TYPE_DATE_TIME:
datetime_builder = builder
if value_t == BSON_TYPE_DATE_TIME:
Expand Down Expand Up @@ -626,6 +651,78 @@ cdef class DatetimeBuilder(_ArrayBuilderBase):
cdef shared_ptr[CTimestampBuilder] unwrap(self):
return self.builder

cdef class Date64Builder(_ArrayBuilderBase):
cdef:
shared_ptr[CDate64Builder] builder
DataType dtype

def __cinit__(self, MemoryPool memory_pool=None):
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
self.builder.reset(new CDate64Builder(pool))
self.type_marker = ARROW_TYPE_DATE64

cdef append_raw(self, int64_t value):
self.builder.get().Append(value)

cpdef append(self, value):
self.builder.get().Append(value)

cpdef append_null(self):
self.builder.get().AppendNull()

def __len__(self):
return self.builder.get().length()

@property
def unit(self):
return self.dtype

cpdef finish(self):
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
return pyarrow_wrap_array(out)

cdef shared_ptr[CDate64Builder] unwrap(self):
return self.builder

cdef class Date32Builder(_ArrayBuilderBase):
cdef:
shared_ptr[CDate32Builder] builder
DataType dtype

def __cinit__(self, MemoryPool memory_pool=None):
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
self.builder.reset(new CDate32Builder(pool))
self.type_marker = ARROW_TYPE_DATE32

cdef append_raw(self, int64_t value):
# Convert from milliseconds to days (1000*60*60*24)
cdef int32_t seconds_val = value // 86400000
self.builder.get().Append(seconds_val)

cpdef append(self, value):
self.builder.get().Append(value)

cpdef append_null(self):
self.builder.get().AppendNull()

def __len__(self):
return self.builder.get().length()

@property
def unit(self):
return self.dtype

cpdef finish(self):
cdef shared_ptr[CArray] out
with nogil:
self.builder.get().Finish(&out)
return pyarrow_wrap_array(out)

cdef shared_ptr[CDate32Builder] unwrap(self):
return self.builder


cdef class BoolBuilder(_ArrayBuilderBase):
cdef:
Expand Down
4 changes: 4 additions & 0 deletions bindings/python/pymongoarrow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class _BsonArrowTypes(enum.Enum):
array = 10
binary = 11
code = 12
date32 = 13
date64 = 14


# Custom Extension Types.
Expand Down Expand Up @@ -266,6 +268,8 @@ def get_numpy_type(type):
_atypes.is_boolean: _BsonArrowTypes.bool,
_atypes.is_struct: _BsonArrowTypes.document,
_atypes.is_list: _BsonArrowTypes.array,
_atypes.is_date32: _BsonArrowTypes.date32,
_atypes.is_date64: _BsonArrowTypes.date64,
_atypes.is_large_string: _BsonArrowTypes.string,
_atypes.is_large_list: _BsonArrowTypes.array,
}
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ archs = "x86_64 arm64"
[tool.pytest.ini_options]
minversion = "7"
addopts = ["-ra", "--strict-config", "--strict-markers", "--durations=5", "--junitxml=xunit-results/TEST-results.xml"]
testpaths = ["test"]
testpaths = ["test", "test/pandas_types"]
log_cli_level = "INFO"
norecursedirs = ["test/*"]
faulthandler_timeout = 1500
Expand Down
23 changes: 22 additions & 1 deletion bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import tempfile
import unittest
import unittest.mock as mock
from datetime import datetime
from datetime import date, datetime
from test import client_context
from test.utils import AllowListEventListener, NullsTestMixin

Expand All @@ -28,6 +28,8 @@
Table,
bool_,
csv,
date32,
date64,
decimal256,
field,
int32,
Expand Down Expand Up @@ -316,6 +318,21 @@ def test_write_schema_validation(self):
with self.assertRaises(ValueError):
self.round_trip(data, Schema(schema))

def test_date_types(self):
schema, data = self._create_data()
self.round_trip(data, Schema(schema))

schema = {"_id": int32(), "date32": date32(), "date64": date64()}
data = Table.from_pydict(
{
"_id": [i for i in range(2)],
"date32": [date(2012, 1, 1) for _ in range(2)],
"date64": [datetime(2012, 1, 1) for _ in range(2)],
},
ArrowSchema(schema),
)
self.round_trip(data, Schema(schema))

@mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True)
def test_write_batching(self, mock):
schema = {
Expand Down Expand Up @@ -349,6 +366,8 @@ def _create_nested_data(self, nested_elem=None):
"ObjectId": [ObjectId().binary for i in range(3)],
"Decimal128": [Decimal128(str(i)).bid for i in range(3)],
"Code": [str(i) for i in range(3)],
"date32": [date(2012, 1, 1) for i in range(3)],
"date64": [date(2012, 1, 1) for i in range(3)],
}

def inner(i):
Expand All @@ -363,6 +382,8 @@ def inner(i):
Binary=Binary(bytes(i), 10),
ObjectId=ObjectId().binary,
Code=str(i),
date32=date(2012, 1, 1),
date64=date(2014, 1, 1),
)
if nested_elem:
inner_dict["list"] = [nested_elem]
Expand Down
43 changes: 41 additions & 2 deletions bindings/python/test/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import calendar
from datetime import datetime, timedelta
from datetime import date, datetime, timedelta
from unittest import TestCase

from bson import Binary, Code, Decimal128, ObjectId
Expand All @@ -22,6 +22,8 @@
BinaryBuilder,
BoolBuilder,
CodeBuilder,
Date32Builder,
Date64Builder,
DatetimeBuilder,
Decimal128Builder,
DocumentBuilder,
Expand Down Expand Up @@ -62,7 +64,7 @@ def setUp(self):
self.data_type = int64()


class TestDate64Builder(TestCase):
class TestDatetimeBuilder(TestCase):
def test_default_unit(self):
# Check default unit
builder = DatetimeBuilder()
Expand Down Expand Up @@ -281,3 +283,40 @@ def test_simple(self):
self.assertEqual(arr.null_count, 1)
self.assertEqual(len(arr), 5)
self.assertEqual(arr.to_pylist(), codes + [None])


class TestDate32Builder(TestCase):
def test_simple(self):
epoch = date(1970, 1, 1)
values = [date(2012, 1, 1), date(2012, 1, 2), date(2014, 4, 5)]
builder = Date32Builder()
builder.append(values[0].toordinal() - epoch.toordinal())
builder.append_values([v.toordinal() - epoch.toordinal() for v in values[1:]])
builder.append_null()
arr = builder.finish()

self.assertIsInstance(arr, Array)
self.assertEqual(arr.null_count, 1)
self.assertEqual(len(arr), 4)
self.assertEqual(arr.to_pylist(), values + [None])


class TestDate64Builder(TestCase):
def test_simple(self):
def msec_since_epoch(d):
epoch = datetime(1970, 1, 1)
d = datetime.fromordinal(d.toordinal())
diff = d - epoch
return diff.total_seconds() * 1000

values = [date(2012, 1, 1), date(2012, 1, 2), date(2014, 4, 5)]
builder = Date64Builder()
builder.append(msec_since_epoch(values[0]))
builder.append_values([msec_since_epoch(v) for v in values[1:]])
builder.append_null()
arr = builder.finish()

self.assertIsInstance(arr, Array)
self.assertEqual(arr.null_count, 1)
self.assertEqual(len(arr), 4)
self.assertEqual(arr.to_pylist(), values + [None])
Loading