Skip to content

Commit

Permalink
Julia preprocessing (#311)
Browse files Browse the repository at this point in the history
* adding pre-processing/"include" logic for julia

* Add templates for Julia code
  • Loading branch information
soumilbaldota authored Aug 16, 2022
1 parent e129173 commit b0c600e
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 12 deletions.
2 changes: 2 additions & 0 deletions python/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, name, **kwargs):
self.array_size = kwargs.pop('array_size', None)

self.includes = set()
self.jl_imports = set()

if kwargs:
raise ValueError(f"Unused kwargs in MemberVariable: {kwargs.keys()}")
Expand All @@ -149,6 +150,7 @@ def __init__(self, name, **kwargs):

self.full_type = rf'std::array<{self.array_type}, {self.array_size}>'
self.includes.add('#include <array>')
self.jl_imports.add('using StaticArrays')

self.is_builtin = self.full_type in BUILTIN_TYPES
# We still have to check if this type is a valid fixed width type that we
Expand Down
71 changes: 59 additions & 12 deletions python/podio_class_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def _eval_template(self, template, data):
def _write_file(self, name, content):
"""Write the content to file. Dispatch to the correct directory depending on
whether it is a header or a .cc file."""
if name.endswith("h"):
if name.endswith("h") or name.endswith("jl"):
fullname = os.path.join(self.install_dir, self.package_name, name)
else:
fullname = os.path.join(self.install_dir, "src", name)
if not self.dryrun:
self.generated_files.append(fullname)
if self.clang_format:
if self.clang_format and not name.endswith('jl'):
with subprocess.Popen(self.clang_format, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as cfproc:
content = cfproc.communicate(input=content.encode())[0].decode()

Expand All @@ -168,14 +168,20 @@ def get_fn_format(tmpl):
'Obj': 'Obj',
'SIOBlock': 'SIOBlock',
'Collection': 'Collection',
'CollectionData': 'CollectionData'}
'CollectionData': 'CollectionData',
'MutableStruct': 'Struct',
'JuliaCollection': 'Collection'
}

return f'{prefix.get(tmpl, "")}{{name}}{postfix.get(tmpl, "")}.{{end}}'

endings = {
'Data': ('h',),
'Component': ('h',),
'PrintInfo': ('h',),
'MutableStruct': ('jl',),
'Constructor': ('jl',),
'JuliaCollection': ('jl',),
}.get(template_base, ('h', 'cc'))

fn_templates = []
Expand All @@ -193,25 +199,28 @@ def _fill_templates(self, template_base, data):
data['package_name'] = self.package_name
data['use_get_syntax'] = self.get_syntax
data['incfolder'] = self.incfolder

for filename, template in self._get_filenames_templates(template_base, data['class'].bare_type):
self._write_file(filename, self._eval_template(template, data))

def _process_component(self, name, component):
"""Process one component"""
includes = set()
includes, includes_jl = set(), set()
includes.update(*(m.includes for m in component['Members']))

includes_jl.update(*(m.jl_imports for m in component['Members']))
for member in component['Members']:
if member.full_type in self.reader.components or member.array_type in self.reader.components:
includes.add(self._build_include(member.bare_type))
includes_jl.add(self._build_include(member.bare_type, julia=True))

includes.update(component.get("ExtraCode", {}).get("includes", "").split('\n'))

component['includes'] = self._sort_includes(includes)
component['includes_jl'] = {'struct': includes_jl, 'constructor': includes_jl}
component['class'] = DataType(name)

self._fill_templates('Component', component)
self._fill_templates('MutableStruct', component)
self._fill_templates('Constructor', component)

def _process_datatype(self, name, definition):
"""Process one datatype"""
Expand All @@ -222,15 +231,40 @@ def _process_datatype(self, name, definition):
self._fill_templates('Obj', datatype)
self._fill_templates('Collection', datatype)
self._fill_templates('CollectionData', datatype)
self._fill_templates('MutableStruct', datatype)
self._fill_templates('Constructor', datatype)
self._fill_templates('JuliaCollection', datatype)

if 'SIO' in self.io_handlers:
self._fill_templates('SIOBlock', datatype)

def _preprocess_for_julia(self, datatype):
"""Do the preprocessing that is necessary for Julia code generation"""
includes_jl, includes_jl_struct = set(), set()
for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations'] + datatype['VectorMembers']:
if self._needs_include(relation) and not relation.is_builtin:
includes_jl.add(self._build_include(relation.bare_type, julia=True, is_struct=True))
# if datatype['class'].bare_type != relation.bare_type:
# includes_jl.add(self._build_include(relation.bare_type + 'Collection', julia=True))
for member in datatype['VectorMembers']:
if self._needs_include(member) and not member.is_builtin:
includes_jl_struct.add(self._build_include(member.bare_type, julia=True))
datatype['includes_jl']['constructor'].update((includes_jl))
datatype['includes_jl']['struct'].update((includes_jl_struct))

@staticmethod
def _get_julia_params(datatype):
"""Get the relations as parameteric types for MutableStructs"""
params = set()
for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations']:
if not relation.is_builtin:
params.add(relation.bare_type)
return list(params)

def _preprocess_for_obj(self, datatype):
"""Do the preprocessing that is necessary for the Obj classes"""
fwd_declarations = {}
includes, includes_cc = set(), set()

for relation in datatype['OneToOneRelations']:
if relation.full_type != datatype['class'].full_type:
if relation.namespace not in fwd_declarations:
Expand Down Expand Up @@ -359,29 +393,38 @@ def _preprocess_datatype(self, name, definition):
data = deepcopy(definition)
data['class'] = DataType(name)
data['includes_data'] = self._get_member_includes(definition["Members"])
data['is_pod'] = self._is_pod_type(definition["Members"])
data['includes_jl'] = {'constructor': self._get_member_includes(definition["Members"], julia=True),
'struct': self._get_member_includes(definition["Members"], julia=True)}
data['is_pod'] = self._is_pod_type(definition['Members'])
data['params_jl'] = self._get_julia_params(data)
self._preprocess_for_class(data)
self._preprocess_for_obj(data)
self._preprocess_for_collection(data)
self._preprocess_for_julia(data)

return data

def _get_member_includes(self, members):
def _get_member_includes(self, members, julia=False):
"""Process all members and gather the necessary includes"""
includes = set()
includes, includes_jl = set(), set()
includes.update(*(m.includes for m in members))
includes_jl.update(*(m.jl_imports for m in members))
for member in members:
if member.is_array and not member.is_builtin_array:
includes.add(self._build_include(member.array_bare_type))
includes_jl.add(self._build_include(member.array_bare_type, julia=True))

for stl_type in ClassDefinitionValidator.allowed_stl_types:
if member.full_type == 'std::' + stl_type:
includes.add(f"#include <{stl_type}>")

if self._needs_include(member):
includes.add(self._build_include(member.bare_type))
includes_jl.add(self._build_include(member.bare_type, julia=True))

return self._sort_includes(includes)
if not julia:
return self._sort_includes(includes)
return includes_jl

def _write_cmake_lists_file(self):
"""Write the names of all generated header and src files into cmake lists"""
Expand Down Expand Up @@ -439,8 +482,12 @@ def _create_selection_xml(self):
'datatypes': [DataType(d) for d in self.reader.datatypes]}
self._write_file('selection.xml', self._eval_template('selection.xml.jinja2', data))

def _build_include(self, classname):
def _build_include(self, classname, julia=False, is_struct=False):
"""Return the include statement."""
if is_struct:
return f'include("{classname}Struct.jl")'
if julia:
return f'include("{classname}.jl")'
if self.include_subfolder:
classname = os.path.join(self.package_name, classname)
return f'#include "{classname}.h"'
Expand Down
26 changes: 26 additions & 0 deletions python/templates/Constructor.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
include("{{ class.bare_type }}Struct.jl")
{% for include in includes_jl['constructor'] %}
{{ include }}
{% endfor %}
function {{ class.bare_type }}()
return {{ class.bare_type }}{% if params_jl %}{ {{ params_jl | join(',') }} }{% endif %}(
{% for member in Members %}
{% if member.is_array %}
{{ member.julia_type }}(undef),
{% elif member.is_builtin %}
{{ member.julia_type }}(0),
{% else %}
{{ member.julia_type }}(),
{% endif %}
{% endfor %}
{% for relation in OneToManyRelations %}
Vector{ {{ relation.julia_type }} }(),
{% endfor %}
{% for relation in OneToOneRelations %}
nothing,
{% endfor %}
{% for member in VectorMembers %}
Vector{ {{ member.julia_type }} }([]),
{% endfor %}
)
end
2 changes: 2 additions & 0 deletions python/templates/JuliaCollection.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include("{{ class.bare_type }}.jl")
{{ class.bare_type }}Collection = Vector{ {{ class.bare_type }}{% if params_jl %}{ {{ params_jl | join(',') }} }{% endif %} }
18 changes: 18 additions & 0 deletions python/templates/MutableStruct.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{% for include in includes_jl['struct'] %}
{{ include }}
{% endfor %}
mutable struct {{ class.bare_type }}{% if params_jl %}{ {{ params_jl | join('T, ') }} }{% endif %}

{% for member in Members %}
{{ member.name }}::{{ member.julia_type }}
{% endfor %}
{% for relation in OneToManyRelations %}
{{ relation.name }}::Vector{ {{ relation.julia_type }}T }
{% endfor %}
{% for relation in OneToOneRelations %}
{{ relation.name }}::Union{Nothing, {{ relation.julia_type }}T }
{% endfor %}
{% for member in VectorMembers %}
{{ member.name }}::Vector{ {{ member.julia_type }} }
{% endfor %}
end
97 changes: 97 additions & 0 deletions tests/unittest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
include("../Tmp/data/ExampleMC.jl")
include("../Tmp/data/ExampleWithVectorMember.jl")
include("../Tmp/data/ExampleForCyclicDependency1.jl")
include("../Tmp/data/ExampleForCyclicDependency2.jl")
include("../Tmp/data/ExampleWithOneRelation.jl")
include("../Tmp/data/ExampleCluster.jl")

using Test
@testset "Julia Bindings" begin
@testset "Relations" begin

mcp1 = ExampleMC()
mcp1.PDG = 2212

mcp2 = ExampleMC()
mcp2.PDG = 2212

mcp3 = ExampleMC()
mcp3.PDG = 1
push!(mcp3.parents,mcp1)

mcp4 = ExampleMC()
mcp4.PDG = -2
push!(mcp4.parents,mcp2)

mcp5 = ExampleMC()
mcp5.PDG = -24
push!(mcp5.parents,mcp1)
push!(mcp5.parents,mcp2)


mcp1.PDG = 12
mcp2.PDG = 13

# passes if values are changed in parents

@test mcp3.parents[1].PDG == 12
@test mcp4.parents[1].PDG == 13
@test mcp5.parents[1].PDG == 12
@test mcp5.parents[2].PDG == 13
end

@testset "Vector Members" begin

m1 = ExampleWithVectorMember()
m1.count = Float32[1,2,3,4,5]
m1.count[5] = 6

@test m1.count[5] == 6
@test m1.count[1] == 1
@test m1.count[2] == 2
@test m1.count[3] == 3
@test m1.count[4] == 4
end

@testset "Cyclic Dependency" begin

cd1 = ExampleForCyclicDependency1()
cd2 = ExampleForCyclicDependency2()
cd1.ref = cd2
cd2.ref = cd1

@test cd1.ref === cd2
@test cd2.ref === cd1
end

@testset "One To One Relations" begin

c1 = ExampleCluster()
c1.energy = Float64(5)

c2 = ExampleWithOneRelation()
c2.cluster = c1

@test c2.cluster.energy == Float64(5)

end

@testset "Collections" begin
mcp1 = ExampleMC()
mcp1.PDG = 2212
mcp2 = ExampleMC()
mcp2.PDG = 2212
mcp3 = ExampleMC()
mcp3.PDG = 1
push!(mcp3.parents,mcp1)
a = ExampleMCCollection([mcp1,mcp2,mcp3])
mc1=a[1]
mc2=a[2]
mc3=a[3]
@test mc1.PDG == 2212
@test mc2.PDG == 2212
@test mc3.PDG == 1
@test length(mc3.parents)== 1
@test mc3.parents[1] == mc1
end
end;

0 comments on commit b0c600e

Please sign in to comment.