Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataclasses - Improve the performance of _dataclass_{get,set}state #103032

Open
sobolevn opened this issue Mar 25, 2023 · 7 comments
Open

Dataclasses - Improve the performance of _dataclass_{get,set}state #103032

sobolevn opened this issue Mar 25, 2023 · 7 comments
Assignees
Labels
3.13 bugs and security fixes performance Performance or resource usage stdlib Python modules in the Lib dir topic-dataclasses type-feature A feature request or enhancement

Comments

@sobolevn
Copy link
Member

Feature or enhancement

I've noticed this comment yesterday:

cpython/Lib/dataclasses.py

Lines 1128 to 1131 in 1fd603f

# _dataclass_getstate and _dataclass_setstate are needed for pickling frozen
# classes with slots. These could be slightly more performant if we generated
# the code instead of iterating over fields. But that can be a project for
# another day, if performance becomes an issue.

So, out of curiosity I've decided to try this out. How fast can I make it?
The results are here:

Testing ZeroFields
Dump time (before): 0.229174 sec
Dump time (after) : 0.147407 sec

Testing OneField
Dump time (before): 0.260065 sec
Dump time (after) : 0.197960 sec

Testing EightFields
Dump time (before): 0.331278 sec
Dump time (after) : 0.192171 sec

Testing NestedFields
Dump time (before): 0.846335 sec
Dump time (after) : 0.470739 sec

Here's the simple benchmark that I was using:

import pickle
from dataclasses import dataclass
from timeit import timeit

def perf(e):
    print('Testing', e.__class__.__name__)

    d = pickle.dumps(e)

    iterations = 10000
    t1 = timeit(lambda: pickle.dumps(e), number=iterations)
    t2 = timeit(lambda: pickle.loads(d), number=iterations)
    print(f'Dump time: {t1:.6f} sec')
    print(f'Load time: {t2:.6f} sec')
    print()


@dataclass(frozen=True, slots=True)
class ZeroFields:
    pass

perf(ZeroFields())


@dataclass(frozen=True, slots=True)
class OneField:
    foo: str

perf(OneField('l'))


@dataclass(frozen=True, slots=True)
class EightFields:
    foo: str
    bar: int
    baz: int
    spam: list[int]
    eggs: dict[str, str]
    x: bool
    y: bool
    z: bool

perf(EightFields(
    "a", 1, 2, [1, 2, 3, 4, 5], {'a': 'a', 'b': 'b', 'c': 'c'},
    True, False, True,
))


@dataclass(frozen=True, slots=True)
class NestedFields:
    z1: ZeroFields
    z2: ZeroFields
    e1: EightFields
    e2: EightFields

e = EightFields(
    "a", 1, 2, [1, 2, 3, 4, 5], {'a': 'a', 'b': 'b', 'c': 'c'},
    True, False, True,
)
perf(NestedFields(ZeroFields(), ZeroFields(), e, e))

Here's the very rough version of what I am planning to do:

        cls.__getstate__ = _create_fn(
            '__getstate__',
            ['self'],
            [f"return ({', '.join(f'self.{f.name}' for f in fields(cls))})"],
        )

Things to do:

  1. Refactor current example code to be inline with other code generators
  2. Add __setstate__ similar support

Does it look like a good enough speed up to make this change?

CC @ericvsmith and @carljm

@sobolevn sobolevn added type-feature A feature request or enhancement stdlib Python modules in the Lib dir labels Mar 25, 2023
@sobolevn sobolevn self-assigned this Mar 25, 2023
@AlexWaygood AlexWaygood added performance Performance or resource usage 3.12 bugs and security fixes labels Mar 25, 2023
@sobolevn
Copy link
Member Author

Final timings:

Testing ZeroFields
-- before
Dump time: 0.230313 sec
Load time: 0.174264 sec
-- after
Dump time: 0.149132 sec
Load time: 0.105592 sec

Testing OneField
-- before
Dump time: 0.275664 sec
Load time: 0.219554 sec
-- after
Dump time: 0.164988 sec
Load time: 0.123629 sec

Testing EightFields
-- before
Dump time: 0.340355 sec
Load time: 0.355984 sec
-- after
Dump time: 0.181618 sec
Load time: 0.242010 sec

Testing NestedFields
-- before
Dump time: 0.840898 sec
Load time: 0.779367 sec
-- after
Dump time: 0.447484 sec
Load time: 0.453657 sec

But, since we now do more work during dataclass creation, I wanted to measure this effect as well. Here's my small benchmark:

from dataclasses import dataclass
from timeit import timeit

def create_zero():
    @dataclass(frozen=True, slots=True)
    class ZeroFields:
        pass
    return ZeroFields

ZeroFields = create_zero()

def create_one():
    @dataclass(frozen=True, slots=True)
    class OneField:
        foo: str

def create_eight():
    @dataclass(frozen=True, slots=True)
    class EightFields:
        foo: str
        bar: int
        baz: int
        spam: list[int]
        eggs: dict[str, str]
        x: bool
        y: bool
        z: bool
    return EightFields

EightFields = create_eight()

def create_nested():
    @dataclass(frozen=True, slots=True)
    class NestedFields:
        z1: ZeroFields
        z2: ZeroFields
        e1: EightFields
        e2: EightFields

for f in [create_zero, create_one, create_eight, create_nested]:
    print("Testing", f.__name__)
    res = timeit(f, number=100)
    print('Result', res)
    print()

Here are the results:

Testing create_zero
-- before
Result 0.19914910700026667
-- after
Result 0.2348911329972907

Testing create_one
-- before
Result 0.22826643000007607
-- after
Result 0.2645908749982482

Testing create_eight
-- before
Result 0.3498685700033093
-- after
Result 0.4295154959982028

Testing create_nested
-- before
Result 0.41222004499286413
-- after
Result 0.4236956989989267

So, we are basically trading "startup time" with "runtime time".
I am not sure which one is more important here.

One more thing, notice numbers=10000 in the pickle benchmark, but the second one has only numbers=100. So, if we care about the absolute time - I guess making the dataclass creation slower is not worth it.

Here's my final patch:

cls.__getstate__, cls.__setstate__ = _dataclass_states(cls)

and:

def _dataclass_states(cls):
    getters = []
    setters = []
    for index, f in enumerate(fields(cls)):
        getters.append(f'self.{f.name}')
        setters.append(f'object.__setattr__(self, "{f.name}", state[{index}])')

    getstate = _create_fn(
        '__getstate__',
        ('self',),
        [f'return ({", ".join(getters)})'],
    )
    setstate = _create_fn(
        '__setstate__',
        ('self', 'state'),
        setters if setters else ['pass'],
    )
    return getstate, setstate

Please, share your feedback and ideas.

@pochmann
Copy link
Contributor

pochmann commented Mar 26, 2023

I guess making the dataclass creation slower is not worth it.

Maybe an attrgetter prebuilt at data class creation would have both fast creation and fast application? Benchmark for application, using EightFields:

 259 ±  3 ns  hardcoded
 334 ±  4 ns  attrgetter_prebuilt
 412 ±  6 ns  attrgetter_prebuilt_wrapped
2377 ± 42 ns  current
2873 ± 40 ns  attrgetter_on_the_fly

3.10.6 (main, Jan  7 2023, 10:15:17) [GCC 12.2.0]

Note that attrgetter_on_the_fly includes creation time, so is only included to show that the creation of the attrgetter is fast. And attrgetter_prebuilt_wrapped is included because I'm just not sure whether attrgetter_prebuilt can be used directly, like cls.__getstate__ = attrgetter(...).

Code

Attempt This Online!

from dataclasses import dataclass, fields
from timeit import timeit
from operator import attrgetter
from statistics import mean, stdev
import sys

@dataclass(frozen=True, slots=True)
class EightFields:
    foo: str
    bar: int
    baz: int
    spam: list[int]
    eggs: dict[str, str]
    x: bool
    y: bool
    z: bool

data = EightFields(
    "a", 1, 2, [1, 2, 3, 4, 5], {'a': 'a', 'b': 'b', 'c': 'c'},
    True, False, True,
)

def current(self):
    return [getattr(self, f.name) for f in fields(self)]

def hardcoded(self):
    return self.foo, self.bar, self.baz, self.spam, self.eggs, self.x, self.y, self.z

def attrgetter_on_the_fly(self):
    return attrgetter(*map(attrgetter('name'), fields(self)))(self)

attrgetter_prebuilt = attrgetter(*map(attrgetter('name'), fields(data)))

def attrgetter_prebuilt_wrapped(self):
    return attrgetter_prebuilt(self)

funcs = current, hardcoded, attrgetter_on_the_fly, attrgetter_prebuilt, attrgetter_prebuilt_wrapped

for f in funcs:
    print(f(data))
print()

times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e9 for t in sorted(times[f])[:5]]
    return f'{round(mean(ts)):4} ± {round(stdev(ts)):2} ns '
for _ in range(25):
    for f in funcs:
        number = 10000
        t = timeit(lambda: f(data), number=number) / number
        times[f].append(t)
for f in sorted(funcs, key=stats):
    print(stats(f), getattr(f, '__name__', 'attrgetter_prebuilt'))

print()
print(sys.version)

@sobolevn
Copy link
Member Author

attrgetter seems like an interesting idea. However, it does not work on ZeroFields:

Traceback (most recent call last):
  File "/Users/sobolev/Desktop/cpython/ex.py", line 18, in <module>
    @dataclass(frozen=True, slots=True)
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sobolev/Desktop/cpython/Lib/dataclasses.py", line 1253, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sobolev/Desktop/cpython/Lib/dataclasses.py", line 1121, in _process_class
    cls = _add_slots(cls, frozen, weakref_slot)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sobolev/Desktop/cpython/Lib/dataclasses.py", line 1226, in _add_slots
    cls.__getstate__ = attrgetter(*map(attrgetter('name'), fields(cls)))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: attrgetter expected 1 argument, got 0

So, we need to avaluate map(attrgetter('name'), fields(cls)) early and convert it to tuple.
And have different code paths for empty types and types with fields.

Moreover, simple cls.__getstate__ = attrgetter(*map(attrgetter('name'), fields(cls))) does not work anyway, it is required to use a wrapper. All in all, this is the final version:

from operator import attrgetter
def __getstate__(self):
    if self.__dataclass_attrgetter__ is None:
        return ()
    return self.__dataclass_attrgetter__(self)

if fields(cls):
    cls.__dataclass_attrgetter__ = attrgetter(*map(attrgetter('name'), fields(cls)))
else:
    cls.__dataclass_attrgetter__ = None

cls.__getstate__ = __getstate__

And here are the timings:

Testing ZeroFields
Dump time: 0.182998 sec

Testing OneField
Dump time: 0.186928 sec

Testing EightFields
Dump time: 0.315626 sec

Testing NestedFields
Dump time: 0.564596 sec

With only one generated method (we still have to do something similar to __setstate__) the creation times are:

Testing create_zero
-- before
Result 0.19914910700026667
-- hardcode
Result 0.2348911329972907
-- getstate attrgetter
Result 0.22353017799468944

Testing create_one
-- before
Result 0.22826643000007607
-- hardcode
Result 0.2645908749982482
-- getstate attrgetter
Result 0.2843174829977215

Testing create_eight
-- before
Result 0.3498685700033093
-- hardcode
Result 0.4295154959982028
-- getstate attrgetter
Result 0.4447074749987223

Testing create_nested
-- before
Result 0.41222004499286413
-- hardcode
Result 0.4236956989989267
-- getstate attrgetter
Result 0.31111940600385424

So, as you can see the creation time with only one method is slower that the hardcode patch and dump times are also worth.

Maybe there's something we can improve? But, in the current state is not worth it.

@pochmann
Copy link
Contributor

pochmann commented Mar 26, 2023

My first instinct is: "Let's fix attrgetter". Don't you hate it when things refuse to work with empty things...

Second instinct: Don't check every time, use a function instead of None:

from operator import attrgetter
def __getstate__(self):
    return self.__dataclass_attrgetter__(self)

if fields_ := fields(cls):
    cls.__dataclass_attrgetter__ = attrgetter(*map(attrgetter('name'), fields_))
else:
    cls.__dataclass_attrgetter__ = lambda _: ()

cls.__getstate__ = __getstate__

Can you tell why the wrapper is needed?

Testing create_nested
-- before
Result 0.41222004499286413
-- hardcode
Result 0.4236956989989267
-- getstate attrgetter
Result 0.31111940600385424

That looks odd. Does that mean that creating the attrgetter took negative time?

@sobolevn
Copy link
Member Author

That looks odd. Does that mean that creating the attrgetter took negative time?

I think I've just messed something up while copy-pasting 🤔
Please, let me double check the results.

@pochmann
Copy link
Contributor

pochmann commented Mar 26, 2023

Actually, without the extra attribute (and calling fields(cls) only once):

from operator import attrgetter

if fields_ := fields(cls):
    getter = attrgetter(*map(attrgetter('name'), fields_))
    cls.__getstate__ = lambda self: getter(self)
else:
    cls.__getstate__ = lambda self: ()

@carljm
Copy link
Member

carljm commented Mar 27, 2023

Every dataclass must be created; many dataclasses are never pickled at all. So I am not super excited about trading creation time for pickling performance (even though the latter could happen many times per class.) Haven't followed all of the numbers or options here (looks like we don't have reliable numbers for the latest attrgetter option(s) yet?) but IMO for this to be worth it we should have very small impact on dataclass creation time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3.13 bugs and security fixes performance Performance or resource usage stdlib Python modules in the Lib dir topic-dataclasses type-feature A feature request or enhancement
Projects
None yet
Development

No branches or pull requests

6 participants