Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to Django 5.2 #199

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
uses: actions/checkout@v4
with:
repository: 'mongodb-forks/django'
ref: 'mongodb-5.0.x'
ref: 'mongodb-5.2.x'
path: 'django_repo'
persist-credentials: false
- name: Install system packages for Django's Python test dependencies
Expand Down
2 changes: 1 addition & 1 deletion django_mongodb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "5.0a0"
__version__ = "5.2a0"

# Check Django compatibility before other imports which may fail if the
# wrong version of Django is installed.
Expand Down
71 changes: 40 additions & 31 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .base import Cursor
from .query import MongoQuery, wrap_database_errors


Expand Down Expand Up @@ -91,7 +90,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column])
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
Expand Down Expand Up @@ -241,11 +240,10 @@ def execute_sql(
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
):
self.pre_sql_setup()
columns = self.get_columns()
try:
query = self.build_query(
# Avoid $project (columns=None) if unneeded.
columns
self.columns
if self.query.annotations or not self.query.default_cols or self.query.distinct
else None
)
Expand All @@ -259,10 +257,10 @@ def execute_sql(
except StopIteration:
return None # No result
else:
return self._make_result(obj, columns)
return self._make_result(obj, self.columns)
# result_type is MULTI
cursor.batch_size(chunk_size)
result = self.cursor_iter(cursor, chunk_size, columns)
result = self.cursor_iter(cursor, chunk_size, self.columns)
if not chunked_fetch:
# If using non-chunked reads, read data into memory.
return list(result)
Expand Down Expand Up @@ -394,7 +392,8 @@ def build_query(self, columns=None):
query.subqueries = self.subqueries
return query

def get_columns(self):
@cached_property
def columns(self):
"""
Return a tuple of (name, expression) with the columns and annotations
which should be loaded by the query.
Expand All @@ -403,12 +402,6 @@ def get_columns(self):
columns = (
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select
)
# Populate QuerySet.select_related() data.
related_columns = []
if self.query.select_related:
self.get_related_selections(related_columns, select_mask)
if related_columns:
related_columns, _ = zip(*related_columns, strict=True)

annotation_idx = 1

Expand All @@ -427,11 +420,28 @@ def project_field(column):
annotation_idx += 1
return target, column

return (
tuple(map(project_field, columns))
+ tuple(self.annotations.items())
+ tuple(map(project_field, related_columns))
)
selected = []
if self.query.selected is None:
selected = [
*(project_field(col) for col in columns),
*self.annotations.items(),
]
else:
for expression in self.query.selected.values():
# Reference to an annotation.
if isinstance(expression, str):
alias, expression = expression, self.annotations[expression]
# Reference to a column.
elif isinstance(expression, int):
alias, expression = project_field(columns[expression])
selected.append((alias, expression))
# Populate QuerySet.select_related() data.
related_columns = []
if self.query.select_related:
self.get_related_selections(related_columns, select_mask)
if related_columns:
related_columns, _ = zip(*related_columns, strict=True)
return tuple(selected) + tuple(map(project_field, related_columns))

@cached_property
def base_table(self):
Expand Down Expand Up @@ -472,14 +482,17 @@ def get_combinator_queries(self):
query.get_compiler(self.using, self.connection, self.elide_empty)
for query in self.query.combined_queries
]
main_query_columns = self.get_columns()
main_query_fields, _ = zip(*main_query_columns, strict=True)
main_query_fields, _ = zip(*self.columns, strict=True)
for compiler_ in compilers:
try:
# If the columns list is limited, then all combined queries
# must have the same columns list. Set the selects defined on
# the query on all combined queries, if not already set.
if not compiler_.query.values_select and self.query.values_select:
selected = self.query.selected
if selected is not None and compiler_.query.selected is None:
compiler_.query = compiler_.query.clone()
compiler_.query.set_values(selected)
elif not compiler_.query.values_select and self.query.values_select:
compiler_.query = compiler_.query.clone()
compiler_.query.set_values(
(
Expand All @@ -490,7 +503,7 @@ def get_combinator_queries(self):
)
compiler_.pre_sql_setup()
compiler_.column_indices = self.column_indices
columns = compiler_.get_columns()
columns = compiler_.columns
parts.append((compiler_.build_query(columns), compiler_, columns))
except EmptyResultSet:
# Omit the empty queryset with UNION.
Expand Down Expand Up @@ -528,7 +541,7 @@ def get_combinator_queries(self):
combinator_pipeline = inner_pipeline
if not self.query.combinator_all:
ids = defaultdict(dict)
for alias, expr in main_query_columns:
for alias, expr in self.columns:
# Unfold foreign fields.
if isinstance(expr, Col) and expr.alias != self.collection_name:
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
Expand Down Expand Up @@ -633,10 +646,9 @@ def explain_query(self):
)
# Build the query pipeline.
self.pre_sql_setup()
columns = self.get_columns()
query = self.build_query(
# Avoid $project (columns=None) if unneeded.
columns if self.query.annotations or not self.query.default_cols else None
self.columns if self.query.annotations or not self.query.default_cols else None
)
pipeline = query.get_pipeline()
# Explain the pipeline.
Expand Down Expand Up @@ -692,15 +704,12 @@ def collection_name(self):

class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def execute_sql(self, result_type=MULTI):
cursor = Cursor()
try:
query = self.build_query()
except EmptyResultSet:
rowcount = 0
return 0
else:
rowcount = query.delete()
cursor.rowcount = rowcount
return cursor
return query.delete()

def check_query(self):
super().check_query()
Expand Down Expand Up @@ -796,7 +805,7 @@ def build_query(self, columns=None):
compiler.pre_sql_setup(with_col_aliases=False)
# Avoid $project (columns=None) if unneeded.
columns = (
compiler.get_columns()
compiler.columns
if self.query.annotations or not self.query.default_cols or self.query.distinct
else None
)
Expand Down
11 changes: 7 additions & 4 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ def order_by(self, compiler, connection):
def query(self, compiler, connection, lookup_name=None):
subquery_compiler = self.get_compiler(connection=connection)
subquery_compiler.pre_sql_setup(with_col_aliases=False)
columns = subquery_compiler.get_columns()
field_name, expr = columns[0]
field_name, expr = subquery_compiler.columns[0]
subquery = subquery_compiler.build_query(
columns
subquery_compiler.columns
if subquery_compiler.query.annotations or not subquery_compiler.query.default_cols
else None
)
Expand Down Expand Up @@ -179,7 +178,11 @@ def ref(self, compiler, connection): # noqa: ARG001
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name
else ""
)
return f"${prefix}{self.refs}"
if hasattr(self, "ordinal"):
refs, _ = compiler.columns[self.ordinal - 1]
else:
refs = self.refs
return f"${prefix}{refs}"


def star(self, compiler, connection): # noqa: ARG001
Expand Down
26 changes: 24 additions & 2 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
greatest_least_ignores_nulls = True
has_json_object_function = False
has_native_json_field = True
rounds_to_even = True
supports_boolean_expr_in_select_clause = True
supports_collation_on_charfield = False
supports_column_check_constraints = False
Expand Down Expand Up @@ -46,6 +47,12 @@ class DatabaseFeatures(BaseDatabaseFeatures):
uses_savepoints = False

_django_test_expected_failures = {
# $concat only supports strings, not int
"db_functions.text.test_concat.ConcatTests.test_concat_non_str",
# QuerySet.order_by() with annotation transform doesn't work:
# "Expression $mod takes exactly 2 arguments. 1 were passed in"
# https://github.com/django/django/commit/b0ad41198b3e333f57351e3fce5a1fb47f23f376
"aggregation.tests.AggregateTestCase.test_order_by_aggregate_transform",
# 'NulledTransform' object has no attribute 'as_mql'.
"lookup.tests.LookupTests.test_exact_none_transform",
# "Save with update_fields did not affect any rows."
Expand All @@ -56,8 +63,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# Pattern lookups that use regexMatch don't work on JSONField:
# Unsupported conversion from array to string in $convert
"model_fields.test_jsonfield.TestQuerying.test_icontains",
# MongoDB gives ROUND(365, -1)=360 instead of 370 like other databases.
"db_functions.math.test_round.RoundTests.test_integer_with_negative_precision",
# Truncating in another timezone doesn't work becauase MongoDB converts
# the result back to UTC.
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_func_with_timezone",
Expand All @@ -78,10 +83,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# Connection creation doesn't follow the usual Django API.
"backends.tests.ThreadTests.test_pass_connection_between_threads",
"backends.tests.ThreadTests.test_default_connection_thread_local",
"test_utils.tests.DisallowedDatabaseQueriesTests.test_disallowed_thread_database_connection",
# Object of type ObjectId is not JSON serializable.
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
# GenericRelation.value_to_string() assumes integer pk.
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
# pymongo.errors.WriteError: Performing an update on the path '_id'
# would modify the immutable field '_id'
"migrations.test_operations.OperationTests.test_composite_pk_operations",
}
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
_django_test_expected_failures_bitwise = {
Expand Down Expand Up @@ -170,6 +179,7 @@ def django_test_expected_failures(self):
"fixtures.tests.FixtureLoadingTests.test_loading_and_dumping",
"m2m_through_regress.test_multitable.MultiTableTests.test_m2m_prefetch_proxied",
"m2m_through_regress.test_multitable.MultiTableTests.test_m2m_prefetch_reverse_proxied",
"many_to_many.tests.ManyToManyQueryTests.test_prefetch_related_no_queries_optimization_disabled",
"many_to_many.tests.ManyToManyTests.test_add_after_prefetch",
"many_to_many.tests.ManyToManyTests.test_add_then_remove_after_prefetch",
"many_to_many.tests.ManyToManyTests.test_clear_after_prefetch",
Expand All @@ -192,9 +202,13 @@ def django_test_expected_failures(self):
"prefetch_related.tests.Ticket21410Tests",
"queryset_pickle.tests.PickleabilityTestCase.test_pickle_prefetch_related_with_m2m_and_objects_deletion",
"serializers.test_json.JsonSerializerTestCase.test_serialize_prefetch_related_m2m",
"serializers.test_json.JsonSerializerTestCase.test_serialize_prefetch_related_m2m_with_natural_keys",
"serializers.test_jsonl.JsonlSerializerTestCase.test_serialize_prefetch_related_m2m",
"serializers.test_jsonl.JsonlSerializerTestCase.test_serialize_prefetch_related_m2m_with_natural_keys",
"serializers.test_xml.XmlSerializerTestCase.test_serialize_prefetch_related_m2m",
"serializers.test_xml.XmlSerializerTestCase.test_serialize_prefetch_related_m2m_with_natural_keys",
"serializers.test_yaml.YamlSerializerTestCase.test_serialize_prefetch_related_m2m",
"serializers.test_yaml.YamlSerializerTestCase.test_serialize_prefetch_related_m2m_with_natural_keys",
},
"AutoField not supported.": {
"bulk_create.tests.BulkCreateTests.test_bulk_insert_nullable_fields",
Expand Down Expand Up @@ -381,7 +395,11 @@ def django_test_expected_failures(self):
"delete.tests.DeletionTests.test_only_referenced_fields_selected",
"expressions.tests.ExistsTests.test_optimizations",
"lookup.tests.LookupTests.test_in_ignore_none",
"lookup.tests.LookupTests.test_lookup_direct_value_rhs_unwrapped",
"lookup.tests.LookupTests.test_textfield_exact_null",
"many_to_many.tests.ManyToManyQueryTests.test_count_join_optimization_disabled",
"many_to_many.tests.ManyToManyQueryTests.test_exists_join_optimization_disabled",
"many_to_many.tests.ManyToManyTests.test_custom_default_manager_exists_count",
"migrations.test_commands.MigrateTests.test_migrate_syncdb_app_label",
"migrations.test_commands.MigrateTests.test_migrate_syncdb_deferred_sql_executed_with_schemaeditor",
"queries.tests.ExistsSql.test_exists",
Expand Down Expand Up @@ -429,6 +447,7 @@ def django_test_expected_failures(self):
"raw_query.tests.RawQueryTests",
"schema.test_logging.SchemaLoggerTests.test_extra_args",
"schema.tests.SchemaTests.test_remove_constraints_capital_letters",
"test_utils.tests.AllowedDatabaseQueriesTests.test_allowed_database_copy_queries",
"timezones.tests.LegacyDatabaseTests.test_cursor_execute_accepts_naive_datetime",
"timezones.tests.LegacyDatabaseTests.test_cursor_execute_returns_naive_datetime",
"timezones.tests.LegacyDatabaseTests.test_raw_sql",
Expand Down Expand Up @@ -587,6 +606,9 @@ def django_test_expected_failures(self):
"foreign_object.tests.MultiColumnFKTests",
"foreign_object.tests.TestExtraJoinFilterQ",
},
"Tuple lookups are not supported.": {
"foreign_object.test_tuple_lookups.TupleLookupsTests",
},
"Custom lookups are not supported.": {
"custom_lookups.tests.BilateralTransformTests",
"custom_lookups.tests.LookupTests.test_basic_lookup",
Expand Down
7 changes: 4 additions & 3 deletions django_mongodb/lookups.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.db import NotSupportedError
from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn
from django.db.models.expressions import ColPairs
from django.db.models.fields.related_lookups import In, RelatedIn
from django.db.models.lookups import (
BuiltinLookup,
FieldGetDbPrepValueIterableMixin,
Expand Down Expand Up @@ -34,8 +35,8 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param):


def in_(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
raise NotImplementedError("MultiColSource is not supported.")
if isinstance(self.lhs, ColPairs):
raise NotImplementedError("ColPairs is not supported.")
db_rhs = getattr(self.rhs, "_db", None)
if db_rhs is not None and db_rhs != connection.alias:
raise ValueError(
Expand Down
22 changes: 0 additions & 22 deletions django_mongodb/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,28 +179,6 @@ def execute_sql_flush(self, tables):
if not options.get("capped", False):
collection.delete_many({})

def prep_lookup_value(self, value, field, lookup):
"""
Perform type-conversion on `value` before using as a filter parameter.
"""
if getattr(field, "rel", None) is not None:
field = field.rel.get_related_field()
field_kind = field.get_internal_type()

if lookup in ("in", "range"):
return [
self._prep_lookup_value(subvalue, field, field_kind, lookup) for subvalue in value
]
return self._prep_lookup_value(value, field, field_kind, lookup)

def _prep_lookup_value(self, value, field, field_kind, lookup):
if value is None:
return None

if field_kind == "DecimalField":
value = self.adapt_decimalfield_value(value, field.max_digits, field.decimal_places)
return value

def explain_query_prefix(self, format=None, **options):
# Validate options.
validated_options = {}
Expand Down
7 changes: 3 additions & 4 deletions django_mongodb/query_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def process_lhs(node, compiler, connection):
# node is a Func or Expression, possibly with multiple source expressions.
result = []
for expr in node.get_source_expressions():
if expr is None:
continue
try:
result.append(expr.as_mql(compiler, connection))
except FullResultSet:
Expand Down Expand Up @@ -40,10 +42,7 @@ def process_rhs(node, compiler, connection):
value = value[0]
if hasattr(node, "prep_lookup_value_mongo"):
value = node.prep_lookup_value_mongo(value)
# No need to prepare expressions like F() objects.
if hasattr(rhs, "resolve_expression"):
return value
return connection.ops.prep_lookup_value(value, node.lhs.output_field, node.lookup_name)
return value


def regex_match(field, regex_vals, insensitive=False):
Expand Down
Loading
Loading