Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Julia preprocessing #311

Merged
merged 18 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, name, **kwargs):
self.array_size = kwargs.pop('array_size', None)

self.includes = set()
self.jl_imports = set()

if kwargs:
raise ValueError(f"Unused kwargs in MemberVariable: {kwargs.keys()}")
Expand All @@ -149,6 +150,7 @@ def __init__(self, name, **kwargs):

self.full_type = rf'std::array<{self.array_type}, {self.array_size}>'
self.includes.add('#include <array>')
self.jl_imports.add('using StaticArrays')

self.is_builtin = self.full_type in BUILTIN_TYPES
# We still have to check if this type is a valid fixed width type that we
Expand Down
71 changes: 59 additions & 12 deletions python/podio_class_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def _eval_template(self, template, data):
def _write_file(self, name, content):
"""Write the content to file. Dispatch to the correct directory depending on
whether it is a header or a .cc file."""
if name.endswith("h"):
if name.endswith("h") or name.endswith("jl"):
fullname = os.path.join(self.install_dir, self.package_name, name)
else:
fullname = os.path.join(self.install_dir, "src", name)
if not self.dryrun:
self.generated_files.append(fullname)
if self.clang_format:
if self.clang_format and not name.endswith('jl'):
with subprocess.Popen(self.clang_format, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as cfproc:
content = cfproc.communicate(input=content.encode())[0].decode()

Expand All @@ -168,14 +168,20 @@ def get_fn_format(tmpl):
'Obj': 'Obj',
'SIOBlock': 'SIOBlock',
'Collection': 'Collection',
'CollectionData': 'CollectionData'}
'CollectionData': 'CollectionData',
'MutableStruct': 'Struct',
'JuliaCollection': 'Collection'
}

return f'{prefix.get(tmpl, "")}{{name}}{postfix.get(tmpl, "")}.{{end}}'

endings = {
'Data': ('h',),
'Component': ('h',),
'PrintInfo': ('h',),
'MutableStruct': ('jl',),
'Constructor': ('jl',),
'JuliaCollection': ('jl',),
}.get(template_base, ('h', 'cc'))

fn_templates = []
Expand All @@ -193,25 +199,28 @@ def _fill_templates(self, template_base, data):
data['package_name'] = self.package_name
data['use_get_syntax'] = self.get_syntax
data['incfolder'] = self.incfolder

for filename, template in self._get_filenames_templates(template_base, data['class'].bare_type):
self._write_file(filename, self._eval_template(template, data))

def _process_component(self, name, component):
"""Process one component"""
includes = set()
includes, includes_jl = set(), set()
includes.update(*(m.includes for m in component['Members']))

includes_jl.update(*(m.jl_imports for m in component['Members']))
for member in component['Members']:
if member.full_type in self.reader.components or member.array_type in self.reader.components:
includes.add(self._build_include(member.bare_type))
includes_jl.add(self._build_include(member.bare_type, julia=True))

includes.update(component.get("ExtraCode", {}).get("includes", "").split('\n'))

component['includes'] = self._sort_includes(includes)
component['includes_jl'] = {'struct': includes_jl, 'constructor': includes_jl}
component['class'] = DataType(name)

self._fill_templates('Component', component)
self._fill_templates('MutableStruct', component)
self._fill_templates('Constructor', component)

def _process_datatype(self, name, definition):
"""Process one datatype"""
Expand All @@ -222,15 +231,40 @@ def _process_datatype(self, name, definition):
self._fill_templates('Obj', datatype)
self._fill_templates('Collection', datatype)
self._fill_templates('CollectionData', datatype)
self._fill_templates('MutableStruct', datatype)
self._fill_templates('Constructor', datatype)
self._fill_templates('JuliaCollection', datatype)

if 'SIO' in self.io_handlers:
self._fill_templates('SIOBlock', datatype)

def _preprocess_for_julia(self, datatype):
"""Do the preprocessing that is necessary for Julia code generation"""
includes_jl, includes_jl_struct = set(), set()
for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations'] + datatype['VectorMembers']:
if self._needs_include(relation) and not relation.is_builtin:
includes_jl.add(self._build_include(relation.bare_type, julia=True, is_struct=True))
# if datatype['class'].bare_type != relation.bare_type:
# includes_jl.add(self._build_include(relation.bare_type + 'Collection', julia=True))
for member in datatype['VectorMembers']:
if self._needs_include(member) and not member.is_builtin:
includes_jl_struct.add(self._build_include(member.bare_type, julia=True))
datatype['includes_jl']['constructor'].update((includes_jl))
datatype['includes_jl']['struct'].update((includes_jl_struct))

@staticmethod
def _get_julia_params(datatype):
"""Get the relations as parameteric types for MutableStructs"""
params = set()
for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations']:
if not relation.is_builtin:
params.add(relation.bare_type)
return list(params)

def _preprocess_for_obj(self, datatype):
"""Do the preprocessing that is necessary for the Obj classes"""
fwd_declarations = {}
includes, includes_cc = set(), set()

for relation in datatype['OneToOneRelations']:
if relation.full_type != datatype['class'].full_type:
if relation.namespace not in fwd_declarations:
Expand Down Expand Up @@ -359,29 +393,38 @@ def _preprocess_datatype(self, name, definition):
data = deepcopy(definition)
data['class'] = DataType(name)
data['includes_data'] = self._get_member_includes(definition["Members"])
data['is_pod'] = self._is_pod_type(definition["Members"])
data['includes_jl'] = {'constructor': self._get_member_includes(definition["Members"], julia=True),
'struct': self._get_member_includes(definition["Members"], julia=True)}
data['is_pod'] = self._is_pod_type(definition['Members'])
data['params_jl'] = self._get_julia_params(data)
self._preprocess_for_class(data)
self._preprocess_for_obj(data)
self._preprocess_for_collection(data)
self._preprocess_for_julia(data)

return data

def _get_member_includes(self, members):
def _get_member_includes(self, members, julia=False):
"""Process all members and gather the necessary includes"""
includes = set()
includes, includes_jl = set(), set()
includes.update(*(m.includes for m in members))
includes_jl.update(*(m.jl_imports for m in members))
for member in members:
if member.is_array and not member.is_builtin_array:
includes.add(self._build_include(member.array_bare_type))
includes_jl.add(self._build_include(member.array_bare_type, julia=True))

for stl_type in ClassDefinitionValidator.allowed_stl_types:
if member.full_type == 'std::' + stl_type:
includes.add(f"#include <{stl_type}>")

if self._needs_include(member):
includes.add(self._build_include(member.bare_type))
includes_jl.add(self._build_include(member.bare_type, julia=True))

return self._sort_includes(includes)
if not julia:
return self._sort_includes(includes)
return includes_jl

def _write_cmake_lists_file(self):
"""Write the names of all generated header and src files into cmake lists"""
Expand Down Expand Up @@ -439,8 +482,12 @@ def _create_selection_xml(self):
'datatypes': [DataType(d) for d in self.reader.datatypes]}
self._write_file('selection.xml', self._eval_template('selection.xml.jinja2', data))

def _build_include(self, classname):
def _build_include(self, classname, julia=False, is_struct=False):
"""Return the include statement."""
if is_struct:
return f'include("{classname}Struct.jl")'
if julia:
return f'include("{classname}.jl")'
if self.include_subfolder:
classname = os.path.join(self.package_name, classname)
return f'#include "{classname}.h"'
Expand Down
26 changes: 26 additions & 0 deletions python/templates/Constructor.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
include("{{ class.bare_type }}Struct.jl")
{% for include in includes_jl['constructor'] %}
{{ include }}
{% endfor %}
function {{ class.bare_type }}()
return {{ class.bare_type }}{% if params_jl %}{ {% for relation in params_jl %} {{ relation }}, {% endfor %} }{% endif %}(
soumilbaldota marked this conversation as resolved.
Show resolved Hide resolved
{% for member in Members %}
{% if member.is_array %}
{{ member.julia_type }}(undef),
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
{% elif member.is_builtin %}
{{ member.julia_type }}(0),
{% else %}
{{ member.julia_type }}(),
{% endif %}
{% endfor %}
{% for relation in OneToManyRelations %}
Vector{ {{ relation.julia_type }} }(),
{% endfor %}
{% for relation in OneToOneRelations %}
nothing,
{% endfor %}
{% for member in VectorMembers %}
Vector{ {{ member.julia_type }} }([]),
{% endfor %}
)
end
2 changes: 2 additions & 0 deletions python/templates/JuliaCollection.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include("{{ class.bare_type }}.jl")
{{ class.bare_type }}Collection = Vector{ {{ class.bare_type }}{% if params_jl %}{ {% for relation in params_jl %} {{ relation }}, {% endfor %} }{% endif %} }
soumilbaldota marked this conversation as resolved.
Show resolved Hide resolved
18 changes: 18 additions & 0 deletions python/templates/MutableStruct.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{% for include in includes_jl['struct'] %}
{{ include }}
{% endfor %}
mutable struct {{ class.bare_type }}{% if params_jl %}{ {% for relation in params_jl %} {{ relation }}T, {% endfor %} }{% endif %}
soumilbaldota marked this conversation as resolved.
Show resolved Hide resolved

{% for member in Members %}
{{ member.name }}::{{ member.julia_type }}
{% endfor %}
{% for relation in OneToManyRelations %}
{{ relation.name }}::Vector{ {{ relation.julia_type }}T }
{% endfor %}
{% for relation in OneToOneRelations %}
{{ relation.name }}::Union{Nothing, {{ relation.julia_type }}T }
{% endfor %}
{% for member in VectorMembers %}
{{ member.name }}::Vector{ {{ member.julia_type }} }
{% endfor %}
end
78 changes: 78 additions & 0 deletions tests/unittest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
include("../Tmp/data/ExampleMC.jl")
include("../Tmp/data/ExampleWithVectorMember.jl")
include("../Tmp/data/ExampleForCyclicDependency1.jl")
include("../Tmp/data/ExampleForCyclicDependency2.jl")
include("../Tmp/data/ExampleWithOneRelation.jl")
include("../Tmp/data/ExampleCluster.jl")

using Test
@testset "All" begin
soumilbaldota marked this conversation as resolved.
Show resolved Hide resolved
@testset "Relations" begin

mcp1 = ExampleMC()
mcp1.PDG = 2212

mcp2 = ExampleMC()
mcp2.PDG = 2212

mcp3 = ExampleMC()
mcp3.PDG = 1
push!(mcp3.parents,mcp1)

mcp4 = ExampleMC()
mcp4.PDG = -2
push!(mcp4.parents,mcp2)

mcp5 = ExampleMC()
mcp5.PDG = -24
push!(mcp5.parents,mcp1)
push!(mcp5.parents,mcp2)


mcp1.PDG = 12
mcp2.PDG = 13

# passes if values are changed in parents

@test mcp3.parents[1].PDG == 12
@test mcp4.parents[1].PDG == 13
@test mcp5.parents[1].PDG == 12
@test mcp5.parents[2].PDG == 13
end

@testset "Vector Members" begin

m1 = ExampleWithVectorMember()
m1.count = Float32[1,2,3,4,5]
m1.count[5] = 6

@test m1.count[5] == 6
@test m1.count[1] == 1
@test m1.count[2] == 2
@test m1.count[3] == 3
@test m1.count[4] == 4
end

@testset "Cyclic Dependency" begin

cd1 = ExampleForCyclicDependency1()
cd2 = ExampleForCyclicDependency2()
cd1.ref = cd2
cd2.ref = cd1

@test cd1.ref === cd2
@test cd2.ref === cd1
end

@testset "One To One Relations" begin

c1 = ExampleCluster()
c1.energy = Float64(5)

c2 = ExampleWithOneRelation()
c2.cluster = c1

@test c2.cluster.energy == Float64(5)

end
end;