-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Track classifier using TensorFlow - continuation #32128
Changes from all commits
0fa14ea
4e13557
7267832
98e7321
5af40b9
92585cc
3212d2f
c591c92
46fafc4
bd640f2
529444b
c328f8d
9157510
4a4bc10
51415cd
cc7be66
054af00
3298b2b
a90da53
e2e192b
8d9efec
7928542
1ab0b82
106ff9d
e371301
0db6299
ba03448
0b6f9d0
235b281
a762ac2
1db5432
26625ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef RecoTracker_FinalTrackSelectors_TfGraphDefWrapper_h | ||
#define RecoTracker_FinalTrackSelectors_TfGraphDefWrapper_h | ||
|
||
#include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||
|
||
class TfGraphDefWrapper { | ||
public: | ||
TfGraphDefWrapper(tensorflow::Session*); | ||
~TfGraphDefWrapper() { tensorflow::closeSession(session_); }; | ||
TfGraphDefWrapper(const TfGraphDefWrapper&) = delete; | ||
TfGraphDefWrapper& operator=(const TfGraphDefWrapper&) = delete; | ||
TfGraphDefWrapper(TfGraphDefWrapper&&) = delete; | ||
TfGraphDefWrapper& operator=(TfGraphDefWrapper&&) = delete; | ||
const tensorflow::Session* getSession() const; | ||
|
||
private: | ||
tensorflow::Session* session_; | ||
}; | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// -*- C++ -*- | ||
// | ||
// Package: RecoTracker/FinalTrackSelectors | ||
// Class: TFGraphDefProducer | ||
// | ||
/**\class TFGraphDefProducer | ||
Description: Produces TfGraphRecord into the event containing a tensorflow GraphDef object that can be used for running inference on a pretrained network | ||
*/ | ||
// | ||
// Original Author: Joona Havukainen | ||
// Created: Fri, 24 Jul 2020 08:04:00 GMT | ||
// | ||
// | ||
|
||
// system include files | ||
#include <memory> | ||
|
||
// user include files | ||
#include "FWCore/Framework/interface/ModuleFactory.h" | ||
#include "FWCore/Framework/interface/ESProducer.h" | ||
|
||
#include "FWCore/Framework/interface/ESHandle.h" | ||
#include "TrackingTools/Records/interface/TfGraphRecord.h" | ||
#include "RecoTracker/FinalTrackSelectors/interface/TfGraphDefWrapper.h" | ||
|
||
// class declaration | ||
|
||
class TfGraphDefProducer : public edm::ESProducer { | ||
public: | ||
TfGraphDefProducer(const edm::ParameterSet&); | ||
using ReturnType = std::unique_ptr<TfGraphDefWrapper>; | ||
|
||
ReturnType produce(const TfGraphRecord&); | ||
|
||
static void fillDescriptions(edm::ConfigurationDescriptions& descriptions); | ||
|
||
private: | ||
const std::string filename_; | ||
// ----------member data --------------------------- | ||
}; | ||
|
||
TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig) | ||
: filename_(iConfig.getParameter<edm::FileInPath>("FileName").fullPath()) { | ||
auto componentName = iConfig.getParameter<std::string>("ComponentName"); | ||
setWhatProduced(this, componentName); | ||
} | ||
|
||
// ------------ method called to produce the data ------------ | ||
std::unique_ptr<TfGraphDefWrapper> TfGraphDefProducer::produce(const TfGraphRecord& iRecord) { | ||
return std::make_unique<TfGraphDefWrapper>(tensorflow::createSession(tensorflow::loadGraphDef(filename_), 1)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this actually leak the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looking into it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing it out! Per the recent CMSSW ML documentation, we need to indeed delete the graphDef by hand after closeSession (for some reason I was under the impression closeSession deletes the associated graphs). I will open a PR to fix this. I don't seem to observe any detectable effect on the memory, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix merged in #32945 |
||
} | ||
|
||
void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) { | ||
edm::ParameterSetDescription desc; | ||
desc.add<std::string>("ComponentName", "tfGraphDef"); | ||
desc.add<edm::FileInPath>("FileName"); | ||
descriptions.add("tfGraphDefProducer", desc); | ||
} | ||
|
||
//define this as a plug-in | ||
#include "FWCore/PluginManager/interface/ModuleDef.h" | ||
#include "FWCore/Framework/interface/MakerMacros.h" | ||
DEFINE_FWK_EVENTSETUP_MODULE(TfGraphDefProducer); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h" | ||
|
||
#include "FWCore/Framework/interface/EventSetup.h" | ||
#include "FWCore/Framework/interface/global/EDProducer.h" | ||
#include "DataFormats/TrackReco/interface/Track.h" | ||
#include "DataFormats/VertexReco/interface/Vertex.h" | ||
#include "FWCore/Framework/interface/ConsumesCollector.h" | ||
#include "getBestVertex.h" | ||
|
||
#include "TrackingTools/Records/interface/TfGraphRecord.h" | ||
#include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||
#include "RecoTracker/FinalTrackSelectors/interface/TfGraphDefWrapper.h" | ||
|
||
namespace { | ||
class TfDnn { | ||
public: | ||
TfDnn(const edm::ParameterSet& cfg) | ||
: tfDnnLabel_(cfg.getParameter<std::string>("tfDnnLabel")), | ||
session_(nullptr) | ||
|
||
{} | ||
|
||
static const char* name() { return "TrackTfClassifier"; } | ||
|
||
static void fillDescriptions(edm::ParameterSetDescription& desc) { | ||
desc.add<std::string>("tfDnnLabel", "trackSelectionTf"); | ||
} | ||
|
||
void beginStream() {} | ||
|
||
void initEvent(const edm::EventSetup& es) { | ||
if (session_ == nullptr) { | ||
edm::ESHandle<TfGraphDefWrapper> tfDnnHandle; | ||
es.get<TfGraphRecord>().get(tfDnnLabel_, tfDnnHandle); | ||
Comment on lines
+33
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
you can see this PR for an example: https://github.com/cms-sw/cmssw/pull/32246/files There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello, it seems like I need to include the new header file for esConsumes(), but I am not sure which header is supposed to used in here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. esConsumes needs to come from the EDProducer, or passed via ConsumesCollector. This will require a few changes to the other Track MVA classes. So that this PR here doesn't get stuck or expand in scope too much, let's defer the ESGetToken update for a follow-up. |
||
session_ = tfDnnHandle.product()->getSession(); | ||
} | ||
} | ||
|
||
float operator()(reco::Track const& trk, | ||
reco::BeamSpot const& beamSpot, | ||
reco::VertexCollection const& vertices) const { | ||
const auto& bestVertex = getBestVertex(trk, vertices); | ||
|
||
tensorflow::Tensor input1(tensorflow::DT_FLOAT, {1, 29}); | ||
tensorflow::Tensor input2(tensorflow::DT_FLOAT, {1, 1}); | ||
|
||
input1.matrix<float>()(0, 0) = trk.pt(); | ||
input1.matrix<float>()(0, 1) = trk.innerMomentum().x(); | ||
input1.matrix<float>()(0, 2) = trk.innerMomentum().y(); | ||
input1.matrix<float>()(0, 3) = trk.innerMomentum().z(); | ||
input1.matrix<float>()(0, 4) = trk.innerMomentum().rho(); | ||
input1.matrix<float>()(0, 5) = trk.outerMomentum().x(); | ||
input1.matrix<float>()(0, 6) = trk.outerMomentum().y(); | ||
input1.matrix<float>()(0, 7) = trk.outerMomentum().z(); | ||
input1.matrix<float>()(0, 8) = trk.outerMomentum().rho(); | ||
input1.matrix<float>()(0, 9) = trk.ptError(); | ||
input1.matrix<float>()(0, 10) = trk.dxy(bestVertex); | ||
input1.matrix<float>()(0, 11) = trk.dz(bestVertex); | ||
input1.matrix<float>()(0, 12) = trk.dxy(beamSpot.position()); | ||
input1.matrix<float>()(0, 13) = trk.dz(beamSpot.position()); | ||
input1.matrix<float>()(0, 14) = trk.dxyError(); | ||
input1.matrix<float>()(0, 15) = trk.dzError(); | ||
input1.matrix<float>()(0, 16) = trk.normalizedChi2(); | ||
input1.matrix<float>()(0, 17) = trk.eta(); | ||
input1.matrix<float>()(0, 18) = trk.phi(); | ||
input1.matrix<float>()(0, 19) = trk.etaError(); | ||
input1.matrix<float>()(0, 20) = trk.phiError(); | ||
input1.matrix<float>()(0, 21) = trk.hitPattern().numberOfValidPixelHits(); | ||
input1.matrix<float>()(0, 22) = trk.hitPattern().numberOfValidStripHits(); | ||
input1.matrix<float>()(0, 23) = trk.ndof(); | ||
input1.matrix<float>()(0, 24) = trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_INNER_HITS); | ||
input1.matrix<float>()(0, 25) = trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_OUTER_HITS); | ||
input1.matrix<float>()(0, 26) = | ||
trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_INNER_HITS); | ||
input1.matrix<float>()(0, 27) = | ||
trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_OUTER_HITS); | ||
input1.matrix<float>()(0, 28) = trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS); | ||
|
||
//Original algo as its own input, it will enter the graph so that it gets one-hot encoded, as is the preferred | ||
//format for categorical inputs, where the labels do not have any metric amongst them | ||
input2.matrix<float>()(0, 0) = trk.originalAlgo(); | ||
|
||
//The names for the input tensors get locked when freezing the trained tensorflow model. The NamedTensors must | ||
//match those names | ||
tensorflow::NamedTensorList inputs; | ||
inputs.resize(2); | ||
inputs[0] = tensorflow::NamedTensor("x", input1); | ||
inputs[1] = tensorflow::NamedTensor("y", input2); | ||
std::vector<tensorflow::Tensor> outputs; | ||
|
||
//evaluate the input | ||
tensorflow::run(const_cast<tensorflow::Session*>(session_), inputs, {"Identity"}, &outputs); | ||
//scale output to be [-1, 1] due to convention | ||
float output = 2.0 * outputs[0].matrix<float>()(0, 0) - 1.0; | ||
return output; | ||
} | ||
|
||
const std::string tfDnnLabel_; | ||
const tensorflow::Session* session_; | ||
}; | ||
|
||
using TrackTfClassifier = TrackMVAClassifier<TfDnn>; | ||
} // namespace | ||
#include "FWCore/PluginManager/interface/ModuleDef.h" | ||
#include "FWCore/Framework/interface/MakerMacros.h" | ||
|
||
DEFINE_FWK_MODULE(TrackTfClassifier); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from RecoTracker.FinalTrackSelectors.tfGraphDefProducer_cfi import tfGraphDefProducer as _tfGraphDefProducer | ||
trackSelectionTf = _tfGraphDefProducer.clone( | ||
ComponentName = "trackSelectionTf", | ||
FileName = "RecoTracker/FinalTrackSelectors/data/TrackTfClassifier/QCDFlatPU_QCDHighPt_ZEE_DisplacedSUSY_2020.pb" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#include "RecoTracker/FinalTrackSelectors/interface/TfGraphDefWrapper.h" | ||
#include "FWCore/Utilities/interface/typelookup.h" | ||
|
||
TYPELOOKUP_DATA_REG(TfGraphDefWrapper); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#include "RecoTracker/FinalTrackSelectors/interface/TfGraphDefWrapper.h" | ||
|
||
TfGraphDefWrapper::TfGraphDefWrapper(tensorflow::Session* session) : session_(session) {} | ||
const tensorflow::Session* TfGraphDefWrapper::getSession() const { | ||
return const_cast<const tensorflow::Session*>(session_); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
#for dnn classifier | ||
from Configuration.ProcessModifiers.trackdnn_cff import trackdnn | ||
from RecoTracker.IterativeTracking.dnnQualityCuts import qualityCutDictionary | ||
|
||
### STEP 0 ### | ||
|
||
|
@@ -315,11 +316,11 @@ | |
qualityCuts = [-0.95,-0.85,-0.75] | ||
)) | ||
|
||
from RecoTracker.FinalTrackSelectors.TrackLwtnnClassifier_cfi import * | ||
from RecoTracker.FinalTrackSelectors.trackSelectionLwtnn_cfi import * | ||
trackdnn.toReplaceWith(initialStep, TrackLwtnnClassifier.clone( | ||
from RecoTracker.FinalTrackSelectors.TrackTfClassifier_cfi import * | ||
from RecoTracker.FinalTrackSelectors.trackSelectionTf_cfi import * | ||
trackdnn.toReplaceWith(initialStep, TrackTfClassifier.clone( | ||
src = 'initialStepTracks', | ||
qualityCuts = [0.0, 0.3, 0.6] | ||
qualityCuts = qualityCutDictionary["InitialStep"] | ||
)) | ||
(trackdnn & fastSim).toModify(initialStep,vertices = 'firstStepPrimaryVerticesBeforeMixing') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change with the whitespace and quotes is not strictly needed, please revert There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be addressed for the final version |
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the destructor, it would be safer to make this class non-copyable and non-movable. In principle movability could be achieved, but that would require changes in the destructor. (currently an accidental copy of
TfGraphDefWrapper
could lead to premature closing of the Session)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have some suggestions to implement it? Maybe try something like ‘TfGraphDefWrapper & operator=(TfGraphDefWrapper&) = delete’ ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Matti for catching this!
@minxiyang per the rule of five https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-five, if you define a custom destructor, you should explicitly state what happens to the move and copy constructors&operators. In this case, they should be deleted: