Skip to content

Commit

Permalink
fix: refine the implementation of copy_behaviors
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Jul 19, 2024
1 parent 0eff78c commit 6c79999
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 22 deletions.
15 changes: 7 additions & 8 deletions src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import struct
import sys
import typing
from collections.abc import Collection

import numpy as np # noqa: TID251
Expand Down Expand Up @@ -105,16 +104,16 @@ def unique_list(items: Collection[T]) -> list[T]:
return result


def copy_behaviors(existing_class: typing.Any, new_class: typing.Any, behavior: dict):
def copy_behaviors(from_name: str, to_name: str, behavior: dict):
output = {}

oldname = existing_class.__name__
newname = new_class.__name__

for key, value in behavior.items():
if oldname in key:
if not isinstance(key, str) and "*" not in key:
new_tuple = tuple(newname if k == oldname else k for k in key)
if isinstance(key, str):
if key == from_name:
output[to_name] = value
else:
if from_name in key:
new_tuple = tuple(to_name if k == from_name else k for k in key)
output[new_tuple] = value

return output
22 changes: 8 additions & 14 deletions tests/test_2433_copy_behaviors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def __eq__(self, other):
ak.behavior[numpy.add, "VectorTwoD", "VectorTwoD"] = lambda v1, v2: v1.add(v2)
assert v + v == v_added

# instead of registering every operator again, just copy the behaviors of
# another class to this class
ak.behavior.update(
ak._util.copy_behaviors("VectorTwoD", "VectorTwoDAgain", ak.behavior)
)

# second sub-class
@ak.mixin_class(ak.behavior)
class VectorTwoDAgain(VectorTwoD):
Expand All @@ -81,17 +87,14 @@ class VectorTwoDAgain(VectorTwoD):
with_name="VectorTwoDAgain",
behavior=ak.behavior,
)
# add method works but the binary operator does not
assert v.add(v) == v_added
with pytest.raises(TypeError):
v + v
assert v + v == v_added

# instead of registering every operator again, just copy the behaviors of
# another class to this class
ak.behavior.update(
ak._util.copy_behaviors(VectorTwoD, VectorTwoDAgain, ak.behavior)
ak._util.copy_behaviors("VectorTwoDAgain", "VectorTwoDAgainAgain", ak.behavior)
)
assert v + v == v_added

# third sub-class
@ak.mixin_class(ak.behavior)
Expand All @@ -112,14 +115,5 @@ class VectorTwoDAgainAgain(VectorTwoDAgain):
with_name="VectorTwoDAgainAgain",
behavior=ak.behavior,
)
# add method works but the binary operator does not
assert v.add(v) == v_added
with pytest.raises(TypeError):
v + v

# instead of registering every operator again, just copy the behaviors of
# another class to this class
ak.behavior.update(
ak._util.copy_behaviors(VectorTwoDAgain, VectorTwoDAgainAgain, ak.behavior)
)
assert v + v == v_added

0 comments on commit 6c79999

Please sign in to comment.