ATLAS Offline Software
HbbTag.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3 */
7 
8 // EDM includes
9 #include "xAODJet/Jet.h"
10 #include "xAODBTagging/BTagging.h"
12 // atlas utilities
14 
15 // external libraries
16 #include "lwtnn/LightweightGraph.hh"
17 #include "lwtnn/parse_json.hh"
18 
19 // c++ core
20 #include <set>
21 #include <fstream>
22 #include <filesystem>
23 
24 
25 namespace {
26 
27  template <typename T>
28  class BTagPairGetter
29  {
30  public:
31  BTagPairGetter(const std::string& key);
32  std::pair<std::string, double> operator()(const xAOD::Jet& jet);
33  private:
34  std::string m_key;
36  };
37 
38  using Pg = std::function<std::pair<std::string, double>(const xAOD::Jet&)>;
39  Pg makePairGetter(const std::string& key);
40 
41  void requireOverwrite(std::map<std::string, double>& target,
42  const std::pair<std::string, double>& value);
43 
44 
45 }
46 
47 namespace FlavorTagDiscriminants {
48 
50  m_parent_link("Parent"),
51  m_subjet_link_getter(config.subjet_link_name),
52  m_n_subjets(0),
53  m_min_subjet_pt(config.min_subjet_pt)
54  {
55  namespace fs = std::filesystem;
56  // setup NN
57  fs::path nn_path = config.input_file_path;
58  if (!fs::exists(nn_path)) {
59  nn_path = PathResolverFindCalibFile(nn_path.string());
60  if (nn_path.empty()) {
61  throw std::runtime_error(
62  "no file found at '" + config.input_file_path.string() + "'");
63  }
64  }
65  std::ifstream input_stream(nn_path.string());
66  lwt::GraphConfig graph_cfg = lwt::parse_json_graph(input_stream);
67  m_graph.reset(new lwt::LightweightGraph(graph_cfg));
68 
69  // setup large-R jet getters and defaults
71  // add the getters
72  for (const std::string& key: keys.fatjet) {
73  m_fat_jet_getters.push_back(makePairGetter(key));
74  }
75  for (const std::string& key: keys.subjet) {
76  m_subjet_getters.push_back(makePairGetter(key));
77  }
78  m_defaults = keys.defaults;
79  m_n_subjets = keys.n_subjets;
80 
81  // setup outputs
82  for (const auto& output: graph_cfg.outputs) {
83  const std::string& node_name = output.first;
84  const lwt::OutputNodeConfig& node = output.second;
85  NodeWriter node_writer;
86  for (const std::string& varname: node.labels) {
87  std::string write_name = node_name + "_" + varname;
88  node_writer.emplace_back(varname, write_name);
89  }
90  m_outputs.emplace_back(node_name, node_writer);
91  }
92  }
93  HbbTag::HbbTag(HbbTag&&) = default;
95 
96  void HbbTag::decorate(const xAOD::Jet& jet) const {
97  namespace hk = hbb_key;
98  std::map<std::string, std::map<std::string, double>> inputs = m_defaults;
99  for (const auto& getter: m_fat_jet_getters) {
100  requireOverwrite(inputs.at(hk::fatjet),getter(jet));
101  }
102 
103  std::vector<const xAOD::IParticle*> subjets;
104  const xAOD::Jet* parent = *m_parent_link(jet);
105  if (!parent) throw std::runtime_error("can't resolve parent jet");
106  for (const auto& link: m_subjet_link_getter(*parent)) {
107  const xAOD::IParticle* subjet = *link;
108  if (!subjet) throw std::runtime_error("can't resolve subjet link");
109  if (subjet->pt() >= m_min_subjet_pt) {
110  subjets.push_back(subjet);
111  }
112  }
113  std::sort(subjets.begin(), subjets.end(),
114  [](auto* a, auto* b) { return a->pt() > b->pt(); });
115 
116  size_t n_jets = std::min(subjets.size(), m_n_subjets);
117  for (size_t jet_n = 0; jet_n < n_jets; jet_n++) {
118  const auto* subjet = dynamic_cast<const xAOD::Jet*>(subjets.at(jet_n));
119  if (!subjet) throw std::runtime_error("IParticle is not a Jet");
120  std::string subjet_name = hk::subjet + std::to_string(jet_n);
121  for (const auto& getter: m_subjet_getters) {
122  requireOverwrite(inputs.at(subjet_name),getter(*subjet));
123  }
124  }
125 
126  // calculate and write
127  for (const auto& node: m_outputs) {
128  const auto& result = m_graph->compute(inputs, {}, node.first);
129  for (const auto& var_writer: node.second) {
130  var_writer.second(jet) = result.at(var_writer.first);
131  }
132  }
133 
134  }
135 
136 }
137 
138 
139 namespace {
140  // implemenation
141  template <typename T>
142  BTagPairGetter<T>::BTagPairGetter(const std::string& key):
143  m_key(key), m_getter(key)
144  {
145  }
146  template <typename T>
147  std::pair<std::string, double>
148  BTagPairGetter<T>::operator()(const xAOD::Jet& jet) {
150  if (!btag) throw std::runtime_error("can't find btagging object");
151  return {m_key, m_getter(*btag)};
152  }
153 
154 
155  Pg makePairGetter(const std::string& key) {
156  namespace hk = FlavorTagDiscriminants::hbb_key;
157  if (key == hk::pt) {
158  return [](const xAOD::Jet& j) -> std::pair<std::string, double> {
159  return {hk::pt, j.pt()};
160  };
161  } else if (key == hk::eta) {
162  return [](const xAOD::Jet& j) -> std::pair<std::string, double> {
163  return {hk::eta, j.eta()};
164  };
165  } else {
166  // for now we assume everything we read from b-tagging is a float,
167  // this is only true for DL1 scores.
168  return BTagPairGetter<float>(key);
169  }
170  }
171 
172  void requireOverwrite(std::map<std::string, double>& target,
173  const std::pair<std::string, double>& value) {
174  const auto itr = target.find(value.first);
175  if (itr == target.end()) {
176  throw std::logic_error("can't fine a default value for " + value.first);
177  }
178  itr->second = value.second;
179  }
180 
181 }
BTaggingUtilities.h
FlavorTagDiscriminants::HbbTag::decorate
void decorate(const xAOD::Jet &jet) const
Definition: HbbTag.cxx:96
Jet.h
FlavorTagDiscriminants::hbb_key::subjet
const std::string subjet
Definition: HbbConstants.h:18
FlavorTagDiscriminants::HbbTag::m_defaults
std::map< std::string, std::map< std::string, double > > m_defaults
Definition: HbbTag.h:44
get_generator_info.result
result
Definition: get_generator_info.py:21
athena.path
path
python interpreter configuration --------------------------------------—
Definition: athena.py:126
FlavorTagDiscriminants::HbbTag::m_outputs
std::vector< std::pair< std::string, NodeWriter > > m_outputs
Definition: HbbTag.h:50
BTagging.h
FlavorTagDiscriminants
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
eta
Scalar eta() const
pseudorapidity method
Definition: AmgMatrixBasePlugin.h:79
FlavorTagDiscriminants::HbbTag::m_graph
std::unique_ptr< lwt::LightweightGraph > m_graph
Definition: HbbTag.h:43
FlavorTagDiscriminants::HbbTag::m_fat_jet_getters
std::vector< Pg > m_fat_jet_getters
Definition: HbbTag.h:37
FlavorTagDiscriminants::HbbTag::NodeWriter
std::vector< std::pair< std::string, Decorator< float > >> NodeWriter
Definition: HbbTag.h:49
test_pyathena.pt
pt
Definition: test_pyathena.py:11
athena.value
value
Definition: athena.py:122
SG::ConstAccessor
Helper class to provide constant type-safe access to aux data.
Definition: ConstAccessor.h:54
xAOD::IParticle
Class providing the definition of the 4-vector interface.
Definition: Event/xAOD/xAODBase/xAODBase/IParticle.h:40
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
FlavorTagDiscriminants::HbbGraphConfig
Definition: HbbGraphConfig.h:23
jet
Definition: JetCalibTools_PlotJESFactors.cxx:23
FlavorTagDiscriminants::HbbTag::m_n_subjets
size_t m_n_subjets
Definition: HbbTag.h:39
FlavorTagDiscriminants::HbbTag::~HbbTag
~HbbTag()
Definition: HbbTag.cxx:94
test_pyathena.parent
parent
Definition: test_pyathena.py:15
FlavorTagDiscriminants::HbbTag::m_parent_link
SG::AuxElement::ConstAccessor< JetLink > m_parent_link
Definition: HbbTag.h:33
xAOD::BTagging_v1
Definition: BTagging_v1.h:39
FlavorTagDiscriminants::hbb_key
Definition: HbbConstants.h:14
min
#define min(a, b)
Definition: cfImp.cxx:40
merge.output
output
Definition: merge.py:17
PathResolver.h
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
FlavorTagDiscriminants::getHbbGraphConfig
HbbGraphConfig getHbbGraphConfig(const lwt::GraphConfig &cfg)
Definition: HbbGraphConfig.cxx:29
FlavorTagDiscriminants::HbbTag::HbbTag
HbbTag(const HbbTagConfig &config)
Definition: HbbTag.cxx:49
HbbConstants.h
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
xAOD::BTaggingUtilities::getBTagging
const BTagging * getBTagging(const SG::AuxElement &part)
Access the default xAOD::BTagging object associated to an object.
Definition: BTaggingUtilities.cxx:37
xAOD::Jet_v1
Class describing a jet.
Definition: Jet_v1.h:57
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
FlavorTagDiscriminants::HbbTagConfig
Definition: HbbTagConfig.h:13
FlavorTagDiscriminants::hbb_key::fatjet
const std::string fatjet
Definition: HbbConstants.h:19
LArG4AODNtuplePlotter.varname
def varname(hname)
Definition: LArG4AODNtuplePlotter.py:37
a
TList * a
Definition: liststreamerinfos.cxx:10
FlavorTagDiscriminants::HbbTag::m_subjet_getters
std::vector< Pg > m_subjet_getters
Definition: HbbTag.h:38
Herwig7_QED_EvtGen_ll.fs
dictionary fs
Definition: Herwig7_QED_EvtGen_ll.py:17
FlavorTagDiscriminants::HbbTag
Definition: HbbTag.h:23
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:790
HbbGraphConfig.h
FlavorTagDiscriminants::HbbTag::m_subjet_link_getter
SG::AuxElement::ConstAccessor< PartLink > m_subjet_link_getter
Definition: HbbTag.h:35
COOLRates.target
target
Definition: COOLRates.py:1106
python.dummyaccess.exists
def exists(filename)
Definition: dummyaccess.py:9
FlavorTagDiscriminants::HbbTag::m_min_subjet_pt
double m_min_subjet_pt
Definition: HbbTag.h:40
HbbTag.h
node
Definition: memory_hooks-stdcmalloc.h:74
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37