Skip to content

Commit

Permalink
add to_json functions using sleap-io logic
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Sep 10, 2024
1 parent 08b43ed commit 5decb55
Showing 1 changed file with 66 additions and 35 deletions.
101 changes: 66 additions & 35 deletions sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,9 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
"""Convert the :class:`Skeleton` to a JSON representation.
Args:
node_to_idx: optional dict which maps :class:`Nodes` to index
node_to_idx: optional dict which maps :class:`Nodes`to index
in some list. This is used when saving
:class:`Labels` where we want to serialize the
:class:`Labels`where we want to serialize the
:class:`Nodes` outside the :class:`Skeleton` object.
If given, then we replace each :class:`Node` with
specified index before converting :class:`Skeleton`.
Expand All @@ -999,22 +999,29 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
Returns:
A string containing the JSON representation of the skeleton.
"""
# Logic taken from github.com/talmolab/sleap-io/io/slp.py::serialize_skeletons
# https://github.com/talmolab/sleap-io/blob/main/sleap_io/io/slp.py#L606
# Create a dictionary to store node data
# Create global list of nodes with all nodes from all skeletons.
nodes_dicts = []
node_to_id = {}
for node in self.nodes:
print(f'node: {node}')
if node not in node_to_id:
node_to_id[node] = node_to_idx[node] if node_to_idx is not None else len(node_to_id)
nodes_dicts.append({"name": node.name, "weight": 1.0})
print(f'node: {node}')
# Note: This ID is not the same as the node index in the skeleton in
# legacy SLEAP, but we do not retain this information in the labels, so
# IDs will be different.
#
# The weight is also kept fixed here, but technically this is not
# modified or used in legacy SLEAP either.
#
# TODO: Store legacy metadata in labels to get byte-level compatibility?
node_to_id[node] = len(node_to_id)
print(f'node_to_id: {node_to_id}')
print(f'nodes_dicts: {nodes_dicts}')

# Create a dictionary to store edge data
nodes_dicts.append({"name": node.name, "weight": 1.0})
print(f'nodes_dicts: {nodes_dicts}')

# Build links dicts for normal edges.
edges_dicts = []
for edge_ind, edge in enumerate(self.edges):
print(f'edge_ind: {edge_ind}')
print(f'edge: {edge}')
if edge_ind == 0:
edge_type = {
Expand All @@ -1027,18 +1034,33 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
else:
edge_type = {"py/id": 1}
print(f'edge_type: {edge_type}')

# Edges are stored as a list of tuples of nodes
# The source and target are the nodes in the tuple (edge) are the first and
# second nodes respectively
source = edge[0]
print(f'source: {source}')
print(f'node_to_id[source]: {node_to_id[source]}')
target = edge[1]
print(f'target: {target}')
print(f'node_to_id[target]: {node_to_id[target]}')
edges_dicts.append(
{
# Note: Insert idx is not the same as the edge index in the skeleton
# in legacy SLEAP.
"edge_insert_idx": edge_ind,
"key": 0, # Always 0.
"source": node_to_id[edge[0]],
"target": node_to_id[edge[1]],
"source": {"py/id": node_to_id[source]},
"target": {"py/id": node_to_id[target]},
"type": edge_type,
}
)
print(f'edges_dicts: {edges_dicts}')
print(f'edges_dicts: {edges_dicts}')

# Build links dicts for symmetry edges.
for symmetry_ind, symmetry in enumerate(self.symmetries):
print(f'symmetry_ind: {symmetry_ind}')
print(f'symmetry: {symmetry}')
if symmetry_ind == 0:
edge_type = {
"py/reduce": [
Expand All @@ -1050,35 +1072,44 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
edge_type = {"py/id": 2}

src, dst = tuple(symmetry.nodes)
print(f'src: {src}')
print(f'dst: {dst}')
edges_dicts.append(
{
"key": 0,
"source": node_to_id[src],
"target": node_to_id[dst],
"source": {"py/id": node_to_id[src]},
"target": {"py/id": node_to_id[dst]},
"type": edge_type,
}
)

# Create skeleton dict.
skeleton_dict = {
"directed": True,
"graph": {
"name": self.name,
"num_edges_inserted": len(self.edges),
},
"links": edges_dicts,
"multigraph": True,
# In the order in Skeleton.nodes and must match up with nodes_dicts.
"nodes": [{"id": node_to_id[node]} for node in self.nodes],
}

# Convert the skeleton dict to a JSON string using the standard json module
json_str = json.dumps(skeleton_dict, indent=4, sort_keys=True)

return json_str

if self.is_template:
skeleton_dict = {
"directed": True,
"graph": {
"name": self.name,
"num_edges_inserted": len(self.edges),
},
"links": edges_dicts,
"multigraph": True,
# In the order in Skeleton.nodes and must match up with nodes_dicts.
"nodes": [{"id": {"py/id": node_to_id[node]}} for node in self.nodes],
"description": self.description,
"preview_image": self.preview_image,
}
else:
skeleton_dict ={
"directed": True,
"graph": {
"name": self.name,
"num_edges_inserted": len(self.edges),
},
"links": edges_dicts,
"multigraph": True,
# In the order in Skeleton.nodes and must match up with nodes_dicts.
"nodes": [{"id": {"py/id": node_to_id[node]}} for node in self.nodes],}


# jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4)
# if node_to_idx is not None:
# indexed_node_graph = nx.relabel_nodes(
Expand Down Expand Up @@ -1360,4 +1391,4 @@ def __hash__(self):


cattr.register_unstructure_hook(Skeleton, lambda skeleton: Skeleton.to_dict(skeleton))
cattr.register_structure_hook(Skeleton, lambda dicts, cls: Skeleton.from_dict(dicts))
cattr.register_structure_hook(Skeleton, lambda dicts, cls: Skeleton.from_dict(dicts))

0 comments on commit 5decb55

Please sign in to comment.