Skip to content

Commit

Permalink
Julia preprocessing (#311)
Browse files Browse the repository at this point in the history
* adding pre-processing/"include" logic for julia

* Add templates for Julia code
  • Loading branch information
soumilbaldota authored and tmadlener committed Aug 16, 2022
1 parent cd3cd96 commit 14870c8
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 10 deletions.
2 changes: 2 additions & 0 deletions python/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, name, **kwargs):
self.array_size = kwargs.pop('array_size', None)

self.includes = set()
self.jl_imports = set()

if kwargs:
raise ValueError(f"Unused kwargs in MemberVariable: {list(kwargs.keys())}")
Expand All @@ -149,6 +150,7 @@ def __init__(self, name, **kwargs):

self.full_type = rf'std::array<{self.array_type}, {self.array_size}>'
self.includes.add('#include <array>')
self.jl_imports.add('using StaticArrays')

self.is_builtin = self.full_type in BUILTIN_TYPES

Expand Down
82 changes: 72 additions & 10 deletions python/podio_class_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def _eval_template(self, template, data):
def _write_file(self, name, content):
"""Write the content to file. Dispatch to the correct directory depending on
whether it is a header or a .cc file."""
if name.endswith("h"):
if name.endswith("h") or name.endswith("jl"):
fullname = os.path.join(self.install_dir, self.package_name, name)
else:
fullname = os.path.join(self.install_dir, "src", name)
if not self.dryrun:
self.generated_files.append(fullname)
if self.clang_format:
if self.clang_format and not name.endswith('jl'):
with subprocess.Popen(self.clang_format, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as cfproc:
content = cfproc.communicate(input=content.encode())[0].decode()

Expand All @@ -171,14 +171,20 @@ def get_fn_format(tmpl):
'Obj': 'Obj',
'SIOBlock': 'SIOBlock',
'Collection': 'Collection',
'CollectionData': 'CollectionData'}
'CollectionData': 'CollectionData',
'MutableStruct': 'Struct',
'JuliaCollection': 'Collection'
}

return f'{prefix.get(tmpl, "")}{{name}}{postfix.get(tmpl, "")}.{{end}}'

endings = {
'Data': ('h',),
'Component': ('h',),
'PrintInfo': ('h',),
'MutableStruct': ('jl',),
'Constructor': ('jl',),
'JuliaCollection': ('jl',),
}.get(template_base, ('h', 'cc'))

fn_templates = []
Expand All @@ -196,25 +202,28 @@ def _fill_templates(self, template_base, data):
data['package_name'] = self.package_name
data['use_get_syntax'] = self.get_syntax
data['incfolder'] = self.incfolder

for filename, template in self._get_filenames_templates(template_base, data['class'].bare_type):
self._write_file(filename, self._eval_template(template, data))

def _process_component(self, name, component):
"""Process one component"""
includes = set()
includes, includes_jl = set(), set()
includes.update(*(m.includes for m in component['Members']))

includes_jl.update(*(m.jl_imports for m in component['Members']))
for member in component['Members']:
if not (member.is_builtin or member.is_builtin_array):
includes.add(self._build_include(member))
includes_jl.add(self._build_julia_include(member))

includes.update(component.get("ExtraCode", {}).get("includes", "").split('\n'))

component['includes'] = self._sort_includes(includes)
component['includes_jl'] = {'struct': includes_jl, 'constructor': includes_jl}
component['class'] = DataType(name)

self._fill_templates('Component', component)
self._fill_templates('MutableStruct', component)
self._fill_templates('Constructor', component)

def _process_datatype(self, name, definition):
"""Process one datatype"""
Expand All @@ -225,15 +234,40 @@ def _process_datatype(self, name, definition):
self._fill_templates('Obj', datatype)
self._fill_templates('Collection', datatype)
self._fill_templates('CollectionData', datatype)
self._fill_templates('MutableStruct', datatype)
self._fill_templates('Constructor', datatype)
self._fill_templates('JuliaCollection', datatype)

if 'SIO' in self.io_handlers:
self._fill_templates('SIOBlock', datatype)

def _preprocess_for_julia(self, datatype):
"""Do the preprocessing that is necessary for Julia code generation"""
includes_jl, includes_jl_struct = set(), set()
for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations'] + datatype['VectorMembers']:
includes_jl.add(self._build_julia_include(relation, is_struct=True))
# if datatype['class'].bare_type != relation.bare_type:
# includes_jl.add(self._build_include(relation.bare_type + 'Collection', julia=True))

for member in datatype['VectorMembers']:
if self._needs_include(member) and not member.is_builtin:
includes_jl_struct.add(self._build_julia_include(member, julia=True))
datatype['includes_jl']['constructor'].update((includes_jl))
datatype['includes_jl']['struct'].update((includes_jl_struct))

@staticmethod
def _get_julia_params(datatype):
"""Get the relations as parameteric types for MutableStructs"""
params = set()
for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations']:
if not relation.is_builtin:
params.add(relation.bare_type)
return list(params)

def _preprocess_for_obj(self, datatype):
"""Do the preprocessing that is necessary for the Obj classes"""
fwd_declarations = {}
includes, includes_cc = set(), set()

for relation in datatype['OneToOneRelations']:
if relation.full_type != datatype['class'].full_type:
if relation.namespace not in fwd_declarations:
Expand Down Expand Up @@ -363,26 +397,35 @@ def _preprocess_datatype(self, name, definition):
data = deepcopy(definition)
data['class'] = DataType(name)
data['includes_data'] = self._get_member_includes(definition["Members"])
data['includes_jl'] = {'constructor': self._get_member_includes(definition["Members"], julia=True),
'struct': self._get_member_includes(definition["Members"], julia=True)}
data['params_jl'] = self._get_julia_params(data)
self._preprocess_for_class(data)
self._preprocess_for_obj(data)
self._preprocess_for_collection(data)
self._preprocess_for_julia(data)

return data

def _get_member_includes(self, members):
def _get_member_includes(self, members, julia=False):
"""Process all members and gather the necessary includes"""
includes = set()
includes, includes_jl = set(), set()
includes.update(*(m.includes for m in members))
includes_jl.update(*(m.jl_imports for m in members))
for member in members:
if member.is_array and not member.is_builtin_array:
include_from = IncludeFrom.INTERNAL
if self.upstream_edm and member.array_type in self.upstream_edm.components:
include_from = IncludeFrom.EXTERNAL
includes.add(self._build_include_for_class(member.array_bare_type, include_from))
includes_jl.add(self._build_julia_include_for_class(member.array_bare_type, include_from))

includes.add(self._build_include(member))
includes_jl.add(self._build_julia_include(member))

return self._sort_includes(includes)
if not julia:
return self._sort_includes(includes)
return includes_jl

def _write_cmake_lists_file(self):
"""Write the names of all generated header and src files into cmake lists"""
Expand Down Expand Up @@ -452,6 +495,25 @@ def _build_include_for_class(self, classname, include_from: IncludeFrom) -> str:
# the generated code)
return ''

def _build_julia_include(self, member, is_struct=False) -> str:
"""Return the include statement for julia"""
return self._build_julia_include_for_class(member.bare_type, self._needs_include(member.full_type), is_struct)

def _build_julia_include_for_class(self, classname, include_from: IncludeFrom, is_struct=False) -> str:
"""Return the include statement for julia for this specific class"""
if include_from == IncludeFrom.INTERNAL:
# If we have an internal include all includes should be relative
inc_folder = ''
if include_from == IncludeFrom.EXTERNAL:
inc_folder = f'{self.upstream_edm.options["includeSubfolder"]}'
if include_from == IncludeFrom.NOWHERE:
# We don't need an include in this case
return ''

if is_struct:
return f'include("{inc_folder}{classname}Struct.jl")'
return f'include("{inc_folder}{classname}.jl")'

def _sort_includes(self, includes):
"""Sort the includes in order to try to have the std includes at the bottom"""
package_includes = sorted(i for i in includes if self.package_name in i)
Expand Down
26 changes: 26 additions & 0 deletions python/templates/Constructor.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
include("{{ class.bare_type }}Struct.jl")
{% for include in includes_jl['constructor'] %}
{{ include }}
{% endfor %}
function {{ class.bare_type }}()
return {{ class.bare_type }}{% if params_jl %}{ {{ params_jl | join(',') }} }{% endif %}(
{% for member in Members %}
{% if member.is_array %}
{{ member.julia_type }}(undef),
{% elif member.is_builtin %}
{{ member.julia_type }}(0),
{% else %}
{{ member.julia_type }}(),
{% endif %}
{% endfor %}
{% for relation in OneToManyRelations %}
Vector{ {{ relation.julia_type }} }(),
{% endfor %}
{% for relation in OneToOneRelations %}
nothing,
{% endfor %}
{% for member in VectorMembers %}
Vector{ {{ member.julia_type }} }([]),
{% endfor %}
)
end
2 changes: 2 additions & 0 deletions python/templates/JuliaCollection.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include("{{ class.bare_type }}.jl")
{{ class.bare_type }}Collection = Vector{ {{ class.bare_type }}{% if params_jl %}{ {{ params_jl | join(',') }} }{% endif %} }
18 changes: 18 additions & 0 deletions python/templates/MutableStruct.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{% for include in includes_jl['struct'] %}
{{ include }}
{% endfor %}
mutable struct {{ class.bare_type }}{% if params_jl %}{ {{ params_jl | join('T, ') }} }{% endif %}

{% for member in Members %}
{{ member.name }}::{{ member.julia_type }}
{% endfor %}
{% for relation in OneToManyRelations %}
{{ relation.name }}::Vector{ {{ relation.julia_type }}T }
{% endfor %}
{% for relation in OneToOneRelations %}
{{ relation.name }}::Union{Nothing, {{ relation.julia_type }}T }
{% endfor %}
{% for member in VectorMembers %}
{{ member.name }}::Vector{ {{ member.julia_type }} }
{% endfor %}
end
97 changes: 97 additions & 0 deletions tests/unittest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
include("../Tmp/data/ExampleMC.jl")
include("../Tmp/data/ExampleWithVectorMember.jl")
include("../Tmp/data/ExampleForCyclicDependency1.jl")
include("../Tmp/data/ExampleForCyclicDependency2.jl")
include("../Tmp/data/ExampleWithOneRelation.jl")
include("../Tmp/data/ExampleCluster.jl")

using Test
@testset "Julia Bindings" begin
@testset "Relations" begin

mcp1 = ExampleMC()
mcp1.PDG = 2212

mcp2 = ExampleMC()
mcp2.PDG = 2212

mcp3 = ExampleMC()
mcp3.PDG = 1
push!(mcp3.parents,mcp1)

mcp4 = ExampleMC()
mcp4.PDG = -2
push!(mcp4.parents,mcp2)

mcp5 = ExampleMC()
mcp5.PDG = -24
push!(mcp5.parents,mcp1)
push!(mcp5.parents,mcp2)


mcp1.PDG = 12
mcp2.PDG = 13

# passes if values are changed in parents

@test mcp3.parents[1].PDG == 12
@test mcp4.parents[1].PDG == 13
@test mcp5.parents[1].PDG == 12
@test mcp5.parents[2].PDG == 13
end

@testset "Vector Members" begin

m1 = ExampleWithVectorMember()
m1.count = Float32[1,2,3,4,5]
m1.count[5] = 6

@test m1.count[5] == 6
@test m1.count[1] == 1
@test m1.count[2] == 2
@test m1.count[3] == 3
@test m1.count[4] == 4
end

@testset "Cyclic Dependency" begin

cd1 = ExampleForCyclicDependency1()
cd2 = ExampleForCyclicDependency2()
cd1.ref = cd2
cd2.ref = cd1

@test cd1.ref === cd2
@test cd2.ref === cd1
end

@testset "One To One Relations" begin

c1 = ExampleCluster()
c1.energy = Float64(5)

c2 = ExampleWithOneRelation()
c2.cluster = c1

@test c2.cluster.energy == Float64(5)

end

@testset "Collections" begin
mcp1 = ExampleMC()
mcp1.PDG = 2212
mcp2 = ExampleMC()
mcp2.PDG = 2212
mcp3 = ExampleMC()
mcp3.PDG = 1
push!(mcp3.parents,mcp1)
a = ExampleMCCollection([mcp1,mcp2,mcp3])
mc1=a[1]
mc2=a[2]
mc3=a[3]
@test mc1.PDG == 2212
@test mc2.PDG == 2212
@test mc3.PDG == 1
@test length(mc3.parents)== 1
@test mc3.parents[1] == mc1
end
end;

0 comments on commit 14870c8

Please sign in to comment.