diff --git a/airflow/api_fastapi/core_api/datamodels/pools.py b/airflow/api_fastapi/core_api/datamodels/pools.py index 1cc3838b23912..137392094cb5d 100644 --- a/airflow/api_fastapi/core_api/datamodels/pools.py +++ b/airflow/api_fastapi/core_api/datamodels/pools.py @@ -19,7 +19,7 @@ from typing import Annotated, Callable -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field def _call_function(function: Callable[[], int]) -> int: @@ -81,16 +81,3 @@ class PoolPostBulkBody(BaseModel): """Pools serializer for post bodies.""" pools: list[PoolPostBody] - - @field_validator("pools", mode="after") - def validate_pools(cls, v: list[PoolPostBody]) -> list[PoolPostBody]: - pool_set = set() - duplicates = [] - for pool in v: - if pool.pool in pool_set: - duplicates.append(pool.pool) - else: - pool_set.add(pool.pool) - if duplicates: - raise ValueError(f"Pool name should be unique, found duplicates: {duplicates}") - return v diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index b485fce47cf72..2a39cf1a82597 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -22,6 +22,7 @@ from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from airflow.api_fastapi.common.db.common import get_session, paginated_select @@ -173,12 +174,14 @@ def post_pool( session: Annotated[Session, Depends(get_session)], ) -> PoolResponse: """Create a Pool.""" - pool = session.scalar(select(Pool).where(Pool.pool == body.pool)) - if pool is not None: - raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{body.pool}` already exists") pool = Pool(**body.model_dump()) session.add(pool) + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{body.pool}` already exists") return PoolResponse.model_validate(pool, from_attributes=True) @@ -197,18 +200,17 @@ def post_pools( session: Annotated[Session, Depends(get_session)], ) -> PoolCollectionResponse: """Create multiple pools.""" - # Check if any of the pools already exists - pools_names = [pool.pool for pool in body.pools] - existing_pools = session.scalars(select(Pool.pool).where(Pool.pool.in_(pools_names))).all() - if existing_pools: + pools = [Pool(**body.model_dump()) for body in body.pools] + session.add_all(pools) + try: + session.commit() + except IntegrityError as e: + session.rollback() raise HTTPException( status.HTTP_409_CONFLICT, - detail=f"Pools with names: `{existing_pools}` already exist", + detail=f"One or more pools already exists. Error: {e}", ) - pools = [Pool(**body.model_dump()) for body in body.pools] - session.add_all(pools) - return PoolCollectionResponse( pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools], total_entries=len(pools), diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index eb56d50b7e33a..c74421a900f41 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -71,7 +71,6 @@ def _create_pool_in_test( response = test_client.post("/public/pools/", json=body) assert response.status_code == expected_status_code - body = response.json() assert response.json() == expected_response if check_count: assert session.query(Pool).count() == n_pools + 1 @@ -464,16 +463,8 @@ class TestPostPools(TestPoolsEndpoint): {"name": "my_pool", "slots": 12}, ] }, - 422, - { - "detail": [ - { - "loc": ["body", "pools"], - "msg": "Value error, Pool name should be unique, found duplicates: ['my_pool']", - "type": "value_error", - } - ] - }, + 409, + {}, ), ], ) @@ -485,11 +476,7 @@ def test_post_pools(self, test_client, session, body, expected_status_code, expe response_json = response.json() if expected_status_code == 201: assert response_json == expected_response - elif expected_status_code == 422: - assert response_json["detail"][0]["loc"] == expected_response["detail"][0]["loc"] - assert response_json["detail"][0]["msg"] == expected_response["detail"][0]["msg"] - assert response_json["detail"][0]["type"] == expected_response["detail"][0]["type"] - if expected_status_code == 201: assert session.query(Pool).count() == n_pools + 2 else: + # since different database backend return different error messages, we just check the status code assert session.query(Pool).count() == n_pools