Skip to content

Commit

Permalink
chore: Remove unnecessary autoflush from tagging and key/value workfl…
Browse files Browse the repository at this point in the history
…ows (apache#26009)
  • Loading branch information
john-bodley authored and josedev-union committed Jan 22, 2024
1 parent 7fa145a commit f9a8817
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 68 deletions.
8 changes: 1 addition & 7 deletions superset/key_value/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,7 @@ def validate(self) -> None:

def delete(self) -> bool:
filter_ = get_filter(self.resource, self.key)
entry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
.autoflush(False)
.first()
)
if entry:
if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first():
db.session.delete(entry)
db.session.commit()
return True
Expand Down
7 changes: 1 addition & 6 deletions superset/key_value/commands/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,7 @@ def validate(self) -> None:

def get(self) -> Optional[Any]:
filter_ = get_filter(self.resource, self.key)
entry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
.autoflush(False)
.first()
)
entry = db.session.query(KeyValueEntry).filter_by(**filter_).first()
if entry and (entry.expires_on is None or entry.expires_on > datetime.now()):
return self.codec.decode(entry.value)
return None
5 changes: 1 addition & 4 deletions superset/key_value/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ def validate(self) -> None:
def update(self) -> Optional[Key]:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
.autoflush(False)
.first()
db.session.query(KeyValueEntry).filter_by(**filter_).first()
)
if entry:
entry.value = self.codec.encode(self.value)
Expand Down
5 changes: 1 addition & 4 deletions superset/key_value/commands/upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ def validate(self) -> None:
def upsert(self) -> Key:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
.autoflush(False)
.first()
db.session.query(KeyValueEntry).filter_by(**filter_).first()
)
if entry:
entry.value = self.codec.encode(self.value)
Expand Down
38 changes: 10 additions & 28 deletions superset/tags/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from typing import TYPE_CHECKING

from flask_appbuilder import Model
from sqlalchemy import Column, Enum, ForeignKey, Integer, String, Table, Text
from sqlalchemy import Column, Enum, ForeignKey, Integer, orm, String, Table, Text
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship, Session, sessionmaker
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.orm.mapper import Mapper

from superset import security_manager
Expand All @@ -35,7 +35,7 @@
from superset.models.slice import Slice
from superset.models.sql_lab import Query

Session = sessionmaker(autoflush=False)
Session = sessionmaker()

user_favorite_tag_table = Table(
"user_favorite_tag",
Expand Down Expand Up @@ -111,7 +111,7 @@ class TaggedObject(Model, AuditMixinNullable):
tag = relationship("Tag", back_populates="objects", overlaps="tags")


def get_tag(name: str, session: Session, type_: TagType) -> Tag:
def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag:
tag_name = name.strip()
tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none()
if tag is None:
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_owners_ids(
@classmethod
def _add_owners(
cls,
session: Session,
session: orm.Session,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
for owner_id in cls.get_owners_ids(target):
Expand All @@ -166,9 +166,7 @@ def after_insert(
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
session = Session(bind=connection)

try:
with Session(bind=connection) as session:
# add `owner:` tags
cls._add_owners(session, target)

Expand All @@ -179,8 +177,6 @@ def after_insert(
)
session.add(tagged_object)
session.commit()
finally:
session.close()

@classmethod
def after_update(
Expand All @@ -189,9 +185,7 @@ def after_update(
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
session = Session(bind=connection)

try:
with Session(bind=connection) as session:
# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
Expand All @@ -210,8 +204,6 @@ def after_update(
# add `owner:` tags
cls._add_owners(session, target)
session.commit()
finally:
session.close()

@classmethod
def after_delete(
Expand All @@ -220,18 +212,14 @@ def after_delete(
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
session = Session(bind=connection)

try:
with Session(bind=connection) as session:
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()

session.commit()
finally:
session.close()


class ChartUpdater(ObjectUpdater):
Expand Down Expand Up @@ -271,8 +259,7 @@ class FavStarUpdater:
def after_insert(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
try:
with Session(bind=connection) as session:
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagType.favorited_by)
tagged_object = TaggedObject(
Expand All @@ -282,15 +269,12 @@ def after_insert(
)
session.add(tagged_object)
session.commit()
finally:
session.close()

@classmethod
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
try:
with Session(bind=connection) as session:
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
Expand All @@ -307,5 +291,3 @@ def after_delete(
)

session.commit()
finally:
session.close()
12 changes: 3 additions & 9 deletions tests/integration_tests/key_value/commands/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def test_create_id_entry(app_context: AppContext, admin: User) -> None:
value=JSON_VALUE,
codec=JSON_CODEC,
).run()
entry = (
db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one()
)
entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
assert json.loads(entry.value) == JSON_VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
Expand All @@ -63,9 +61,7 @@ def test_create_uuid_entry(app_context: AppContext, admin: User) -> None:
key = CreateKeyValueCommand(
resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC
).run()
entry = (
db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one()
)
entry = db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).one()
assert json.loads(entry.value) == JSON_VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
Expand Down Expand Up @@ -93,9 +89,7 @@ def test_create_pickle_entry(app_context: AppContext, admin: User) -> None:
value=PICKLE_VALUE,
codec=PICKLE_CODEC,
).run()
entry = (
db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one()
)
entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE)
assert entry.created_by_fk == admin.id
db.session.delete(entry)
Expand Down
6 changes: 2 additions & 4 deletions tests/integration_tests/key_value/commands/update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_update_id_entry(
).run()
assert key is not None
assert key.id == ID_KEY
entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one()
entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id

Expand All @@ -79,9 +79,7 @@ def test_update_uuid_entry(
).run()
assert key is not None
assert key.uuid == UUID_KEY
entry = (
db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one()
)
entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id

Expand Down
8 changes: 2 additions & 6 deletions tests/integration_tests/key_value/commands/upsert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def test_upsert_id_entry(
).run()
assert key is not None
assert key.id == ID_KEY
entry = (
db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one()
)
entry = db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id

Expand All @@ -81,9 +79,7 @@ def test_upsert_uuid_entry(
).run()
assert key is not None
assert key.uuid == UUID_KEY
entry = (
db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one()
)
entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id

Expand Down

0 comments on commit f9a8817

Please sign in to comment.