ATLAS Offline Software
GNN.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
10 
11 #include "xAODBTagging/BTagging.h"
12 #include "xAODJet/JetContainer.h"
13 
15 
16 namespace {
17  const std::string jetLinkName = "jetLink";
18 
19  auto getOnnxUtil(const std::string& nn_file) {
20  using namespace FlavorTagDiscriminants;
21  std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_file);
22  return std::make_shared<const OnnxUtil>(fullPathToOnnxFile);
23  }
24 }
25 
26 namespace FlavorTagDiscriminants {
27 
28  GNN::GNN(const std::string& nn_file, const GNNOptions& o):
29  GNN(getOnnxUtil(nn_file), o)
30  {
31  }
32 
33  GNN::GNN(const GNN& old, const GNNOptions& o):
34  GNN(old.m_onnxUtil, o)
35  {
36  }
37 
38  GNN::GNN(std::shared_ptr<const OnnxUtil> util, const GNNOptions& o):
39  m_onnxUtil(util),
40  m_jetLink(jetLinkName),
41  m_defaultValue(o.default_output_value)
42  {
43 
44  // Extract metadata from the ONNX file, primarily about the model's inputs.
45  auto lwt_config = m_onnxUtil->getLwtConfig();
46 
47  // Create configuration objects for data preprocessing.
48  auto [inputs, constituents_configs, options] = dataprep::createGetterConfig(
49  lwt_config, o.flip_config, o.variable_remapping, o.track_link_type);
50  for (auto config : constituents_configs){
51  switch (config.type){
53  m_constituentsLoaders.push_back(std::make_shared<TracksLoader>(config, options));
54  break;
56  m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(config, options));
57  break;
58  }
59  }
60 
61  // Initialize jet and b-tagging input getters.
62  auto [vb, vj, ds] = dataprep::createBvarGetters(inputs);
63  m_varsFromBTag = vb;
64  m_varsFromJet = vj;
66 
67  // Retrieve the configuration for the model outputs.
68  OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
69 
70  // Create the output decorators.
71  auto [dd, rd] = createDecorators(gnn_output_config, options);
73 
74  // Update dependencies and used remap from the constituents loaders.
75  for (auto loader : m_constituentsLoaders){
76  m_dataDependencyNames += loader->getDependencies();
77  rd.merge(loader->getUsedRemap());
78  }
79  dataprep::checkForUnusedRemaps(options.remap_scalar, rd);
80  }
81 
82  GNN::GNN(const std::string& file,
83  const FlipTagConfig& flip,
84  const std::map<std::string, std::string>& remap,
85  const TrackLinkType link_type,
86  float def_out_val):
87  GNN( file, GNNOptions { flip, remap, link_type, def_out_val} )
88  {}
89 
90  GNN::GNN(GNN&&) = default;
91  GNN::GNN(const GNN&) = default;
92  GNN::~GNN() = default;
93 
94  void GNN::decorate(const xAOD::BTagging& btag) const {
95  /* tag a b-tagging object */
96  auto jetLink = m_jetLink(btag);
97  if (!jetLink.isValid()) {
98  throw std::runtime_error("invalid jetLink");
99  }
100  const xAOD::Jet& jet = **jetLink;
101  decorate(jet, btag);
102  }
103 
104  void GNN::decorate(const xAOD::Jet& jet) const {
105  /* tag a jet */
106  decorate(jet, jet);
107  }
108 
110  for (const auto& dec: m_decorators.jetFloat) {
111  dec.second(jet) = m_defaultValue;
112  }
113  // for some networks we need to set a lot of empty vectors as well
114  if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V1) {
115  // vector outputs, e.g. track predictions
116  for (const auto& dec: m_decorators.jetVecChar) {
117  dec.second(jet) = {};
118  }
119  for (const auto& dec: m_decorators.jetVecFloat) {
120  dec.second(jet) = {};
121  }
122  for (const auto& dec: m_decorators.jetTrackLinks) {
123  dec.second(jet) = {};
124  }
125  }
126  }
127 
128  void GNN::decorate(const xAOD::Jet& jet, const SG::AuxElement& btag) const {
129  /* Main function for decorating a jet or b-tagging object with GNN outputs. */
130  using namespace internal;
131 
132  // prepare input
133  // -------------
134  std::map<std::string, Inputs> gnn_inputs;
135 
136  // jet level inputs
137  std::vector<float> jet_feat;
138  for (const auto& getter: m_varsFromBTag) {
139  jet_feat.push_back(getter(btag).second);
140  }
141  for (const auto& getter: m_varsFromJet) {
142  jet_feat.push_back(getter(jet).second);
143  }
144  std::vector<int64_t> jet_feat_dim = {1, static_cast<int64_t>(jet_feat.size())};
145  Inputs jet_info(jet_feat, jet_feat_dim);
146  if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V2) {
147  gnn_inputs.insert({"jets", jet_info});
148  } else {
149  gnn_inputs.insert({"jet_features", jet_info});
150  }
151 
152  // constituent level inputs
153  Tracks input_tracks;
154  for (auto loader : m_constituentsLoaders){
155  auto [input_name, input_data, input_objects] = loader->getData(jet, btag);
156  if (m_onnxUtil->getOnnxModelVersion() != OnnxModelVersion::V2) {
157  input_name.pop_back();
158  input_name.append("_features");
159  }
160  gnn_inputs.insert({input_name, input_data});
161 
162  // for now we only collect tracks for aux task decoration
163  // they have to be converted back from IParticle to TrackParticle first
164  if (loader->getType() == ConstituentsType::TRACK){
165  for (auto constituent : input_objects){
166  input_tracks.push_back(dynamic_cast<const xAOD::TrackParticle*>(constituent));
167  }
168  }
169  }
170 
171  // run inference
172  // -------------
173  auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_inputs);
174 
175  // decorate outputs
176  // ----------------
177 
178  // with old metadata, doesn't support writing aux tasks
179  if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V0) {
180  for (const auto& dec: m_decorators.jetFloat) {
181  if (out_vf.at(dec.first).size() != 1){
182  throw std::logic_error("expected vectors of length 1 for float decorators");
183  }
184  dec.second(btag) = out_vf.at(dec.first).at(0);
185  }
186  }
187  // the new metadata format supports writing aux tasks
188  else if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V1) {
189  // float outputs, e.g. jet probabilities
190  for (const auto& dec: m_decorators.jetFloat) {
191  dec.second(btag) = out_f.at(dec.first);
192  }
193  // vector outputs, e.g. track predictions
194  for (const auto& dec: m_decorators.jetVecChar) {
195  dec.second(btag) = out_vc.at(dec.first);
196  }
197  for (const auto& dec: m_decorators.jetVecFloat) {
198  dec.second(btag) = out_vf.at(dec.first);
199  }
200 
201  // decorate links to the input tracks to the b-tagging object
202  for (const auto& dec: m_decorators.jetTrackLinks) {
204  for (const xAOD::TrackParticle* it: input_tracks) {
205  TrackLinks::value_type link;
206  const auto* itc = dynamic_cast<const xAOD::TrackParticleContainer*>(
207  it->container());
208  link.toIndexedElement(*itc, it->index());
209  links.push_back(link);
210  }
211  dec.second(btag) = links;
212  }
213  }
214  else {
215  throw std::logic_error("unsupported ONNX metadata version");
216  }
217  } // end of decorate()
218 
219  // Dependencies
220  std::set<std::string> GNN::getDecoratorKeys() const {
222  }
223  std::set<std::string> GNN::getAuxInputKeys() const {
225  }
226  std::set<std::string> GNN::getConstituentAuxInputKeys() const {
228  }
229 
230  std::tuple<FTagDataDependencyNames, std::set<std::string>>
231  GNN::createDecorators(const OnnxUtil::OutputConfig& outConfig, const FTagOptions& options) {
233  Decorators decs;
234 
235  std::map<std::string, std::string> remap = options.remap_scalar;
236  std::set<std::string> usedRemap;
237 
238  // get the regex to rewrite the outputs if we're using flip taggers
239  auto flip_converters = dataprep::getNameFlippers(options.flip);
240  std::string context = "building negative tag b-btagger";
241 
242  for (const auto& outNode : outConfig) {
243  // the node's output name will be used to define the decoration name
244  std::string dec_name = outNode.name;
245 
246  // modify the deco name if we're using flip taggers
247  if (options.flip != FlipTagConfig::STANDARD) {
248  dec_name = str::sub_first(flip_converters, dec_name, context);
249  }
250 
251  // remap the deco name if necessary
252  dec_name = str::remapName(dec_name, remap, usedRemap);
253 
254  // keep track of dependencies for EDM bookkeeping
255  deps.bTagOutputs.insert(dec_name);
256 
257  // Create decorators based on output type and target
258  switch (outNode.type) {
260  m_decorators.jetFloat.emplace_back(outNode.name, Dec<float>(dec_name));
261  break;
263  m_decorators.jetVecChar.emplace_back(outNode.name, Dec<std::vector<char>>(dec_name));
264  break;
266  m_decorators.jetVecFloat.emplace_back(outNode.name, Dec<std::vector<float>>(dec_name));
267  break;
268  default:
269  throw std::logic_error("Unknown output data type");
270  }
271  }
272 
273  // Create decorators for links to the input tracks
274  if (!m_decorators.jetVecChar.empty() || !m_decorators.jetVecFloat.empty()) {
275  std::string name = m_onnxUtil->getModelName() + "_TrackLinks";
276  name = str::remapName(name, remap, usedRemap);
277  deps.bTagOutputs.insert(name);
279  }
280 
281  return std::make_tuple(deps, usedRemap);
282  }
283 
284 } // end of namespace FlavorTagDiscriminants
python.SystemOfUnits.second
int second
Definition: SystemOfUnits.py:120
FlavorTagDiscriminants::GNNOptions::flip_config
FlipTagConfig flip_config
Definition: GNNOptions.h:17
BTagTrackIpAccessor.h
FlavorTagDiscriminants::OnnxModelVersion::V2
@ V2
FlavorTagDiscriminants::dataprep::createGetterConfig
std::tuple< std::vector< FTagInputConfig >, std::vector< ConstituentsInputConfig >, FTagOptions > createGetterConfig(lwt::GraphConfig &graph_config, FlipTagConfig flip_config, std::map< std::string, std::string > remap_scalar, TrackLinkType track_link_type)
Definition: DataPrepUtilities.cxx:235
checkxAOD.ds
ds
Definition: Tools/PyUtils/bin/checkxAOD.py:257
FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR
@ VECCHAR
FlavorTagDiscriminants::GNN::TrackLinks
std::vector< ElementLink< TPC > > TrackLinks
Definition: GNN.h:71
FlavorTagDiscriminants::GNN::GNN
GNN(const std::string &nnFile, const GNNOptions &opts)
Definition: GNN.cxx:28
FlavorTagDiscriminants::GNN::Decorators::jetFloat
Decs< float > jetFloat
Definition: GNN.h:80
FlavorTagDiscriminants::GNN::m_constituentsLoaders
std::vector< std::shared_ptr< IConstituentsLoader > > m_constituentsLoaders
Definition: GNN.h:96
FlavorTagDiscriminants::FlipTagConfig::STANDARD
@ STANDARD
BTagging.h
FlavorTagDiscriminants
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
FlavorTagDiscriminants::Inputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
Definition: FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h:28
FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT
@ FLOAT
FlavorTagDiscriminants::GNNOptions
Definition: GNNOptions.h:16
FlavorTagDiscriminants::GNN::m_onnxUtil
std::shared_ptr< const OnnxUtil > m_onnxUtil
Definition: GNN.h:65
SG::AuxElement
Base class for elements of a container that can have aux data.
Definition: AuxElement.h:446
FlavorTagDiscriminants::GNN::m_varsFromJet
std::vector< internal::VarFromJet > m_varsFromJet
Definition: GNN.h:95
skel.it
it
Definition: skel.GENtoEVGEN.py:423
FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT
@ VECFLOAT
FlavorTagDiscriminants::ConstituentsType::TRACK
@ TRACK
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
FlavorTagDiscriminants::GNN::m_dataDependencyNames
FTagDataDependencyNames m_dataDependencyNames
Definition: GNN.h:100
FlavorTagDiscriminants::FTagDataDependencyNames::bTagInputs
std::set< std::string > bTagInputs
Definition: FTagDataDependencyNames.h:14
GNNOptions.h
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
FlavorTagDiscriminants::GNN::getDecoratorKeys
virtual std::set< std::string > getDecoratorKeys() const
Definition: GNN.cxx:220
FlavorTagDiscriminants::GNN::m_varsFromBTag
std::vector< internal::VarFromBTag > m_varsFromBTag
Definition: GNN.h:94
FlavorTagDiscriminants::GNNOptions::variable_remapping
std::map< std::string, std::string > variable_remapping
Definition: GNNOptions.h:18
FlavorTagDiscriminants::GNN::getConstituentAuxInputKeys
virtual std::set< std::string > getConstituentAuxInputKeys() const
Definition: GNN.cxx:226
jet
Definition: JetCalibTools_PlotJESFactors.cxx:23
FlavorTagDiscriminants::FTagDataDependencyNames
Definition: FTagDataDependencyNames.h:12
FlavorTagDiscriminants::FTagDataDependencyNames::bTagOutputs
std::set< std::string > bTagOutputs
Definition: FTagDataDependencyNames.h:15
FlavorTagDiscriminants::str::sub_first
std::string sub_first(const StringRegexes &res, const std::string &var_name, const std::string &context)
Definition: PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/StringUtils.cxx:19
SG::Decorator
Helper class to provide type-safe access to aux data.
Definition: Decorator.h:58
FlavorTagDiscriminants::GNN::getAuxInputKeys
virtual std::set< std::string > getAuxInputKeys() const
Definition: GNN.cxx:223
DMTest::links
links
Definition: CLinks_v1.cxx:22
FlavorTagDiscriminants::GNNOptions::track_link_type
TrackLinkType track_link_type
Definition: GNNOptions.h:19
file
TFile * file
Definition: tile_monitor.h:29
FlavorTagDiscriminants::TrackLinkType
TrackLinkType
Definition: AssociationEnums.h:12
FlavorTagDiscriminants::GNN::Decorators::jetVecFloat
Decs< std::vector< float > > jetVecFloat
Definition: GNN.h:82
python.AtlRunQueryLib.options
options
Definition: AtlRunQueryLib.py:379
xAOD::BTagging_v1
Definition: BTagging_v1.h:39
DataVector< xAOD::TrackParticle_v1 >
FlavorTagDiscriminants::GNN::Decorators::jetTrackLinks
Decs< TrackLinks > jetTrackLinks
Definition: GNN.h:83
FlavorTagDiscriminants::dataprep::checkForUnusedRemaps
void checkForUnusedRemaps(const std::map< std::string, std::string > &requested, const std::set< std::string > &used)
Definition: DataPrepUtilities.cxx:454
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
GNN.h
FlavorTagDiscriminants::FTagOptions
Definition: DataPrepUtilities.h:45
library_scraper.dd
list dd
Definition: library_scraper.py:46
FlavorTagDiscriminants::str::remapName
const std::string remapName(const std::string &name, std::map< std::string, std::string > &remap, std::set< std::string > &usedRemap)
Definition: PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/StringUtils.cxx:8
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
xAOD::Jet_v1
Class describing a jet.
Definition: Jet_v1.h:57
FlavorTagDiscriminants::GNN::m_jetLink
SG::AuxElement::ConstAccessor< ElementLink< xAOD::JetContainer > > m_jetLink
Definition: GNN.h:92
FlavorTagDiscriminants::GNN
Definition: GNN.h:40
FlavorTagDiscriminants::FTagDataDependencyNames::trackInputs
std::set< std::string > trackInputs
Definition: FTagDataDependencyNames.h:13
FlavorTagDiscriminants::GNN::decorateWithDefaults
virtual void decorateWithDefaults(const SG::AuxElement &jet) const
Definition: GNN.cxx:109
JetContainer.h
remap
std::map< std::string, std::string > remap
list of directories to be explicitly remapped
Definition: hcg.cxx:92
CSV_InDetExporter.old
old
Definition: CSV_InDetExporter.py:145
FlavorTagDiscriminants::dataprep::createBvarGetters
std::tuple< std::vector< internal::VarFromBTag >, std::vector< internal::VarFromJet >, FTagDataDependencyNames > createBvarGetters(const std::vector< FTagInputConfig > &inputs)
Definition: DataPrepUtilities.cxx:349
FlavorTagDiscriminants::GNN::~GNN
virtual ~GNN()
FlavorTagDiscriminants::Tracks
std::vector< const xAOD::TrackParticle * > Tracks
Definition: TracksLoader.h:36
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
FlavorTagDiscriminants::GNN::m_defaultValue
float m_defaultValue
Definition: GNN.h:99
FlavorTagDiscriminants::FlipTagConfig
FlipTagConfig
Definition: FlipTagEnums.h:14
FlavorTagDiscriminants::OnnxModelVersion::V0
@ V0
FlavorTagDiscriminants::ConstituentsType::IPARTICLE
@ IPARTICLE
FlavorTagDiscriminants::dataprep::getNameFlippers
StringRegexes getNameFlippers(const FlipTagConfig &flip_config)
Definition: DataPrepUtilities.cxx:197
FlavorTagDiscriminants::GNN::m_decorators
Decorators m_decorators
Definition: GNN.h:98
FlavorTagDiscriminants::GNN::Decorators
Definition: GNN.h:79
util
Definition: Reconstruction/MVAUtils/util/__init__.py:1
FlavorTagDiscriminants::OnnxModelVersion::V1
@ V1
FlavorTagDiscriminants::GNN::Decorators::jetVecChar
Decs< std::vector< char > > jetVecChar
Definition: GNN.h:81
StringUtils.h
FlavorTagDiscriminants::GNN::createDecorators
std::tuple< FTagDataDependencyNames, std::set< std::string > > createDecorators(const OnnxUtil::OutputConfig &outConfig, const FTagOptions &options)
Definition: GNN.cxx:231
OnnxUtil.h
FlavorTagDiscriminants::GNN::decorate
virtual void decorate(const xAOD::BTagging &btag) const
Definition: GNN.cxx:94