Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert existing association types to LCRelation collections #81

Merged
merged 8 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions k4EDM4hep2LcioConv/include/k4EDM4hep2LcioConv/k4EDM4hep2LcioConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,21 @@ void resolveRelations(ObjectMappingT& typeMapping);
template <typename ObjectMappingT, typename ObjectMappingU>
void resolveRelations(ObjectMappingT& updateMaps, const ObjectMappingU& lookupMaps);

/**
* Convert the passed associatoin collections to LCRelation collections
*/
template <typename ObjectMappingT>
std::vector<std::tuple<std::string, std::unique_ptr<lcio::LCCollection>>> createLCRelationCollections(
const std::vector<std::tuple<std::string, const podio::CollectionBase*>>& associationCollections,
const ObjectMappingT& objectMaps);

/**
* Create an LCRelation collection from the passed Association Collection
*/
template <typename AssocCollT, typename FromMapT, typename ToMapT>
std::unique_ptr<lcio::LCCollection> createLCRelationCollection(const AssocCollT& associations, const FromMapT& fromMap,
const ToMapT& toMap);

template <typename ObjectMappingT>
[[deprecated("Use resolveRelations instead")]] void FillMissingCollections(ObjectMappingT& update_pairs) {
resolveRelations(update_pairs);
Expand Down
120 changes: 118 additions & 2 deletions k4EDM4hep2LcioConv/include/k4EDM4hep2LcioConv/k4EDM4hep2LcioConv.ipp
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
#include "k4EDM4hep2LcioConv/MappingUtils.h"

#include <cassert>
#include <cmath>
#include <edm4hep/MCRecoCaloAssociationCollection.h>
#include <edm4hep/MCRecoCaloParticleAssociationCollection.h>
#include <edm4hep/MCRecoClusterParticleAssociationCollection.h>
#include <edm4hep/MCRecoParticleAssociationCollection.h>
#include <edm4hep/MCRecoTrackParticleAssociationCollection.h>
#include <edm4hep/MCRecoTrackerAssociationCollection.h>
#include <edm4hep/RecoParticleVertexAssociationCollection.h>

#include <UTIL/LCRelationNavigator.h>

#include "TMath.h"

Expand Down Expand Up @@ -744,4 +751,113 @@ void resolveRelations(ObjectMappingT& update_pairs, const ObjectMappingU& lookup
resolveRelationsClusters(update_pairs.clusters, lookup_pairs.caloHits);
}

template <typename ObjectMappingT>
std::vector<std::tuple<std::string, std::unique_ptr<lcio::LCCollection>>> createLCRelationCollections(
const std::vector<std::tuple<std::string, const podio::CollectionBase*>>& associationCollections,
const ObjectMappingT& objectMaps) {
std::vector<std::tuple<std::string, std::unique_ptr<lcio::LCCollection>>> relationColls{};
relationColls.reserve(associationCollections.size());

for (const auto& [name, coll] : associationCollections) {
if (const auto assocs = dynamic_cast<const edm4hep::MCRecoParticleAssociationCollection*>(coll)) {
relationColls.emplace_back(name,
createLCRelationCollection(*assocs, objectMaps.recoParticles, objectMaps.mcParticles));
} else if (const auto assocs = dynamic_cast<const edm4hep::MCRecoCaloAssociationCollection*>(coll)) {
relationColls.emplace_back(name,
createLCRelationCollection(*assocs, objectMaps.caloHits, objectMaps.simCaloHits));
} else if (const auto assocs = dynamic_cast<const edm4hep::MCRecoTrackerAssociationCollection*>(coll)) {
relationColls.emplace_back(
name, createLCRelationCollection(*assocs, objectMaps.trackerHits, objectMaps.simTrackerHits));
} else if (const auto assocs = dynamic_cast<const edm4hep::MCRecoCaloParticleAssociationCollection*>(coll)) {
relationColls.emplace_back(name,
createLCRelationCollection(*assocs, objectMaps.caloHits, objectMaps.mcParticles));
} else if (const auto assocs = dynamic_cast<const edm4hep::MCRecoClusterParticleAssociationCollection*>(coll)) {
relationColls.emplace_back(name,
createLCRelationCollection(*assocs, objectMaps.clusters, objectMaps.mcParticles));
} else if (const auto assocs = dynamic_cast<const edm4hep::MCRecoTrackParticleAssociationCollection*>(coll)) {
relationColls.emplace_back(name, createLCRelationCollection(*assocs, objectMaps.tracks, objectMaps.mcParticles));
} else if (const auto assocs = dynamic_cast<const edm4hep::RecoParticleVertexAssociationCollection*>(coll)) {
relationColls.emplace_back(name,
createLCRelationCollection(*assocs, objectMaps.recoParticles, objectMaps.vertices));
} else {
std::cerr << "Trying to create an LCRelation collection from a " << coll->getTypeName()
<< " which is not supported" << std::endl;
}
}

return relationColls;
}

namespace detail {
template <typename T>
constexpr const char* getTypeName();

#define DEFINE_TYPE_NAME(type) \
template <> \
constexpr const char* getTypeName<IMPL::type##Impl>() { \
return #type; \
}

DEFINE_TYPE_NAME(MCParticle);
DEFINE_TYPE_NAME(SimTrackerHit);
DEFINE_TYPE_NAME(SimCalorimeterHit);
DEFINE_TYPE_NAME(Track);
DEFINE_TYPE_NAME(TrackerHit);
DEFINE_TYPE_NAME(Vertex);
DEFINE_TYPE_NAME(ReconstructedParticle);
DEFINE_TYPE_NAME(Cluster);
DEFINE_TYPE_NAME(CalorimeterHit);

#undef DEFINE_TYPE_NAME
} // namespace detail

template <typename AssocCollT, typename FromMapT, typename ToMapT>
std::unique_ptr<lcio::LCCollection> createLCRelationCollection(const AssocCollT& associations, const FromMapT& fromMap,
const ToMapT& toMap) {
using FromLCIOT = std::remove_pointer_t<k4EDM4hep2LcioConv::detail::key_t<FromMapT>>;
using ToLCIOT = std::remove_pointer_t<k4EDM4hep2LcioConv::detail::key_t<ToMapT>>;

auto lcioColl = std::make_unique<lcio::LCCollectionVec>(lcio::LCIO::LCRELATION);
lcioColl->parameters().setValue("FromType", detail::getTypeName<FromLCIOT>());
lcioColl->parameters().setValue("ToType", detail::getTypeName<ToLCIOT>());

for (const auto assoc : associations) {
auto lcioRel = new lcio::LCRelationImpl{};
lcioRel->setWeight(assoc.getWeight());

const auto edm4hepFrom = assoc.getRec();
const auto lcioFrom = k4EDM4hep2LcioConv::detail::mapLookupFrom(edm4hepFrom, fromMap);
if (lcioFrom) {
lcioRel->setFrom(lcioFrom.value());
} else {
std::cerr << "Cannot find an object for building an LCRelation of type " << detail::getTypeName<FromLCIOT>()
<< std::endl;
}

if constexpr (std::is_same_v<AssocCollT, edm4hep::RecoParticleVertexAssociationCollection>) {
const auto edm4hepTo = assoc.getVertex();
const auto lcioTo = k4EDM4hep2LcioConv::detail::mapLookupFrom(edm4hepTo, toMap);
if (lcioTo) {
lcioRel->setTo(lcioTo.value());
} else {
std::cerr << "Cannot find an objects for building an LCRelation of type " << detail::getTypeName<ToLCIOT>()
<< std::endl;
}
} else {
const auto edm4hepTo = assoc.getSim();
const auto lcioTo = k4EDM4hep2LcioConv::detail::mapLookupFrom(edm4hepTo, toMap);
if (lcioTo) {
lcioRel->setTo(lcioTo.value());
} else {
std::cerr << "Cannot find an objects for building an LCRelation of type " << detail::getTypeName<ToLCIOT>()
<< std::endl;
}
}

lcioColl->addElement(lcioRel);
}

return lcioColl;
}

} // namespace EDM4hep2LCIOConv
12 changes: 11 additions & 1 deletion k4EDM4hep2LcioConv/src/k4EDM4hep2LcioConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

#include "edm4hep/Constants.h"
#include "edm4hep/utils/ParticleIDUtils.h"
#include <edm4hep/ParticleIDCollection.h>

#include "UTIL/PIDHandler.h"
#include <edm4hep/ParticleIDCollection.h>

#include <algorithm>
#include <limits>
Expand Down Expand Up @@ -73,6 +73,10 @@ std::unique_ptr<lcio::LCEventImpl> convertEvent(const podio::Frame& edmEvent, co
// collection
std::vector<ParticleIDConvData> pidCollections{};

// We convert these at the very end, once all the necessary information is
// available
std::vector<std::tuple<std::string, const podio::CollectionBase*>> associations{};

const auto& collections = edmEvent.getAvailableCollections();
for (const auto& name : collections) {
const auto edmCollection = edmEvent.get(name);
Expand Down Expand Up @@ -123,6 +127,8 @@ std::unique_ptr<lcio::LCEventImpl> convertEvent(const podio::Frame& edmEvent, co
} else if (dynamic_cast<const edm4hep::CaloHitContributionCollection*>(edmCollection)) {
// "converted" during relation resolving later
continue;
} else if (edmCollection->getTypeName().find("Association") != std::string_view::npos) {
associations.emplace_back(name, edmCollection);
} else {
std::cerr << "Error trying to convert requested " << edmCollection->getValueTypeName() << " with name " << name
<< "\n"
Expand All @@ -144,6 +150,10 @@ std::unique_ptr<lcio::LCEventImpl> convertEvent(const podio::Frame& edmEvent, co

resolveRelations(objectMappings);

for (auto& [name, coll] : createLCRelationCollections(associations, objectMappings)) {
lcioEvent->addCollection(coll.release(), name);
}

return lcioEvent;
}

Expand Down
2 changes: 2 additions & 0 deletions tests/edm4hep_roundtrip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ int main() {
ASSERT_SAME_OR_ABORT(edm4hep::ParticleIDCollection, "ParticleID_coll_1");
ASSERT_SAME_OR_ABORT(edm4hep::ParticleIDCollection, "ParticleID_coll_2");
ASSERT_SAME_OR_ABORT(edm4hep::ParticleIDCollection, "ParticleID_coll_3");
ASSERT_SAME_OR_ABORT(edm4hep::MCRecoParticleAssociationCollection, "mcRecoAssocs");
ASSERT_SAME_OR_ABORT(edm4hep::MCRecoCaloAssociationCollection, "mcCaloHitsAssocs");

return 0;
}
2 changes: 2 additions & 0 deletions tests/edm4hep_to_lcio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ int main() {
ASSERT_COMPARE_OR_EXIT(edm4hep::RawTimeSeriesCollection)
ASSERT_COMPARE_OR_EXIT(edm4hep::ClusterCollection)
ASSERT_COMPARE_OR_EXIT(edm4hep::VertexCollection)
ASSERT_COMPARE_OR_EXIT(edm4hep::MCRecoParticleAssociationCollection)
ASSERT_COMPARE_OR_EXIT(edm4hep::MCRecoCaloAssociationCollection)
}

return 0;
Expand Down
37 changes: 14 additions & 23 deletions tests/src/CompareEDM4hepEDM4hep.cc
Original file line number Diff line number Diff line change
@@ -1,27 +1,4 @@
#include "CompareEDM4hepEDM4hep.h"
#include "ComparisonUtils.h"

#include "edm4hep/CalorimeterHitCollection.h"
#include "edm4hep/ClusterCollection.h"
#include "edm4hep/MCParticleCollection.h"
#include "edm4hep/ParticleIDCollection.h"
#include "edm4hep/ReconstructedParticleCollection.h"
#include "edm4hep/SimCalorimeterHitCollection.h"
#include "edm4hep/TrackCollection.h"
#include "edm4hep/TrackerHitPlaneCollection.h"

#include <edm4hep/TrackState.h>
#include <iostream>
#include <podio/RelationRange.h>

#define REQUIRE_SAME(expected, actual, msg) \
{ \
if (!((expected) == (actual))) { \
std::cerr << msg << " are not the same (expected: " << (expected) << ", actual: " << (actual) << ")" \
<< std::endl; \
return false; \
} \
}

bool compare(const edm4hep::CalorimeterHitCollection& origColl,
const edm4hep::CalorimeterHitCollection& roundtripColl) {
Expand Down Expand Up @@ -325,3 +302,17 @@ bool compare(const edm4hep::ParticleIDCollection& origColl, const edm4hep::Parti
}
return true;
}

bool compare(const edm4hep::RecoParticleVertexAssociationCollection& origColl,
const edm4hep::RecoParticleVertexAssociationCollection& roundtripColl) {
REQUIRE_SAME(origColl.size(), roundtripColl.size(), "collection sizes");
for (size_t i = 0; i < origColl.size(); ++i) {
const auto origAssoc = origColl[i];
const auto assoc = roundtripColl[i];

REQUIRE_SAME(origAssoc.getWeight(), assoc.getWeight(), "weight in association " << i);
REQUIRE_SAME(origAssoc.getVertex().id(), assoc.getVertex().id(), "vertex in association " << i);
REQUIRE_SAME(origAssoc.getRec().id(), assoc.getRec().id(), "reco particle in association " << i);
}
return true;
}
29 changes: 29 additions & 0 deletions tests/src/CompareEDM4hepEDM4hep.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
#ifndef K4EDM4HEP2LCIOCONV_TEST_COMPAREEDM4HEPEDM4HEP_H
#define K4EDM4HEP2LCIOCONV_TEST_COMPAREEDM4HEPEDM4HEP_H

#include "ComparisonUtils.h"
#include "EDM4hep2LCIOUtilities.h"

#include <iostream>

#define REQUIRE_SAME(expected, actual, msg) \
{ \
if (!((expected) == (actual))) { \
std::cerr << msg << " are not the same (expected: " << (expected) << ", actual: " << (actual) << ")" \
<< std::endl; \
return false; \
} \
}

bool compare(const edm4hep::CalorimeterHitCollection& origColl, const edm4hep::CalorimeterHitCollection& roundtripColl);

bool compare(const edm4hep::MCParticleCollection& origColl, const edm4hep::MCParticleCollection& roundtripColl);
Expand All @@ -24,4 +36,21 @@ bool compare(const edm4hep::ReconstructedParticleCollection& origColl,

bool compare(const edm4hep::ParticleIDCollection& origColl, const edm4hep::ParticleIDCollection& roundtripColl);

bool compare(const edm4hep::RecoParticleVertexAssociationCollection& origColl,
const edm4hep::RecoParticleVertexAssociationCollection& roundtripColl);

template <typename AssociationCollT>
bool compare(const AssociationCollT& origColl, const AssociationCollT& roundtripColl) {
REQUIRE_SAME(origColl.size(), roundtripColl.size(), "collection sizes");
for (size_t i = 0; i < origColl.size(); ++i) {
const auto origAssoc = origColl[i];
const auto assoc = roundtripColl[i];

REQUIRE_SAME(origAssoc.getWeight(), assoc.getWeight(), "weight in association " << i);
REQUIRE_SAME(origAssoc.getSim().id(), assoc.getSim().id(), "MC part(icle) in association " << i);
REQUIRE_SAME(origAssoc.getRec().id(), assoc.getRec().id(), "reco part(icle) in association " << i);
}
return true;
}

#endif // K4EDM4HEP2LCIOCONV_TEST_COMPAREEDM4HEPEDM4HEP_H
Loading
Loading