ATLAS Offline Software
Classes | Public Member Functions | Public Attributes | Private Types | Private Member Functions | Private Attributes | List of all members
FlavorTagDiscriminants::GNN Class Reference

#include <GNN.h>

Collaboration diagram for FlavorTagDiscriminants::GNN:

Classes

struct  Decorators
 

Public Member Functions

 GNN (const std::string &nnFile, const GNNOptions &opts)
 
 GNN (const GNN &, const GNNOptions &opts)
 
 GNN (const std::string &nnFile, const FlipTagConfig &flip_config=FlipTagConfig::STANDARD, const std::map< std::string, std::string > &variableRemapping={}, const TrackLinkType trackLinkType=TrackLinkType::TRACK_PARTICLE, float defaultOutputValue=NAN)
 
 GNN (GNN &&)
 
 GNN (const GNN &)
 
virtual ~GNN ()
 
virtual void decorate (const xAOD::BTagging &btag) const
 
virtual void decorate (const xAOD::Jet &jet) const
 
virtual void decorateWithDefaults (const SG::AuxElement &jet) const
 
void decorate (const xAOD::Jet &jet, const SG::AuxElement &decorated) const
 
virtual std::set< std::string > getDecoratorKeys () const
 
virtual std::set< std::string > getAuxInputKeys () const
 
virtual std::set< std::string > getConstituentAuxInputKeys () const
 

Public Attributes

std::shared_ptr< const OnnxUtilm_onnxUtil
 

Private Types

using TPC = xAOD::TrackParticleContainer
 
using TrackLinks = std::vector< ElementLink< TPC > >
 
template<typename T >
using Dec = SG::AuxElement::Decorator< T >
 
template<typename T >
using Decs = std::vector< std::pair< std::string, Dec< T > >>
 

Private Member Functions

 GNN (std::shared_ptr< const OnnxUtil >, const GNNOptions &opts)
 
std::tuple< FTagDataDependencyNames, std::set< std::string > > createDecorators (const OnnxUtil::OutputConfig &outConfig, const FTagOptions &options)
 

Private Attributes

SG::AuxElement::ConstAccessor< ElementLink< xAOD::JetContainer > > m_jetLink
 
std::string m_input_node_name
 
std::vector< internal::VarFromBTagm_varsFromBTag
 
std::vector< internal::VarFromJetm_varsFromJet
 
std::vector< std::shared_ptr< IConstituentsLoader > > m_constituentsLoaders
 
Decorators m_decorators
 
float m_defaultValue
 
FTagDataDependencyNames m_dataDependencyNames
 

Detailed Description

Definition at line 39 of file GNN.h.

Member Typedef Documentation

◆ Dec

template<typename T >
using FlavorTagDiscriminants::GNN::Dec = SG::AuxElement::Decorator<T>
private

Definition at line 74 of file GNN.h.

◆ Decs

template<typename T >
using FlavorTagDiscriminants::GNN::Decs = std::vector<std::pair<std::string, Dec<T> >>
private

Definition at line 77 of file GNN.h.

◆ TPC

Definition at line 70 of file GNN.h.

◆ TrackLinks

Definition at line 71 of file GNN.h.

Constructor & Destructor Documentation

◆ GNN() [1/6]

FlavorTagDiscriminants::GNN::GNN ( const std::string &  nnFile,
const GNNOptions opts 
)

Definition at line 28 of file GNN.cxx.

28  :
29  GNN(getOnnxUtil(nn_file), o)
30  {
31  }

◆ GNN() [2/6]

FlavorTagDiscriminants::GNN::GNN ( const GNN old,
const GNNOptions opts 
)

Definition at line 33 of file GNN.cxx.

33  :
34  GNN(old.m_onnxUtil, o)
35  {
36  }

◆ GNN() [3/6]

FlavorTagDiscriminants::GNN::GNN ( const std::string &  nnFile,
const FlipTagConfig flip_config = FlipTagConfig::STANDARD,
const std::map< std::string, std::string > &  variableRemapping = {},
const TrackLinkType  trackLinkType = TrackLinkType::TRACK_PARTICLE,
float  defaultOutputValue = NAN 
)

Definition at line 82 of file GNN.cxx.

86  :
87  GNN( file, GNNOptions { flip, remap, link_type, def_out_val} )
88  {}

◆ GNN() [4/6]

FlavorTagDiscriminants::GNN::GNN ( GNN &&  )
default

◆ GNN() [5/6]

FlavorTagDiscriminants::GNN::GNN ( const GNN )
default

◆ ~GNN()

FlavorTagDiscriminants::GNN::~GNN ( )
virtualdefault

◆ GNN() [6/6]

FlavorTagDiscriminants::GNN::GNN ( std::shared_ptr< const OnnxUtil util,
const GNNOptions opts 
)
private

Definition at line 38 of file GNN.cxx.

38  :
40  m_jetLink(jetLinkName),
41  m_defaultValue(o.default_output_value)
42  {
43 
44  // Extract metadata from the ONNX file, primarily about the model's inputs.
45  auto lwt_config = m_onnxUtil->getLwtConfig();
46 
47  // Create configuration objects for data preprocessing.
48  auto [inputs, constituents_configs, options] = dataprep::createGetterConfig(
49  lwt_config, o.flip_config, o.variable_remapping, o.track_link_type);
50  for (auto config : constituents_configs){
51  switch (config.type){
53  m_constituentsLoaders.push_back(std::make_shared<TracksLoader>(config, options));
54  break;
56  m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(config, options));
57  break;
58  }
59  }
60 
61  // Initialize jet and b-tagging input getters.
62  auto [vb, vj, ds] = dataprep::createBvarGetters(inputs);
63  m_varsFromBTag = vb;
64  m_varsFromJet = vj;
66 
67  // Retrieve the configuration for the model outputs.
68  OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
69 
70  // Create the output decorators.
71  auto [dd, rd] = createDecorators(gnn_output_config, options);
73 
74  // Update dependencies and used remap from the constituents loaders.
75  for (auto loader : m_constituentsLoaders){
76  m_dataDependencyNames += loader->getDependencies();
77  rd.merge(loader->getUsedRemap());
78  }
79  dataprep::checkForUnusedRemaps(options.remap_scalar, rd);
80  }

Member Function Documentation

◆ createDecorators()

std::tuple< FTagDataDependencyNames, std::set< std::string > > FlavorTagDiscriminants::GNN::createDecorators ( const OnnxUtil::OutputConfig &  outConfig,
const FTagOptions options 
)
private

Definition at line 231 of file GNN.cxx.

231  {
233  Decorators decs;
234 
235  std::map<std::string, std::string> remap = options.remap_scalar;
236  std::set<std::string> usedRemap;
237 
238  // get the regex to rewrite the outputs if we're using flip taggers
239  auto flip_converters = dataprep::getNameFlippers(options.flip);
240  std::string context = "building negative tag b-btagger";
241 
242  for (const auto& outNode : outConfig) {
243  // the node's output name will be used to define the decoration name
244  std::string dec_name = outNode.name;
245 
246  // modify the deco name if we're using flip taggers
247  if (options.flip != FlipTagConfig::STANDARD) {
248  dec_name = str::sub_first(flip_converters, dec_name, context);
249  }
250 
251  // remap the deco name if necessary
252  dec_name = str::remapName(dec_name, remap, usedRemap);
253 
254  // keep track of dependencies for EDM bookkeeping
255  deps.bTagOutputs.insert(dec_name);
256 
257  // Create decorators based on output type and target
258  switch (outNode.type) {
260  m_decorators.jetFloat.emplace_back(outNode.name, Dec<float>(dec_name));
261  break;
263  m_decorators.jetVecChar.emplace_back(outNode.name, Dec<std::vector<char>>(dec_name));
264  break;
266  m_decorators.jetVecFloat.emplace_back(outNode.name, Dec<std::vector<float>>(dec_name));
267  break;
268  default:
269  throw std::logic_error("Unknown output data type");
270  }
271  }
272 
273  // Create decorators for links to the input tracks
274  if (!m_decorators.jetVecChar.empty() || !m_decorators.jetVecFloat.empty()) {
275  std::string name = m_onnxUtil->getModelName() + "_TrackLinks";
276  name = str::remapName(name, remap, usedRemap);
277  deps.bTagOutputs.insert(name);
278  m_decorators.jetTrackLinks.emplace_back(name, Dec<TrackLinks>(name));
279  }
280 
281  return std::make_tuple(deps, usedRemap);
282  }

◆ decorate() [1/3]

void FlavorTagDiscriminants::GNN::decorate ( const xAOD::BTagging btag) const
virtual

Definition at line 94 of file GNN.cxx.

94  {
95  /* tag a b-tagging object */
96  auto jetLink = m_jetLink(btag);
97  if (!jetLink.isValid()) {
98  throw std::runtime_error("invalid jetLink");
99  }
100  const xAOD::Jet& jet = **jetLink;
101  decorate(jet, btag);
102  }

◆ decorate() [2/3]

void FlavorTagDiscriminants::GNN::decorate ( const xAOD::Jet jet) const
virtual

Definition at line 104 of file GNN.cxx.

104  {
105  /* tag a jet */
106  decorate(jet, jet);
107  }

◆ decorate() [3/3]

void FlavorTagDiscriminants::GNN::decorate ( const xAOD::Jet jet,
const SG::AuxElement decorated 
) const

Definition at line 128 of file GNN.cxx.

128  {
129  /* Main function for decorating a jet or b-tagging object with GNN outputs. */
130  using namespace internal;
131 
132  // prepare input
133  // -------------
134  std::map<std::string, Inputs> gnn_inputs;
135 
136  // jet level inputs
137  std::vector<float> jet_feat;
138  for (const auto& getter: m_varsFromBTag) {
139  jet_feat.push_back(getter(btag).second);
140  }
141  for (const auto& getter: m_varsFromJet) {
142  jet_feat.push_back(getter(jet).second);
143  }
144  std::vector<int64_t> jet_feat_dim = {1, static_cast<int64_t>(jet_feat.size())};
145  Inputs jet_info(jet_feat, jet_feat_dim);
146  if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V2) {
147  gnn_inputs.insert({"jets", jet_info});
148  } else {
149  gnn_inputs.insert({"jet_features", jet_info});
150  }
151 
152  // constituent level inputs
153  Tracks input_tracks;
154  for (auto loader : m_constituentsLoaders){
155  auto [input_name, input_data, input_objects] = loader->getData(jet, btag);
156  if (m_onnxUtil->getOnnxModelVersion() != OnnxModelVersion::V2) {
157  input_name.pop_back();
158  input_name.append("_features");
159  }
160  gnn_inputs.insert({input_name, input_data});
161 
162  // for now we only collect tracks for aux task decoration
163  // they have to be converted back from IParticle to TrackParticle first
164  if (loader->getType() == ConstituentsType::TRACK){
165  for (auto constituent : input_objects){
166  input_tracks.push_back(dynamic_cast<const xAOD::TrackParticle*>(constituent));
167  }
168  }
169  }
170 
171  // run inference
172  // -------------
173  auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_inputs);
174 
175  // decorate outputs
176  // ----------------
177 
178  // with old metadata, doesn't support writing aux tasks
179  if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V0) {
180  for (const auto& dec: m_decorators.jetFloat) {
181  if (out_vf.at(dec.first).size() != 1){
182  throw std::logic_error("expected vectors of length 1 for float decorators");
183  }
184  dec.second(btag) = out_vf.at(dec.first).at(0);
185  }
186  }
187  // the new metadata format supports writing aux tasks
188  else if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V1) {
189  // float outputs, e.g. jet probabilities
190  for (const auto& dec: m_decorators.jetFloat) {
191  dec.second(btag) = out_f.at(dec.first);
192  }
193  // vector outputs, e.g. track predictions
194  for (const auto& dec: m_decorators.jetVecChar) {
195  dec.second(btag) = out_vc.at(dec.first);
196  }
197  for (const auto& dec: m_decorators.jetVecFloat) {
198  dec.second(btag) = out_vf.at(dec.first);
199  }
200 
201  // decorate links to the input tracks to the b-tagging object
202  for (const auto& dec: m_decorators.jetTrackLinks) {
204  for (const xAOD::TrackParticle* it: input_tracks) {
205  TrackLinks::value_type link;
206  const auto* itc = dynamic_cast<const xAOD::TrackParticleContainer*>(
207  it->container());
208  link.toIndexedElement(*itc, it->index());
209  links.push_back(link);
210  }
211  dec.second(btag) = links;
212  }
213  }
214  else {
215  throw std::logic_error("unsupported ONNX metadata version");
216  }
217  } // end of decorate()

◆ decorateWithDefaults()

void FlavorTagDiscriminants::GNN::decorateWithDefaults ( const SG::AuxElement jet) const
virtual

Definition at line 109 of file GNN.cxx.

109  {
110  for (const auto& dec: m_decorators.jetFloat) {
111  dec.second(jet) = m_defaultValue;
112  }
113  // for some networks we need to set a lot of empty vectors as well
114  if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V1) {
115  // vector outputs, e.g. track predictions
116  for (const auto& dec: m_decorators.jetVecChar) {
117  dec.second(jet) = {};
118  }
119  for (const auto& dec: m_decorators.jetVecFloat) {
120  dec.second(jet) = {};
121  }
122  for (const auto& dec: m_decorators.jetTrackLinks) {
123  dec.second(jet) = {};
124  }
125  }
126  }

◆ getAuxInputKeys()

std::set< std::string > FlavorTagDiscriminants::GNN::getAuxInputKeys ( ) const
virtual

Definition at line 223 of file GNN.cxx.

223  {
225  }

◆ getConstituentAuxInputKeys()

std::set< std::string > FlavorTagDiscriminants::GNN::getConstituentAuxInputKeys ( ) const
virtual

Definition at line 226 of file GNN.cxx.

226  {
228  }

◆ getDecoratorKeys()

std::set< std::string > FlavorTagDiscriminants::GNN::getDecoratorKeys ( ) const
virtual

Definition at line 220 of file GNN.cxx.

220  {
222  }

Member Data Documentation

◆ m_constituentsLoaders

std::vector<std::shared_ptr<IConstituentsLoader> > FlavorTagDiscriminants::GNN::m_constituentsLoaders
private

Definition at line 96 of file GNN.h.

◆ m_dataDependencyNames

FTagDataDependencyNames FlavorTagDiscriminants::GNN::m_dataDependencyNames
private

Definition at line 100 of file GNN.h.

◆ m_decorators

Decorators FlavorTagDiscriminants::GNN::m_decorators
private

Definition at line 98 of file GNN.h.

◆ m_defaultValue

float FlavorTagDiscriminants::GNN::m_defaultValue
private

Definition at line 99 of file GNN.h.

◆ m_input_node_name

std::string FlavorTagDiscriminants::GNN::m_input_node_name
private

Definition at line 93 of file GNN.h.

◆ m_jetLink

SG::AuxElement::ConstAccessor<ElementLink<xAOD::JetContainer> > FlavorTagDiscriminants::GNN::m_jetLink
private

Definition at line 92 of file GNN.h.

◆ m_onnxUtil

std::shared_ptr<const OnnxUtil> FlavorTagDiscriminants::GNN::m_onnxUtil

Definition at line 65 of file GNN.h.

◆ m_varsFromBTag

std::vector<internal::VarFromBTag> FlavorTagDiscriminants::GNN::m_varsFromBTag
private

Definition at line 94 of file GNN.h.

◆ m_varsFromJet

std::vector<internal::VarFromJet> FlavorTagDiscriminants::GNN::m_varsFromJet
private

Definition at line 95 of file GNN.h.


The documentation for this class was generated from the following files:
python.SystemOfUnits.second
int second
Definition: SystemOfUnits.py:120
FlavorTagDiscriminants::OnnxModelVersion::V2
@ V2
FlavorTagDiscriminants::dataprep::createGetterConfig
std::tuple< std::vector< FTagInputConfig >, std::vector< ConstituentsInputConfig >, FTagOptions > createGetterConfig(lwt::GraphConfig &graph_config, FlipTagConfig flip_config, std::map< std::string, std::string > remap_scalar, TrackLinkType track_link_type)
Definition: DataPrepUtilities.cxx:235
checkxAOD.ds
ds
Definition: Tools/PyUtils/bin/checkxAOD.py:257
FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR
@ VECCHAR
FlavorTagDiscriminants::GNN::TrackLinks
std::vector< ElementLink< TPC > > TrackLinks
Definition: GNN.h:71
FlavorTagDiscriminants::GNN::GNN
GNN(const std::string &nnFile, const GNNOptions &opts)
Definition: GNN.cxx:28
FlavorTagDiscriminants::GNN::Decorators::jetFloat
Decs< float > jetFloat
Definition: GNN.h:80
FlavorTagDiscriminants::GNN::m_constituentsLoaders
std::vector< std::shared_ptr< IConstituentsLoader > > m_constituentsLoaders
Definition: GNN.h:96
FlavorTagDiscriminants::FlipTagConfig::STANDARD
@ STANDARD
FlavorTagDiscriminants::Inputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
Definition: FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h:28
FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT
@ FLOAT
FlavorTagDiscriminants::GNNOptions
Definition: GNNOptions.h:16
FlavorTagDiscriminants::GNN::m_onnxUtil
std::shared_ptr< const OnnxUtil > m_onnxUtil
Definition: GNN.h:65
FlavorTagDiscriminants::GNN::m_varsFromJet
std::vector< internal::VarFromJet > m_varsFromJet
Definition: GNN.h:95
skel.it
it
Definition: skel.GENtoEVGEN.py:423
FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT
@ VECFLOAT
FlavorTagDiscriminants::ConstituentsType::TRACK
@ TRACK
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
FlavorTagDiscriminants::GNN::m_dataDependencyNames
FTagDataDependencyNames m_dataDependencyNames
Definition: GNN.h:100
FlavorTagDiscriminants::FTagDataDependencyNames::bTagInputs
std::set< std::string > bTagInputs
Definition: FTagDataDependencyNames.h:14
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
FlavorTagDiscriminants::GNN::m_varsFromBTag
std::vector< internal::VarFromBTag > m_varsFromBTag
Definition: GNN.h:94
jet
Definition: JetCalibTools_PlotJESFactors.cxx:23
FlavorTagDiscriminants::FTagDataDependencyNames
Definition: FTagDataDependencyNames.h:12
FlavorTagDiscriminants::FTagDataDependencyNames::bTagOutputs
std::set< std::string > bTagOutputs
Definition: FTagDataDependencyNames.h:15
FlavorTagDiscriminants::str::sub_first
std::string sub_first(const StringRegexes &res, const std::string &var_name, const std::string &context)
Definition: PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/StringUtils.cxx:19
DMTest::links
links
Definition: CLinks_v1.cxx:22
file
TFile * file
Definition: tile_monitor.h:29
FlavorTagDiscriminants::GNN::Decorators::jetVecFloat
Decs< std::vector< float > > jetVecFloat
Definition: GNN.h:82
python.AtlRunQueryLib.options
options
Definition: AtlRunQueryLib.py:379
DataVector< xAOD::TrackParticle_v1 >
FlavorTagDiscriminants::GNN::Decorators::jetTrackLinks
Decs< TrackLinks > jetTrackLinks
Definition: GNN.h:83
FlavorTagDiscriminants::dataprep::checkForUnusedRemaps
void checkForUnusedRemaps(const std::map< std::string, std::string > &requested, const std::set< std::string > &used)
Definition: DataPrepUtilities.cxx:454
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
library_scraper.dd
list dd
Definition: library_scraper.py:46
FlavorTagDiscriminants::str::remapName
const std::string remapName(const std::string &name, std::map< std::string, std::string > &remap, std::set< std::string > &usedRemap)
Definition: PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/StringUtils.cxx:8
xAOD::Jet_v1
Class describing a jet.
Definition: Jet_v1.h:57
FlavorTagDiscriminants::GNN::m_jetLink
SG::AuxElement::ConstAccessor< ElementLink< xAOD::JetContainer > > m_jetLink
Definition: GNN.h:92
FlavorTagDiscriminants::FTagDataDependencyNames::trackInputs
std::set< std::string > trackInputs
Definition: FTagDataDependencyNames.h:13
remap
std::map< std::string, std::string > remap
list of directories to be explicitly remapped
Definition: hcg.cxx:92
CSV_InDetExporter.old
old
Definition: CSV_InDetExporter.py:145
FlavorTagDiscriminants::dataprep::createBvarGetters
std::tuple< std::vector< internal::VarFromBTag >, std::vector< internal::VarFromJet >, FTagDataDependencyNames > createBvarGetters(const std::vector< FTagInputConfig > &inputs)
Definition: DataPrepUtilities.cxx:349
FlavorTagDiscriminants::GNN::Dec
SG::AuxElement::Decorator< T > Dec
Definition: GNN.h:74
FlavorTagDiscriminants::Tracks
std::vector< const xAOD::TrackParticle * > Tracks
Definition: TracksLoader.h:36
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
FlavorTagDiscriminants::GNN::m_defaultValue
float m_defaultValue
Definition: GNN.h:99
FlavorTagDiscriminants::OnnxModelVersion::V0
@ V0
FlavorTagDiscriminants::ConstituentsType::IPARTICLE
@ IPARTICLE
FlavorTagDiscriminants::dataprep::getNameFlippers
StringRegexes getNameFlippers(const FlipTagConfig &flip_config)
Definition: DataPrepUtilities.cxx:197
FlavorTagDiscriminants::GNN::m_decorators
Decorators m_decorators
Definition: GNN.h:98
util
Definition: Reconstruction/MVAUtils/util/__init__.py:1
FlavorTagDiscriminants::OnnxModelVersion::V1
@ V1
FlavorTagDiscriminants::GNN::Decorators::jetVecChar
Decs< std::vector< char > > jetVecChar
Definition: GNN.h:81
FlavorTagDiscriminants::GNN::createDecorators
std::tuple< FTagDataDependencyNames, std::set< std::string > > createDecorators(const OnnxUtil::OutputConfig &outConfig, const FTagOptions &options)
Definition: GNN.cxx:231
FlavorTagDiscriminants::GNN::decorate
virtual void decorate(const xAOD::BTagging &btag) const
Definition: GNN.cxx:94