21 #include "lwtnn/parse_json.hh"
22 #include "lwtnn/InputOrder.hh"
23 #include <Eigen/Dense>
27 const std::string& modelFileName,
28 const std::string& quantileFileName,
29 const std::vector<std::string>&
variables,
30 const bool multiClass) :
31 asg::AsgMessagingForward(owner),
33 m_multiClass(multiClass),
39 if (modelFileName.empty()){
40 throw std::runtime_error(
"No file found at '" + modelFileName +
"'");
47 ATH_MSG_INFO(
"Loading model: " << modelFileName.c_str());
50 lwt::InputOrder
order;
57 size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
58 if (nOutputs != 6 && nOutputs != 1){
59 throw std::runtime_error(
"Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
62 throw std::runtime_error(
"Given model has 1 output but config file specifies mutliclass. Something is wrong");
65 throw std::runtime_error(
"Given model has 6 output but config file does not specify mutliclass. Something is wrong");
68 m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph,
order);
70 if (quantileFileName.empty()){
71 throw std::runtime_error(
"No file found at '" + quantileFileName +
"'");
75 ATH_MSG_INFO(
"Loading QuantileTransformer " << quantileFileName);
76 std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
78 throw std::runtime_error(
"Could not load all variables for the QuantileTransformer");
89 throw std::runtime_error(
"Passed vector of variables has wrong size");
99 std::vector<Eigen::VectorXf> inp;
100 inp.emplace_back(std::move(inputVector));
112 int size = quantiles.size();
119 auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(),
value);
120 int lowBin =
lowBound - quantiles.begin() - 1;
123 double xLup = quantiles[lowBin], yLup =
m_references[lowBin], xRup = quantiles[lowBin+1], yRup =
m_references[lowBin+1];
126 auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(),
value);
127 int upperBin = upperBound - quantiles.begin();
130 double xRdown = quantiles[upperBin], yRdown =
m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown =
m_references[upperBin-1];
133 double dydxup = ( yRup - yLup ) / ( xRup - xLup );
134 double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
137 return 0.5 * ((yLup + dydxup * (
value - xLup)) + (yLdown + dydxdown * (
value - xLdown)));
149 std::map<std::string, double> readVars;
151 sc =
tree->SetBranchAddress(TString(
var), &readVars[
var]) == -5 ? 0 : 1;
154 for (
int ientry = 0; ientry <
tree->GetEntries(); ientry++){
155 tree->GetEntry(ientry);