ATLAS Offline Software
ElectronDNNCalculator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 */
4 
14 #include "ElectronDNNCalculator.h"
16 #include "TSystem.h"
17 #include "TFile.h"
18 #include <map>
19 #include <fstream>
20 #include <iostream>
21 #include "lwtnn/parse_json.hh"
22 #include "lwtnn/InputOrder.hh"
23 #include <Eigen/Dense>
24 
25 
27  const std::string& modelFileName,
28  const std::string& quantileFileName,
29  const std::vector<std::string>& variables,
30  const bool multiClass,
31  const bool newVars) :
32  asg::AsgMessagingForward(owner),
33  m_multiClass(multiClass),
34  m_newVars(newVars)
35 {
36  ATH_MSG_INFO("Initializing ElectronDNNCalculator...");
37 
38  if (modelFileName.empty()){
39  throw std::runtime_error("No file found at '" + modelFileName + "'");
40  }
41 
42  // Make an input file object
43  std::ifstream inputFile;
44 
45  // Open your trained model
46  ATH_MSG_INFO("Loading model: " << modelFileName.c_str());
47 
48  // Create input order for the NN, the data needs to be passed in this exact order
49  lwt::InputOrder order;
50  // TODO: for latest DNN `inputVariables` has the same content
51  // as `variables` (including order), check if valid for
52  // old dnn and if yes use `variables` directly.
53  std::vector<std::string> inputVariables;
54  if(m_newVars){
55  inputVariables = {"d0significance", "dPOverP",
56  "deltaEta1", "deltaPhiRescaled2", "trans_TRTPID",
57  "nPixHitsPlusDeadSensors", "nSCTHitsPlusDeadSensors",
58  "EoverP", "eta", "et", "Rhad1", "Rhad", "f3", "f1",
59  "weta2", "Rphi", "Reta", "Eratio", "wtots1", "SCTWeightedCharge","qd0"};
60  }
61  else {
62  inputVariables = {"d0", "d0significance", "dPOverP",
63  "deltaEta1", "deltaPhiRescaled2", "trans_TRTPID",
64  "nPixHitsPlusDeadSensors", "nSCTHitsPlusDeadSensors",
65  "EoverP", "eta", "et", "Rhad1", "Rhad", "f3", "f1",
66  "weta2", "Rphi", "Reta", "Eratio", "wtots1"};
67  }
68  order.scalar.emplace_back("node_0", inputVariables );
69 
70  // create the model
71  inputFile.open(modelFileName);
72  auto parsedGraph = lwt::parse_json_graph(inputFile);
73  // Test whether the number of outputs of the given network corresponds to the expected number
74  size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
75  if (nOutputs != 6 && nOutputs != 1){
76  throw std::runtime_error("Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
77  }
78  else if (nOutputs == 1 && m_multiClass){
79  throw std::runtime_error("Given model has 1 output but config file specifies mutliclass. Something is wrong");
80  }
81  else if (nOutputs == 6 && !m_multiClass){
82  throw std::runtime_error("Given model has 6 output but config file does not specify mutliclass. Something is wrong");
83  }
84 
85  m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph, order);
86 
87  if (quantileFileName.empty()){
88  throw std::runtime_error("No file found at '" + quantileFileName + "'");
89  }
90 
91  // Open quantiletransformer file
92  ATH_MSG_INFO("Loading QuantileTransformer " << quantileFileName);
93  std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
94  if (readQuantileTransformer((TTree*)qtfile->Get("tree"), variables) == 0){
95  throw std::runtime_error("Could not load all variables for the QuantileTransformer");
96 
97  }
98 }
99 
100 
101 // takes the input variables, transforms them according to the given QuantileTransformer and predicts the DNN value(s)
103 {
104  // Create the input for the model
105  Eigen::VectorXf inputVector(21);
106 
107  // This has to be in the same order as the InputOrder was defined
108 
109  if(m_newVars){
110  inputVector(0) = transformInput( m_quantiles.d0significance, varsStruct.d0significance);
111  inputVector(1) = transformInput( m_quantiles.dPOverP, varsStruct.dPOverP);
112  inputVector(2) = transformInput( m_quantiles.deltaEta1, varsStruct.deltaEta1);
113  inputVector(3) = transformInput( m_quantiles.deltaPhiRescaled2, varsStruct.deltaPhiRescaled2);
114  inputVector(4) = transformInput( m_quantiles.trans_TRTPID, varsStruct.trans_TRTPID);
117  inputVector(7) = transformInput( m_quantiles.EoverP, varsStruct.EoverP);
118  inputVector(8) = transformInput( m_quantiles.eta, varsStruct.eta);
119  inputVector(9) = transformInput( m_quantiles.et, varsStruct.et);
120  inputVector(10) = transformInput( m_quantiles.Rhad1, varsStruct.Rhad1);
121  inputVector(11) = transformInput( m_quantiles.Rhad, varsStruct.Rhad);
122  inputVector(12) = transformInput( m_quantiles.f3, varsStruct.f3);
123  inputVector(13) = transformInput( m_quantiles.f1, varsStruct.f1);
124  inputVector(14) = transformInput( m_quantiles.weta2, varsStruct.weta2);
125  inputVector(15) = transformInput( m_quantiles.Rphi, varsStruct.Rphi);
126  inputVector(16) = transformInput( m_quantiles.Reta, varsStruct.Reta);
127  inputVector(17) = transformInput( m_quantiles.Eratio, varsStruct.Eratio);
128  inputVector(18) = transformInput( m_quantiles.wtots1, varsStruct.wtots1);
129  inputVector(19) = transformInput( m_quantiles.SCTWeightedCharge, varsStruct.SCTWeightedCharge);
130  inputVector(20) = transformInput( m_quantiles.qd0, varsStruct.qd0);
131  }
132  else {
133  inputVector(0) = transformInput( m_quantiles.d0, varsStruct.d0);
134  inputVector(1) = transformInput( m_quantiles.d0significance, varsStruct.d0significance);
135  inputVector(2) = transformInput( m_quantiles.dPOverP, varsStruct.dPOverP);
136  inputVector(3) = transformInput( m_quantiles.deltaEta1, varsStruct.deltaEta1);
137  inputVector(4) = transformInput( m_quantiles.deltaPhiRescaled2, varsStruct.deltaPhiRescaled2);
138  inputVector(5) = transformInput( m_quantiles.trans_TRTPID, varsStruct.trans_TRTPID);
141  inputVector(8) = transformInput( m_quantiles.EoverP, varsStruct.EoverP);
142  inputVector(9) = transformInput( m_quantiles.eta, varsStruct.eta);
143  inputVector(10) = transformInput( m_quantiles.et, varsStruct.et);
144  inputVector(11) = transformInput( m_quantiles.Rhad1, varsStruct.Rhad1);
145  inputVector(12) = transformInput( m_quantiles.Rhad, varsStruct.Rhad);
146  inputVector(13) = transformInput( m_quantiles.f3, varsStruct.f3);
147  inputVector(14) = transformInput( m_quantiles.f1, varsStruct.f1);
148  inputVector(15) = transformInput( m_quantiles.weta2, varsStruct.weta2);
149  inputVector(16) = transformInput( m_quantiles.Rphi, varsStruct.Rphi);
150  inputVector(17) = transformInput( m_quantiles.Reta, varsStruct.Reta);
151  inputVector(18) = transformInput( m_quantiles.Eratio, varsStruct.Eratio);
152  inputVector(19) = transformInput( m_quantiles.wtots1, varsStruct.wtots1);
153  }
154 
155  std::vector<Eigen::VectorXf> inp;
156  inp.emplace_back(std::move(inputVector));
157 
158  auto output = m_graph->compute(inp);
159  return output;
160 }
161 
162 
163 // transform the input based on a QuantileTransformer. quantiles are bins in the variable, while references are bins from 0 to 1
164 // The interpolation is done averaging the interpolation going from small to large bins and going from large to small bins
165 // to deal with non strictly monotonic rising bins.
166 double ElectronDNNCalculator::transformInput( const std::vector<double>& quantiles, double value ) const
167 {
168  int size = quantiles.size();
169 
170  // if given value is outside of range of the given quantiles return min (0) or max (1) of references
171  if (value <= quantiles[0]) return m_references[0];
172  if (value >= quantiles[size-1]) return m_references[size-1];
173 
174  // find the bin where the value is smaller than the next bin (going from low bins to large bins)
175  auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(), value);
176  int lowBin = lowBound - quantiles.begin() - 1;
177 
178  // get x and y values on left and right side from value while going up
179  double xLup = quantiles[lowBin], yLup = m_references[lowBin], xRup = quantiles[lowBin+1], yRup = m_references[lowBin+1];
180 
181  // find the bin where the value is larger than the next bin (going from large bins to low bins)
182  auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(), value);
183  int upperBin = upperBound - quantiles.begin();
184 
185  // get x and y values on left and right side from value while going down
186  double xRdown = quantiles[upperBin], yRdown = m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown = m_references[upperBin-1];
187 
188  // calculate the gradients
189  double dydxup = ( yRup - yLup ) / ( xRup - xLup );
190  double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
191 
192  // average linear interpolation of up and down case
193  return 0.5 * ((yLup + dydxup * (value - xLup)) + (yLdown + dydxdown * (value - xLdown)));
194 }
195 
196 
197 // Read the information needed for the QuantileTransformer from a ROOT TTree
198 int ElectronDNNCalculator::readQuantileTransformer( TTree* tree, const std::vector<std::string>& variables )
199 {
200  int sc(1);
201  // the reference bins to which the variables will be transformed to
202  double references;
203  sc = tree->SetBranchAddress("references", &references) == -5 ? 0 : 1;
204 
205  std::map<std::string, double> readVars;
206  for ( const auto& var : variables ){
207  sc = tree->SetBranchAddress(TString(var), &readVars[var]) == -5 ? 0 : 1;
208  }
209  for (int i = 0; i < tree->GetEntries(); i++){
210  tree->GetEntry(i);
211  m_references.push_back(references);
212  m_quantiles.d0significance.push_back(readVars["d0significance"]);
213  m_quantiles.dPOverP.push_back(readVars["dPOverP"]);
214  m_quantiles.deltaEta1.push_back(readVars["deltaEta1"]);
215  m_quantiles.deltaPhiRescaled2.push_back(readVars["deltaPhiRescaled2"]);
216  m_quantiles.trans_TRTPID.push_back(readVars["trans_TRTPID"]);
217  m_quantiles.nPixHitsPlusDeadSensors.push_back(readVars["nPixHitsPlusDeadSensors"]);
218  m_quantiles.nSCTHitsPlusDeadSensors.push_back(readVars["nSCTHitsPlusDeadSensors"]);
219  m_quantiles.EoverP.push_back(readVars["EoverP"]);
220  m_quantiles.eta.push_back(readVars["eta"]);
221  m_quantiles.et.push_back(readVars["et"]);
222  m_quantiles.Rhad1.push_back(readVars["Rhad1"]);
223  m_quantiles.Rhad.push_back(readVars["Rhad"]);
224  m_quantiles.f3.push_back(readVars["f3"]);
225  m_quantiles.f1.push_back(readVars["f1"]);
226  m_quantiles.weta2.push_back(readVars["weta2"]);
227  m_quantiles.Rphi.push_back(readVars["Rphi"]);
228  m_quantiles.Reta.push_back(readVars["Reta"]);
229  m_quantiles.Eratio.push_back(readVars["Eratio"]);
230  m_quantiles.wtots1.push_back(readVars["wtots1"]);
231  if(m_newVars){
232  m_quantiles.SCTWeightedCharge.push_back(readVars["SCTWeightedCharge"]);
233  m_quantiles.qd0.push_back(readVars["qd0"]);
234  }
235  else {
236  m_quantiles.d0.push_back(readVars["d0"]);
237  }
238  }
239  return sc;
240 }
AsgElectronSelectorTool
Electron selector tool to select signal electrons using the ElectronDNNCalculator retrieve a score ba...
Definition: AsgElectronSelectorTool.h:27
MVAEnum::QTVars::nSCTHitsPlusDeadSensors
std::vector< double > nSCTHitsPlusDeadSensors
Definition: ElectronDNNCalculator.h:70
MVAEnum::QTVars::trans_TRTPID
std::vector< double > trans_TRTPID
Definition: ElectronDNNCalculator.h:66
MVAEnum::QTVars::weta2
std::vector< double > weta2
Definition: ElectronDNNCalculator.h:56
MVAEnum::MVACalcVars::nPixHitsPlusDeadSensors
double nPixHitsPlusDeadSensors
Definition: ElectronDNNCalculator.h:44
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
ElectronDNNCalculator::m_graph
std::unique_ptr< lwt::generic::FastGraph< float > > m_graph
DNN interface via lwtnn.
Definition: ElectronDNNCalculator.h:101
MVAEnum::QTVars::Rhad1
std::vector< double > Rhad1
Definition: ElectronDNNCalculator.h:54
ChangeHistoRange.lowBound
lowBound
Definition: ChangeHistoRange.py:29
MVAEnum::MVACalcVars::f3
double f3
Definition: ElectronDNNCalculator.h:27
ElectronDNNCalculator::transformInput
double transformInput(const std::vector< double > &quantiles, double value) const
transform the input variables according to a given QuantileTransformer.
Definition: ElectronDNNCalculator.cxx:166
checkCoolLatestUpdate.variables
variables
Definition: checkCoolLatestUpdate.py:13
MVAEnum::QTVars::et
std::vector< double > et
Definition: ElectronDNNCalculator.h:51
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
ElectronDNNCalculator::m_references
std::vector< double > m_references
Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
Definition: ElectronDNNCalculator.h:106
MVAEnum::MVACalcVars::et
double et
Definition: ElectronDNNCalculator.h:26
MVAEnum::MVACalcVars::Reta
double Reta
Definition: ElectronDNNCalculator.h:30
MVAEnum::MVACalcVars::weta2
double weta2
Definition: ElectronDNNCalculator.h:31
ElectronDNNCalculator::m_multiClass
bool m_multiClass
Whether the used model is a multiclass model or not.
Definition: ElectronDNNCalculator.h:108
MVAEnum::MVACalcVars::d0significance
double d0significance
Definition: ElectronDNNCalculator.h:37
ElectronDNNCalculator::m_newVars
bool m_newVars
Whether the model uses old or new set of variables.
Definition: ElectronDNNCalculator.h:110
tree
TChain * tree
Definition: tile_monitor.h:30
asg
Definition: DataHandleTestTool.h:28
MVAEnum::QTVars::deltaEta1
std::vector< double > deltaEta1
Definition: ElectronDNNCalculator.h:59
MVAEnum::QTVars::d0
std::vector< double > d0
Definition: ElectronDNNCalculator.h:60
MVAEnum::MVACalcVars::d0
double d0
Definition: ElectronDNNCalculator.h:35
MVAEnum::QTVars::qd0
std::vector< double > qd0
Definition: ElectronDNNCalculator.h:61
athena.value
value
Definition: athena.py:122
MVAEnum::MVACalcVars::dPOverP
double dPOverP
Definition: ElectronDNNCalculator.h:39
references
std::vector< reference > references
Definition: hcg.cxx:522
MVAEnum::QTVars::EoverP
std::vector< double > EoverP
Definition: ElectronDNNCalculator.h:68
MVAEnum::QTVars::SCTWeightedCharge
std::vector< double > SCTWeightedCharge
Definition: ElectronDNNCalculator.h:71
MVAEnum::QTVars::Rhad
std::vector< double > Rhad
Definition: ElectronDNNCalculator.h:53
AthenaPoolTestRead.sc
sc
Definition: AthenaPoolTestRead.py:27
MVAEnum::MVACalcVars::Rphi
double Rphi
Definition: ElectronDNNCalculator.h:38
MVAEnum::QTVars::wtots1
std::vector< double > wtots1
Definition: ElectronDNNCalculator.h:67
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
MVAEnum::QTVars::Reta
std::vector< double > Reta
Definition: ElectronDNNCalculator.h:55
MVAEnum::QTVars::f1
std::vector< double > f1
Definition: ElectronDNNCalculator.h:57
MVAEnum::QTVars::eta
std::vector< double > eta
Definition: ElectronDNNCalculator.h:50
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
MVAEnum::MVACalcVars
Definition: ElectronDNNCalculator.h:24
lumiFormat.i
int i
Definition: lumiFormat.py:92
MVAEnum::MVACalcVars::eta
double eta
Definition: ElectronDNNCalculator.h:25
MVAEnum::QTVars::dPOverP
std::vector< double > dPOverP
Definition: ElectronDNNCalculator.h:64
RTTAlgmain.Matrix
list Matrix
Definition: RTTAlgmain.py:19
mc.order
order
Configure Herwig7.
Definition: mc.Herwig7_Dijet.py:12
MVAEnum::QTVars::Rphi
std::vector< double > Rphi
Definition: ElectronDNNCalculator.h:63
ElectronDNNCalculator.h
MVAEnum::MVACalcVars::qd0
double qd0
Definition: ElectronDNNCalculator.h:36
ElectronDNNCalculator::ElectronDNNCalculator
ElectronDNNCalculator(AsgElectronSelectorTool *owner, const std::string &modelFileName, const std::string &quantileFileName, const std::vector< std::string > &variablesName, const bool multiClass, const bool newVars)
Constructor of the class.
Definition: ElectronDNNCalculator.cxx:26
MVAEnum::MVACalcVars::Eratio
double Eratio
Definition: ElectronDNNCalculator.h:33
merge.output
output
Definition: merge.py:17
MVAEnum::MVACalcVars::wtots1
double wtots1
Definition: ElectronDNNCalculator.h:42
MVAEnum::MVACalcVars::deltaEta1
double deltaEta1
Definition: ElectronDNNCalculator.h:34
MVAEnum::MVACalcVars::f1
double f1
Definition: ElectronDNNCalculator.h:32
PathResolver.h
MVAEnum::MVACalcVars::Rhad1
double Rhad1
Definition: ElectronDNNCalculator.h:29
MVAEnum::MVACalcVars::Rhad
double Rhad
Definition: ElectronDNNCalculator.h:28
MVAEnum::QTVars::Eratio
std::vector< double > Eratio
Definition: ElectronDNNCalculator.h:58
MVAEnum::QTVars::deltaPhiRescaled2
std::vector< double > deltaPhiRescaled2
Definition: ElectronDNNCalculator.h:65
MVAEnum::MVACalcVars::trans_TRTPID
double trans_TRTPID
Definition: ElectronDNNCalculator.h:41
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
MVAEnum::QTVars::f3
std::vector< double > f3
Definition: ElectronDNNCalculator.h:52
MVAEnum::MVACalcVars::EoverP
double EoverP
Definition: ElectronDNNCalculator.h:43
MVAEnum::MVACalcVars::nSCTHitsPlusDeadSensors
double nSCTHitsPlusDeadSensors
Definition: ElectronDNNCalculator.h:45
ElectronDNNCalculator::m_quantiles
MVAEnum::QTVars m_quantiles
Quantile values for each variable that needs to be transformed with the QuantileTransformer.
Definition: ElectronDNNCalculator.h:104
MVAEnum::MVACalcVars::deltaPhiRescaled2
double deltaPhiRescaled2
Definition: ElectronDNNCalculator.h:40
MVAEnum::MVACalcVars::SCTWeightedCharge
double SCTWeightedCharge
Definition: ElectronDNNCalculator.h:46
MVAEnum::QTVars::d0significance
std::vector< double > d0significance
Definition: ElectronDNNCalculator.h:62
MVAEnum::QTVars::nPixHitsPlusDeadSensors
std::vector< double > nPixHitsPlusDeadSensors
Definition: ElectronDNNCalculator.h:69
ElectronDNNCalculator::calculate
Eigen::Matrix< float, -1, 1 > calculate(const MVAEnum::MVACalcVars &varsStruct) const
Get the prediction of the DNN model.
Definition: ElectronDNNCalculator.cxx:102
ElectronDNNCalculator::readQuantileTransformer
int readQuantileTransformer(TTree *tree, const std::vector< std::string > &variables)
read the bins and values of the QuantileTransformer to transform the input variables.
Definition: ElectronDNNCalculator.cxx:198
readCCLHist.float
float
Definition: readCCLHist.py:83