Skip to content

Commit

Permalink
Add support for mapped columns
Browse files Browse the repository at this point in the history
  • Loading branch information
atugushev committed Apr 8, 2024
1 parent 02045c4 commit 801183a
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions sqlalchemy_utils/i18n.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import MappedColumn
from sqlalchemy.sql.expression import ColumnElement

from .exceptions import ImproperlyConfigured
Expand Down Expand Up @@ -35,7 +36,7 @@ def cast_locale(obj, locale, attr):
"""
if callable(locale):
try:
locale = locale(obj, attr.key)
locale = locale(obj, _get_attr_key(attr))
except TypeError:
try:
locale = locale(obj)
Expand Down Expand Up @@ -83,26 +84,26 @@ def getter_factory(self, attr):
def getter(obj):
current_locale = cast_locale(obj, self.current_locale, attr)
try:
return getattr(obj, attr.key)[current_locale]
return getattr(obj, _get_attr_key(attr))[current_locale]
except (TypeError, KeyError):
default_locale = cast_locale(obj, self.default_locale, attr)
try:
return getattr(obj, attr.key)[default_locale]
return getattr(obj, _get_attr_key(attr))[default_locale]
except (TypeError, KeyError):
return self.default_value
return getter

def setter_factory(self, attr):
def setter(obj, value):
if getattr(obj, attr.key) is None:
setattr(obj, attr.key, {})
if getattr(obj, _get_attr_key(attr)) is None:
setattr(obj, _get_attr_key(attr), {})
locale = cast_locale(obj, self.current_locale, attr)
getattr(obj, attr.key)[locale] = value
getattr(obj, _get_attr_key(attr))[locale] = value
return setter

def expr_factory(self, attr):
def expr(cls):
cls_attr = getattr(cls, attr.key)
cls_attr = getattr(cls, _get_attr_key(attr))
current_locale = cast_locale_expr(cls, self.current_locale, attr)
default_locale = cast_locale_expr(cls, self.default_locale, attr)
return sa.func.coalesce(
Expand All @@ -117,3 +118,9 @@ def __call__(self, attr):
fset=self.setter_factory(attr),
expr=self.expr_factory(attr)
)


def _get_attr_key(attr):
if isinstance(attr, MappedColumn):
return attr.column.key
return attr.key

0 comments on commit 801183a

Please sign in to comment.