Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Julia preprocessing #311

Merged
merged 18 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, name, **kwargs):
self.array_size = kwargs.pop('array_size', None)

self.includes = set()
self.jl_imports = set()

if kwargs:
raise ValueError(f"Unused kwargs in MemberVariable: {kwargs.keys()}")
Expand All @@ -149,6 +150,7 @@ def __init__(self, name, **kwargs):

self.full_type = rf'std::array<{self.array_type}, {self.array_size}>'
self.includes.add('#include <array>')
self.jl_imports.add('using StaticArrays')

self.is_builtin = self.full_type in BUILTIN_TYPES
# We still have to check if this type is a valid fixed width type that we
Expand Down
40 changes: 32 additions & 8 deletions python/podio_class_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _eval_template(self, template, data):
def _write_file(self, name, content):
"""Write the content to file. Dispatch to the correct directory depending on
whether it is a header or a .cc file."""
if name.endswith("h"):
if name.endswith("h") or name.endswith("jl"):
fullname = os.path.join(self.install_dir, self.package_name, name)
else:
fullname = os.path.join(self.install_dir, "src", name)
Expand Down Expand Up @@ -176,6 +176,7 @@ def get_fn_format(tmpl):
'Data': ('h',),
'Component': ('h',),
'PrintInfo': ('h',),
'Julia': ('jl',),
}.get(template_base, ('h', 'cc'))

fn_templates = []
Expand Down Expand Up @@ -212,6 +213,7 @@ def _process_component(self, name, component):
component['class'] = DataType(name)

self._fill_templates('Component', component)
self._fill_templates('Julia', component)

def _process_datatype(self, name, definition):
"""Process one datatype"""
Expand All @@ -222,21 +224,23 @@ def _process_datatype(self, name, definition):
self._fill_templates('Obj', datatype)
self._fill_templates('Collection', datatype)
self._fill_templates('CollectionData', datatype)
self._fill_templates('Julia', datatype)

if 'SIO' in self.io_handlers:
self._fill_templates('SIOBlock', datatype)

def _preprocess_for_obj(self, datatype):
"""Do the preprocessing that is necessary for the Obj classes"""
fwd_declarations = {}
includes, includes_cc = set(), set()
includes, includes_cc, includes_jl = set(), set(), set()

for relation in datatype['OneToOneRelations']:
if relation.full_type != datatype['class'].full_type:
if relation.namespace not in fwd_declarations:
fwd_declarations[relation.namespace] = []
fwd_declarations[relation.namespace].append(relation.bare_type)
includes_cc.add(self._build_include(relation.bare_type))
includes_jl.add(self._build_include(relation.bare_type, julia=True))

if datatype['VectorMembers'] or datatype['OneToManyRelations']:
includes.add('#include <vector>')
Expand All @@ -245,12 +249,15 @@ def _preprocess_for_obj(self, datatype):
if not relation.is_builtin:
if relation.full_type == datatype['class'].full_type:
includes_cc.add(self._build_include(datatype['class'].bare_type))
includes_jl.add(self._build_include(datatype['class'].bare_type, julia=True))
else:
includes.add(self._build_include(relation.bare_type))
includes_jl.add(self._build_include(relation.bare_type, julia=True))

datatype['forward_declarations_obj'] = fwd_declarations
datatype['includes_obj'] = self._sort_includes(includes)
datatype['includes_cc_obj'] = self._sort_includes(includes_cc)
datatype['includes_jl'].update(self._sort_includes(includes_jl))
trivial_types = datatype['VectorMembers'] or datatype['OneToManyRelations'] or datatype['OneToOneRelations']
datatype['is_trivial_type'] = trivial_types

Expand All @@ -259,6 +266,7 @@ def _preprocess_for_class(self, datatype):
includes = set(datatype['includes_data'])
fwd_declarations = {}
includes_cc = set()
includes_jl = set()

for member in datatype["Members"]:
if self.expose_pod_members and not member.is_builtin and not member.is_array:
Expand All @@ -271,6 +279,7 @@ def _preprocess_for_class(self, datatype):
fwd_declarations[relation.namespace].append(relation.bare_type)
fwd_declarations[relation.namespace].append('Mutable' + relation.bare_type)
includes_cc.add(self._build_include(relation.bare_type))
includes_jl.add(self._build_include(relation.bare_type, julia=True))

if datatype['VectorMembers'] or datatype['OneToManyRelations']:
includes.add('#include <vector>')
Expand All @@ -279,10 +288,12 @@ def _preprocess_for_class(self, datatype):
for relation in datatype['OneToManyRelations']:
if self._needs_include(relation):
includes.add(self._build_include(relation.bare_type))
includes_jl.add(self._build_include(relation.bare_type, julia=True))

for vectormember in datatype['VectorMembers']:
if vectormember.full_type in self.reader.components:
includes.add(self._build_include(vectormember.bare_type))
includes_jl.add(self._build_include(vectormember.bare_type, julia=True))

includes.update(datatype.get('ExtraCode', {}).get('includes', '').split('\n'))
# TODO: in principle only the mutable classes would need these includes! # pylint: disable=fixme
Expand All @@ -299,21 +310,26 @@ def _preprocess_for_class(self, datatype):
datatype['includes'] = self._sort_includes(includes)
datatype['includes_cc'] = self._sort_includes(includes_cc)
datatype['forward_declarations'] = fwd_declarations
datatype['includes_jl'] = set()
datatype['includes_jl'].update(self._sort_includes(includes_jl))

def _preprocess_for_collection(self, datatype):
"""Do the necessary preprocessing for the collection"""
includes_cc, includes = set(), set()
includes_cc, includes, includes_jl = set(), set(), set()

for relation in datatype['OneToManyRelations'] + datatype['OneToOneRelations']:
if datatype['class'].bare_type != relation.bare_type:
includes_cc.add(self._build_include(relation.bare_type + 'Collection'))
includes_jl.add(self._build_include(relation.bare_type + 'Collection', julia=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? Respectively can you use the generated files with this? I am asking because as far as I can see there is no XYZCollection.jl that we generate here(?)

includes.add(self._build_include(relation.bare_type))
includes_jl.add(self._build_include(relation.bare_type, julia=True))

if datatype['VectorMembers']:
includes_cc.add('#include <numeric>')

datatype['includes_coll_cc'] = self._sort_includes(includes_cc)
datatype['includes_coll_data'] = self._sort_includes(includes)
datatype['includes_jl'].update(self._sort_includes(includes_jl))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the _sort_includes work as expected with the Julia includes?


# the ostream operator needs a bit of help from the python side in the form
# of some pre processing but also in the form of formatting, both are done
Expand Down Expand Up @@ -359,29 +375,35 @@ def _preprocess_datatype(self, name, definition):
data = deepcopy(definition)
data['class'] = DataType(name)
data['includes_data'] = self._get_member_includes(definition["Members"])
data['includes_data_jl'] = self._get_member_includes(definition["Members"], julia=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? In the templates I don't see that we use includes_data_jl

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see they are used in the templates now. Is there a reason why they have to go to a different list than the usual includes_jl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they can go in the same list. I will change it in the next commit.

data['is_pod'] = self._is_pod_type(definition["Members"])
self._preprocess_for_class(data)
self._preprocess_for_obj(data)
self._preprocess_for_collection(data)

return data

def _get_member_includes(self, members):
def _get_member_includes(self, members, julia=False):
"""Process all members and gather the necessary includes"""
includes = set()
includes, includes_jl = set(), set()
includes.update(*(m.includes for m in members))
includes_jl.update(*(m.jl_imports for m in members))
for member in members:
if member.is_array and not member.is_builtin_array:
includes.add(self._build_include(member.array_bare_type))
includes_jl.add(self._build_include(member.array_bare_type, julia=True))

for stl_type in ClassDefinitionValidator.allowed_stl_types:
if member.full_type == 'std::' + stl_type:
includes.add(f"#include <{stl_type}>")

if self._needs_include(member):
includes.add(self._build_include(member.bare_type))
includes_jl.add(self._build_include(member.bare_type, julia=True))

return self._sort_includes(includes)
if not julia:
return self._sort_includes(includes)
return includes_jl

def _write_cmake_lists_file(self):
"""Write the names of all generated header and src files into cmake lists"""
Expand Down Expand Up @@ -439,11 +461,13 @@ def _create_selection_xml(self):
'datatypes': [DataType(d) for d in self.reader.datatypes]}
self._write_file('selection.xml', self._eval_template('selection.xml.jinja2', data))

def _build_include(self, classname):
def _build_include(self, classname, julia=False):
"""Return the include statement."""
if self.include_subfolder:
classname = os.path.join(self.package_name, classname)
return f'#include "{classname}.h"'
if not julia:
return f'#include "{classname}.h"'
return f'include("{classname}.jl")'

def _sort_includes(self, includes):
"""Sort the includes in order to try to have the std includes at the bottom"""
Expand Down
18 changes: 18 additions & 0 deletions python/templates/Julia.jl.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{% for include in includes_jl %}
{{ include }}
{% endfor %}

mutable struct {{ class.bare_type }}
{% for member in Members %}
{{member.name}}::{{member.julia_type }}
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
{% endfor %}
{% for relation in OneToManyRelations %}
{{relation.name}}::{{relation.julia_type }}
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
{% endfor %}
{% for relation in OneToOneRelations %}
{{relation.name}}::{{relation.julia_type }}
{% endfor %}
{% for member in VectorMembers %}
{{member.name}}::Vector{ {{member.julia_type }} }
{% endfor %}
end