Skip to content

Commit

Permalink
Add a record and a producer for holding tensorflow graphDef
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] authored and [Minxi committed Nov 9, 2020
1 parent 92154ca commit 3b4cacb
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 0 deletions.
5 changes: 5 additions & 0 deletions DataFormats/TrackTfGraph/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<use name="PhysicsTools/TensorFlow"/>
<use name="RecoTracker/Record"/>
<export>
<lib name="1"/>
</export>
16 changes: 16 additions & 0 deletions DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef TrackTfGraph_TfGraphDefWrapper_h
#define TrackTfGraph_TfGraphDefWrapper_h

#include "PhysicsTools/TensorFlow/interface/TensorFlow.h"

class TfGraphDefWrapper {
public:
TfGraphDefWrapper(tensorflow::GraphDef*);
tensorflow::GraphDef* GetGraphDef() const;

private:
tensorflow::GraphDef* graphDef_;
};


#endif
11 changes: 11 additions & 0 deletions DataFormats/TrackTfGraph/plugins/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<use name="FWCore/Framework"/>
<use name="FWCore/ParameterSet"/>
<use name="FWCore/PluginManager"/>
<use name="PhysicsTools/TensorFlow"/>
<use name="FWCore/Utilities"/>
<use name="RecoTracker/Record"/>
<use name="DataFormats/TrackTfGraph"/>
<flags EDM_PLUGIN="1"/>
<export>
<lib name="1"/>
</export>
66 changes: 66 additions & 0 deletions DataFormats/TrackTfGraph/plugins/TfGraphDefProducer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// -*- C++ -*-
//
// Package: test/TFGraphDefProducer
// Class: TFGraphDefProducer
//
/**\class TFGraphDefProducer
Description: Produces TfGraphRecord into the event containing a tensorflow GraphDef object that can be used for running inference on a pretrained network
*/
//
// Original Author: Joona Havukainen
// Created: Fri, 24 Jul 2020 08:04:00 GMT
//
//

// system include files
#include <memory>

// user include files
#include "FWCore/Framework/interface/ModuleFactory.h"
#include "FWCore/Framework/interface/ESProducer.h"

#include "FWCore/Framework/interface/ESHandle.h"
#include "TrackingTools/Records/interface/TfGraphRecord.h"
#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h"

// class declaration

class TfGraphDefProducer : public edm::ESProducer {
public:
TfGraphDefProducer(const edm::ParameterSet&);

using ReturnType = std::unique_ptr<TfGraphDefWrapper>;

ReturnType produce(const TfGraphRecord&);

static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);

private:
TfGraphDefWrapper wrapper_;

// ----------member data ---------------------------
};

TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig):
wrapper_(TfGraphDefWrapper(tensorflow::loadGraphDef(iConfig.getParameter<edm::FileInPath>("FileName").fullPath())))
{
auto componentName = iConfig.getParameter<std::string>("ComponentName");
setWhatProduced(this, componentName);
}

// ------------ method called to produce the data ------------
std::unique_ptr<TfGraphDefWrapper> TfGraphDefProducer::produce(const TfGraphRecord& iRecord) {
return std::unique_ptr<TfGraphDefWrapper>(&wrapper_);
}

void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
edm::ParameterSetDescription desc;
desc.add<std::string>("ComponentName", "tfGraphDef");
desc.add<edm::FileInPath>("FileName", edm::FileInPath());
descriptions.add("tfGraphDefProducer", desc);
}

//define this as a plug-in
#include "FWCore/PluginManager/interface/ModuleDef.h"
#include "FWCore/Framework/interface/MakerMacros.h"
DEFINE_FWK_EVENTSETUP_MODULE(TfGraphDefProducer);
4 changes: 4 additions & 0 deletions DataFormats/TrackTfGraph/src/ES_TfGraphDefWrapper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h"
#include "FWCore/Utilities/interface/typelookup.h"

TYPELOOKUP_DATA_REG(TfGraphDefWrapper);
7 changes: 7 additions & 0 deletions DataFormats/TrackTfGraph/src/TfGraphDefWrapper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h"

TfGraphDefWrapper::TfGraphDefWrapper(tensorflow::GraphDef* graph) {graphDef_ = graph;}

tensorflow::GraphDef* TfGraphDefWrapper::GetGraphDef() const {
return graphDef_;
}
1 change: 1 addition & 0 deletions TrackingTools/Records/interface/Records.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
#include "TrackingTools/Records/interface/DetIdAssociatorRecord.h"
#include "TrackingTools/Records/interface/TransientRecHitRecord.h"
#include "TrackingTools/Records/interface/TransientTrackRecord.h"
#include "TrackingTools/Records/interface/TfGraphRecord.h"
22 changes: 22 additions & 0 deletions TrackingTools/Records/interface/TfGraphRecord.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef TfGraphRecord_TfGraphRecord_h
#define TfGraphRecord_TfGraphRecord_h
// -*- C++ -*-
//
// Package: TrackingTools/Records
// Class : TfGraphRecord
//
/**\class TfGraphRecord TfGraphRecord.h TrackingTools/Records/interface/TfGraphRecord.h
Description: Class to hold Record of a Tensorflow GraphDef that can be used to serve a pretrained tensorflow model for inference
Usage:
Used by DataFormats/TrackTfGraph to produce the GraphRecord and RecoTrack/FinalTrackSelection/plugins/TrackTfClassifier.cc to evaluate a track using the graph.
*/
//
// Author: Joona Havukainen
// Created: Fri, 24 Jul 2020 07:39:35 GMT
//

#include "FWCore/Framework/interface/EventSetupRecordImplementation.h"

class TfGraphRecord : public edm::eventsetup::EventSetupRecordImplementation<TfGraphRecord> {};

#endif
12 changes: 12 additions & 0 deletions TrackingTools/Records/src/TfGraphRecord.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// -*- C++ -*-
//
// Package: TrackingTools/Records
// Class : TfGraphRecord
//
// Author: Joona Havukainen
// Created: Fri, 24 Jul 2020 07:39:35 GMT

#include "TrackingTools/Records/interface/TfGraphRecord.h"
#include "FWCore/Framework/interface/eventsetuprecord_registration_macro.h"

EVENTSETUP_RECORD_REG(TfGraphRecord);

0 comments on commit 3b4cacb

Please sign in to comment.