ATLAS Offline Software
ElectronDNNCalculator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 */
4 
14 #include "ElectronDNNCalculator.h"
16 #include "TSystem.h"
17 #include "TFile.h"
18 #include <map>
19 #include <fstream>
20 #include <iostream>
21 #include "lwtnn/parse_json.hh"
22 #include "lwtnn/InputOrder.hh"
23 #include <Eigen/Dense>
24 
25 
27  const std::string& modelFileName,
28  const std::string& quantileFileName,
29  const std::vector<std::string>& variables,
30  const bool multiClass) :
31  asg::AsgMessagingForward(owner),
32  m_quantiles(variables.size()),
33  m_multiClass(multiClass),
34  m_variables(variables),
35  m_var_size(variables.size())
36 {
37  ATH_MSG_INFO("Initializing ElectronDNNCalculator...");
38 
39  if (modelFileName.empty()){
40  throw std::runtime_error("No file found at '" + modelFileName + "'");
41  }
42 
43  // Make an input file object
44  std::ifstream inputFile;
45 
46  // Open your trained model
47  ATH_MSG_INFO("Loading model: " << modelFileName.c_str());
48 
49  // Create input order for the NN, the data needs to be passed in this exact order
50  lwt::InputOrder order;
51  order.scalar.emplace_back("node_0", m_variables );
52 
53  // create the model
54  inputFile.open(modelFileName);
55  auto parsedGraph = lwt::parse_json_graph(inputFile);
56  // Test whether the number of outputs of the given network corresponds to the expected number
57  size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
58  if (nOutputs != 6 && nOutputs != 1){
59  throw std::runtime_error("Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
60  }
61  else if (nOutputs == 1 && m_multiClass){
62  throw std::runtime_error("Given model has 1 output but config file specifies mutliclass. Something is wrong");
63  }
64  else if (nOutputs == 6 && !m_multiClass){
65  throw std::runtime_error("Given model has 6 output but config file does not specify mutliclass. Something is wrong");
66  }
67 
68  m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph, order);
69 
70  if (quantileFileName.empty()){
71  throw std::runtime_error("No file found at '" + quantileFileName + "'");
72  }
73 
74  // Open quantiletransformer file
75  ATH_MSG_INFO("Loading QuantileTransformer " << quantileFileName);
76  std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
77  if (readQuantileTransformer((TTree*)qtfile->Get("tree")) == 0){
78  throw std::runtime_error("Could not load all variables for the QuantileTransformer");
79 
80  }
81 }
82 
83 
84 // takes the input variables, transforms them according to the given QuantileTransformer and predicts the DNN value(s)
85 Eigen::Matrix<float, -1, 1> ElectronDNNCalculator::calculate( const std::vector<double>& variableValues ) const
86 {
87 
88  if(variableValues.size() != m_var_size)
89  throw std::runtime_error("Passed vector of variables has wrong size");
90 
91  // Create the input for the model
92  Eigen::VectorXf inputVector(m_variables.size());
93 
94  // This has to be in the same order as the InputOrder was defined
95  for(uint i = 0; i < m_var_size; ++i){
96  inputVector(i) = transformInput(m_quantiles.at(i), variableValues.at(i));
97  }
98 
99  std::vector<Eigen::VectorXf> inp;
100  inp.emplace_back(std::move(inputVector));
101 
102  auto output = m_graph->compute(inp);
103  return output;
104 }
105 
106 
107 // transform the input based on a QuantileTransformer. quantiles are bins in the variable, while references are bins from 0 to 1
108 // The interpolation is done averaging the interpolation going from small to large bins and going from large to small bins
109 // to deal with non strictly monotonic rising bins.
110 double ElectronDNNCalculator::transformInput( const std::vector<double>& quantiles, double value ) const
111 {
112  int size = quantiles.size();
113 
114  // if given value is outside of range of the given quantiles return min (0) or max (1) of references
115  if (value <= quantiles[0]) return m_references[0];
116  if (value >= quantiles[size-1]) return m_references[size-1];
117 
118  // find the bin where the value is smaller than the next bin (going from low bins to large bins)
119  auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(), value);
120  int lowBin = lowBound - quantiles.begin() - 1;
121 
122  // get x and y values on left and right side from value while going up
123  double xLup = quantiles[lowBin], yLup = m_references[lowBin], xRup = quantiles[lowBin+1], yRup = m_references[lowBin+1];
124 
125  // find the bin where the value is larger than the next bin (going from large bins to low bins)
126  auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(), value);
127  int upperBin = upperBound - quantiles.begin();
128 
129  // get x and y values on left and right side from value while going down
130  double xRdown = quantiles[upperBin], yRdown = m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown = m_references[upperBin-1];
131 
132  // calculate the gradients
133  double dydxup = ( yRup - yLup ) / ( xRup - xLup );
134  double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
135 
136  // average linear interpolation of up and down case
137  return 0.5 * ((yLup + dydxup * (value - xLup)) + (yLdown + dydxdown * (value - xLdown)));
138 }
139 
140 
141 // Read the information needed for the QuantileTransformer from a ROOT TTree
143 {
144  int sc(1);
145  // the reference bins to which the variables will be transformed to
146  double references;
147  sc = tree->SetBranchAddress("references", &references) == -5 ? 0 : 1;
148 
149  std::map<std::string, double> readVars;
150  for ( const auto& var : m_variables ){
151  sc = tree->SetBranchAddress(TString(var), &readVars[var]) == -5 ? 0 : 1;
152  }
153 
154  for (int ientry = 0; ientry < tree->GetEntries(); ientry++){
155  tree->GetEntry(ientry);
156  m_references.push_back(references);
157  for(uint ivar = 0; ivar < m_var_size; ++ivar){
158  m_quantiles.at(ivar).push_back(readVars[m_variables.at(ivar)]);
159  }
160  }
161  return sc;
162 }
AsgElectronSelectorTool
Electron selector tool to select signal electrons using the ElectronDNNCalculator retrieve a score ba...
Definition: AsgElectronSelectorTool.h:27
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
ElectronDNNCalculator::m_graph
std::unique_ptr< lwt::generic::FastGraph< float > > m_graph
DNN interface via lwtnn.
Definition: ElectronDNNCalculator.h:46
ChangeHistoRange.lowBound
lowBound
Definition: ChangeHistoRange.py:29
ElectronDNNCalculator::ElectronDNNCalculator
ElectronDNNCalculator(AsgElectronSelectorTool *owner, const std::string &modelFileName, const std::string &quantileFileName, const std::vector< std::string > &variablesName, const bool multiClass)
Constructor of the class.
Definition: ElectronDNNCalculator.cxx:26
ElectronDNNCalculator::transformInput
double transformInput(const std::vector< double > &quantiles, double value) const
transform the input variables according to a given QuantileTransformer.
Definition: ElectronDNNCalculator.cxx:110
checkCoolLatestUpdate.variables
variables
Definition: checkCoolLatestUpdate.py:13
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
ElectronDNNCalculator::m_references
std::vector< double > m_references
Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
Definition: ElectronDNNCalculator.h:51
ElectronDNNCalculator::m_multiClass
bool m_multiClass
Whether the used model is a multiclass model or not.
Definition: ElectronDNNCalculator.h:53
ElectronDNNCalculator::m_variables
std::vector< std::string > m_variables
Model variables.
Definition: ElectronDNNCalculator.h:55
tree
TChain * tree
Definition: tile_monitor.h:30
asg
Definition: DataHandleTestTool.h:28
athena.value
value
Definition: athena.py:124
references
std::vector< reference > references
Definition: hcg.cxx:522
AthenaPoolTestRead.sc
sc
Definition: AthenaPoolTestRead.py:27
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
uint
unsigned int uint
Definition: LArOFPhaseFill.cxx:20
ElectronDNNCalculator::calculate
Eigen::Matrix< float, -1, 1 > calculate(const std::vector< double > &) const
Get the prediction of the DNN model.
Definition: ElectronDNNCalculator.cxx:85
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
lumiFormat.i
int i
Definition: lumiFormat.py:85
ElectronDNNCalculator::readQuantileTransformer
int readQuantileTransformer(TTree *tree)
read the bins and values of the QuantileTransformer to transform the input variables.
Definition: ElectronDNNCalculator.cxx:142
RTTAlgmain.Matrix
list Matrix
Definition: RTTAlgmain.py:19
mc.order
order
Configure Herwig7.
Definition: mc.Herwig7_Dijet.py:12
plotBeamSpotCompare.ivar
int ivar
Definition: plotBeamSpotCompare.py:383
ElectronDNNCalculator.h
ElectronDNNCalculator::m_quantiles
std::vector< std::vector< double > > m_quantiles
Quantile values for each variable that needs to be transformed with the QuantileTransformer.
Definition: ElectronDNNCalculator.h:49
merge.output
output
Definition: merge.py:17
PathResolver.h
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
ElectronDNNCalculator::m_var_size
uint m_var_size
Definition: ElectronDNNCalculator.h:56
readCCLHist.float
float
Definition: readCCLHist.py:83