![Logo](../../ATLAS-Logo-Square-Blue-RGB.png) |
ATLAS Offline Software
|
Used by AsgElectronSelectorTool to calculate the score of a python trained DNN using lwtnn as interface to do electron ID. Also applies a transformation to the input variables based on a QuantileTransformer.
More...
#include <ElectronDNNCalculator.h>
|
std::unique_ptr< lwt::generic::FastGraph< float > > | m_graph = 0 |
| DNN interface via lwtnn. More...
|
|
MVAEnum::QTVars | m_quantiles |
| Quantile values for each variable that needs to be transformed with the QuantileTransformer. More...
|
|
std::vector< double > | m_references |
| Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1. More...
|
|
bool | m_multiClass |
| Whether the used model is a multiclass model or not. More...
|
|
bool | m_newVars |
| Whether the model uses old or new set of variables. More...
|
|
std::function< MsgStream &()> | m_msg |
| the message stream we use More...
|
|
Used by AsgElectronSelectorTool to calculate the score of a python trained DNN using lwtnn as interface to do electron ID. Also applies a transformation to the input variables based on a QuantileTransformer.
- Author
- Lukas Ehrke
Definition at line 77 of file ElectronDNNCalculator.h.
◆ ElectronDNNCalculator()
ElectronDNNCalculator::ElectronDNNCalculator |
( |
AsgElectronSelectorTool * |
owner, |
|
|
const std::string & |
modelFileName, |
|
|
const std::string & |
quantileFileName, |
|
|
const std::vector< std::string > & |
variablesName, |
|
|
const bool |
multiClass, |
|
|
const bool |
newVars |
|
) |
| |
Constructor of the class.
Definition at line 26 of file ElectronDNNCalculator.cxx.
38 if (modelFileName.empty()){
39 throw std::runtime_error(
"No file found at '" + modelFileName +
"'");
46 ATH_MSG_INFO(
"Loading model: " << modelFileName.c_str());
49 lwt::InputOrder
order;
53 std::vector<std::string> inputVariables;
55 inputVariables = {
"d0significance",
"dPOverP",
56 "deltaEta1",
"deltaPhiRescaled2",
"trans_TRTPID",
57 "nPixHitsPlusDeadSensors",
"nSCTHitsPlusDeadSensors",
58 "EoverP",
"eta",
"et",
"Rhad1",
"Rhad",
"f3",
"f1",
59 "weta2",
"Rphi",
"Reta",
"Eratio",
"wtots1",
"SCTWeightedCharge",
"qd0"};
62 inputVariables = {
"d0",
"d0significance",
"dPOverP",
63 "deltaEta1",
"deltaPhiRescaled2",
"trans_TRTPID",
64 "nPixHitsPlusDeadSensors",
"nSCTHitsPlusDeadSensors",
65 "EoverP",
"eta",
"et",
"Rhad1",
"Rhad",
"f3",
"f1",
66 "weta2",
"Rphi",
"Reta",
"Eratio",
"wtots1"};
68 order.scalar.emplace_back(
"node_0", inputVariables );
74 size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
75 if (nOutputs != 6 && nOutputs != 1){
76 throw std::runtime_error(
"Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
79 throw std::runtime_error(
"Given model has 1 output but config file specifies mutliclass. Something is wrong");
82 throw std::runtime_error(
"Given model has 6 output but config file does not specify mutliclass. Something is wrong");
85 m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph,
order);
87 if (quantileFileName.empty()){
88 throw std::runtime_error(
"No file found at '" + quantileFileName +
"'");
92 ATH_MSG_INFO(
"Loading QuantileTransformer " << quantileFileName);
93 std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
95 throw std::runtime_error(
"Could not load all variables for the QuantileTransformer");
◆ ~ElectronDNNCalculator()
ElectronDNNCalculator::~ElectronDNNCalculator |
( |
| ) |
|
|
inline |
◆ calculate()
Get the prediction of the DNN model.
Definition at line 102 of file ElectronDNNCalculator.cxx.
105 Eigen::VectorXf inputVector(21);
155 std::vector<Eigen::VectorXf> inp;
156 inp.emplace_back(std::move(inputVector));
◆ msg() [1/2]
MsgStream & asg::AsgMessagingForward::msg |
( |
| ) |
const |
|
inherited |
The standard message stream.
- Returns
- A reference to the default message stream of this object.
Definition at line 24 of file AsgMessagingForward.cxx.
◆ msg() [2/2]
MsgStream & asg::AsgMessagingForward::msg |
( |
const MSG::Level |
lvl | ) |
const |
|
inherited |
The standard message stream.
- Parameters
-
lvl | The message level to set the stream to |
- Returns
- A reference to the default message stream, set to level "lvl"
Definition at line 29 of file AsgMessagingForward.cxx.
◆ msgLvl()
bool asg::AsgMessagingForward::msgLvl |
( |
const MSG::Level |
lvl | ) |
const |
|
inherited |
Test the output level of the object.
- Parameters
-
lvl | The message level to test against |
- Returns
- boolean Indicting if messages at given level will be printed
-
true
If messages at level "lvl" will be printed
Definition at line 11 of file AsgMessagingForward.cxx.
14 if (
msg.level() <= lvl)
◆ readQuantileTransformer()
int ElectronDNNCalculator::readQuantileTransformer |
( |
TTree * |
tree, |
|
|
const std::vector< std::string > & |
variables |
|
) |
| |
|
private |
read the bins and values of the QuantileTransformer to transform the input variables.
Definition at line 198 of file ElectronDNNCalculator.cxx.
205 std::map<std::string, double> readVars;
207 sc =
tree->SetBranchAddress(TString(
var), &readVars[
var]) == -5 ? 0 : 1;
209 for (
int i = 0;
i <
tree->GetEntries();
i++){
◆ transformInput()
double ElectronDNNCalculator::transformInput |
( |
const std::vector< double > & |
quantiles, |
|
|
double |
value |
|
) |
| const |
|
private |
transform the input variables according to a given QuantileTransformer.
Definition at line 166 of file ElectronDNNCalculator.cxx.
168 int size = quantiles.size();
175 auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(),
value);
176 int lowBin =
lowBound - quantiles.begin() - 1;
179 double xLup = quantiles[lowBin], yLup =
m_references[lowBin], xRup = quantiles[lowBin+1], yRup =
m_references[lowBin+1];
182 auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(),
value);
183 int upperBin = upperBound - quantiles.begin();
186 double xRdown = quantiles[upperBin], yRdown =
m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown =
m_references[upperBin-1];
189 double dydxup = ( yRup - yLup ) / ( xRup - xLup );
190 double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
193 return 0.5 * ((yLup + dydxup * (
value - xLup)) + (yLdown + dydxdown * (
value - xLdown)));
◆ m_graph
std::unique_ptr<lwt::generic::FastGraph<float> > ElectronDNNCalculator::m_graph = 0 |
|
private |
◆ m_msg
std::function<MsgStream& ()> asg::AsgMessagingForward::m_msg |
|
privateinherited |
the message stream we use
This used to be a simple pointer to the MsgStream
itself, but in AthenaMT the actual object used is local to the thread. So instead of pointing to it directly we are now using a function to look it up, which will get the thread-local object.
Definition at line 77 of file AsgMessagingForward.h.
◆ m_multiClass
bool ElectronDNNCalculator::m_multiClass |
|
private |
◆ m_newVars
bool ElectronDNNCalculator::m_newVars |
|
private |
◆ m_quantiles
Quantile values for each variable that needs to be transformed with the QuantileTransformer.
Definition at line 104 of file ElectronDNNCalculator.h.
◆ m_references
std::vector<double> ElectronDNNCalculator::m_references |
|
private |
Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
Definition at line 106 of file ElectronDNNCalculator.h.
The documentation for this class was generated from the following files:
std::vector< double > nSCTHitsPlusDeadSensors
std::function< MsgStream &()> m_msg
the message stream we use
std::vector< double > trans_TRTPID
std::vector< double > weta2
double nPixHitsPlusDeadSensors
std::unique_ptr< lwt::generic::FastGraph< float > > m_graph
DNN interface via lwtnn.
std::vector< double > Rhad1
double transformInput(const std::vector< double > &quantiles, double value) const
transform the input variables according to a given QuantileTransformer.
base class to forward messages to another class
std::vector< double > m_references
Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
bool m_multiClass
Whether the used model is a multiclass model or not.
bool m_newVars
Whether the model uses old or new set of variables.
std::vector< double > deltaEta1
std::vector< double > qd0
std::vector< reference > references
std::vector< double > EoverP
std::vector< double > SCTWeightedCharge
MsgStream & msg() const
The standard message stream.
std::vector< double > Rhad
std::vector< double > wtots1
std::vector< double > Reta
std::vector< double > eta
std::vector< double > dPOverP
std::vector< double > Rphi
std::vector< double > Eratio
std::vector< double > deltaPhiRescaled2
GraphConfig parse_json_graph(std::istream &json)
double nSCTHitsPlusDeadSensors
MVAEnum::QTVars m_quantiles
Quantile values for each variable that needs to be transformed with the QuantileTransformer.
std::vector< double > d0significance
std::vector< double > nPixHitsPlusDeadSensors
int readQuantileTransformer(TTree *tree, const std::vector< std::string > &variables)
read the bins and values of the QuantileTransformer to transform the input variables.