17 const std::string jetLinkName =
"jetLink";
19 auto getOnnxUtil(
const std::string& nn_file) {
22 return std::make_shared<const OnnxUtil>(fullPathToOnnxFile);
29 GNN(getOnnxUtil(nn_file), o)
40 m_jetLink(jetLinkName),
41 m_defaultValue(o.default_output_value)
50 for (
auto config : constituents_configs){
68 OnnxUtil::OutputConfig gnn_output_config =
m_onnxUtil->getOutputConfig();
77 rd.merge(loader->getUsedRemap());
84 const std::map<std::string, std::string>&
remap,
97 if (!jetLink.isValid()) {
98 throw std::runtime_error(
"invalid jetLink");
117 dec.second(
jet) = {};
120 dec.second(
jet) = {};
123 dec.second(
jet) = {};
130 using namespace internal;
134 std::map<std::string, Inputs> gnn_inputs;
137 std::vector<float> jet_feat;
139 jet_feat.push_back(getter(btag).
second);
144 std::vector<int64_t> jet_feat_dim = {1,
static_cast<int64_t
>(jet_feat.size())};
145 Inputs jet_info(jet_feat, jet_feat_dim);
147 gnn_inputs.insert({
"jets", jet_info});
149 gnn_inputs.insert({
"jet_features", jet_info});
155 auto [input_name, input_data, input_objects] = loader->getData(
jet, btag);
157 input_name.pop_back();
158 input_name.append(
"_features");
160 gnn_inputs.insert({input_name, input_data});
165 for (
auto constituent : input_objects){
173 auto [out_f, out_vc, out_vf] =
m_onnxUtil->runInference(gnn_inputs);
181 if (out_vf.at(dec.first).size() != 1){
182 throw std::logic_error(
"expected vectors of length 1 for float decorators");
184 dec.second(btag) = out_vf.at(dec.first).at(0);
191 dec.second(btag) = out_f.at(dec.first);
195 dec.second(btag) = out_vc.at(dec.first);
198 dec.second(btag) = out_vf.at(dec.first);
205 TrackLinks::value_type link;
208 link.toIndexedElement(*itc,
it->index());
209 links.push_back(link);
211 dec.second(btag) =
links;
215 throw std::logic_error(
"unsupported ONNX metadata version");
230 std::tuple<FTagDataDependencyNames, std::set<std::string>>
235 std::map<std::string, std::string>
remap =
options.remap_scalar;
236 std::set<std::string> usedRemap;
240 std::string context =
"building negative tag b-btagger";
242 for (
const auto& outNode : outConfig) {
244 std::string dec_name = outNode.name;
258 switch (outNode.type) {
269 throw std::logic_error(
"Unknown output data type");
281 return std::make_tuple(deps, usedRemap);