Skip to content

Commit

Permalink
Handle case when nodes are replaced by integer indices from caller
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Sep 25, 2024
1 parent 4c8bdd6 commit 6444378
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 26 deletions.
33 changes: 14 additions & 19 deletions sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,24 @@ def _encode_links(self, links: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
encoded_link[key] = encoded_value

encoded_links.append(encoded_link)
print(f"Encoded links: {encoded_links}")
return encoded_links

def _encode_node(self, node: Node) -> Dict[str, Any]:
def _encode_node(self, node: Union["Node", int]) -> Dict[str, Any]:
"""Encodes a Node object.
Args:
node: The Node object to encode.
node: The Node object to encode or integer index. The latter requires that
the class has the `idx_to_node` attribute set.
Returns:
The encoded `Node` object as a dictionary.
"""
if isinstance(node, int):
# We sometimes have the node object already replaced by its index (when
# `node_to_idx` is provided). In this case, we assume that the node object
# will be handled by the caller, so just return the index.
return node

# Check if object has been encoded before
first_encoding = self._is_first_encoding(node)
py_id = self._get_or_assign_id(node, first_encoding)
Expand All @@ -228,7 +234,7 @@ def _encode_node(self, node: Node) -> Dict[str, Any]:
# Reference by py/id
return {"py/id": py_id}

def _encode_edge_type(self, edge_type: EdgeType) -> Dict[str, Any]:
def _encode_edge_type(self, edge_type: "EdgeType") -> Dict[str, Any]:
"""Encodes an EdgeType object.
Args:
Expand Down Expand Up @@ -269,8 +275,6 @@ def _get_or_assign_id(self, obj: Any, first_encoding: bool) -> int:
py_id = len(self._encoded_objects) + 1 # py/id starts at 1
# Assign the py/id to the object and store it in _encoded_objects
self._encoded_objects[obj_id] = py_id
print(f"Assigned py_id: {py_id} to object: {obj} with id: {obj_id}")
print(f"Returning py_id: {self._encoded_objects[obj_id]} for object: {obj}")
return self._encoded_objects[obj_id]

def _is_first_encoding(self, obj: Any) -> bool:
Expand All @@ -284,8 +288,6 @@ def _is_first_encoding(self, obj: Any) -> bool:
"""
obj_id = id(obj)
first_time = obj_id not in self._encoded_objects
print(f"Length of encoded objects: {len(self._encoded_objects)}")
print(f"Is first time: {first_time} for object: {obj}")
return first_time


Expand Down Expand Up @@ -1141,7 +1143,7 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D

# This is a weird hack to serialize the whole _graph into a dict.
# I use the underlying to_json and parse it.
return json.loads(obj.to_json(node_to_idx))
return json.loads(obj.to_json(node_to_idx=node_to_idx))

@classmethod
def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton":
Expand Down Expand Up @@ -1205,17 +1207,14 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
"""
jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4)
if node_to_idx is not None:
indexed_node_graph = nx.relabel_nodes(
G=self._graph, mapping=node_to_idx
) # map nodes to int
print(f"indexed_node_graph: {indexed_node_graph}")
# Map Nodes to int
indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx)
else:
# Keep graph nodes as Node objects
indexed_node_graph = self._graph
print(f"indexed_node_graph: {indexed_node_graph}")

# Encode to JSON
graph = json_graph.node_link_data(indexed_node_graph)
print(f"graph: {graph}")

# SLEAP v1.3.0 added `description` and `preview_image` to `Skeleton`, but saving
# these fields breaks data format compatibility. Currently, these are only
Expand All @@ -1227,14 +1226,10 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
"description": self.description,
"preview_image": self.preview_image,
}
print(f"data: {data}")
else:
data = graph
print(f"data: {data}")

# json_str = jsonpickle.encode(data)
json_str = SkeletonEncoder.encode(data)
print(f"json_str: {json_str}")

return json_str

Expand Down
7 changes: 0 additions & 7 deletions tests/test_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ def test_decoded_encoded_Skeleton_from_load_json(fly_legs_skeleton_json):
# Check that the decoded skeleton is the same as the original skeleton
assert skeleton.matches(decoded_skeleton)

# Read the JSON string from the fixture
with open(fly_legs_skeleton_json, "r") as f:
original_json_str = f.read()

# Check that the original JSON string is the same as the encoded JSON string
assert json.loads(original_json_str) == json.loads(encoded_json_str)


@pytest.mark.parametrize(
"skeleton_fixture_name", ["flies13_skeleton", "skeleton", "stickman"]
Expand Down

0 comments on commit 6444378

Please sign in to comment.