Skip to content

Commit

Permalink
remove unneeded bits
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Jan 4, 2025
1 parent 3e3ee82 commit 14af7cb
Showing 1 changed file with 4 additions and 23 deletions.
27 changes: 4 additions & 23 deletions django_mongodb/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class ArrayField(CheckFieldDefaultMixin, Field):

def __init__(self, base_field, size=None, **kwargs):
self.base_field = base_field
self.db_collation = getattr(self.base_field, "db_collation", None)
self.size = size
if self.size:
self.default_validators = [
Expand Down Expand Up @@ -72,7 +71,6 @@ def check(self, **kwargs):
)
)
else:
# Remove the field name checks as they are not needed here.
base_checks = self.base_field.check()
if base_checks:
error_messages = "\n ".join(
Expand Down Expand Up @@ -114,17 +112,10 @@ def description(self):
def db_type(self, connection):
return "array"

def db_parameters(self, connection):
db_params = super().db_parameters(connection)
db_params["collation"] = self.db_collation
return db_params

def get_placeholder(self, value, compiler, connection):
return f"%s::{self.db_type(connection)}"

def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, list | tuple):
# Workaround for https://code.djangoproject.com/ticket/35982
# (fixed in Django 5.2).
if isinstance(self.base_field, DecimalField):
return [self.base_field.get_db_prep_save(i, connection) for i in value]
return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
Expand All @@ -144,7 +135,7 @@ def deconstruct(self):

def to_python(self, value):
if isinstance(value, str):
# Assume we're deserializing
# Assume value is being deserialized,
vals = json.loads(value)
value = [self.base_field.to_python(val) for val in vals]
return value
Expand Down Expand Up @@ -236,9 +227,7 @@ def as_mql(self, compiler, connection):

class ArrayRHSMixin:
def __init__(self, lhs, rhs):
# Don't wrap arrays that contains only None values, psycopg doesn't
# allow this.
if isinstance(rhs, tuple | list) and any(self._rhs_not_none_values(rhs)):
if isinstance(rhs, tuple | list):
expressions = []
for value in rhs:
if not hasattr(value, "resolve_expression"):
Expand All @@ -248,13 +237,6 @@ def __init__(self, lhs, rhs):
rhs = Array(*expressions)
super().__init__(lhs, rhs)

def _rhs_not_none_values(self, rhs):
for x in rhs:
if isinstance(x, list | tuple):
yield from self._rhs_not_none_values(x)
elif x is not None:
yield True


@ArrayField.register_lookup
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
Expand Down Expand Up @@ -298,8 +280,7 @@ def get_prep_lookup(self):
values = super().get_prep_lookup()
if hasattr(values, "resolve_expression"):
return values
# In.process_rhs() expects values to be hashable, so convert lists
# to tuples.
# process_rhs() expects hashable values, so convert lists to tuples.
prepared_values = []
for value in values:
if hasattr(value, "resolve_expression"):
Expand Down

0 comments on commit 14af7cb

Please sign in to comment.