21#include "lwtnn/parse_json.hh"
22#include "lwtnn/InputOrder.hh"
27 const std::string& modelFileName,
28 const std::string& quantileFileName,
29 const std::vector<std::string>& variables,
30 const bool multiClass) :
39 if (modelFileName.empty()){
40 throw std::runtime_error(
"No file found at '" + modelFileName +
"'");
44 std::ifstream inputFile;
47 ATH_MSG_INFO(
"Loading model: " << modelFileName.c_str());
50 lwt::InputOrder order;
54 inputFile.open(modelFileName);
55 auto parsedGraph = lwt::parse_json_graph(inputFile);
57 size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
58 if (nOutputs != 6 && nOutputs != 1){
59 throw std::runtime_error(
"Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
62 throw std::runtime_error(
"Given model has 1 output but config file specifies mutliclass. Something is wrong");
65 throw std::runtime_error(
"Given model has 6 output but config file does not specify mutliclass. Something is wrong");
68 m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph, order);
70 if (quantileFileName.empty()){
71 throw std::runtime_error(
"No file found at '" + quantileFileName +
"'");
75 ATH_MSG_INFO(
"Loading QuantileTransformer " << quantileFileName);
76 std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
78 throw std::runtime_error(
"Could not load all variables for the QuantileTransformer");
89 throw std::runtime_error(
"Passed vector of variables has wrong size");
99 std::vector<Eigen::VectorXf> inp;
100 inp.emplace_back(std::move(inputVector));
102 auto output =
m_graph->compute(inp);
112 int size = quantiles.size();
116 if (value >= quantiles[size-1])
return m_references[size-1];
119 auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(), value);
120 int lowBin = lowBound - quantiles.begin() - 1;
123 double xLup = quantiles[lowBin], yLup =
m_references[lowBin], xRup = quantiles[lowBin+1], yRup =
m_references[lowBin+1];
126 auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(), value);
127 int upperBin = upperBound - quantiles.begin();
130 double xRdown = quantiles[upperBin], yRdown =
m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown =
m_references[upperBin-1];
133 double dydxup = ( yRup - yLup ) / ( xRup - xLup );
134 double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
137 return 0.5 * ((yLup + dydxup * (value - xLup)) + (yLdown + dydxdown * (value - xLdown)));
149 std::map<std::string, double> readVars;
151 sc =
tree->SetBranchAddress(TString(var), &readVars[var]) == -5 ? 0 : 1;
154 for (
int ientry = 0; ientry <
tree->GetEntries(); ientry++){
155 tree->GetEntry(ientry);
Eigen::Matrix< float, -1, 1 > calculate(const std::vector< double > &) const
Get the prediction of the DNN model.
std::vector< std::vector< double > > m_quantiles
Quantile values for each variable that needs to be transformed with the QuantileTransformer.
std::unique_ptr< lwt::generic::FastGraph< float > > m_graph
DNN interface via lwtnn.
double transformInput(const std::vector< double > &quantiles, double value) const
transform the input variables according to a given QuantileTransformer.
bool m_multiClass
Whether the used model is a multiclass model or not.
std::vector< std::string > m_variables
Model variables.
int readQuantileTransformer(TTree *tree)
read the bins and values of the QuantileTransformer to transform the input variables.
std::vector< double > m_references
Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
ElectronDNNCalculator(AsgElectronSelectorTool *owner, const std::string &modelFileName, const std::string &quantileFileName, const std::vector< std::string > &variablesName, const bool multiClass)
Constructor of the class.
AsgMessagingForward(T *owner)
forwarding constructor
std::vector< reference > references