Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unpackable decorator, use asdict() #233

Merged
merged 9 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Cleanup: removed unnecessary decorator `@unpackable`
[PR #233](https://github.com/appliedAI-Initiative/pyDVL/pull/233)
- Stopping criteria: fixed problem with `StandardError` and enable proper composition
of index convergence statuses. Fixed a bug with `n_jobs` in
`truncated_montecarlo_shapley`.
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ ray[default] >= 0.8
tox<4.0.0
tox-wheel
types-tqdm
twine
twine==4.0.2
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 6 additions & 10 deletions src/pydvl/utils/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

.. code-block:: python

cached_fun = memcached(**cache_options)(fun, signature=custom_signature)
cached_fun = memcached(**asdict(cache_options))(fun, signature=custom_signature)

If you are running experiments with the same :class:`~pydvl.utils.utility.Utility`
but different datasets, this will lead to evaluations of the utility on new data
Expand All @@ -93,7 +93,7 @@
import socket
import uuid
import warnings
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from functools import wraps
from hashlib import blake2b
from io import BytesIO
Expand Down Expand Up @@ -164,7 +164,7 @@ def memcached(
.. code-block:: python
:caption: Example usage

cached_fun = memcached(**cache_options)(heavy_computation)
cached_fun = memcached(**asdict(cache_options))(heavy_computation)

:param client_config: configuration for `pymemcache's Client()
<https://pymemcache.readthedocs.io/en/stable/apidoc/pymemcache.client.base.html>`_.
Expand Down Expand Up @@ -198,9 +198,8 @@ def connect(config: MemcachedClientConfig):
"""First tries to establish a connection, then tries setting and
getting a value."""
try:
test_config: Dict = dict(**config)
client = RetryingClient(
Client(**test_config),
Client(**asdict(config)),
attempts=3,
retry_delay=0.1,
retry_for=[MemcacheUnexpectedCloseError],
Expand Down Expand Up @@ -294,7 +293,7 @@ def __setstate__(self, d: dict):
"""Restores a client connection after loading from a pickle."""
self.config = d["config"]
self.stats = d["stats"]
self.client = Client(**self.config)
self.client = Client(**asdict(self.config))
self._signature = signature

def get_key_value(self, key: bytes):
Expand Down Expand Up @@ -325,9 +324,6 @@ def get_key_value(self, key: bytes):
Wrapped.__qualname__ = ".".join(reversed(patched))

# TODO: pick from some config file or something
config = MemcachedClientConfig()
if client_config is not None:
config.update(client_config) # type: ignore
return Wrapped(config)
return Wrapped(client_config or MemcachedClientConfig())

return wrapper
4 changes: 0 additions & 4 deletions src/pydvl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from pymemcache.serde import PickleSerde

from .types import unpackable

PICKLE_VERSION = 5 # python >= 3.8

__all__ = ["ParallelConfig", "MemcachedClientConfig", "MemcachedConfig"]
Expand All @@ -27,7 +25,6 @@ class ParallelConfig:
logging_level: int = logging.WARNING


@unpackable
@dataclass
class MemcachedClientConfig:
"""Configuration of the memcached client.
Expand All @@ -53,7 +50,6 @@ class MemcachedClientConfig:
serde: PickleSerde = PickleSerde(pickle_version=PICKLE_VERSION)


@unpackable
@dataclass
class MemcachedConfig:
"""Configuration for :func:`~pydvl.utils.caching.memcached`, providing
Expand Down
47 changes: 0 additions & 47 deletions src/pydvl/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,53 +27,6 @@ def score(self, x: NDArray, y: NDArray) -> float:
pass


def unpackable(cls: Type) -> Type:
"""A class decorator that allows unpacking of all attributes of an object
with the double asterisk operator.

:Example:

>>> @unpackable
... @dataclass
... class Schtuff:
... a: int
... b: str
>>> x = Schtuff(a=1, b='meh')
>>> d = dict(**x)
"""

def keys(self):
return self.__dict__.keys()

def __getitem__(self, item):
return getattr(self, item)

def __len__(self):
return len(self.keys())

def __iter__(self):
for k in self.keys():
yield getattr(self, k)

# HACK: I needed this somewhere else
def update(self, values: dict):
for k, v in values.items():
setattr(self, k, v)

def items(self):
for k in self.keys():
yield k, getattr(self, k)

setattr(cls, "keys", keys)
setattr(cls, "__getitem__", __getitem__)
setattr(cls, "__len__", __len__)
setattr(cls, "__iter__", __iter__)
setattr(cls, "update", update)
setattr(cls, "items", items)

return cls


def maybe_add_argument(fun: Callable, new_arg: str):
"""Wraps a function to accept the given keyword parameter if it doesn't
already.
Expand Down
6 changes: 5 additions & 1 deletion src/pydvl/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
import logging
import warnings
from dataclasses import asdict
from typing import Dict, FrozenSet, Iterable, Optional, Tuple, Union, cast

import numpy as np
Expand Down Expand Up @@ -141,7 +142,10 @@ def __init__(

def _initialize_utility_wrapper(self):
if self.enable_cache:
self._utility_wrapper = memcached(**self.cache_options)( # type: ignore
# asdict() is recursive, but we want client_config to remain a dataclass
options = asdict(self.cache_options)
options["client_config"] = self.cache_options.client_config
self._utility_wrapper = memcached(**options)( # type: ignore
self._utility, signature=self._signature
)
else:
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import os
from collections import defaultdict
from dataclasses import asdict
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Type

import numpy as np
Expand Down Expand Up @@ -132,7 +133,7 @@ def memcache_client_config(memcached_service) -> MemcachedClientConfig:
client_config = MemcachedClientConfig(
server=memcached_service, connect_timeout=1.0, timeout=1, no_delay=True
)
Client(**client_config).flush_all()
Client(**asdict(client_config)).flush_all()
return client_config


Expand All @@ -141,7 +142,7 @@ def memcached_client(memcache_client_config) -> Tuple[Client, MemcachedClientCon
from pymemcache.client import Client

try:
c = Client(**memcache_client_config)
c = Client(**asdict(memcache_client_config))
c.flush_all()
return c, memcache_client_config
except Exception as e:
Expand Down