ATLAS Offline Software
IParticleWriter.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
6 
7 #include "PrimitiveHelpers.h"
8 
9 #include "xAODBase/IParticle.h"
10 #include "HDF5Utils/Writer.h"
11 
12 // needed concrete classes for ElementLink access
14 
15 namespace {
16  template <unsigned int N>
18  using In_t = Writer_t<1>::input_type;
19  using Consumer_t = Writer_t<1>::consumer_type;
20  using CountWriter_t = H5Utils::Writer<0, unsigned char>;
21 
22  template <typename A=detail::defaultAccessor_t<Consumer_t>>
23  void addCustomType(Consumer_t& c,
24  const Primitive& p,
25  A a=detail::defaultAccessor<Consumer_t>) {
26  using I = In_t;
27 
28  if (!detail::isCustom(p)) {
29  throw std::logic_error("called addCustomType on non-custom type");
30  }
31  bool force_precision = p.type == Primitive::Type::PRECISION_CUSTOM;
32  const H5Utils::Compression half = force_precision ?
35 
36  const std::string s = p.source;
37  const std::string t = p.target;
38 
39  // check for match
40  bool m = false;
41 
42  auto add = [&c, a, t](auto f, auto compression, float mult = 1.0) {
43  c.add<float>(
44  t,
45  [a, f, mult](I in) -> float {
46  const auto* associated = a(in);
47  if (!associated) return NAN;
48  return f(associated)*mult;
49  },
50  NAN, compression);
51  };
52 
53  // these do most of the matching work
54  auto match = [&add, s, &m, half](const std::string& n, auto func) {
55  if (s == n) {
56  add(func, half);
57  m = true;
58  }
59  };
60  // This matches two possible strings: the original string and one
61  // with a "GeV" suffix. The purpose of the suffixed version is to
62  // allow us to store things at half precision (since most things
63  // stored in MeV would overflow half precision floats).
64  auto matchGeV = [&add, s, &m, half](const std::string& n, auto func) {
65  if (s == n) {
66  // Sort of convoluted logic here: if we force higher precision
67  // we set "half" to be full precision. But if someone asks for
68  // full precision in a variable that must be stored that way
69  // anyway, we throw an exception.
70  if (half == H5Utils::Compression::STANDARD) {
71  throw std::logic_error(
72  "asked for a full precision version of a variable that can"
73  " not be stored at half precision: " + s);
74  }
76  m = true;
77  } else if (s == n + "GeV") {
78  add(func, half, 0.001);
79  m = true;
80  }
81  };
82 
83  // match cases
84  matchGeV("pt", [](auto in) {return in->pt(); });
85  match("eta", [](auto in) {return in->eta(); });
86  match("phi", [](auto in) {return in->phi(); });
87  matchGeV("px", [](auto in) {return in->p4().Px(); });
88  matchGeV("py", [](auto in) {return in->p4().Py(); });
89  matchGeV("pz", [](auto in) {return in->p4().Pz(); });
90  matchGeV("mass", [](auto in) {return in->m(); });
91 
92  if (m) return;
93 
94  if (s == "valid") {
95  c.add<bool>(s, [](I) {return true; }, false);
96  return;
97  }
98  throw std::logic_error("unknow known custom primitive: " + s);
99  }
100 
101 
102  // stuff to retrieve associated links
103  template <typename T, typename R=SG::AuxElement>
104  class LinkGetter
105  {
106  public:
107  LinkGetter(std::string name);
108  const R* operator()(In_t in) const;
109  private:
110  using LinkAccessor = SG::AuxElement::ConstAccessor<ElementLink<T>>;
111  const LinkAccessor m_accessor;
112  const std::string m_linkName;
113  };
114  template <typename T, typename R>
115  LinkGetter<T,R>::LinkGetter(std::string name):
116  m_accessor(name),
117  m_linkName(name)
118  {}
119  template <typename T, typename R>
120  const R* LinkGetter<T,R>::operator()(In_t in) const {
121  auto elink = m_accessor(*in);
122  // isDefault should generally indicate that the element link has
123  // been created but not set to point to any specific place. This
124  // is distinct from being created and pointing to a particle that
125  // has been thinned or slimmed away. So it's safe to use as a "not
126  // set" code.
127  if (elink.isDefault()) {
128  return nullptr;
129  }
130  // If the link is _not_ default, but is also invalid, then we're
131  // trying to access something that has been removed. We treat this
132  // as an error, because the behavior becomes dependent on the
133  // format.
134  if (!elink.isValid()) {
135  throw std::runtime_error("invalid link " + m_linkName);
136  }
137  return *elink;
138  }
139 
140  CountWriter_t::consumer_type getOffsetConsumer()
141  {
142  using CountIn_t = CountWriter_t::input_type;
143  CountWriter_t::consumer_type c;
144  c.add<CountIn_t>("count", [](CountIn_t i) { return i; });
145  return c;
146  }
147 
148 }
149 
150 
151 // IParticleWriter implementaitons
152 namespace details {
154  {
155  public:
156  virtual ~IParticleWriterBase() = default;
157  virtual void fill(const std::vector<const xAOD::IParticle*>& info) = 0;
158  virtual void flush() = 0;
159  };
160 }
161 namespace {
162  // implementation for writer for 2d arrays
163  class IParticle2dWriter: public details::IParticleWriterBase
164  {
165  private:
166  Writer_t<1> m_writer;
167  public:
168  IParticle2dWriter(H5::Group& group,
169  const std::string& n,
170  Consumer_t c,
171  long long unsigned size):
172  m_writer(group, n, c, {{size}}) {}
173  ~IParticle2dWriter() = default;
174 
175  void fill(const std::vector<In_t>& v) override {
176  m_writer.fill(v);
177  }
178  void flush() override {
179  m_writer.flush();
180  }
181  };
182 
183  // implementation for writer for awkward arrays
184  class IParticleAwkwardWriter: public details::IParticleWriterBase
185  {
186  private:
187  H5::Group m_group;
188  Writer_t<0> m_writer;
189  CountWriter_t m_counts;
190  public:
191  IParticleAwkwardWriter(H5::Group& parent,
192  const std::string& n,
193  Consumer_t c):
194  m_group(parent.createGroup(n)),
195  m_writer(m_group, "raw", c),
196  m_counts(m_group, "counts", getOffsetConsumer())
197  {}
198  ~IParticleAwkwardWriter() = default;
199 
200  void fill(const std::vector<In_t>& v) override {
201  using Count_t = CountWriter_t::input_type;
202  constexpr auto max = std::numeric_limits<Count_t>::max();
203  auto n_entries = v.size();
204  if (n_entries > max) {
205  throw std::overflow_error(
206  "number of entries exceeds maximum for this datatype "
207  "[" + std::to_string(n_entries) + " > " + std::to_string(max) + "]"
208  );
209  }
210  for (const auto& e: v) m_writer.fill(e);
211  m_counts.fill(n_entries);
212  }
213  void flush() override {
214  m_writer.flush();
215  m_counts.flush();
216  }
217  };
218 
219 
220  // implementation for writer for awkward arrays
221  class IParticleFlatWriter: public details::IParticleWriterBase
222  {
223  private:
224  H5::Group m_group;
225  Writer_t<0> m_writer;
226  public:
227  IParticleFlatWriter(H5::Group& parent,
228  const std::string& n,
229  Consumer_t c):
230  m_writer(parent, n, c)
231  {}
232  ~IParticleFlatWriter() = default;
233 
234  void fill(const std::vector<In_t>& v) override {
235  for (const auto& e: v) m_writer.fill(e);
236  }
237  void flush() override {
238  m_writer.flush();
239  }
240  };
241 
242  std::unique_ptr<details::IParticleWriterBase> getWriter(
243  H5::Group& g,
244  const IParticleWriterConfig& cfg,
245  Consumer_t c) {
247  switch (cfg.format) {
248  case af::AWKWARD: {
249  if (cfg.maximum_size != 0) {
250  throw std::domain_error(
251  "no maximum_size should be specified for awkward arrays");
252  }
253  return std::make_unique<IParticleAwkwardWriter>(g, cfg.name, c);
254  }
255  case af::FLAT: {
256  if (cfg.maximum_size != 0) {
257  throw std::domain_error(
258  "no maximum_size should be specified for flat arrays");
259  }
260  return std::make_unique<IParticleFlatWriter>(g, cfg.name, c);
261  }
262  case af::PADDED: {
263  if (cfg.maximum_size == 0) {
264  throw std::domain_error(
265  "maximum_size should be specified for padded 2d arrays");
266  }
267  return std::make_unique<IParticle2dWriter>(
268  g, cfg.name, c, cfg.maximum_size);
269  }
270  default: {
271  throw std::domain_error("unknown array format");
272  }
273  }
274  }
275 }
276 
277 
279  H5::Group& group,
280  const IParticleWriterConfig& cfg)
281 {
282  using input_type = In_t;
283  using IPC = xAOD::IParticleContainer;
284  using IP = xAOD::IParticle;
285  Consumer_t c;
286  for (const auto& input: cfg.inputs) {
287  if (input.link_name.empty()) {
288  const auto& primitive = input.input;
289  if (detail::isCustom(primitive.type)) {
290  addCustomType(c, primitive);
291  } else {
292  detail::addInput(c, primitive);
293  }
294  } else {
295  // else we have some association to follow
296  std::string n = input.link_name;
297  // unfortunately we need a special case for b-tagging
298  if (n == "btaggingLink") {
299  LinkGetter<xAOD::BTaggingContainer> getter(n);
300  detail::addInput(c, input.input, getter);
301  } else {
302  // everything else is an IParticle
303  LinkGetter<IPC,IP> getter(n);
304  if (detail::isCustom(input.input.type)) {
305  addCustomType(c, input.input, getter);
306  } else {
307  detail::addInput(c, input.input, getter);
308  }
309  }
310  }
311  }
312  m_writer = getWriter(group, cfg, c);
313 }
314 
315 
317 
318 
319 void IParticleWriter::fill(const std::vector<In_t>& info) {
320  m_writer->fill(info);
321 }
322 
323 
325  m_writer->flush();
326 }
grepfile.info
info
Definition: grepfile.py:38
AllowedVariables::e
e
Definition: AsgElectronSelectorTool.cxx:37
TrigDefs::Group
Group
Properties of a chain group.
Definition: GroupProperties.h:13
python.SystemOfUnits.s
int s
Definition: SystemOfUnits.py:131
python.SystemOfUnits.m
int m
Definition: SystemOfUnits.py:91
IParticle.h
IParticleWriter::m_writer
std::unique_ptr< details::IParticleWriterBase > m_writer
Definition: IParticleWriter.h:33
H5Utils::Compression::STANDARD
@ STANDARD
IParticleWriter::fill
void fill(const std::vector< const xAOD::IParticle * > &)
Definition: IParticleWriter.cxx:319
IParticleWriterConfig::ArrayFormat
ArrayFormat
Definition: IParticleWriterConfig.h:27
max
constexpr double max()
Definition: ap_fixedTest.cxx:33
PrimitiveHelpers.h
IParticleWriter::IParticleWriter
IParticleWriter(H5::Group &output_group, const IParticleWriterConfig &)
Definition: IParticleWriter.cxx:278
Primitive
Definition: Primitive.h:10
xAOD::IParticleContainer
DataVector< IParticle > IParticleContainer
Simple convenience declaration of IParticleContainer.
Definition: xAOD/xAODBase/xAODBase/IParticleContainer.h:32
detail::addInput
void addInput(T &c, const Primitive &input, A a=defaultAccessor< T >)
Definition: PrimitiveHelpers.h:50
SG::ConstAccessor
Helper class to provide constant type-safe access to aux data.
Definition: ConstAccessor.h:55
read_hist_ntuple.t
t
Definition: read_hist_ntuple.py:5
IParticleWriter::flush
void flush()
Definition: IParticleWriter.cxx:324
xAOD::IParticle
Class providing the definition of the 4-vector interface.
Definition: Event/xAOD/xAODBase/xAODBase/IParticle.h:41
details::IParticleWriterBase::~IParticleWriterBase
virtual ~IParticleWriterBase()=default
Writer.h
IParticleWriter.h
H5Utils::Compression::HALF_PRECISION
@ HALF_PRECISION
H5Utils::Compression
Compression
Definition: CompressionEnums.h:11
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
Primitive::Type::PRECISION_CUSTOM
@ PRECISION_CUSTOM
A
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
H5Utils::Writer
Writer.
Definition: Writer.h:349
details::IParticleWriterBase
Definition: IParticleWriter.cxx:154
details
Definition: IParticleWriter.h:21
lumiFormat.i
int i
Definition: lumiFormat.py:85
python.CaloCondTools.g
g
Definition: CaloCondTools.py:15
beamspotman.n
n
Definition: beamspotman.py:731
IParticleWriterConfig
Definition: IParticleWriterConfig.h:21
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
test_pyathena.parent
parent
Definition: test_pyathena.py:15
add
bool add(const std::string &hname, TKey *tobj)
Definition: fastadd.cxx:55
hist_file_dump.f
f
Definition: hist_file_dump.py:135
AnalysisUtils::Delta::R
double R(const INavigable4Momentum *p1, const double v_eta, const double v_phi)
Definition: AnalysisMisc.h:49
BTaggingContainer.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
WriteCaloSwCorrections.cfg
cfg
Definition: WriteCaloSwCorrections.py:23
details::IParticleWriterBase::fill
virtual void fill(const std::vector< const xAOD::IParticle * > &info)=0
python.PyAthena.v
v
Definition: PyAthena.py:154
a
TList * a
Definition: liststreamerinfos.cxx:10
CaloLCW_tf.group
group
Definition: CaloLCW_tf.py:28
detail::isCustom
bool isCustom(const Primitive &p)
Definition: PrimitiveHelpers.cxx:10
IParticleWriterConfig.h
IParticleWriter::~IParticleWriter
~IParticleWriter()
I
#define I(x, y, z)
Definition: MD5.cxx:116
python.compressB64.c
def c
Definition: compressB64.py:93
details::IParticleWriterBase::flush
virtual void flush()=0
match
bool match(std::string s1, std::string s2)
match the individual directories of two strings
Definition: hcg.cxx:356
python.BeamSpotUpdate.compression
compression
Definition: BeamSpotUpdate.py:188