21 #include "lwtnn/parse_json.hh"
22 #include "lwtnn/InputOrder.hh"
23 #include <Eigen/Dense>
27 const std::string& modelFileName,
28 const std::string& quantileFileName,
29 const std::vector<std::string>&
variables,
30 const bool multiClass,
32 asg::AsgMessagingForward(owner),
33 m_multiClass(multiClass),
38 if (modelFileName.empty()){
39 throw std::runtime_error(
"No file found at '" + modelFileName +
"'");
46 ATH_MSG_INFO(
"Loading model: " << modelFileName.c_str());
49 lwt::InputOrder
order;
53 std::vector<std::string> inputVariables;
55 inputVariables = {
"d0significance",
"dPOverP",
56 "deltaEta1",
"deltaPhiRescaled2",
"trans_TRTPID",
57 "nPixHitsPlusDeadSensors",
"nSCTHitsPlusDeadSensors",
58 "EoverP",
"eta",
"et",
"Rhad1",
"Rhad",
"f3",
"f1",
59 "weta2",
"Rphi",
"Reta",
"Eratio",
"wtots1",
"SCTWeightedCharge",
"qd0"};
62 inputVariables = {
"d0",
"d0significance",
"dPOverP",
63 "deltaEta1",
"deltaPhiRescaled2",
"trans_TRTPID",
64 "nPixHitsPlusDeadSensors",
"nSCTHitsPlusDeadSensors",
65 "EoverP",
"eta",
"et",
"Rhad1",
"Rhad",
"f3",
"f1",
66 "weta2",
"Rphi",
"Reta",
"Eratio",
"wtots1"};
68 order.scalar.emplace_back(
"node_0", inputVariables );
74 size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
75 if (nOutputs != 6 && nOutputs != 1){
76 throw std::runtime_error(
"Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
79 throw std::runtime_error(
"Given model has 1 output but config file specifies mutliclass. Something is wrong");
82 throw std::runtime_error(
"Given model has 6 output but config file does not specify mutliclass. Something is wrong");
85 m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph,
order);
87 if (quantileFileName.empty()){
88 throw std::runtime_error(
"No file found at '" + quantileFileName +
"'");
92 ATH_MSG_INFO(
"Loading QuantileTransformer " << quantileFileName);
93 std::unique_ptr<TFile> qtfile(TFile::Open(quantileFileName.data()));
95 throw std::runtime_error(
"Could not load all variables for the QuantileTransformer");
105 Eigen::VectorXf inputVector(21);
155 std::vector<Eigen::VectorXf> inp;
156 inp.emplace_back(std::move(inputVector));
168 int size = quantiles.size();
175 auto lowBound = std::lower_bound(quantiles.begin(), quantiles.end(),
value);
176 int lowBin =
lowBound - quantiles.begin() - 1;
179 double xLup = quantiles[lowBin], yLup =
m_references[lowBin], xRup = quantiles[lowBin+1], yRup =
m_references[lowBin+1];
182 auto upperBound = std::upper_bound(quantiles.begin(), quantiles.end(),
value);
183 int upperBin = upperBound - quantiles.begin();
186 double xRdown = quantiles[upperBin], yRdown =
m_references[upperBin], xLdown = quantiles[upperBin-1], yLdown =
m_references[upperBin-1];
189 double dydxup = ( yRup - yLup ) / ( xRup - xLup );
190 double dydxdown = ( yRdown - yLdown ) / ( xRdown - xLdown );
193 return 0.5 * ((yLup + dydxup * (
value - xLup)) + (yLdown + dydxdown * (
value - xLdown)));
205 std::map<std::string, double> readVars;
207 sc =
tree->SetBranchAddress(TString(
var), &readVars[
var]) == -5 ? 0 : 1;
209 for (
int i = 0;
i <
tree->GetEntries();
i++){