Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add field-level db comments when a field description exists #1293

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
# convert FieldInfo definitions into sqlalchemy columns
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
Expand All @@ -575,6 +576,12 @@ def get_config(name: str) -> Any:
# TODO: remove this in the future
set_config_value(model=new_cls, parameter="read_with_orm_mode", value=True)

# enables field-level docstrings on the pydanatic `description` field, which we then copy into
# sa_args, which is persisted to sql table comments
set_config_value(
model=new_cls, parameter="use_attribute_docstrings", value=True
)

config_registry = get_config("registry")
if config_registry is not Undefined:
config_registry = cast(registry, config_registry)
Expand Down Expand Up @@ -635,6 +642,7 @@ def __init__(
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
# this where RelationshipInfo objects are converted to lazy column evaluators
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
setattr(cls, rel_name, rel_value) # Fix #315
# SQLAlchemy no longer uses dict_
Expand Down Expand Up @@ -702,21 +710,32 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")


def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: PydanticFieldInfo | FieldInfo) -> Column: # type: ignore
"""
Takes a field definition, which can either come from the sqlmodel FieldInfo class or the pydantic variant of that class,
and converts it into a sqlalchemy Column object.
"""
if IS_PYDANTIC_V2:
field_info = field
else:
field_info = field.field_info

sa_column = getattr(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
# if a db field comment is not already defined, and a description exists on the field, add it to the column definition
if not sa_column.comment and field_info.description:
sa_column.comment = field_info.description

return sa_column
sa_type = get_sqlalchemy_type(field)

primary_key = getattr(field_info, "primary_key", Undefined)
if primary_key is Undefined:
primary_key = False

index = getattr(field_info, "index", Undefined)
if index is Undefined:
index = False

nullable = not primary_key and is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly
# on the field
Expand Down Expand Up @@ -746,19 +765,37 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
"index": index,
"unique": unique,
}

sa_default = Undefined
if field_info.default_factory:
sa_default = field_info.default_factory
elif field_info.default is not Undefined:
sa_default = field_info.default
if sa_default is not Undefined:
kwargs["default"] = sa_default

sa_column_args = getattr(field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence[Any], sa_column_args)))

sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)

if field_info.description:
if sa_column_kwargs is Undefined:
sa_column_kwargs = {}

assert isinstance(sa_column_kwargs, dict)

# only update comments if not already set
if "comment" not in sa_column_kwargs:
sa_column_kwargs["comment"] = field_info.description

if sa_column_kwargs is not Undefined:
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))

sa_type = get_sqlalchemy_type(field)

# if sa_column is not specified, then the column is constructed here
return Column(sa_type, *args, **kwargs) # type: ignore


Expand Down
6 changes: 6 additions & 0 deletions sqlmodel/sql/sqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@


class AutoString(types.TypeDecorator): # type: ignore
"""
Determines the best sqlalchemy string type based on the database dialect.

For example, when using Postgres this will return sqlalchemy's String()
"""

impl = types.String
cache_ok = True
mysql_default_length = 255
Expand Down
Loading