Skip to content

Commit

Permalink
Specify tagger for register and promote (#174)
Browse files Browse the repository at this point in the history
* add test for gto.api.check_ref

* make error msg about incorrect name more precise

* allow to specify tagger

* rename GIT_COMMITTER to author

* more efficient git ref lookup

* use funcy

* use tagger instead of GIT_COMMITTER
  • Loading branch information
aguschin authored Jun 21, 2022
1 parent 3fdb382 commit 5d22e7a
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 25 deletions.
8 changes: 8 additions & 0 deletions gto/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def register(
bump_minor: bool = False,
bump_patch: bool = False,
stdout: bool = False,
author: Optional[str] = None,
author_email: Optional[str] = None,
):
"""Register new artifact version"""
return GitRegistry.from_repo(repo).register(
Expand All @@ -86,6 +88,8 @@ def register(
bump_minor=bump_minor,
bump_patch=bump_patch,
stdout=stdout,
author=author,
author_email=author_email,
)


Expand All @@ -101,6 +105,8 @@ def promote(
force: bool = False,
skip_registration: bool = False,
stdout: bool = False,
author: Optional[str] = None,
author_email: Optional[str] = None,
):
"""Assign stage to specific artifact version"""
return GitRegistry.from_repo(repo).promote(
Expand All @@ -114,6 +120,8 @@ def promote(
force=force,
skip_registration=skip_registration,
stdout=stdout,
author=author,
author_email=author_email,
)


Expand Down
1 change: 1 addition & 0 deletions gto/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def assert_name_is_valid(name):
if not check_name_is_valid(name):
raise ValidationError(
f"Invalid value '{name}'. Only alphanumeric characters, '-', '/' are allowed."
"Value must be of len >= 2, must with a letter and end with a letter or a number."
)


Expand Down
14 changes: 11 additions & 3 deletions gto/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Union
from typing import Optional, Union

import git
from git import InvalidGitRepositoryError, NoSuchPathError, Repo
Expand Down Expand Up @@ -94,7 +94,7 @@ def find_artifact(
name, create_new=create_new # type: ignore
)

def register(
def register( # pylint: disable=too-many-locals
self,
name,
ref,
Expand All @@ -104,6 +104,8 @@ def register(
bump_minor=False,
bump_patch=False,
stdout=False,
author: Optional[str] = None,
author_email: Optional[str] = None,
):
"""Register artifact version"""
assert_name_is_valid(name)
Expand Down Expand Up @@ -152,6 +154,8 @@ def register(
version,
ref,
message=message or f"Registering artifact {name} version {version}",
author=author,
author_email=author_email,
)
registered_version = self.find_artifact(name).find_version(
name=version, raise_if_not_found=True
Expand All @@ -162,7 +166,7 @@ def register(
)
return registered_version

def promote(
def promote( # pylint: disable=too-many-locals
self,
name,
stage,
Expand All @@ -174,6 +178,8 @@ def promote(
force=False,
skip_registration=False,
stdout=False,
author: Optional[str] = None,
author_email: Optional[str] = None,
) -> BasePromotion:
"""Assign stage to specific artifact version"""
assert_name_is_valid(name)
Expand Down Expand Up @@ -219,6 +225,8 @@ def promote(
message=message
or f"Promoting {name} version {promote_version} to stage {stage}",
simple=simple,
author=author,
author_email=author_email,
)
promotion = (
self.get_state()
Expand Down
53 changes: 43 additions & 10 deletions gto/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,28 @@ def find(
return tags


def create_tag(repo, name, ref, message):
if all(c.hexsha != ref for c in repo.iter_commits()):
raise RefNotFound(ref=ref)
def create_tag(
repo: git.Repo,
name: str,
ref: str,
message: str,
tagger: str = None,
tagger_email: str = None,
):
try:
repo.commit(ref)
except (ValueError, git.BadName) as e:
raise RefNotFound(ref=ref) from e
if name in repo.refs:
raise TagExists(name=name)
repo.create_tag(
name,
ref=ref,
message=message,
)

env = {}
if tagger:
env["GIT_COMMITTER_NAME"] = tagger
if tagger_email:
env["GIT_COMMITTER_EMAIL"] = tagger_email

repo.git.tag(["-a", name, "-m", message, ref], env=env)


def version_from_tag(tag: git.Tag) -> BaseVersion:
Expand Down Expand Up @@ -278,12 +290,22 @@ def update_state(self, state: BaseRegistryState) -> BaseRegistryState:
class TagVersionManager(TagManager):
actions: FrozenSet[Action] = frozenset((Action.REGISTER,))

def register(self, name, version, ref, message):
def register(
self,
name,
version,
ref,
message,
author: Optional[str] = None,
author_email: Optional[str] = None,
):
create_tag(
self.repo,
name_tag(Action.REGISTER, name, version=version, repo=self.repo),
ref=ref,
message=message,
tagger=author,
tagger_email=author_email,
)

def check_ref(self, ref: str, state: BaseRegistryState):
Expand All @@ -307,12 +329,23 @@ def check_ref(self, ref: str, state: BaseRegistryState):
class TagStageManager(TagManager):
actions: FrozenSet[Action] = frozenset((Action.PROMOTE,))

def promote(self, name, stage, ref, message, simple):
def promote(
self,
name,
stage,
ref,
message,
simple,
author: Optional[str] = None,
author_email: Optional[str] = None,
):
create_tag(
self.repo,
name_tag(Action.PROMOTE, name, stage=stage, repo=self.repo, simple=simple),
ref=ref,
message=message,
tagger=author,
tagger_email=author_email,
)

def check_ref(self, ref: str, state: BaseRegistryState):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"semver==3.0.0-dev.3",
"entrypoints",
"tabulate==0.8.9",
"funcy",
]


Expand Down
69 changes: 66 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""TODO: add more tests for API"""
import os
from contextlib import contextmanager
from typing import Callable, Tuple

import git
Expand Down Expand Up @@ -82,10 +84,21 @@ def test_register(repo_with_artifact):
)
repo.index.commit("Irrelevant action to create a git commit")
message = "Some message"
gto.api.register(repo.working_dir, name, "HEAD", message=message)
author = "GTO"
author_email = "[email protected]"
gto.api.register(
repo.working_dir,
name,
"HEAD",
message=message,
author=author,
author_email=author_email,
)
latest = gto.api.find_latest_version(repo.working_dir, name)
assert latest.name == vname2
assert latest.message == message
assert latest.author == author
assert latest.author_email == author_email


def test_promote(repo_with_artifact: Tuple[git.Repo, str]):
Expand All @@ -94,17 +107,19 @@ def test_promote(repo_with_artifact: Tuple[git.Repo, str]):
repo.create_tag("v1.0.0")
repo.create_tag("wrong-tag-unrelated")
message = "some msg"
author = "GTO"
author_email = "[email protected]"
gto.api.promote(
repo.working_dir,
name,
stage,
promote_ref="HEAD",
name_version="v0.0.1",
message=message,
author=author,
author_email=author_email,
)
promotion = gto.api.find_versions_in_stage(repo.working_dir, name, stage)
author = repo.commit().author.name
author_email = repo.commit().author.email
_check_obj(
promotion,
dict(
Expand Down Expand Up @@ -137,3 +152,51 @@ def test_promote_skip_registration(repo_with_artifact):
)
promotion = gto.api.find_versions_in_stage(repo.working_dir, name, stage)
assert not SemVer.is_valid(promotion.version)


@contextmanager
def environ(**overrides):
old = {name: os.environ[name] for name in overrides if name in os.environ}
to_del = set(overrides) - set(old)
try:
os.environ.update(overrides)
yield
finally:
os.environ.update(old)
for name in to_del:
os.environ.pop(name, None)


def test_check_ref(repo_with_artifact: Tuple[git.Repo, Callable]):
repo, name = repo_with_artifact # pylint: disable=unused-variable

NAME = "model"
VERSION = "v1.2.3"
GIT_AUTHOR_NAME = "Alexander Guschin"
GIT_AUTHOR_EMAIL = "[email protected]"
GIT_COMMITTER_NAME = "Oliwav"
GIT_COMMITTER_EMAIL = "[email protected]"

with environ(
GIT_AUTHOR_NAME=GIT_AUTHOR_NAME,
GIT_AUTHOR_EMAIL=GIT_AUTHOR_EMAIL,
GIT_COMMITTER_NAME=GIT_COMMITTER_NAME,
GIT_COMMITTER_EMAIL=GIT_COMMITTER_EMAIL,
):
gto.api.register(repo, name=NAME, ref="HEAD", version=VERSION)

info = gto.api.check_ref(repo, f"{NAME}@{VERSION}")["version"][NAME]
_check_obj(
info,
{
"artifact": NAME,
"name": VERSION,
"author": GIT_COMMITTER_NAME,
"author_email": GIT_COMMITTER_EMAIL,
"discovered": False,
"tag": f"{NAME}@{VERSION}",
"promotions": [],
"enrichments": [],
},
skip_keys={"commit_hexsha", "created_at", "message"},
)
20 changes: 19 additions & 1 deletion tests/test_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pytest

from gto.constants import Action
from gto.tag import ActionSign, name_tag, parse_name
from gto.exceptions import RefNotFound, TagExists
from gto.tag import ActionSign, create_tag, name_tag, parse_name


def test_name_tag(empty_git_repo):
Expand Down Expand Up @@ -58,3 +59,20 @@ def test_parse_name():
)
def test_parse_wrong_names(tag_name):
assert not parse_name(tag_name, raise_on_fail=False)


def test_create_tag_bad_ref(repo_with_commit):
repo, _ = repo_with_commit
with pytest.raises(RefNotFound):
create_tag(repo, "name", ref="wrongref", message="msg")
with pytest.raises(RefNotFound):
create_tag(
repo, "name", ref="679dd96f8f22bef6505b9646803bf3c2afe94692", message="msg"
)


def test_create_tag_repeated_tagname(repo_with_commit):
repo, _ = repo_with_commit
create_tag(repo, "name", ref="HEAD", message="msg")
with pytest.raises(TagExists):
create_tag(repo, "name", ref="HEAD", message="msg")
20 changes: 12 additions & 8 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from typing import Any, Dict, Sequence, Set, Union

from funcy import omit
from pydantic import BaseModel

# def _equals(a, b):
# # separate function is helpful for debug
# # cause you see dicts without skip_keys
# assert a == b

def _assert_equals(a, b):
# separate function is helpful for debug
# cause you see dicts without skip_keys
assert a == b, f"\n{a} \n!=\n {b}"


def _check_obj(
obj: BaseModel, values: Dict[str, Any], skip_keys: Union[Set[str], Sequence[str]]
):
obj_values = obj.dict(exclude=set(skip_keys))
assert obj_values == values
_assert_equals(obj_values, values)
# assert obj_values == values


def _check_dict(
obj: Dict[str, Any],
values: Dict[str, Any],
skip_keys: Union[Set[str], Sequence[str]],
):
obj_values = {k: v for k, v in obj.items() if k not in skip_keys}
values = {k: v for k, v in values.items() if k not in skip_keys}
assert obj_values == values
obj_values = omit(obj, skip_keys)
values = omit(values, skip_keys)
_assert_equals(obj_values, values)
# assert obj_values == values

0 comments on commit 5d22e7a

Please sign in to comment.