Skip to content

Commit

Permalink
Add pydantic =1 support (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Sep 30, 2024
1 parent 453f51f commit 92a1396
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 53 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
make lint
make test
make docs-build
mamba install --name descent --yes "pydantic <2"
make test
- name: CodeCov
uses: codecov/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ format:
$(CONDA_ENV_RUN) ruff check --fix --select I $(PACKAGE_DIR)

test:
$(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-report=xml --color=yes $(PACKAGE_DIR)/tests/
$(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-append --cov-report=xml --color=yes $(PACKAGE_DIR)/tests/

docs-build:
$(CONDA_ENV_RUN) mkdocs build
Expand Down
143 changes: 92 additions & 51 deletions descent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,53 @@ def _unflatten_tensors(
return tensors


class _PotentialKey(pydantic.BaseModel):
"""
TODO: Needed until interchange upgrades to pydantic >=2
"""
if pydantic.__version__.startswith("1."):
_PotentialKey = openff.interchange.models.PotentialKey
PotentialKeyList = list[_PotentialKey]
else:

id: str
mult: int | None = None
associated_handler: str | None = None
bond_order: float | None = None

def __hash__(self) -> int:
return hash((self.id, self.mult, self.associated_handler, self.bond_order))
class _PotentialKey(pydantic.BaseModel):
"""
def __eq__(self, other: object) -> bool:
import openff.interchange.models
TODO: Needed until interchange upgrades to pydantic >=2
"""

return (
isinstance(other, (_PotentialKey, openff.interchange.models.PotentialKey))
and self.id == other.id
and self.mult == other.mult
and self.associated_handler == other.associated_handler
and self.bond_order == other.bond_order
)
id: str
mult: int | None = None
associated_handler: str | None = None
bond_order: float | None = None

def __hash__(self) -> int:
return hash((self.id, self.mult, self.associated_handler, self.bond_order))

def __eq__(self, other: object) -> bool:
import openff.interchange.models

return (
isinstance(
other, (_PotentialKey, openff.interchange.models.PotentialKey)
)
and self.id == other.id
and self.mult == other.mult
and self.associated_handler == other.associated_handler
and self.bond_order == other.bond_order
)

def _convert_keys(value: typing.Any) -> typing.Any:
if not isinstance(value, list):
return value

def _convert_keys(value: typing.Any) -> typing.Any:
if not isinstance(value, list):
value = [
_PotentialKey(**v.dict())
if isinstance(v, openff.interchange.models.PotentialKey)
else v
for v in value
]
return value

value = [
_PotentialKey(**v.dict())
if isinstance(v, openff.interchange.models.PotentialKey)
else v
for v in value
PotentialKeyList = typing.Annotated[
list[_PotentialKey], pydantic.BeforeValidator(_convert_keys)
]
return value


PotentialKeyList = typing.Annotated[
list[_PotentialKey], pydantic.BeforeValidator(_convert_keys)
]


class AttributeConfig(pydantic.BaseModel):
Expand All @@ -89,17 +94,35 @@ class AttributeConfig(pydantic.BaseModel):
"none indicates no constraint.",
)

@pydantic.model_validator(mode="after")
def _validate_keys(self):
"""Ensure that the keys in `scales` and `limits` match `cols`."""
if pydantic.__version__.startswith("1."):

if any(key not in self.cols for key in self.scales):
raise ValueError("cannot scale non-trainable parameters")
@pydantic.root_validator
def _validate_keys(cls, values):
cols = values.get("cols")

if any(key not in self.cols for key in self.limits):
raise ValueError("cannot clamp non-trainable parameters")
scales = values.get("scales")
limits = values.get("limits")

return self
if any(key not in cols for key in scales):
raise ValueError("cannot scale non-trainable parameters")
if any(key not in cols for key in limits):
raise ValueError("cannot clamp non-trainable parameters")

return values

else:

@pydantic.model_validator(mode="after")
def _validate_keys(self):
"""Ensure that the keys in `scales` and `limits` match `cols`."""

if any(key not in self.cols for key in self.scales):
raise ValueError("cannot scale non-trainable parameters")

if any(key not in self.cols for key in self.limits):
raise ValueError("cannot clamp non-trainable parameters")

return self


class ParameterConfig(AttributeConfig):
Expand All @@ -118,18 +141,36 @@ class ParameterConfig(AttributeConfig):
"If ``None``, no parameters will be excluded.",
)

@pydantic.model_validator(mode="after")
def _validate_include_exclude(self):
"""Ensure that the keys in `include` and `exclude` are disjoint."""
if pydantic.__version__.startswith("1."):

@pydantic.root_validator
def _validate_include_exclude(cls, values):
include = values.get("include")
exclude = values.get("exclude")

if include is not None and exclude is not None:
include = {*include}
exclude = {*exclude}

if include & exclude:
raise ValueError("cannot include and exclude the same parameter")

return values

else:

@pydantic.model_validator(mode="after")
def _validate_include_exclude(self):
"""Ensure that the keys in `include` and `exclude` are disjoint."""

if self.include is not None and self.exclude is not None:
include = {*self.include}
exclude = {*self.exclude}
if self.include is not None and self.exclude is not None:
include = {*self.include}
exclude = {*self.exclude}

if include & exclude:
raise ValueError("cannot include and exclude the same parameter")
if include & exclude:
raise ValueError("cannot include and exclude the same parameter")

return self
return self


class Trainable:
Expand Down
5 changes: 4 additions & 1 deletion devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ dependencies:

# Core packages
- smee >=0.10.0
- pydantic-units # TODO: Remove this line once smee deps are updated

- pytorch
- pydantic
- pyarrow
- datasets

- pydantic

### Levenberg Marquardt
- scipy

Expand Down Expand Up @@ -57,5 +59,6 @@ dependencies:
- mkdocs-literate-nav
- mkdocstrings
- mkdocstrings-python
- griffe <1
- black
- mike

0 comments on commit 92a1396

Please sign in to comment.