ATLAS Offline Software
Loading...
Searching...
No Matches
ElectronDNNCalculator Class Reference

Used by AsgElectronSelectorTool to calculate the score of a python trained DNN using lwtnn as interface to do electron ID. More...

#include <ElectronDNNCalculator.h>

Inheritance diagram for ElectronDNNCalculator:
Collaboration diagram for ElectronDNNCalculator:

Public Member Functions

 ElectronDNNCalculator (AsgElectronSelectorTool *owner, const std::string &modelFileName, const std::string &quantileFileName, const std::vector< std::string > &variablesName, const bool multiClass)
 Constructor of the class.
 ~ElectronDNNCalculator ()
 Standard destructor.
Eigen::Matrix< float, -1, 1 > calculate (const std::vector< double > &) const
 Get the prediction of the DNN model.
bool msgLvl (const MSG::Level lvl) const
 Test the output level of the object.
MsgStream & msg () const
 The standard message stream.
MsgStream & msg (const MSG::Level lvl) const
 The standard message stream.

Private Member Functions

double transformInput (const std::vector< double > &quantiles, double value) const
 transform the input variables according to a given QuantileTransformer.
int readQuantileTransformer (TTree *tree)
 read the bins and values of the QuantileTransformer to transform the input variables.

Private Attributes

std::unique_ptr< lwt::generic::FastGraph< float > > m_graph = 0
 DNN interface via lwtnn.
std::vector< std::vector< double > > m_quantiles
 Quantile values for each variable that needs to be transformed with the QuantileTransformer.
std::vector< double > m_references
 Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
bool m_multiClass
 Whether the used model is a multiclass model or not.
std::vector< std::string > m_variables
 Model variables.
uint m_var_size
std::function< MsgStream &()> m_msg
 the message stream we use

Detailed Description

Used by AsgElectronSelectorTool to calculate the score of a python trained DNN using lwtnn as interface to do electron ID.

Also applies a transformation to the input variables based on a QuantileTransformer.

Author
Lukas Ehrke

Definition at line 23 of file ElectronDNNCalculator.h.

Constructor & Destructor Documentation

◆ ElectronDNNCalculator()

ElectronDNNCalculator::ElectronDNNCalculator ( AsgElectronSelectorTool * owner,
const std::string & modelFileName,
const std::string & quantileFileName,
const std::vector< std::string > & variablesName,
const bool multiClass )

Constructor of the class.

Definition at line 26 of file ElectronDNNCalculator.cxx.

30 :
31 asg::AsgMessagingForward(owner),
32 m_quantiles(variables.size()),
33 m_multiClass(multiClass),
34 m_variables(variables),
35 m_var_size(variables.size())
36{
37 ATH_MSG_INFO("Initializing ElectronDNNCalculator...");
38
39 if (modelFileName.empty()){
40 throw std::runtime_error("No file found at '" + modelFileName + "'");
41 }
42
43 // Make an input file object
44 std::ifstream inputFile;
45
46 // Open your trained model
47 ATH_MSG_INFO("Loading model: " << modelFileName.c_str());
48
49 // Create input order for the NN, the data needs to be passed in this exact order
50 lwt::InputOrder order;
51 order.scalar.emplace_back("node_0", m_variables );
52
53 // create the model
54 inputFile.open(modelFileName);
55 auto parsedGraph = lwt::parse_json_graph(inputFile);
56 // Test whether the number of outputs of the given network corresponds to the expected number
57 size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
58 if (nOutputs != 6 && nOutputs != 1){
59 throw std::runtime_error("Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
60 }
61 else if (nOutputs == 1 && m_multiClass){
62 throw std::runtime_error("Given model has 1 output but config file specifies mutliclass. Something is wrong");
63 }
64 else if (nOutputs == 6 && !m_multiClass){
65 throw std::runtime_error("Given model has 6 output but config file does not specify mutliclass. Something is wrong");
66 }
67
68 m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph, order);
69
70 if (quantileFileName.empty()){
71 throw std::runtime_error("No file found at '" + quantileFileName + "'");
72 }
73
74 // Open quantiletransformer file
75 ATH_MSG_INFO("Loading QuantileTransformer " << quantileFileName);
76 std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
77 if (readQuantileTransformer((TTree*)qtfile->Get("tree")) == 0){
78 throw std::runtime_error("Could not load all variables for the QuantileTransformer");
79
80 }
81}
#define ATH_MSG_INFO(x)
std::vector< std::vector< double > > m_quantiles
Quantile values for each variable that needs to be transformed with the QuantileTransformer.
std::unique_ptr< lwt::generic::FastGraph< float > > m_graph
DNN interface via lwtnn.
bool m_multiClass
Whether the used model is a multiclass model or not.
std::vector< std::string > m_variables
Model variables.
int readQuantileTransformer(TTree *tree)
read the bins and values of the QuantileTransformer to transform the input variables.
order
Configure Herwig7.

◆ ~ElectronDNNCalculator()

ElectronDNNCalculator::~ElectronDNNCalculator ( )
inline

Standard destructor.

Definition at line 34 of file ElectronDNNCalculator.h.

34{};

Member Function Documentation

◆ calculate()

Eigen::Matrix< float, -1, 1 > ElectronDNNCalculator::calculate ( const std::vector< double > & variableValues) const

Get the prediction of the DNN model.

Definition at line 85 of file ElectronDNNCalculator.cxx.

86{
87
88 if(variableValues.size() != m_var_size)
89 throw std::runtime_error("Passed vector of variables has wrong size");
90
91 // Create the input for the model
92 Eigen::VectorXf inputVector(m_variables.size());
93
94 // This has to be in the same order as the InputOrder was defined
95 for(uint i = 0; i < m_var_size; ++i){
96 inputVector(i) = transformInput(m_quantiles.at(i), variableValues.at(i));
97 }
98
99 std::vector<Eigen::VectorXf> inp;
100 inp.emplace_back(std::move(inputVector));
101
102 auto output = m_graph->compute(inp);
103 return output;
104}
unsigned int uint
double transformInput(const std::vector< double > &quantiles, double value) const
transform the input variables according to a given QuantileTransformer.
output
Definition merge.py:16

◆ msg() [1/2]

MsgStream & asg::AsgMessagingForward::msg ( ) const
inherited

The standard message stream.

Returns
A reference to the default message stream of this object.

Definition at line 24 of file AsgMessagingForward.cxx.

25 {
26 return m_msg();
27 }
std::function< MsgStream &()> m_msg
the message stream we use

◆ msg() [2/2]

MsgStream & asg::AsgMessagingForward::msg ( const MSG::Level lvl) const
inherited

The standard message stream.

Parameters
lvlThe message level to set the stream to
Returns
A reference to the default message stream, set to level "lvl"

Definition at line 29 of file AsgMessagingForward.cxx.

30 {
31 MsgStream& msg = m_msg ();
32 msg << lvl;
33 return msg;
34 }
MsgStream & msg() const
The standard message stream.

◆ msgLvl()

bool asg::AsgMessagingForward::msgLvl ( const MSG::Level lvl) const
inherited

Test the output level of the object.

Parameters
lvlThe message level to test against
Returns
boolean Indicting if messages at given level will be printed
true If messages at level "lvl" will be printed

Definition at line 11 of file AsgMessagingForward.cxx.

12 {
13 MsgStream& msg = m_msg();
14 if (msg.level() <= lvl)
15 {
16 msg << lvl;
17 return true;
18 } else
19 {
20 return false;
21 }
22 }

◆ readQuantileTransformer()

int ElectronDNNCalculator::readQuantileTransformer ( TTree * tree)
private

read the bins and values of the QuantileTransformer to transform the input variables.

Definition at line 142 of file ElectronDNNCalculator.cxx.

143{
144 int sc(1);
145 // the reference bins to which the variables will be transformed to
146 double references;
147 sc = tree->SetBranchAddress("references", &references) == -5 ? 0 : 1;
148
149 std::map<std::string, double> readVars;
150 for ( const auto& var : m_variables ){
151 sc = tree->SetBranchAddress(TString(var), &readVars[var]) == -5 ? 0 : 1;
152 }
153
154 for (int ientry = 0; ientry < tree->GetEntries(); ientry++){
155 tree->GetEntry(ientry);
156 m_references.push_back(references);
157 for(uint ivar = 0; ivar < m_var_size; ++ivar){
158 m_quantiles.at(ivar).push_back(readVars[m_variables.at(ivar)]);
159 }
160 }
161 return sc;
162}
static Double_t sc
std::vector< double > m_references
Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
std::vector< reference > references
Definition hcg.cxx:523
TChain * tree

◆ transformInput()

double ElectronDNNCalculator::transformInput ( const std::vector< double > & quantiles,
double value ) const
private

transform the input variables according to a given QuantileTransformer.

Definition at line 110 of file ElectronDNNCalculator.cxx.

111{
112 int size = quantiles.size();
113
114 // if given value is outside of range of the given quantiles return min (0) or max (1) of references
115 if (value <= quantiles[0]) return m_references[0];
116 if (value >= quantiles[size-1]) return m_references[size-1];
117
118 // find the bin where the value is smaller than the next bin (going from low bins to large bins)
119 auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(), value);
120 int lowBin = lowBound - quantiles.begin() - 1;
121
122 // get x and y values on left and right side from value while going up
123 double xLup = quantiles[lowBin], yLup = m_references[lowBin], xRup = quantiles[lowBin+1], yRup = m_references[lowBin+1];
124
125 // find the bin where the value is larger than the next bin (going from large bins to low bins)
126 auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(), value);
127 int upperBin = upperBound - quantiles.begin();
128
129 // get x and y values on left and right side from value while going down
130 double xRdown = quantiles[upperBin], yRdown = m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown = m_references[upperBin-1];
131
132 // calculate the gradients
133 double dydxup = ( yRup - yLup ) / ( xRup - xLup );
134 double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
135
136 // average linear interpolation of up and down case
137 return 0.5 * ((yLup + dydxup * (value - xLup)) + (yLdown + dydxdown * (value - xLdown)));
138}

Member Data Documentation

◆ m_graph

std::unique_ptr<lwt::generic::FastGraph<float> > ElectronDNNCalculator::m_graph = 0
private

DNN interface via lwtnn.

Definition at line 46 of file ElectronDNNCalculator.h.

◆ m_msg

std::function<MsgStream& ()> asg::AsgMessagingForward::m_msg
privateinherited

the message stream we use

This used to be a simple pointer to the MsgStream itself, but in AthenaMT the actual object used is local to the thread. So instead of pointing to it directly we are now using a function to look it up, which will get the thread-local object.

Definition at line 77 of file AsgMessagingForward.h.

◆ m_multiClass

bool ElectronDNNCalculator::m_multiClass
private

Whether the used model is a multiclass model or not.

Definition at line 53 of file ElectronDNNCalculator.h.

◆ m_quantiles

std::vector<std::vector<double> > ElectronDNNCalculator::m_quantiles
private

Quantile values for each variable that needs to be transformed with the QuantileTransformer.

Definition at line 49 of file ElectronDNNCalculator.h.

◆ m_references

std::vector<double> ElectronDNNCalculator::m_references
private

Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.

Definition at line 51 of file ElectronDNNCalculator.h.

◆ m_var_size

uint ElectronDNNCalculator::m_var_size
private

Definition at line 56 of file ElectronDNNCalculator.h.

◆ m_variables

std::vector<std::string> ElectronDNNCalculator::m_variables
private

Model variables.

Definition at line 55 of file ElectronDNNCalculator.h.


The documentation for this class was generated from the following files: