Skip to content

Commit

Permalink
Refactor duplicate pool insert handling
Browse files Browse the repository at this point in the history
- handle exception from db level instead of application level
  • Loading branch information
jason810496 committed Nov 20, 2024
1 parent f248fe2 commit 1ec682f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 41 deletions.
15 changes: 1 addition & 14 deletions airflow/api_fastapi/core_api/datamodels/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from typing import Annotated, Callable

from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field


def _call_function(function: Callable[[], int]) -> int:
Expand Down Expand Up @@ -81,16 +81,3 @@ class PoolPostBulkBody(BaseModel):
"""Pools serializer for post bodies."""

pools: list[PoolPostBody]

@field_validator("pools", mode="after")
def validate_pools(cls, v: list[PoolPostBody]) -> list[PoolPostBody]:
pool_set = set()
duplicates = []
for pool in v:
if pool.pool in pool_set:
duplicates.append(pool.pool)
else:
pool_set.add(pool.pool)
if duplicates:
raise ValueError(f"Pool name should be unique, found duplicates: {duplicates}")
return v
24 changes: 13 additions & 11 deletions airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_session, paginated_select
Expand Down Expand Up @@ -173,12 +174,14 @@ def post_pool(
session: Annotated[Session, Depends(get_session)],
) -> PoolResponse:
"""Create a Pool."""
pool = session.scalar(select(Pool).where(Pool.pool == body.pool))
if pool is not None:
raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{body.pool}` already exists")
pool = Pool(**body.model_dump())

session.add(pool)
try:
session.commit()
except IntegrityError:
session.rollback()
raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{body.pool}` already exists")

return PoolResponse.model_validate(pool, from_attributes=True)

Expand All @@ -197,18 +200,17 @@ def post_pools(
session: Annotated[Session, Depends(get_session)],
) -> PoolCollectionResponse:
"""Create multiple pools."""
# Check if any of the pools already exists
pools_names = [pool.pool for pool in body.pools]
existing_pools = session.scalars(select(Pool.pool).where(Pool.pool.in_(pools_names))).all()
if existing_pools:
pools = [Pool(**body.model_dump()) for body in body.pools]
session.add_all(pools)
try:
session.commit()
except IntegrityError as e:
session.rollback()
raise HTTPException(
status.HTTP_409_CONFLICT,
detail=f"Pools with names: `{existing_pools}` already exist",
detail=f"One or more pools already exists. Error: {e}",
)

pools = [Pool(**body.model_dump()) for body in body.pools]
session.add_all(pools)

return PoolCollectionResponse(
pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools],
total_entries=len(pools),
Expand Down
19 changes: 3 additions & 16 deletions tests/api_fastapi/core_api/routes/public/test_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def _create_pool_in_test(
response = test_client.post("/public/pools/", json=body)
assert response.status_code == expected_status_code

body = response.json()
assert response.json() == expected_response
if check_count:
assert session.query(Pool).count() == n_pools + 1
Expand Down Expand Up @@ -464,16 +463,8 @@ class TestPostPools(TestPoolsEndpoint):
{"name": "my_pool", "slots": 12},
]
},
422,
{
"detail": [
{
"loc": ["body", "pools"],
"msg": "Value error, Pool name should be unique, found duplicates: ['my_pool']",
"type": "value_error",
}
]
},
409,
{},
),
],
)
Expand All @@ -485,11 +476,7 @@ def test_post_pools(self, test_client, session, body, expected_status_code, expe
response_json = response.json()
if expected_status_code == 201:
assert response_json == expected_response
elif expected_status_code == 422:
assert response_json["detail"][0]["loc"] == expected_response["detail"][0]["loc"]
assert response_json["detail"][0]["msg"] == expected_response["detail"][0]["msg"]
assert response_json["detail"][0]["type"] == expected_response["detail"][0]["type"]
if expected_status_code == 201:
assert session.query(Pool).count() == n_pools + 2
else:
# since different database backend return different error messages, we just check the status code
assert session.query(Pool).count() == n_pools

0 comments on commit 1ec682f

Please sign in to comment.