Skip to content

Commit

Permalink
removed pydantic (#86)
Browse files Browse the repository at this point in the history
* remove custom init in model

* remove unrelated

* lint

* lint

* removed pydantic

* lint

* lint

* lint

* fixed stuff
  • Loading branch information
DerThorsten authored Nov 16, 2023
1 parent 39e9f81 commit 64341e4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 70 deletions.
77 changes: 40 additions & 37 deletions empack/file_patterns.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,41 @@
import fnmatch
import re
from typing import Optional, Union

import yaml
from pydantic import BaseModel, Field, PrivateAttr, RootModel


class FilePatternsModelBase(BaseModel, extra="forbid"):
pass


# match based on a regex
class RegexPattern(FilePatternsModelBase):
regex: str
_pattern: str = PrivateAttr()
class RegexPattern:
def __init__(self, regex):
self._pattern = re.compile(regex)

def match(self, path):
if not hasattr(self, "_pattern") or self._pattern:
self._pattern = re.compile(self.regex)
return self._pattern.match(path) is not None


class UnixPattern(FilePatternsModelBase):
pattern: str
class UnixPattern:
def __init__(self, pattern):
self.pattern = pattern

def match(self, path):
return fnmatch.fnmatch(path, self.pattern)


class FilePattern(RootModel):
root: Union[RegexPattern, UnixPattern]

def match(self, path):
return self.root.match(path)

class FileFilter:
def __init__(self, include_patterns=None, exclude_patterns=None):
def patter_from_dict(**d):
if "pattern" in d:
return UnixPattern(**d)
elif "regex" in d:
return RegexPattern(**d)
else:
raise ValueError("pattern or regex must be provided")

class FileFilter(BaseModel, extra="forbid"):
include_patterns: list[FilePattern] = Field(default_factory=list)
exclude_patterns: list[FilePattern] = Field(default_factory=list)
if include_patterns is None:
include_patterns = []
if exclude_patterns is None:
exclude_patterns = []
self.include_patterns = [patter_from_dict(**p) for p in include_patterns]
self.exclude_patterns = [patter_from_dict(**p) for p in exclude_patterns]

def match(self, path):
include = False
Expand All @@ -48,13 +46,25 @@ def match(self, path):
for ep in self.exclude_patterns:
if ep.match(path):
return False

return include


class PkgFileFilter(BaseModel, extra="forbid"):
packages: dict[str, Union[FileFilter, list[FileFilter]]]
default: FileFilter
class PkgFileFilter:
def __init__(self, packages, default=None):
self.packages = {}
for k, v in packages.items():
if isinstance(v, dict):
self.packages[k] = FileFilter(**v)
elif isinstance(v, list):
self.packages[k] = [FileFilter(**x) for x in v]
else:
err = f"invalid value for package {k}: {v}"
raise ValueError(err)

if default is not None:
self.default = FileFilter(**default)
else:
self.default = None

def get_filter_for_pkg(self, pkg_name):
return self.packages.get(pkg_name, self.default)
Expand All @@ -74,21 +84,14 @@ def merge(self, *others):
self.packages[pkg_name] = filters


# when multiple config files are provided, the default
# must be optional for the additional configs, otherwise
# the would always overwrite the main default config
class AdditionalPkgFileFilter(BaseModel, extra="forbid"):
packages: dict[str, Union[FileFilter, list[FileFilter]]]
default: Optional[FileFilter] = None


def pkg_file_filter_from_yaml(path, *extra_path):
with open(path) as pack_config_file:
pack_config = yaml.safe_load(pack_config_file)
pkg_file_filter = PkgFileFilter.model_validate(pack_config)
pkg_file_filter = PkgFileFilter(**pack_config)

for path in extra_path:
with open(path) as pack_config_file:
pack_config = yaml.safe_load(pack_config_file)
pkg_file_filter.merge(AdditionalPkgFileFilter.model_validate(pack_config))
additonal_pkg_file_filter = PkgFileFilter(**pack_config)
pkg_file_filter.merge(additonal_pkg_file_filter)
return pkg_file_filter
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ classifiers = [
dependencies = [
"appdirs",
"networkx",
"pydantic>=2,<3",
"pyyaml",
"requests",
"typer",
Expand Down
46 changes: 14 additions & 32 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,24 @@
from empack.file_patterns import FileFilter, FilePattern, pkg_file_filter_from_yaml


def test_regex_pattern():
fp = FilePattern.model_validate(
{
"regex": R"^(?!.*\/tests\/).*((.*.\.py$)|(.*.\.so$))|(.*dateutil-zoneinfo\.tar\.gz$)",
}
)
assert fp.match("/home/fu/bar.py")
assert fp.match("/home/fu/bar.so")
assert not fp.match("/home/tests/fu/bar.py")
assert not fp.match("/home/tests/fu/bar.so")
assert fp.match("/hometests/fu/bar.py")
assert fp.match("/hometests/fu/bar.so")
from empack.file_patterns import FileFilter, UnixPattern, pkg_file_filter_from_yaml


def test_unix_pattern():
fp = FilePattern.model_validate({"pattern": "*.py"})
fp = UnixPattern(pattern="*.py")
assert fp.match("/home/fu/bar.py")
assert not fp.match("/hometests/fu/bar.pyc")

fp = FilePattern.model_validate({"pattern": "**/tests/*"})
fp = UnixPattern(pattern="**/tests/*")
assert fp.match("/home/tests/bar")
assert not fp.match("/home/fu/bar")


def test_file_filter():
fp = FileFilter.model_validate(
{
"include_patterns": [
{"pattern": "*.py"},
{"pattern": "*.so"},
{"pattern": "*matplotlibrc"},
],
"exclude_patterns": [{"pattern": "**/tests/*"}],
}
fp = FileFilter(
include_patterns=[
dict(pattern="*.py"),
dict(pattern="*.so"),
dict(pattern="*matplotlibrc"),
],
exclude_patterns=[dict(pattern="**/tests/*")],
)
assert fp.match(
"/tmp/xeus-python-kernel/envs/xeus-python-kernel/lib/python3.10/" # noqa: S108
Expand All @@ -49,7 +33,7 @@ def test_file_filter():


def test_empty_file_filter():
fp = FileFilter.model_validate({"include_patterns": [], "exclude_patterns": []})
fp = FileFilter(include_patterns=[], exclude_patterns=[])
assert not fp.match("/home/fu/bar.py")
assert not fp.match("/home/fu/bar.so")
assert not fp.match("/home/tests/fu/bar.py")
Expand All @@ -59,11 +43,9 @@ def test_empty_file_filter():


def test_dataset_filter():
fp = FileFilter.model_validate(
{
"include_patterns": [{"pattern": "**/sklearn/datasets/**"}],
"exclude_patterns": [],
}
fp = FileFilter(
include_patterns=[dict(pattern="**/sklearn/datasets/**")],
exclude_patterns=[],
)
assert fp.match("/home/fu/sklearn/datasets/some/folder.txt")
assert fp.match("/home/fu/sklearn/datasets/some/folder.py")
Expand Down

0 comments on commit 64341e4

Please sign in to comment.