Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support Pydantic V2 #281

Merged
merged 16 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ async def logging_middleware(
port = request.client.port
client_address = f"{ip_address}:{port}"
else:
ip_address = "0.0.0.0" # In case of a test (see https://github.com/encode/starlette/pull/2377)
client_address = "unknown"

settings: Settings = app.dependency_overrides.get(get_settings, get_settings)()
Expand Down
120 changes: 66 additions & 54 deletions app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from functools import cached_property
from typing import Any

from jose import jwk
from jose.exceptions import JWKError
from pydantic import BaseSettings, root_validator
from pydantic import computed_field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

from app.utils.auth import providers

Expand All @@ -20,13 +22,22 @@ class Settings(BaseSettings):
To access these settings, the `get_settings` dependency should be used.
"""

# By default, the settings are loaded from the `.env` file but this behaviour can be overridden by using
# `_env_file` parameter during instantiation
# Ex: `Settings(_env_file=".env.dev")`
# Without this property, @cached_property decorator raise "TypeError: cannot pickle '_thread.RLock' object"
# See https://github.com/samuelcolvin/pydantic/issues/1241
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore"
)

# NOTE: Variables without a value should not be configured in this class, but added to the dotenv .env file

#####################################
# SMTP configuration using starttls #
#####################################

SMTP_ACTIVE: bool
SMTP_ACTIVE: bool = False
SMTP_PORT: int
SMTP_SERVER: str
SMTP_USERNAME: str
Expand All @@ -41,10 +52,10 @@ class Settings(BaseSettings):
# If the following parameters are not set, logging won't use the Matrix handler
# MATRIX_SERVER_BASE_URL is optional, the official Matrix server will be used if not configured
# Advanced note: Username and password will be used to ask for an access token. A Matrix custom client `Hyperion` is used to make all requests
MATRIX_SERVER_BASE_URL: str | None
MATRIX_TOKEN: str | None
MATRIX_LOG_ERROR_ROOM_ID: str | None
MATRIX_LOG_AMAP_ROOM_ID: str | None
MATRIX_SERVER_BASE_URL: str | None = None
MATRIX_TOKEN: str | None = None
MATRIX_LOG_ERROR_ROOM_ID: str | None = None
MATRIX_LOG_AMAP_ROOM_ID: str | None = None

########################
# Redis configuration #
Expand All @@ -54,7 +65,7 @@ class Settings(BaseSettings):
# If you want to use a custom configuration, a password and a specific binds should be used to avoid security issues
REDIS_HOST: str
REDIS_PORT: int
REDIS_PASSWORD: str | None
REDIS_PASSWORD: str | None = None
REDIS_LIMIT: int
REDIS_WINDOW: int

Expand All @@ -72,11 +83,11 @@ class Settings(BaseSettings):
SQLITE_DB: str | None = (
None # If set, the application use a SQLite database instead of PostgreSQL, for testing or development purposes (should not be used if possible)
)
POSTGRES_HOST: str
POSTGRES_USER: str
POSTGRES_PASSWORD: str
POSTGRES_DB: str
DATABASE_DEBUG: bool # If True, the database will log all queries
POSTGRES_HOST: str = ""
POSTGRES_USER: str = ""
POSTGRES_PASSWORD: str = ""
POSTGRES_DB: str = ""
DATABASE_DEBUG: bool = False # If True, the database will log all queries

#####################
# Hyperion settings #
Expand Down Expand Up @@ -105,11 +116,11 @@ class Settings(BaseSettings):
# Tokens validity #
###################

USER_ACTIVATION_TOKEN_EXPIRE_HOURS = 24
PASSWORD_RESET_TOKEN_EXPIRE_HOURS = 12
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 14 # 14 days
AUTHORIZATION_CODE_EXPIRE_MINUTES = 7
USER_ACTIVATION_TOKEN_EXPIRE_HOURS: int = 24
PASSWORD_RESET_TOKEN_EXPIRE_HOURS: int = 12
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 14 # 14 days
AUTHORIZATION_CODE_EXPIRE_MINUTES: int = 7

###############################################
# Authorization using OAuth or Openid connect #
Expand All @@ -129,7 +140,7 @@ class Settings(BaseSettings):
)

# Openid connect issuer name
AUTH_ISSUER = "hyperion"
AUTH_ISSUER: str = "hyperion"

# Add an AUTH_CLIENTS variable to the .env dotenv to configure auth clients
# This variable should have the format: [["client id", "client secret", "redirect_uri", "app.utils.auth.providers class name"]]
Expand All @@ -149,16 +160,19 @@ class Settings(BaseSettings):
# The combination of `@property` and `@lru_cache` should be replaced by `@cached_property`
# See https://docs.python.org/3.8/library/functools.html?highlight=#functools.cached_property

@computed_field # type: ignore[misc] # Current issue with mypy, see https://docs.pydantic.dev/2.0/usage/computed_fields/ and https://github.com/python/mypy/issues/1362
@cached_property
def RSA_PRIVATE_KEY(cls):
def RSA_PRIVATE_KEY(cls) -> Any:
return jwk.construct(cls.RSA_PRIVATE_PEM_STRING, algorithm="RS256")

@computed_field # type: ignore[misc]
@cached_property
def RSA_PUBLIC_KEY(cls):
def RSA_PUBLIC_KEY(cls) -> Any:
return cls.RSA_PRIVATE_KEY.public_key()

@computed_field # type: ignore[misc]
@cached_property
def RSA_PUBLIC_JWK(cls):
def RSA_PUBLIC_JWK(cls) -> dict[str, list[dict[str, str]]]:
JWK = cls.RSA_PUBLIC_KEY.to_dict()
JWK.update(
{
Expand All @@ -169,11 +183,12 @@ def RSA_PUBLIC_JWK(cls):
return {"keys": [JWK]}

# Tokens validity
USER_ACTIVATION_TOKEN_EXPIRES_HOURS = 24
PASSWORD_RESET_TOKEN_EXPIRES_HOURS = 12
USER_ACTIVATION_TOKEN_EXPIRES_HOURS: int = 24
PASSWORD_RESET_TOKEN_EXPIRES_HOURS: int = 12

# This property parse AUTH_CLIENTS to create a dictionary of auth clients:
# {"client_id": AuthClientClassInstance}
@computed_field # type: ignore[misc]
@cached_property
def KNOWN_AUTH_CLIENTS(cls) -> dict[str, providers.BaseAuthClient]:
clients = {}
Expand Down Expand Up @@ -205,58 +220,55 @@ def KNOWN_AUTH_CLIENTS(cls) -> dict[str, providers.BaseAuthClient]:
# Validators may be used to perform more complexe validation
# For example, we can check that at least one of two optional fields is set or that the RSA key is provided and valid

# TODO: Pydantic 2.0 will allow to use `@model_validator`

@root_validator
def check_database_settings(cls, settings: dict):
@model_validator(mode="after")
def check_database_settings(self) -> "Settings":
"""
All fields are optional, but the dotenv should configure SQLITE_DB or a Postgres database
"""
SQLITE_DB = settings.get("SQLITE_DB")
POSTGRES_HOST = settings.get("POSTGRES_HOST")
POSTGRES_USER = settings.get("POSTGRES_USER")
POSTGRES_PASSWORD = settings.get("POSTGRES_PASSWORD")
POSTGRES_DB = settings.get("POSTGRES_DB")

if not (
SQLITE_DB
or (POSTGRES_HOST and POSTGRES_USER and POSTGRES_PASSWORD and POSTGRES_DB)
self.SQLITE_DB
or (
self.POSTGRES_HOST
and self.POSTGRES_USER
and self.POSTGRES_PASSWORD
and self.POSTGRES_DB
)
):
raise ValueError(
"Either SQLITE_DB or POSTGRES_HOST, POSTGRES_USER, POSTGRES_PASSWORD and POSTGRES_DB should be configured in the dotenv"
)

return settings

@root_validator
def check_secrets(cls, settings: dict):
ACCESS_TOKEN_SECRET_KEY = settings.get("ACCESS_TOKEN_SECRET_KEY")
RSA_PRIVATE_PEM_STRING = settings.get("RSA_PRIVATE_PEM_STRING")
return self

if not ACCESS_TOKEN_SECRET_KEY:
@model_validator(mode="after")
def check_secrets(self) -> "Settings":
if not self.ACCESS_TOKEN_SECRET_KEY:
raise ValueError(
"ACCESS_TOKEN_SECRET_KEY should be configured in the dotenv"
)

if not RSA_PRIVATE_PEM_STRING:
if not self.RSA_PRIVATE_PEM_STRING:
raise ValueError(
"RSA_PRIVATE_PEM_STRING should be configured in the dotenv"
)

try:
jwk.construct(RSA_PRIVATE_PEM_STRING, algorithm="RS256")
jwk.construct(self.RSA_PRIVATE_PEM_STRING, algorithm="RS256")
except JWKError as e:
raise ValueError("RSA_PRIVATE_PEM_STRING is not a valid RSA key", e)

return settings
return self

class Config:
# By default, the settings are loaded from the `.env` file but this behaviour can be overridden by using
# `_env_file` parameter during instantiation
# Ex: `Settings(_env_file=".env.dev")`
env_file = ".env"
env_file_encoding = "utf-8"
@model_validator(mode="after")
def init_cached_property(self) -> "Settings":
"""
Cached property are not computed during the instantiation of the class, but when they are accessed for the first time.
By calling them in this validator, we force their initialization during the instantiation of the class.
This allow them to raise error on Hyperion startup if they are not correctly configured instead of creating an error on runtime.
"""
self.KNOWN_AUTH_CLIENTS
self.RSA_PRIVATE_KEY
self.RSA_PUBLIC_KEY
self.RSA_PUBLIC_JWK

# Without this property, @cached_property decorator raise "TypeError: cannot pickle '_thread.RLock' object"
# See https://github.com/samuelcolvin/pydantic/issues/1241
keep_untouched = (cached_property,)
return self
4 changes: 2 additions & 2 deletions app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def create_access_token(
if expires_delta is None:
# We use the default value
expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = data.dict(exclude_none=True)
to_encode = data.model_dump(exclude_none=True)
iat = datetime.utcnow()
expire_on = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire_on, "iat": iat})
Expand All @@ -124,7 +124,7 @@ def create_access_token_RS256(
expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)

to_encode: dict[str, Any] = additional_data
to_encode.update(data.dict(exclude_none=True))
to_encode.update(data.model_dump(exclude_none=True))

iat = datetime.utcnow()
expire_on = datetime.utcnow() + expires_delta
Expand Down
4 changes: 2 additions & 2 deletions app/cruds/cruds_advert.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def update_advertiser(
await db.execute(
update(models_advert.Advertiser)
.where(models_advert.Advertiser.id == advertiser_id)
.values(**advertiser_update.dict(exclude_none=True))
.values(**advertiser_update.model_dump(exclude_none=True))
)
try:
await db.commit()
Expand Down Expand Up @@ -127,7 +127,7 @@ async def update_advert(
await db.execute(
update(models_advert.Advert)
.where(models_advert.Advert.id == advert_id)
.values(**advert_update.dict(exclude_none=True))
.values(**advert_update.model_dump(exclude_none=True))
)
try:
await db.commit()
Expand Down
14 changes: 8 additions & 6 deletions app/cruds/cruds_amap.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def edit_product(
await db.execute(
update(models_amap.Product)
.where(models_amap.Product.id == product_id)
.values(**product_update.dict(exclude_none=True))
.values(**product_update.model_dump(exclude_none=True))
)
await db.commit()

Expand Down Expand Up @@ -139,7 +139,7 @@ async def create_delivery(
db: AsyncSession,
) -> models_amap.Delivery | None:
"""Create a new delivery in database and return it"""
db.add(models_amap.Delivery(**delivery.dict(exclude={"products_ids"})))
db.add(models_amap.Delivery(**delivery.model_dump(exclude={"products_ids"})))
try:
await db.commit()
except IntegrityError:
Expand Down Expand Up @@ -208,7 +208,7 @@ async def edit_delivery(
await db.execute(
update(models_amap.Delivery)
.where(models_amap.Delivery.id == delivery_id)
.values(**delivery.dict(exclude_none=True))
.values(**delivery.model_dump(exclude_none=True))
)
await db.commit()

Expand Down Expand Up @@ -254,7 +254,9 @@ async def add_order_to_delivery(
order: schemas_amap.OrderComplete,
):
db.add(
models_amap.Order(**order.dict(exclude={"products_ids", "products_quantity"}))
models_amap.Order(
**order.model_dump(exclude={"products_ids", "products_quantity"})
)
)
try:
await db.commit()
Expand All @@ -278,7 +280,7 @@ async def edit_order_without_products(
await db.execute(
update(models_amap.Order)
.where(models_amap.Order.order_id == order_id)
.values(**order.dict(exclude_none=True))
.values(**order.model_dump(exclude_none=True))
)
try:
await db.commit()
Expand Down Expand Up @@ -477,6 +479,6 @@ async def edit_information(
await db.execute(
update(models_amap.AmapInformation)
.where(models_amap.AmapInformation.unique_id == "information")
.values(**information_update.dict(exclude_none=True))
.values(**information_update.model_dump(exclude_none=True))
)
await db.commit()
6 changes: 3 additions & 3 deletions app/cruds/cruds_booking.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def update_manager(
await db.execute(
update(models_booking.Manager)
.where(models_booking.Manager.id == manager_id)
.values(**manager_update.dict(exclude_none=True))
.values(**manager_update.model_dump(exclude_none=True))
)
try:
await db.commit()
Expand Down Expand Up @@ -130,7 +130,7 @@ async def get_booking_by_id(


async def create_booking(db: AsyncSession, booking: schemas_booking.BookingComplete):
db_booking = models_booking.Booking(**booking.dict())
db_booking = models_booking.Booking(**booking.model_dump())
db.add(db_booking)
try:
await db.commit()
Expand All @@ -145,7 +145,7 @@ async def edit_booking(
await db.execute(
update(models_booking.Booking)
.where(models_booking.Booking.id == booking_id)
.values(**booking.dict(exclude_none=True))
.values(**booking.model_dump(exclude_none=True))
)
try:
await db.commit()
Expand Down
2 changes: 1 addition & 1 deletion app/cruds/cruds_calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def edit_event(
await db.execute(
update(models_calendar.Event)
.where(models_calendar.Event.id == event_id)
.values(**event.dict(exclude_none=True))
.values(**event.model_dump(exclude_none=True))
)
try:
await db.commit()
Expand Down
4 changes: 2 additions & 2 deletions app/cruds/cruds_campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ async def update_list(
await db.execute(
update(models_campaign.Lists)
.where(models_campaign.Lists.id == list_id)
.values(**campaign_list.dict(exclude={"members"}, exclude_none=True))
.values(**campaign_list.model_dump(exclude={"members"}, exclude_none=True))
)

# We may need to recreate the list of members
Expand All @@ -292,7 +292,7 @@ async def update_list(
update(models_campaign.Lists)
.where(models_campaign.Lists.id == list_id)
.values(
**campaign_list.dict(exclude={"members"}, exclude_none=True),
**campaign_list.model_dump(exclude={"members"}, exclude_none=True),
)
)

Expand Down
4 changes: 2 additions & 2 deletions app/cruds/cruds_cinema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def get_session_by_id(
async def create_session(
session: schemas_cinema.CineSessionComplete, db: AsyncSession
) -> models_cinema.Session:
db_session = models_cinema.Session(**session.dict())
db_session = models_cinema.Session(**session.model_dump())
db.add(db_session)
try:
await db.commit()
Expand All @@ -53,7 +53,7 @@ async def update_session(
await db.execute(
update(models_cinema.Session)
.where(models_cinema.Session.id == session_id)
.values(**session_update.dict(exclude_none=True))
.values(**session_update.model_dump(exclude_none=True))
)
await db.commit()

Expand Down
Loading