ATLAS Offline Software
Loading...
Searching...
No Matches
ElectronDNNCalculator.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3*/
4
13
16#include "TSystem.h"
17#include "TFile.h"
18#include <map>
19#include <fstream>
20#include <iostream>
21#include "lwtnn/parse_json.hh"
22#include "lwtnn/InputOrder.hh"
23#include <Eigen/Dense>
24
25
27 const std::string& modelFileName,
28 const std::string& quantileFileName,
29 const std::vector<std::string>& variables,
30 const bool multiClass) :
32 m_quantiles(variables.size()),
33 m_multiClass(multiClass),
34 m_variables(variables),
35 m_var_size(variables.size())
36{
37 ATH_MSG_INFO("Initializing ElectronDNNCalculator...");
38
39 if (modelFileName.empty()){
40 throw std::runtime_error("No file found at '" + modelFileName + "'");
41 }
42
43 // Make an input file object
44 std::ifstream inputFile;
45
46 // Open your trained model
47 ATH_MSG_INFO("Loading model: " << modelFileName.c_str());
48
49 // Create input order for the NN, the data needs to be passed in this exact order
50 lwt::InputOrder order;
51 order.scalar.emplace_back("node_0", m_variables );
52
53 // create the model
54 inputFile.open(modelFileName);
55 auto parsedGraph = lwt::parse_json_graph(inputFile);
56 // Test whether the number of outputs of the given network corresponds to the expected number
57 size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
58 if (nOutputs != 6 && nOutputs != 1){
59 throw std::runtime_error("Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
60 }
61 else if (nOutputs == 1 && m_multiClass){
62 throw std::runtime_error("Given model has 1 output but config file specifies mutliclass. Something is wrong");
63 }
64 else if (nOutputs == 6 && !m_multiClass){
65 throw std::runtime_error("Given model has 6 output but config file does not specify mutliclass. Something is wrong");
66 }
67
68 m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph, order);
69
70 if (quantileFileName.empty()){
71 throw std::runtime_error("No file found at '" + quantileFileName + "'");
72 }
73
74 // Open quantiletransformer file
75 ATH_MSG_INFO("Loading QuantileTransformer " << quantileFileName);
76 std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
77 if (readQuantileTransformer((TTree*)qtfile->Get("tree")) == 0){
78 throw std::runtime_error("Could not load all variables for the QuantileTransformer");
79
80 }
81}
82
83
84// takes the input variables, transforms them according to the given QuantileTransformer and predicts the DNN value(s)
85Eigen::Matrix<float, -1, 1> ElectronDNNCalculator::calculate( const std::vector<double>& variableValues ) const
86{
87
88 if(variableValues.size() != m_var_size)
89 throw std::runtime_error("Passed vector of variables has wrong size");
90
91 // Create the input for the model
92 Eigen::VectorXf inputVector(m_variables.size());
93
94 // This has to be in the same order as the InputOrder was defined
95 for(uint i = 0; i < m_var_size; ++i){
96 inputVector(i) = transformInput(m_quantiles.at(i), variableValues.at(i));
97 }
98
99 std::vector<Eigen::VectorXf> inp;
100 inp.emplace_back(std::move(inputVector));
101
102 auto output = m_graph->compute(inp);
103 return output;
104}
105
106
107// transform the input based on a QuantileTransformer. quantiles are bins in the variable, while references are bins from 0 to 1
108// The interpolation is done averaging the interpolation going from small to large bins and going from large to small bins
109// to deal with non strictly monotonic rising bins.
110double ElectronDNNCalculator::transformInput( const std::vector<double>& quantiles, double value ) const
111{
112 int size = quantiles.size();
113
114 // if given value is outside of range of the given quantiles return min (0) or max (1) of references
115 if (value <= quantiles[0]) return m_references[0];
116 if (value >= quantiles[size-1]) return m_references[size-1];
117
118 // find the bin where the value is smaller than the next bin (going from low bins to large bins)
119 auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(), value);
120 int lowBin = lowBound - quantiles.begin() - 1;
121
122 // get x and y values on left and right side from value while going up
123 double xLup = quantiles[lowBin], yLup = m_references[lowBin], xRup = quantiles[lowBin+1], yRup = m_references[lowBin+1];
124
125 // find the bin where the value is larger than the next bin (going from large bins to low bins)
126 auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(), value);
127 int upperBin = upperBound - quantiles.begin();
128
129 // get x and y values on left and right side from value while going down
130 double xRdown = quantiles[upperBin], yRdown = m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown = m_references[upperBin-1];
131
132 // calculate the gradients
133 double dydxup = ( yRup - yLup ) / ( xRup - xLup );
134 double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
135
136 // average linear interpolation of up and down case
137 return 0.5 * ((yLup + dydxup * (value - xLup)) + (yLdown + dydxdown * (value - xLdown)));
138}
139
140
141// Read the information needed for the QuantileTransformer from a ROOT TTree
143{
144 int sc(1);
145 // the reference bins to which the variables will be transformed to
146 double references;
147 sc = tree->SetBranchAddress("references", &references) == -5 ? 0 : 1;
148
149 std::map<std::string, double> readVars;
150 for ( const auto& var : m_variables ){
151 sc = tree->SetBranchAddress(TString(var), &readVars[var]) == -5 ? 0 : 1;
152 }
153
154 for (int ientry = 0; ientry < tree->GetEntries(); ientry++){
155 tree->GetEntry(ientry);
156 m_references.push_back(references);
157 for(uint ivar = 0; ivar < m_var_size; ++ivar){
158 m_quantiles.at(ivar).push_back(readVars[m_variables.at(ivar)]);
159 }
160 }
161 return sc;
162}
163
#define ATH_MSG_INFO(x)
unsigned int uint
static Double_t sc
Electron selector tool to select signal electrons using the ElectronDNNCalculator retrieve a score ba...
Eigen::Matrix< float, -1, 1 > calculate(const std::vector< double > &) const
Get the prediction of the DNN model.
std::vector< std::vector< double > > m_quantiles
Quantile values for each variable that needs to be transformed with the QuantileTransformer.
std::unique_ptr< lwt::generic::FastGraph< float > > m_graph
DNN interface via lwtnn.
double transformInput(const std::vector< double > &quantiles, double value) const
transform the input variables according to a given QuantileTransformer.
bool m_multiClass
Whether the used model is a multiclass model or not.
std::vector< std::string > m_variables
Model variables.
int readQuantileTransformer(TTree *tree)
read the bins and values of the QuantileTransformer to transform the input variables.
std::vector< double > m_references
Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
ElectronDNNCalculator(AsgElectronSelectorTool *owner, const std::string &modelFileName, const std::string &quantileFileName, const std::vector< std::string > &variablesName, const bool multiClass)
Constructor of the class.
AsgMessagingForward(T *owner)
forwarding constructor
std::vector< reference > references
Definition hcg.cxx:523
TChain * tree