-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a record and a producer for holding tensorflow graphDef
- Loading branch information
1 parent
92154ca
commit 3b4cacb
Showing
9 changed files
with
144 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
<use name="PhysicsTools/TensorFlow"/> | ||
<use name="RecoTracker/Record"/> | ||
<export> | ||
<lib name="1"/> | ||
</export> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef TrackTfGraph_TfGraphDefWrapper_h | ||
#define TrackTfGraph_TfGraphDefWrapper_h | ||
|
||
#include "PhysicsTools/TensorFlow/interface/TensorFlow.h" | ||
|
||
class TfGraphDefWrapper { | ||
public: | ||
TfGraphDefWrapper(tensorflow::GraphDef*); | ||
tensorflow::GraphDef* GetGraphDef() const; | ||
|
||
private: | ||
tensorflow::GraphDef* graphDef_; | ||
}; | ||
|
||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<use name="FWCore/Framework"/> | ||
<use name="FWCore/ParameterSet"/> | ||
<use name="FWCore/PluginManager"/> | ||
<use name="PhysicsTools/TensorFlow"/> | ||
<use name="FWCore/Utilities"/> | ||
<use name="RecoTracker/Record"/> | ||
<use name="DataFormats/TrackTfGraph"/> | ||
<flags EDM_PLUGIN="1"/> | ||
<export> | ||
<lib name="1"/> | ||
</export> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// -*- C++ -*- | ||
// | ||
// Package: test/TFGraphDefProducer | ||
// Class: TFGraphDefProducer | ||
// | ||
/**\class TFGraphDefProducer | ||
Description: Produces TfGraphRecord into the event containing a tensorflow GraphDef object that can be used for running inference on a pretrained network | ||
*/ | ||
// | ||
// Original Author: Joona Havukainen | ||
// Created: Fri, 24 Jul 2020 08:04:00 GMT | ||
// | ||
// | ||
|
||
// system include files | ||
#include <memory> | ||
|
||
// user include files | ||
#include "FWCore/Framework/interface/ModuleFactory.h" | ||
#include "FWCore/Framework/interface/ESProducer.h" | ||
|
||
#include "FWCore/Framework/interface/ESHandle.h" | ||
#include "TrackingTools/Records/interface/TfGraphRecord.h" | ||
#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" | ||
|
||
// class declaration | ||
|
||
class TfGraphDefProducer : public edm::ESProducer { | ||
public: | ||
TfGraphDefProducer(const edm::ParameterSet&); | ||
|
||
using ReturnType = std::unique_ptr<TfGraphDefWrapper>; | ||
|
||
ReturnType produce(const TfGraphRecord&); | ||
|
||
static void fillDescriptions(edm::ConfigurationDescriptions& descriptions); | ||
|
||
private: | ||
TfGraphDefWrapper wrapper_; | ||
|
||
// ----------member data --------------------------- | ||
}; | ||
|
||
TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig): | ||
wrapper_(TfGraphDefWrapper(tensorflow::loadGraphDef(iConfig.getParameter<edm::FileInPath>("FileName").fullPath()))) | ||
{ | ||
auto componentName = iConfig.getParameter<std::string>("ComponentName"); | ||
setWhatProduced(this, componentName); | ||
} | ||
|
||
// ------------ method called to produce the data ------------ | ||
std::unique_ptr<TfGraphDefWrapper> TfGraphDefProducer::produce(const TfGraphRecord& iRecord) { | ||
return std::unique_ptr<TfGraphDefWrapper>(&wrapper_); | ||
} | ||
|
||
void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) { | ||
edm::ParameterSetDescription desc; | ||
desc.add<std::string>("ComponentName", "tfGraphDef"); | ||
desc.add<edm::FileInPath>("FileName", edm::FileInPath()); | ||
descriptions.add("tfGraphDefProducer", desc); | ||
} | ||
|
||
//define this as a plug-in | ||
#include "FWCore/PluginManager/interface/ModuleDef.h" | ||
#include "FWCore/Framework/interface/MakerMacros.h" | ||
DEFINE_FWK_EVENTSETUP_MODULE(TfGraphDefProducer); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" | ||
#include "FWCore/Utilities/interface/typelookup.h" | ||
|
||
TYPELOOKUP_DATA_REG(TfGraphDefWrapper); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" | ||
|
||
TfGraphDefWrapper::TfGraphDefWrapper(tensorflow::GraphDef* graph) {graphDef_ = graph;} | ||
|
||
tensorflow::GraphDef* TfGraphDefWrapper::GetGraphDef() const { | ||
return graphDef_; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#ifndef TfGraphRecord_TfGraphRecord_h | ||
#define TfGraphRecord_TfGraphRecord_h | ||
// -*- C++ -*- | ||
// | ||
// Package: TrackingTools/Records | ||
// Class : TfGraphRecord | ||
// | ||
/**\class TfGraphRecord TfGraphRecord.h TrackingTools/Records/interface/TfGraphRecord.h | ||
Description: Class to hold Record of a Tensorflow GraphDef that can be used to serve a pretrained tensorflow model for inference | ||
Usage: | ||
Used by DataFormats/TrackTfGraph to produce the GraphRecord and RecoTrack/FinalTrackSelection/plugins/TrackTfClassifier.cc to evaluate a track using the graph. | ||
*/ | ||
// | ||
// Author: Joona Havukainen | ||
// Created: Fri, 24 Jul 2020 07:39:35 GMT | ||
// | ||
|
||
#include "FWCore/Framework/interface/EventSetupRecordImplementation.h" | ||
|
||
class TfGraphRecord : public edm::eventsetup::EventSetupRecordImplementation<TfGraphRecord> {}; | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// -*- C++ -*- | ||
// | ||
// Package: TrackingTools/Records | ||
// Class : TfGraphRecord | ||
// | ||
// Author: Joona Havukainen | ||
// Created: Fri, 24 Jul 2020 07:39:35 GMT | ||
|
||
#include "TrackingTools/Records/interface/TfGraphRecord.h" | ||
#include "FWCore/Framework/interface/eventsetuprecord_registration_macro.h" | ||
|
||
EVENTSETUP_RECORD_REG(TfGraphRecord); |