Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forbid extra fields in BaseModel #44306

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions airflow/api_fastapi/core_api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,15 @@ class BaseModel(PydanticBaseModel):
"""

model_config = ConfigDict(from_attributes=True, populate_by_name=True)


class StrictBaseModel(BaseModel):
"""
StrictBaseModel is a base Pydantic model for REST API that does not allow any extra fields.

Use this class for models that should not have any extra fields in the payload.

:meta private:
"""

model_config = ConfigDict(from_attributes=True, populate_by_name=True, extra="forbid")
10 changes: 5 additions & 5 deletions airflow/api_fastapi/core_api/datamodels/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@

from pydantic import Field, field_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.utils.log.secrets_masker import redact


class DagScheduleAssetReference(BaseModel):
class DagScheduleAssetReference(StrictBaseModel):
"""DAG schedule reference serializer for assets."""

dag_id: str
created_at: datetime
updated_at: datetime


class TaskOutletAssetReference(BaseModel):
class TaskOutletAssetReference(StrictBaseModel):
"""Task outlet reference serializer for assets."""

dag_id: str
Expand Down Expand Up @@ -84,7 +84,7 @@ class AssetAliasCollectionResponse(BaseModel):
total_entries: int


class DagRunAssetReference(BaseModel):
class DagRunAssetReference(StrictBaseModel):
"""DAGRun serializer for asset responses."""

run_id: str
Expand Down Expand Up @@ -141,7 +141,7 @@ class QueuedEventCollectionResponse(BaseModel):
total_entries: int


class CreateAssetEventsBody(BaseModel):
class CreateAssetEventsBody(StrictBaseModel):
"""Create asset events request."""

asset_id: int
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from datetime import datetime

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.models.backfill import ReprocessBehavior


class BackfillPostBody(BaseModel):
class BackfillPostBody(StrictBaseModel):
"""Object used for create backfill request."""

dag_id: str
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_fastapi/core_api/datamodels/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pydantic import Discriminator, Field, Tag

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel

# Common Bulk Data Models
T = TypeVar("T")
Expand Down Expand Up @@ -57,7 +57,7 @@ class BulkActionNotOnExistence(enum.Enum):
SKIP = "skip"


class BulkBaseAction(BaseModel, Generic[T]):
class BulkBaseAction(StrictBaseModel, Generic[T]):
"""Base class for bulk actions."""

action: BulkAction = Field(..., description="The action to be performed on the entities.")
Expand Down Expand Up @@ -88,7 +88,7 @@ def _action_discriminator(action: Any) -> str:
return BulkAction(action["action"]).value


class BulkBody(BaseModel, Generic[T]):
class BulkBody(StrictBaseModel, Generic[T]):
"""Serializer for bulk entity operations."""

actions: list[
Expand Down
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/datamodels/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
# under the License.
from __future__ import annotations

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import StrictBaseModel


class ConfigOption(BaseModel):
class ConfigOption(StrictBaseModel):
"""Config option."""

key: str
Expand All @@ -32,7 +32,7 @@ def text_format(self):
return f"{self.key} = {self.value}"


class ConfigSection(BaseModel):
class ConfigSection(StrictBaseModel):
"""Config Section Schema."""

name: str
Expand All @@ -53,7 +53,7 @@ def text_format(self):
return f"[{self.name}]\n" + "\n".join(option.text_format for option in self.options) + "\n"


class Config(BaseModel):
class Config(StrictBaseModel):
"""List of config sections with their options."""

sections: list[ConfigSection]
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pydantic import Field, field_validator
from pydantic_core.core_schema import ValidationInfo

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.utils.log.secrets_masker import redact


Expand Down Expand Up @@ -76,7 +76,7 @@ class ConnectionTestResponse(BaseModel):


# Request Models
class ConnectionBody(BaseModel):
class ConnectionBody(StrictBaseModel):
"""Connection Serializer for requests body."""

connection_id: str = Field(serialization_alias="conn_id", max_length=200, pattern=r"^[\w.-]+$")
Expand Down
10 changes: 5 additions & 5 deletions airflow/api_fastapi/core_api/datamodels/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pydantic import AwareDatetime, Field, NonNegativeInt, computed_field, model_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.models import DagRun
from airflow.utils import timezone
from airflow.utils.state import DagRunState
Expand All @@ -37,14 +37,14 @@ class DAGRunPatchStates(str, Enum):
FAILED = DagRunState.FAILED


class DAGRunPatchBody(BaseModel):
class DAGRunPatchBody(StrictBaseModel):
"""DAG Run Serializer for PATCH requests."""

state: DAGRunPatchStates | None = None
note: str | None = Field(None, max_length=1000)


class DAGRunClearBody(BaseModel):
class DAGRunClearBody(StrictBaseModel):
"""DAG Run serializer for clear endpoint body."""

dry_run: bool = True
Expand Down Expand Up @@ -78,7 +78,7 @@ class DAGRunCollectionResponse(BaseModel):
total_entries: int


class TriggerDAGRunPostBody(BaseModel):
class TriggerDAGRunPostBody(StrictBaseModel):
"""Trigger DAG Run Serializer for POST body."""

dag_run_id: str | None = None
Expand Down Expand Up @@ -109,7 +109,7 @@ def logical_date(self) -> datetime:
return timezone.utcnow()


class DAGRunsBatchBody(BaseModel):
class DAGRunsBatchBody(StrictBaseModel):
"""List DAG Runs body for batch endpoint."""

order_by: str | None = None
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
field_validator,
)

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.api_fastapi.core_api.datamodels.dag_tags import DagTagResponse
from airflow.configuration import conf

Expand Down Expand Up @@ -92,7 +92,7 @@ def file_token(self) -> str:
return serializer.dumps(self.fileloc)


class DAGPatchBody(BaseModel):
class DAGPatchBody(StrictBaseModel):
"""Dag Serializer for updatable bodies."""

is_paused: bool
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_fastapi/core_api/datamodels/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pydantic import BeforeValidator, ConfigDict, Field

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel


def _call_function(function: Callable[[], int]) -> int:
Expand Down Expand Up @@ -60,7 +60,7 @@ class PoolCollectionResponse(BaseModel):
total_entries: int


class PoolPatchBody(BaseModel):
class PoolPatchBody(StrictBaseModel):
"""Pool serializer for patch bodies."""

model_config = ConfigDict(populate_by_name=True, from_attributes=True)
Expand All @@ -71,7 +71,7 @@ class PoolPatchBody(BaseModel):
include_deferred: bool | None = None


class PoolBody(BasePool):
class PoolBody(BasePool, StrictBaseModel):
"""Pool serializer for post bodies."""

pool: str = Field(alias="name", max_length=256)
Expand Down
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
model_validator,
)

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.api_fastapi.core_api.datamodels.job import JobResponse
from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -97,7 +97,7 @@ class TaskDependencyCollectionResponse(BaseModel):
dependencies: list[TaskDependencyResponse]


class TaskInstancesBatchBody(BaseModel):
class TaskInstancesBatchBody(StrictBaseModel):
"""Task Instance body for get batch."""

dag_ids: list[str] | None = None
Expand Down Expand Up @@ -159,7 +159,7 @@ class TaskInstanceHistoryCollectionResponse(BaseModel):
total_entries: int


class ClearTaskInstancesBody(BaseModel):
class ClearTaskInstancesBody(StrictBaseModel):
"""Request body for Clear Task Instances endpoint."""

dry_run: bool = True
Expand Down Expand Up @@ -195,7 +195,7 @@ def validate_model(cls, data: Any) -> Any:
return data


class PatchTaskInstanceBody(BaseModel):
class PatchTaskInstanceBody(StrictBaseModel):
"""Request body for Clear Task Instances endpoint."""

new_state: TaskInstanceState | None = None
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pydantic import ConfigDict, Field, model_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.models.base import ID_LEN
from airflow.typing_compat import Self
from airflow.utils.log.secrets_masker import redact
Expand Down Expand Up @@ -52,7 +52,7 @@ def redact_val(self) -> Self:
return self


class VariableBody(BaseModel):
class VariableBody(StrictBaseModel):
"""Variable serializer for bodies."""

key: str = Field(max_length=ID_LEN)
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def value_to_string(cls, v):
return str(v) if v is not None else None


class XComCollection(BaseModel):
"""List of XCom items."""
class XComCollectionResponse(BaseModel):
pierrejeambrun marked this conversation as resolved.
Show resolved Hide resolved
"""XCom Collection serializer for responses."""

xcom_entries: list[XComResponse]
total_entries: int
Loading