Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use annotations to declare schema #2656

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

#### New Features & Functionality

...
- Add type annotations as a way to declare schema

#### Bug Fixes

Expand Down
8 changes: 6 additions & 2 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from ibis.expr.datatypes import dtype
from superduper.components.datatype import BaseDataType, File, Vector
from superduper.components.datatype import (
BaseDataType,
FileItem,
Vector,
)
from superduper.components.schema import ID, FieldType, Schema

SPECIAL_ENCODABLES_FIELDS = {
File: "str",
FileItem: "str",
}


Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ target-version = ["py38"]
ignore_missing_imports = true
no_implicit_optional = true
warn_unused_ignores = true
disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg", "import-untyped"]
disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg", "import-untyped", "no-redef", "valid-type", "valid-newtype"]

[tool.pytest.ini_options]
addopts = "-W ignore"
Expand Down Expand Up @@ -136,6 +136,7 @@ ignore = [
"D401",
"D102",
"E402",
"F403"
]
exclude = ["templates", "superduper/templates"]

Expand Down
8 changes: 4 additions & 4 deletions superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from superduper.base.leaf import Leaf, import_item
from superduper.base.variables import _replace_variables
from superduper.components.component import Component
from superduper.components.datatype import BaseDataType, Blob, File
from superduper.components.datatype import BaseDataType, Blob, FileItem
from superduper.components.schema import Schema, get_schema
from superduper.misc.reference import parse_reference
from superduper.misc.special_dicts import MongoStyleDict, SuperDuperFlatEncode
Expand Down Expand Up @@ -300,7 +300,7 @@ def decode(
)

def my_getter(x):
return File(path=r[KEY_FILES].get(x.split(':')[-1]), db=db)
return FileItem(path=r[KEY_FILES].get(x.split(':')[-1]), db=db)

if r.get(KEY_FILES):
getters.add_getter('file', my_getter)
Expand Down Expand Up @@ -527,7 +527,7 @@ def _deep_flat_encode(
blobs[r.identifier] = r.bytes
return '&:blob:' + r.identifier

if isinstance(r, File):
if isinstance(r, FileItem):
files[r.identifier] = r.path
return '&:file:' + r.identifier

Expand Down Expand Up @@ -723,7 +723,7 @@ def _get_component(db, path):

def _get_file_callback(db):
def callback(ref):
return File(identifier=ref, db=db)
return FileItem(identifier=ref, db=db)

return callback

Expand Down
11 changes: 11 additions & 0 deletions superduper/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def __new__(cls, name, bases, dct):
new_cls._fields[field.name] = 'default'
elif annotation is t.Callable or _is_optional_callable(annotation):
new_cls._fields[field.name] = 'default'
# a hack...
elif 'superduper.misc.typing' in str(annotation):
annotation = str(annotation)
import re

match1 = re.match('^typing\.Optional\[(.*)\]$', annotation)
match2 = re.match('^t\.Optional\[(.*)\]$', annotation)
match = match1 or match2
if match:
annotation = match.groups()[0]
new_cls._fields[field.name] = annotation.split('.')[-1]
except KeyError:
continue
return new_cls
Expand Down
80 changes: 50 additions & 30 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@

from superduper import CFG
from superduper.base.leaf import Leaf
from superduper.components.component import Component
from superduper.components.component import Component, ComponentMeta

Decode = t.Callable[[bytes], t.Any]
Encode = t.Callable[[t.Any], bytes]

INBUILT_DATATYPES = {}


class DataTypeFactory:
"""Abstract class for creating a DataType # noqa."""
Expand All @@ -43,7 +45,21 @@ def create(data: t.Any) -> "BaseDataType":
raise NotImplementedError


class BaseDataType(Component):
class DataTypeMeta(ComponentMeta):
"""Metaclass for the `Model` class and descendants # noqa."""

def __new__(mcls, name, bases, dct):
"""Create a new class with merged docstrings # noqa."""
cls = super().__new__(mcls, name, bases, dct)
try:
instance = cls(cls.__name__)
INBUILT_DATATYPES[cls.__name__] = instance
except TypeError:
pass
return cls


class BaseDataType(Component, metaclass=DataTypeMeta):
"""Base class for datatype."""

type_id: t.ClassVar[str] = 'datatype'
Expand Down Expand Up @@ -180,7 +196,7 @@ def decode_data(self, item):
return pickle.loads(item)


class PickleSerializer(_Artifact, _PickleMixin, BaseDataType):
class Pickle(_Artifact, _PickleMixin, BaseDataType):
"""Serializer with pickle."""


Expand All @@ -196,26 +212,28 @@ def decode_data(self, item):
return dill.loads(item)


class DillSerializer(_Artifact, _DillMixin, BaseDataType):
class Dill(_Artifact, _DillMixin, BaseDataType):
"""Serializer with dill.

This is also the default serializer.
>>> from superduper.components.datatype import DEFAULT_SERIALIZER
"""


class _DillEncoder(_Encodable, _DillMixin, BaseDataType):
class DillEncoder(_Encodable, _DillMixin, BaseDataType):
"""Encoder with dill."""

...


class FileType(BaseDataType):
class File(BaseDataType):
"""Type for encoding files on disk."""

encodable: t.ClassVar[str] = 'file'

def encode_data(self, item):
assert os.path.exists(item)
return File(path=item)
return FileItem(path=item)

def decode_data(self, item):
return item
Expand Down Expand Up @@ -254,7 +272,7 @@ def unpack(self):
pass


class File(Saveable):
class FileItem(Saveable):
"""Placeholder for a file.

:param path: Path to file.
Expand Down Expand Up @@ -313,25 +331,27 @@ def reference(self):

json_encoder = JSON('json')
pickle_encoder = PickleEncoder('pickle_encoder')
pickle_serializer = PickleSerializer('pickle_serializer')
dill_encoder = _DillEncoder('dill_encoder')
dill_serializer = DillSerializer('dill_serializer')
file = FileType('file')

DEFAULT_ENCODER = PickleEncoder('default_encoder')
DEFAULT_SERIALIZER = DillSerializer('default')


INBUILT_DATATYPES = {
dt.identifier: dt
for dt in [
json_encoder,
pickle_encoder,
pickle_serializer,
dill_encoder,
dill_serializer,
file,
DEFAULT_SERIALIZER,
DEFAULT_ENCODER,
]
}
pickle_serializer = Pickle('pickle_serializer')
dill_encoder = DillEncoder('dill_encoder')
dill_serializer = Dill('dill_serializer')
file = File('file')


INBUILT_DATATYPES.update(
{
dt.identifier: dt
for dt in [
json_encoder,
pickle_encoder,
pickle_serializer,
dill_encoder,
dill_serializer,
file,
]
}
)

DEFAULT_ENCODER = INBUILT_DATATYPES['PickleEncoder']
DEFAULT_SERIALIZER = INBUILT_DATATYPES['Dill']
INBUILT_DATATYPES['default'] = DEFAULT_SERIALIZER
INBUILT_DATATYPES['Blob'] = INBUILT_DATATYPES['Pickle']
1 change: 1 addition & 0 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def __new__(mcls, name, bases, dct):
return cls


# TODO there are a lot of redundant parameters here
class Model(Component, metaclass=ModelMeta):
"""Base class for components which can predict.

Expand Down
6 changes: 3 additions & 3 deletions superduper/components/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing as t

from superduper import Component, logging
from superduper.components.datatype import File, file
from superduper.components.datatype import FileItem, file


class Plugin(Component):
Expand All @@ -24,7 +24,7 @@ class Plugin(Component):
cache_path: str = "~/.superduper/plugins"

def __post_init__(self, db):
if isinstance(self.path, File):
if isinstance(self.path, FileItem):
self._prepare_plugin()
else:
path_name = os.path.basename(self.path.rstrip("/"))
Expand Down Expand Up @@ -92,7 +92,7 @@ def _pip_install(self, requirement_path):

def _prepare_plugin(self):
plugin_name_tag = f"{self.identifier}"
assert isinstance(self.path, File)
assert isinstance(self.path, FileItem)
cache_path = os.path.expanduser(self.cache_path)
uuid_path = os.path.join(cache_path, self.uuid)
# Check if plugin is already in cache
Expand Down
9 changes: 9 additions & 0 deletions superduper/misc/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import typing as t

from superduper.components.datatype import *

File = t.NewType('File', t.AnyStr)
Blob = t.NewType('Blob', t.Callable)
Dill = t.NewType('Dill', t.Callable)
Pickle = t.NewType('Pickle', t.Callable)
JSON = t.NewType('JSON', t.Dict)
4 changes: 2 additions & 2 deletions test/unittest/component/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from superduper import Component, Schema, Table
from superduper.components.datatype import (
Blob,
File,
FileItem,
dill_serializer,
file,
pickle_encoder,
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_schema_with_file(db, tmp_file):
r = db['documents'].select().tolist()[0]

# loaded document contains a pointer to the file
assert isinstance(r['my_file'], File)
assert isinstance(r['my_file'], FileItem)

# however the path has not been populated
assert not r['my_file'].path
Expand Down
4 changes: 2 additions & 2 deletions test/unittest/misc/test_auto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def test_infer_datatype():

assert infer_datatype({"a": 1}).identifier == "json"

assert infer_datatype({"a": np.array([1, 2, 3])}).identifier == "default_encoder"
assert infer_datatype({"a": np.array([1, 2, 3])}).identifier == "PickleEncoder"

assert (
infer_datatype(pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})).identifier
== "default_encoder"
== "PickleEncoder"
)

with pytest.raises(UnsupportedDatatype):
Expand Down
35 changes: 35 additions & 0 deletions test/unittest/misc/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from superduper.components.component import Component
from superduper.components.datatype import File, Pickle
from superduper.misc import typing as t


class MyComponent(Component):
path: t.File
my_func: t.Blob


def new_func(x):
return x + 1


def test_annotations():
s = MyComponent.build_class_schema()

assert isinstance(s.fields['path'], File)
assert isinstance(s.fields['my_func'], Pickle)

import tempfile

with tempfile.NamedTemporaryFile() as tmp:
print(tmp.name)
tmp.write('test'.encode())

my_component = MyComponent('my_c', path=tmp.name, my_func=new_func)
r = my_component.encode()

assert len(r['_blobs']) == 1
assert len(r['_files']) == 1

import pprint

pprint.pprint(r)
Loading