ATLAS Offline Software
Public Member Functions | Private Member Functions | Private Attributes | List of all members
ElectronDNNCalculator Class Reference

Used by AsgElectronSelectorTool to calculate the score of a python trained DNN using lwtnn as interface to do electron ID. Also applies a transformation to the input variables based on a QuantileTransformer. More...

#include <ElectronDNNCalculator.h>

Inheritance diagram for ElectronDNNCalculator:
Collaboration diagram for ElectronDNNCalculator:

Public Member Functions

 ElectronDNNCalculator (AsgElectronSelectorTool *owner, const std::string &modelFileName, const std::string &quantileFileName, const std::vector< std::string > &variablesName, const bool multiClass)
 Constructor of the class. More...
 
 ~ElectronDNNCalculator ()
 Standard destructor. More...
 
Eigen::Matrix< float, -1, 1 > calculate (const std::vector< double > &) const
 Get the prediction of the DNN model. More...
 
bool msgLvl (const MSG::Level lvl) const
 Test the output level of the object. More...
 
MsgStream & msg () const
 The standard message stream. More...
 
MsgStream & msg (const MSG::Level lvl) const
 The standard message stream. More...
 

Private Member Functions

double transformInput (const std::vector< double > &quantiles, double value) const
 transform the input variables according to a given QuantileTransformer. More...
 
int readQuantileTransformer (TTree *tree)
 read the bins and values of the QuantileTransformer to transform the input variables. More...
 

Private Attributes

std::unique_ptr< lwt::generic::FastGraph< float > > m_graph = 0
 DNN interface via lwtnn. More...
 
std::vector< std::vector< double > > m_quantiles
 Quantile values for each variable that needs to be transformed with the QuantileTransformer. More...
 
std::vector< double > m_references
 Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1. More...
 
bool m_multiClass
 Whether the used model is a multiclass model or not. More...
 
std::vector< std::string > m_variables
 Model variables. More...
 
uint m_var_size
 
std::function< MsgStream &()> m_msg
 the message stream we use More...
 

Detailed Description

Used by AsgElectronSelectorTool to calculate the score of a python trained DNN using lwtnn as interface to do electron ID. Also applies a transformation to the input variables based on a QuantileTransformer.

Author
Lukas Ehrke

Definition at line 23 of file ElectronDNNCalculator.h.

Constructor & Destructor Documentation

◆ ElectronDNNCalculator()

ElectronDNNCalculator::ElectronDNNCalculator ( AsgElectronSelectorTool owner,
const std::string &  modelFileName,
const std::string &  quantileFileName,
const std::vector< std::string > &  variablesName,
const bool  multiClass 
)

Constructor of the class.

Definition at line 26 of file ElectronDNNCalculator.cxx.

30  :
32  m_quantiles(variables.size()),
33  m_multiClass(multiClass),
35  m_var_size(variables.size())
36 {
37  ATH_MSG_INFO("Initializing ElectronDNNCalculator...");
38 
39  if (modelFileName.empty()){
40  throw std::runtime_error("No file found at '" + modelFileName + "'");
41  }
42 
43  // Make an input file object
44  std::ifstream inputFile;
45 
46  // Open your trained model
47  ATH_MSG_INFO("Loading model: " << modelFileName.c_str());
48 
49  // Create input order for the NN, the data needs to be passed in this exact order
50  lwt::InputOrder order;
51  order.scalar.emplace_back("node_0", m_variables );
52 
53  // create the model
54  inputFile.open(modelFileName);
55  auto parsedGraph = lwt::parse_json_graph(inputFile);
56  // Test whether the number of outputs of the given network corresponds to the expected number
57  size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
58  if (nOutputs != 6 && nOutputs != 1){
59  throw std::runtime_error("Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
60  }
61  else if (nOutputs == 1 && m_multiClass){
62  throw std::runtime_error("Given model has 1 output but config file specifies mutliclass. Something is wrong");
63  }
64  else if (nOutputs == 6 && !m_multiClass){
65  throw std::runtime_error("Given model has 6 output but config file does not specify mutliclass. Something is wrong");
66  }
67 
68  m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph, order);
69 
70  if (quantileFileName.empty()){
71  throw std::runtime_error("No file found at '" + quantileFileName + "'");
72  }
73 
74  // Open quantiletransformer file
75  ATH_MSG_INFO("Loading QuantileTransformer " << quantileFileName);
76  std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
77  if (readQuantileTransformer((TTree*)qtfile->Get("tree")) == 0){
78  throw std::runtime_error("Could not load all variables for the QuantileTransformer");
79 
80  }
81 }

◆ ~ElectronDNNCalculator()

ElectronDNNCalculator::~ElectronDNNCalculator ( )
inline

Standard destructor.

Definition at line 34 of file ElectronDNNCalculator.h.

34 {};

Member Function Documentation

◆ calculate()

Eigen::Matrix< float, -1, 1 > ElectronDNNCalculator::calculate ( const std::vector< double > &  variableValues) const

Get the prediction of the DNN model.

Definition at line 85 of file ElectronDNNCalculator.cxx.

86 {
87 
88  if(variableValues.size() != m_var_size)
89  throw std::runtime_error("Passed vector of variables has wrong size");
90 
91  // Create the input for the model
92  Eigen::VectorXf inputVector(m_variables.size());
93 
94  // This has to be in the same order as the InputOrder was defined
95  for(uint i = 0; i < m_var_size; ++i){
96  inputVector(i) = transformInput(m_quantiles.at(i), variableValues.at(i));
97  }
98 
99  std::vector<Eigen::VectorXf> inp;
100  inp.emplace_back(std::move(inputVector));
101 
102  auto output = m_graph->compute(inp);
103  return output;
104 }

◆ msg() [1/2]

MsgStream & asg::AsgMessagingForward::msg ( ) const
inherited

The standard message stream.

Returns
A reference to the default message stream of this object.

Definition at line 24 of file AsgMessagingForward.cxx.

25  {
26  return m_msg();
27  }

◆ msg() [2/2]

MsgStream & asg::AsgMessagingForward::msg ( const MSG::Level  lvl) const
inherited

The standard message stream.

Parameters
lvlThe message level to set the stream to
Returns
A reference to the default message stream, set to level "lvl"

Definition at line 29 of file AsgMessagingForward.cxx.

30  {
31  MsgStream& msg = m_msg ();
32  msg << lvl;
33  return msg;
34  }

◆ msgLvl()

bool asg::AsgMessagingForward::msgLvl ( const MSG::Level  lvl) const
inherited

Test the output level of the object.

Parameters
lvlThe message level to test against
Returns
boolean Indicting if messages at given level will be printed
true If messages at level "lvl" will be printed

Definition at line 11 of file AsgMessagingForward.cxx.

12  {
13  MsgStream& msg = m_msg();
14  if (msg.level() <= lvl)
15  {
16  msg << lvl;
17  return true;
18  } else
19  {
20  return false;
21  }
22  }

◆ readQuantileTransformer()

int ElectronDNNCalculator::readQuantileTransformer ( TTree *  tree)
private

read the bins and values of the QuantileTransformer to transform the input variables.

Definition at line 142 of file ElectronDNNCalculator.cxx.

143 {
144  int sc(1);
145  // the reference bins to which the variables will be transformed to
146  double references;
147  sc = tree->SetBranchAddress("references", &references) == -5 ? 0 : 1;
148 
149  std::map<std::string, double> readVars;
150  for ( const auto& var : m_variables ){
151  sc = tree->SetBranchAddress(TString(var), &readVars[var]) == -5 ? 0 : 1;
152  }
153 
154  for (int ientry = 0; ientry < tree->GetEntries(); ientry++){
155  tree->GetEntry(ientry);
156  m_references.push_back(references);
157  for(uint ivar = 0; ivar < m_var_size; ++ivar){
158  m_quantiles.at(ivar).push_back(readVars[m_variables.at(ivar)]);
159  }
160  }
161  return sc;
162 }

◆ transformInput()

double ElectronDNNCalculator::transformInput ( const std::vector< double > &  quantiles,
double  value 
) const
private

transform the input variables according to a given QuantileTransformer.

Definition at line 110 of file ElectronDNNCalculator.cxx.

111 {
112  int size = quantiles.size();
113 
114  // if given value is outside of range of the given quantiles return min (0) or max (1) of references
115  if (value <= quantiles[0]) return m_references[0];
116  if (value >= quantiles[size-1]) return m_references[size-1];
117 
118  // find the bin where the value is smaller than the next bin (going from low bins to large bins)
119  auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(), value);
120  int lowBin = lowBound - quantiles.begin() - 1;
121 
122  // get x and y values on left and right side from value while going up
123  double xLup = quantiles[lowBin], yLup = m_references[lowBin], xRup = quantiles[lowBin+1], yRup = m_references[lowBin+1];
124 
125  // find the bin where the value is larger than the next bin (going from large bins to low bins)
126  auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(), value);
127  int upperBin = upperBound - quantiles.begin();
128 
129  // get x and y values on left and right side from value while going down
130  double xRdown = quantiles[upperBin], yRdown = m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown = m_references[upperBin-1];
131 
132  // calculate the gradients
133  double dydxup = ( yRup - yLup ) / ( xRup - xLup );
134  double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
135 
136  // average linear interpolation of up and down case
137  return 0.5 * ((yLup + dydxup * (value - xLup)) + (yLdown + dydxdown * (value - xLdown)));
138 }

Member Data Documentation

◆ m_graph

std::unique_ptr<lwt::generic::FastGraph<float> > ElectronDNNCalculator::m_graph = 0
private

DNN interface via lwtnn.

Definition at line 46 of file ElectronDNNCalculator.h.

◆ m_msg

std::function<MsgStream& ()> asg::AsgMessagingForward::m_msg
privateinherited

the message stream we use

This used to be a simple pointer to the MsgStream itself, but in AthenaMT the actual object used is local to the thread. So instead of pointing to it directly we are now using a function to look it up, which will get the thread-local object.

Definition at line 77 of file AsgMessagingForward.h.

◆ m_multiClass

bool ElectronDNNCalculator::m_multiClass
private

Whether the used model is a multiclass model or not.

Definition at line 53 of file ElectronDNNCalculator.h.

◆ m_quantiles

std::vector<std::vector<double> > ElectronDNNCalculator::m_quantiles
private

Quantile values for each variable that needs to be transformed with the QuantileTransformer.

Definition at line 49 of file ElectronDNNCalculator.h.

◆ m_references

std::vector<double> ElectronDNNCalculator::m_references
private

Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.

Definition at line 51 of file ElectronDNNCalculator.h.

◆ m_var_size

uint ElectronDNNCalculator::m_var_size
private

Definition at line 56 of file ElectronDNNCalculator.h.

◆ m_variables

std::vector<std::string> ElectronDNNCalculator::m_variables
private

Model variables.

Definition at line 55 of file ElectronDNNCalculator.h.


The documentation for this class was generated from the following files:
asg::AsgMessagingForward::m_msg
std::function< MsgStream &()> m_msg
the message stream we use
Definition: AsgMessagingForward.h:77
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
ElectronDNNCalculator::m_graph
std::unique_ptr< lwt::generic::FastGraph< float > > m_graph
DNN interface via lwtnn.
Definition: ElectronDNNCalculator.h:46
ChangeHistoRange.lowBound
lowBound
Definition: ChangeHistoRange.py:29
ElectronDNNCalculator::transformInput
double transformInput(const std::vector< double > &quantiles, double value) const
transform the input variables according to a given QuantileTransformer.
Definition: ElectronDNNCalculator.cxx:110
checkCoolLatestUpdate.variables
variables
Definition: checkCoolLatestUpdate.py:13
asg::AsgMessagingForward
base class to forward messages to another class
Definition: AsgMessagingForward.h:29
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
ElectronDNNCalculator::m_references
std::vector< double > m_references
Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
Definition: ElectronDNNCalculator.h:51
ElectronDNNCalculator::m_multiClass
bool m_multiClass
Whether the used model is a multiclass model or not.
Definition: ElectronDNNCalculator.h:53
ElectronDNNCalculator::m_variables
std::vector< std::string > m_variables
Model variables.
Definition: ElectronDNNCalculator.h:55
tree
TChain * tree
Definition: tile_monitor.h:30
athena.value
value
Definition: athena.py:124
references
std::vector< reference > references
Definition: hcg.cxx:522
asg::AsgMessagingForward::msg
MsgStream & msg() const
The standard message stream.
Definition: AsgMessagingForward.cxx:24
AthenaPoolTestRead.sc
sc
Definition: AthenaPoolTestRead.py:27
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
uint
unsigned int uint
Definition: LArOFPhaseFill.cxx:20
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
lumiFormat.i
int i
Definition: lumiFormat.py:85
ElectronDNNCalculator::readQuantileTransformer
int readQuantileTransformer(TTree *tree)
read the bins and values of the QuantileTransformer to transform the input variables.
Definition: ElectronDNNCalculator.cxx:142
mc.order
order
Configure Herwig7.
Definition: mc.Herwig7_Dijet.py:12
plotBeamSpotCompare.ivar
int ivar
Definition: plotBeamSpotCompare.py:383
ElectronDNNCalculator::m_quantiles
std::vector< std::vector< double > > m_quantiles
Quantile values for each variable that needs to be transformed with the QuantileTransformer.
Definition: ElectronDNNCalculator.h:49
merge.output
output
Definition: merge.py:17
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
ElectronDNNCalculator::m_var_size
uint m_var_size
Definition: ElectronDNNCalculator.h:56