Skip to content

Commit

Permalink
Run Alembic migrations on Hyperion startup (#312)
Browse files Browse the repository at this point in the history
### Description

Before this pull request we only created missing tables on Hyperion
startup. The developer had to run migrations files mannually when a
table was modified.

With this PR, Hyperion:
- create an up to date database and stamp it as "head" when the database
was never initialized
- upgrade the database to "head" by asking alembic to run migration
files

---------

Co-authored-by: armanddidierjean <[email protected]>
  • Loading branch information
Rotheem and armanddidierjean authored Feb 4, 2024
1 parent 2cd1e1c commit 520faaf
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 59 deletions.
2 changes: 1 addition & 1 deletion alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ handlers =
qualname = sqlalchemy.engine

[logger_alembic]
level = INFO
level = WARN
handlers =
qualname = alembic

Expand Down
136 changes: 124 additions & 12 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
from contextlib import asynccontextmanager
from typing import Literal

import alembic.command as alembic_command
import alembic.config as alembic_config
import alembic.migration as alembic_migration
import redis
from fastapi import FastAPI, Request, Response, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy.engine import Connection
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncEngine

from app import api
from app.core.config import Settings
Expand All @@ -29,20 +34,123 @@
from app.utils.types.groups_type import GroupType
from app.utils.types.module_list import ModuleList

# NOTE: We can not get loggers at the top of this file like we do in other files
# as the loggers are not yet initialized

async def create_db_tables(engine, drop_db, hyperion_error_logger):
"""Create db tables
Alembic should be used for any migration, this function can only create new tables and ensure that the necessary groups are available

def get_alembic_config(connection: Connection) -> alembic_config.Config:
"""
Return the alembic configuration object
"""
alembic_cfg = alembic_config.Config("alembic.ini")
alembic_cfg.attributes["connection"] = connection

return alembic_cfg


def get_alembic_current_revision(connection: Connection) -> str | None:
"""
Return the current revision of the database
WARNING: SQLAlchemy does not support `Inspection on an AsyncConnection`. The call to Alembic must be wrapped in a `run_sync` call.
See https://alembic.sqlalchemy.org/en/latest/cookbook.html#programmatic-api-use-connection-sharing-with-asyncio for more information.
Exemple usage:
```python
async with engine.connect() as conn:
await conn.run_sync(run_alembic_upgrade)
```
"""

context = alembic_migration.MigrationContext.configure(connection)
return context.get_current_revision()


def stamp_alembic_head(connection: Connection) -> None:
"""
Stamp the database with the latest revision
WARNING: SQLAlchemy does not support `Inspection on an AsyncConnection`. The call to Alembic must be wrapped in a `run_sync` call.
See https://alembic.sqlalchemy.org/en/latest/cookbook.html#programmatic-api-use-connection-sharing-with-asyncio for more information.
Exemple usage:
```python
async with engine.connect() as conn:
await conn.run_sync(run_alembic_upgrade)
```
"""
alembic_cfg = get_alembic_config(connection)
alembic_command.stamp(alembic_cfg, "head")


def run_alembic_upgrade(connection: Connection) -> None:
"""
Run the alembic upgrade command to upgrade the database to the latest version (`head`)
WARNING: SQLAlchemy does not support `Inspection on an AsyncConnection`. The call to Alembic must be wrapped in a `run_sync` call.
See https://alembic.sqlalchemy.org/en/latest/cookbook.html#programmatic-api-use-connection-sharing-with-asyncio for more information.
Exemple usage:
```python
async with engine.connect() as conn:
await conn.run_sync(run_alembic_upgrade)
```
"""

alembic_cfg = get_alembic_config(connection)

alembic_command.upgrade(alembic_cfg, "head")


async def update_db_tables(engine: AsyncEngine, drop_db: bool = False):
"""
async with engine.begin() as conn:
try:
If the database is not initialized, create the tables and stamp the database with the latest revision.
Otherwise, run the alembic upgrade command to upgrade the database to the latest version (`head`).
if drop_db is True, we will drop all tables before creating them again
"""

hyperion_error_logger = logging.getLogger("hyperion.error")

try:
async with engine.begin() as conn:
if drop_db:
# All tables should be dropped, including the alembic_version table
# or Hyperion will think that the database is up to date and will not initialize it
# when running tests a second time.
# To let SQLAlchemy drop the alembic_version table, we created a AlembicVersion model.
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
except Exception as error:
hyperion_error_logger.fatal(
f"Startup: Could not create tables in the database: {error}"
)

# run_sync is used to run a synchronous function in an asynchronous context
# the function `get_alembic_current_revision` will be called with "a synchronous-style Connection as the first argument"
# See https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncConnection.run_sync
alembic_current_revision = await conn.run_sync(get_alembic_current_revision)

if alembic_current_revision is None:
# We generate the database using SQLAlchemy
# in order not to have to run all migrations one by one
# See https://alembic.sqlalchemy.org/en/latest/cookbook.html#building-an-up-to-date-database-from-scratch
hyperion_error_logger.info(
"Startup: Database tables not created yet, creating them"
)

# Create all tables
await conn.run_sync(Base.metadata.create_all)
# We stamp the database with the latest revision so that
# alembic knows that the database is up to date
await conn.run_sync(stamp_alembic_head)
else:
hyperion_error_logger.info(
f"Startup: Database tables already created (current revision: {alembic_current_revision}), running migrations"
)
await conn.run_sync(run_alembic_upgrade)

hyperion_error_logger.info("Startup: Database tables updated")
except Exception as error:
hyperion_error_logger.fatal(
f"Startup: Could not create tables in the database: {error}"
)
raise


async def initialize_groups(SessionLocal, hyperion_error_logger):
Expand Down Expand Up @@ -126,9 +234,13 @@ async def startup(app: FastAPI):
):
hyperion_error_logger.info("Redis client not configured")

engine = get_db_engine(settings=settings)
await create_db_tables(engine, drop_db, hyperion_error_logger)
# Update database tables
engine = app.dependency_overrides.get(get_db_engine, get_db_engine)(
settings=settings
)
await update_db_tables(engine, drop_db)

# Initialize database tables
SessionLocal = app.dependency_overrides.get(
get_session_maker, get_session_maker
)()
Expand Down
14 changes: 14 additions & 0 deletions app/models/models_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,17 @@ class ModuleVisibility(Base):

root: Mapped[str] = mapped_column(String, primary_key=True)
allowed_group_id: Mapped[str] = mapped_column(String, primary_key=True)


class AlembicVersion(Base):
"""
A table managed exclusively by Alembic, used to keep track of the database schema version.
This model allows to have exactly the same tables in the models and in the database.
Without this model, SQLAlchemy `conn.run_sync(Base.metadata.drop_all)` will ignore this table.
WARNING: Hyperion should not modify this table.
"""

__tablename__ = "alembic_version"

version_num: Mapped[str] = mapped_column(String, primary_key=True)
105 changes: 63 additions & 42 deletions migrations/env.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,41 @@
"""Environment file defining the required functions for the alembic migration to work"""

import asyncio
import os
import re
from logging.config import fileConfig

from sqlalchemy.engine import Connection
from app.dependencies import get_db_engine, get_settings
from alembic import context

from app.database import Base
from app.dependencies import get_settings, get_db_engine

settings = get_settings()
engine = get_db_engine(settings=settings)

if settings.SQLITE_DB:
SQLALCHEMY_DATABASE_URL = (
f"sqlite+aiosqlite:///./{settings.SQLITE_DB}" # Connect to the test's database
)
else:
SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{settings.POSTGRES_USER}:{settings.POSTGRES_PASSWORD}@{settings.POSTGRES_HOST}/{settings.POSTGRES_DB}"

# from sqlalchemy import engine_from_config, pool
# from sqlalchemy.ext.asyncio import AsyncEngine


models_files = [x for x in os.listdir("./app/models") if re.match("models*", x)]
for models_file in models_files:
__import__(f"app.models.{models_file[:-3]}")
# from app.models import models_core

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config


# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Don't disable existing loggers
# See https://stackoverflow.com/questions/42427487/using-alembic-config-main-redirects-log-output
# We could in the future use Hyperion loggers for Alembic
fileConfig(config.config_file_name, disable_existing_loggers=False)


# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata

# models_files = [x for x in os.listdir("./app/models") if re.match("models*", x)]
# for models_file in models_files:
# __import__(f"app.models.{models_file[:-3]}")

# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.


def run_migrations_offline():
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
Expand All @@ -62,44 +47,80 @@ def run_migrations_offline():
script output.
"""
url = SQLALCHEMY_DATABASE_URL
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)

with context.begin_transaction():
context.run_migrations()


def do_run_migrations(connection):
context.configure(
connection=connection, target_metadata=target_metadata, compare_type=True
)
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)

with context.begin_transaction():
context.run_migrations()


async def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
If a connection is already present in the context config,
we will use it instead of creating a new one.
This connection should be set when invoking alembic programmatically.
See https://alembic.sqlalchemy.org/en/latest/cookbook.html#connection-sharing
When calling alembic from the CLI,we need to create a new connection
"""
connectable = engine

async with connectable.connect() as connection:
connection = config.attributes.get("connection", None)

if connection is None:
# only create Engine if we don't have a Connection
# from the outside

# If we don't have a connection, we can safely assume that Hyperion is not running
# Migrations should have been called from the CLI. We thus want to point to the production database
# As we want to use the production database, we can call the `get_settings` function directly
# instead of using it as a dependency (`app.dependency_overrides.get(get_settings, get_settings)()`)
settings = get_settings()
connectable = get_db_engine(settings)

async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
else:
await connection.run_sync(do_run_migrations)

await connectable.dispose()

def run_migrations_online() -> None:
"""
Run migrations in 'online' mode.
If a connection is already present in the context config, it means that we already are in an event loop.
We can not create a second event loop in the same thread so we can not call `asyncio.run(run_async_migrations())`.
Instead we need to call `run_async_migrations()` directly.
This `connection` attributes should be set when invoking alembic programmatically.
WARNING: SQLAlchemy does not support `Inspection on an AsyncConnection`. The call to Alembic must be wrapped in a `run_sync` call.
See https://alembic.sqlalchemy.org/en/latest/cookbook.html#programmatic-api-use-connection-sharing-with-asyncio for more information.
If not connection were provided, we may assume we are not in an existing event loop (ie. alembic was invoking from the cli). We create a new event loop and run the migrations in it.
"""

connectable = config.attributes.get("connection", None)

if connectable is None:
asyncio.run(run_async_migrations())
else:
do_run_migrations(connectable)


if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())
run_migrations_online()
10 changes: 6 additions & 4 deletions migrations/script.py.mako
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ Revises: ${down_revision | comma,n}
Create Date: ${create_date}

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
${imports if imports else ""}

# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}


def upgrade() -> None:
Expand Down
Loading

0 comments on commit 520faaf

Please sign in to comment.