Skip to content

Commit

Permalink
Merge pull request #114 from pyiron/dataclass
Browse files Browse the repository at this point in the history
Parse dataclass
  • Loading branch information
samwaseda authored Jan 30, 2025
2 parents 722730f + 2b3fff1 commit c5e1267
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
47 changes: 42 additions & 5 deletions pyiron_ontology/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TypeAlias, Any
import warnings

from semantikon.converter import parse_input_args, parse_output_args
from semantikon.converter import parse_input_args, parse_output_args, _meta_to_dict
from rdflib import Graph, Literal, RDF, RDFS, URIRef, OWL, PROV, Namespace
from pyiron_workflow import NOT_DATA, Workflow, Macro
from pyiron_workflow.node import Node
Expand Down Expand Up @@ -62,19 +63,54 @@ def get_inputs_and_outputs(node: Node) -> dict:
}


def _is_semantikon_class(dtype: type) -> bool:
return hasattr(dtype, "_is_semantikon_class") and dtype._is_semantikon_class


def _translate_has_value(
graph: Graph,
label: URIRef,
tag: str,
value: Any = None,
dtype: type | None = None,
units: URIRef | None = None,
parent: URIRef | None = None,
) -> Graph:
tag_uri = URIRef(tag + ".value")
graph.add((label, PNS.hasValue, tag_uri))
if value is not None:
graph.add((tag_uri, RDF.value, Literal(value)))
if units is not None:
graph.add((tag_uri, PNS.hasUnits, URIRef(units)))
if _is_semantikon_class(dtype):
warnings.warn(
"semantikon_class is experimental - triples may change in the future",
FutureWarning,
)
for k, v in dtype.__dict__.items():
if isinstance(v, type) and _is_semantikon_class(v):
_translate_has_value(
graph=graph,
label=label,
tag=tag + "." + k,
value=getattr(value, k, None),
dtype=v,
parent=tag_uri,
)
for k, v in dtype.__annotations__.items():
metadata = _meta_to_dict(v)
_translate_has_value(
graph=graph,
label=label,
tag=tag + "." + k,
value=getattr(value, k, None),
dtype=metadata["dtype"],
units=metadata.get("units", None),
parent=tag_uri,
)
else:
if parent is not None:
graph.add((tag_uri, RDFS.subClassOf, parent))
if value is not None:
graph.add((tag_uri, RDF.value, Literal(value)))
if units is not None:
graph.add((tag_uri, PNS.hasUnits, URIRef(units)))
return graph


Expand Down Expand Up @@ -156,6 +192,7 @@ def get_triples(
label=channel_label,
tag=tag,
value=d.get("value", None),
dtype=d.get("dtype", None),
units=d.get("units", None),
)
for t in _get_triples_from_restrictions(d):
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from pyiron_workflow import Workflow
from semantikon.typing import u
from semantikon.converter import semantikon_class
from dataclasses import dataclass
from rdflib import Namespace


Expand Down Expand Up @@ -221,5 +223,56 @@ def test_parsing_without_running(self):
)


@semantikon_class
@dataclass
class Input:
T: u(float, units="kelvin")
n: int
# This line should be removed with the next version of semantikon
_is_semantikon_class = True

class parameters:
_is_semantikon_class = True
a: int = 2


@semantikon_class
@dataclass
class Output:
E: u(float, units="electron_volt")
L: u(float, units="angstrom")
# This line should be removed with the next version of semantikon
_is_semantikon_class = True


@Workflow.wrap.as_function_node
def run_md(inp: Input) -> Output:
out = Output(E=1.0, L=2.0)
return out


class TestDataclass(unittest.TestCase):
def test_dataclass(self):
wf = Workflow("my_wf")
inp = Input(T=300.0, n=100)
inp.parameters.a = 1
wf.node = run_md(inp)
wf.run()
graph = parse_workflow(wf)
i_txt = "my_wf.node.inputs.inp"
o_txt = "my_wf.node.outputs.out"
triples = (
(URIRef(f"{i_txt}.n.value"), RDFS.subClassOf, URIRef(f"{i_txt}.value")),
(URIRef(f"{i_txt}.n.value"), RDF.value, Literal(100)),
(URIRef(f"{i_txt}.parameters.a.value"), RDF.value, Literal(1)),
(URIRef(o_txt), PNS.hasValue, URIRef(f"{o_txt}.E.value")),
)
s = graph.serialize(format="turtle")
for triple in triples:
self.assertEqual(
len(list(graph.triples(triple))), 1, msg=f"{triple} not found in {s}"
)


if __name__ == "__main__":
unittest.main()

0 comments on commit c5e1267

Please sign in to comment.