From 3b4cacb80746a178d27ebd5dbbc049329296ce1a Mon Sep 17 00:00:00 2001 From: "hajohajo@helsinki.fi" Date: Tue, 22 Sep 2020 15:52:27 +0200 Subject: [PATCH] Add a record and a producer for holding tensorflow graphDef --- DataFormats/TrackTfGraph/BuildFile.xml | 5 ++ .../interface/TfGraphDefWrapper.h | 16 +++++ .../TrackTfGraph/plugins/BuildFile.xml | 11 ++++ .../plugins/TfGraphDefProducer.cc | 66 +++++++++++++++++++ .../TrackTfGraph/src/ES_TfGraphDefWrapper.cc | 4 ++ .../TrackTfGraph/src/TfGraphDefWrapper.cc | 7 ++ TrackingTools/Records/interface/Records.h | 1 + .../Records/interface/TfGraphRecord.h | 22 +++++++ TrackingTools/Records/src/TfGraphRecord.cc | 12 ++++ 9 files changed, 144 insertions(+) create mode 100644 DataFormats/TrackTfGraph/BuildFile.xml create mode 100644 DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h create mode 100644 DataFormats/TrackTfGraph/plugins/BuildFile.xml create mode 100644 DataFormats/TrackTfGraph/plugins/TfGraphDefProducer.cc create mode 100644 DataFormats/TrackTfGraph/src/ES_TfGraphDefWrapper.cc create mode 100644 DataFormats/TrackTfGraph/src/TfGraphDefWrapper.cc create mode 100644 TrackingTools/Records/interface/TfGraphRecord.h create mode 100644 TrackingTools/Records/src/TfGraphRecord.cc diff --git a/DataFormats/TrackTfGraph/BuildFile.xml b/DataFormats/TrackTfGraph/BuildFile.xml new file mode 100644 index 0000000000000..af42c46e755e8 --- /dev/null +++ b/DataFormats/TrackTfGraph/BuildFile.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h b/DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h new file mode 100644 index 0000000000000..16c001be5667a --- /dev/null +++ b/DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h @@ -0,0 +1,16 @@ +#ifndef TrackTfGraph_TfGraphDefWrapper_h +#define TrackTfGraph_TfGraphDefWrapper_h + +#include "PhysicsTools/TensorFlow/interface/TensorFlow.h" + +class TfGraphDefWrapper { + public: + TfGraphDefWrapper(tensorflow::GraphDef*); + tensorflow::GraphDef* GetGraphDef() const; + + private: + tensorflow::GraphDef* graphDef_; +}; + + +#endif diff --git a/DataFormats/TrackTfGraph/plugins/BuildFile.xml b/DataFormats/TrackTfGraph/plugins/BuildFile.xml new file mode 100644 index 0000000000000..8bbd37716de49 --- /dev/null +++ b/DataFormats/TrackTfGraph/plugins/BuildFile.xml @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/DataFormats/TrackTfGraph/plugins/TfGraphDefProducer.cc b/DataFormats/TrackTfGraph/plugins/TfGraphDefProducer.cc new file mode 100644 index 0000000000000..6522db9061b6a --- /dev/null +++ b/DataFormats/TrackTfGraph/plugins/TfGraphDefProducer.cc @@ -0,0 +1,66 @@ +// -*- C++ -*- +// +// Package: test/TFGraphDefProducer +// Class: TFGraphDefProducer +// +/**\class TFGraphDefProducer + Description: Produces TfGraphRecord into the event containing a tensorflow GraphDef object that can be used for running inference on a pretrained network +*/ +// +// Original Author: Joona Havukainen +// Created: Fri, 24 Jul 2020 08:04:00 GMT +// +// + +// system include files +#include + +// user include files +#include "FWCore/Framework/interface/ModuleFactory.h" +#include "FWCore/Framework/interface/ESProducer.h" + +#include "FWCore/Framework/interface/ESHandle.h" +#include "TrackingTools/Records/interface/TfGraphRecord.h" +#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" + +// class declaration + +class TfGraphDefProducer : public edm::ESProducer { +public: + TfGraphDefProducer(const edm::ParameterSet&); + + using ReturnType = std::unique_ptr; + + ReturnType produce(const TfGraphRecord&); + + static void fillDescriptions(edm::ConfigurationDescriptions& descriptions); + +private: + TfGraphDefWrapper wrapper_; + + // ----------member data --------------------------- +}; + +TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig): +wrapper_(TfGraphDefWrapper(tensorflow::loadGraphDef(iConfig.getParameter("FileName").fullPath()))) +{ + auto componentName = iConfig.getParameter("ComponentName"); + setWhatProduced(this, componentName); +} + +// ------------ method called to produce the data ------------ +std::unique_ptr TfGraphDefProducer::produce(const TfGraphRecord& iRecord) { + return std::unique_ptr(&wrapper_); +} + +void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) { + edm::ParameterSetDescription desc; + desc.add("ComponentName", "tfGraphDef"); + desc.add("FileName", edm::FileInPath()); + descriptions.add("tfGraphDefProducer", desc); +} + +//define this as a plug-in +#include "FWCore/PluginManager/interface/ModuleDef.h" +#include "FWCore/Framework/interface/MakerMacros.h" +DEFINE_FWK_EVENTSETUP_MODULE(TfGraphDefProducer); diff --git a/DataFormats/TrackTfGraph/src/ES_TfGraphDefWrapper.cc b/DataFormats/TrackTfGraph/src/ES_TfGraphDefWrapper.cc new file mode 100644 index 0000000000000..34e6a3cc68ebf --- /dev/null +++ b/DataFormats/TrackTfGraph/src/ES_TfGraphDefWrapper.cc @@ -0,0 +1,4 @@ +#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" +#include "FWCore/Utilities/interface/typelookup.h" + +TYPELOOKUP_DATA_REG(TfGraphDefWrapper); diff --git a/DataFormats/TrackTfGraph/src/TfGraphDefWrapper.cc b/DataFormats/TrackTfGraph/src/TfGraphDefWrapper.cc new file mode 100644 index 0000000000000..eb60489b30ec4 --- /dev/null +++ b/DataFormats/TrackTfGraph/src/TfGraphDefWrapper.cc @@ -0,0 +1,7 @@ +#include "DataFormats/TrackTfGraph/interface/TfGraphDefWrapper.h" + +TfGraphDefWrapper::TfGraphDefWrapper(tensorflow::GraphDef* graph) {graphDef_ = graph;} + +tensorflow::GraphDef* TfGraphDefWrapper::GetGraphDef() const { + return graphDef_; +} diff --git a/TrackingTools/Records/interface/Records.h b/TrackingTools/Records/interface/Records.h index f5e9e7d7e41b6..fd776205f933c 100644 --- a/TrackingTools/Records/interface/Records.h +++ b/TrackingTools/Records/interface/Records.h @@ -2,3 +2,4 @@ #include "TrackingTools/Records/interface/DetIdAssociatorRecord.h" #include "TrackingTools/Records/interface/TransientRecHitRecord.h" #include "TrackingTools/Records/interface/TransientTrackRecord.h" +#include "TrackingTools/Records/interface/TfGraphRecord.h" diff --git a/TrackingTools/Records/interface/TfGraphRecord.h b/TrackingTools/Records/interface/TfGraphRecord.h new file mode 100644 index 0000000000000..f5fec7147a11b --- /dev/null +++ b/TrackingTools/Records/interface/TfGraphRecord.h @@ -0,0 +1,22 @@ +#ifndef TfGraphRecord_TfGraphRecord_h +#define TfGraphRecord_TfGraphRecord_h +// -*- C++ -*- +// +// Package: TrackingTools/Records +// Class : TfGraphRecord +// +/**\class TfGraphRecord TfGraphRecord.h TrackingTools/Records/interface/TfGraphRecord.h + Description: Class to hold Record of a Tensorflow GraphDef that can be used to serve a pretrained tensorflow model for inference + Usage: + Used by DataFormats/TrackTfGraph to produce the GraphRecord and RecoTrack/FinalTrackSelection/plugins/TrackTfClassifier.cc to evaluate a track using the graph. +*/ +// +// Author: Joona Havukainen +// Created: Fri, 24 Jul 2020 07:39:35 GMT +// + +#include "FWCore/Framework/interface/EventSetupRecordImplementation.h" + +class TfGraphRecord : public edm::eventsetup::EventSetupRecordImplementation {}; + +#endif diff --git a/TrackingTools/Records/src/TfGraphRecord.cc b/TrackingTools/Records/src/TfGraphRecord.cc new file mode 100644 index 0000000000000..edcbf58893369 --- /dev/null +++ b/TrackingTools/Records/src/TfGraphRecord.cc @@ -0,0 +1,12 @@ +// -*- C++ -*- +// +// Package: TrackingTools/Records +// Class : TfGraphRecord +// +// Author: Joona Havukainen +// Created: Fri, 24 Jul 2020 07:39:35 GMT + +#include "TrackingTools/Records/interface/TfGraphRecord.h" +#include "FWCore/Framework/interface/eventsetuprecord_registration_macro.h" + +EVENTSETUP_RECORD_REG(TfGraphRecord);