ATLAS Offline Software
Loading...
Searching...
No Matches
IParticleWriter.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
6
7#include "PrimitiveHelpers.h"
8
10#include "HDF5Utils/Writer.h"
11
12// needed concrete classes for ElementLink access
14
15namespace {
16 template <unsigned int N>
18 using In_t = Writer_t<1>::input_type;
19 using Consumer_t = Writer_t<1>::consumer_type;
20 using CountWriter_t = H5Utils::Writer<0, unsigned char>;
21
22 template <typename A=detail::defaultAccessor_t<Consumer_t>>
23 void addCustomType(Consumer_t& c,
24 const Primitive& p,
26 using I = In_t;
27
28 if (!detail::isCustom(p)) {
29 throw std::logic_error("called addCustomType on non-custom type");
30 }
31 bool force_precision = p.type == Primitive::Type::PRECISION_CUSTOM;
32 const H5Utils::Compression half = force_precision ?
35
36 const std::string s = p.source;
37 const std::string t = p.target;
38
39 // check for match
40 bool m = false;
41
42 auto add = [&c, a, t](auto f, auto compression, float mult = 1.0) {
43 c.add(
44 t,
45 [a, f, mult](I in) -> float {
46 const auto* associated = a(in);
47 if (!associated) return NAN;
48 return f(associated)*mult;
49 },
50 NAN, compression);
51 };
52
53 // these do most of the matching work
54 auto match = [&add, s, &m, half](const std::string& n, auto func) {
55 if (s == n) {
56 add(func, half);
57 m = true;
58 }
59 };
60 // This matches two possible strings: the original string and one
61 // with a "GeV" suffix. The purpose of the suffixed version is to
62 // allow us to store things at half precision (since most things
63 // stored in MeV would overflow half precision floats).
64 auto matchGeV = [&add, s, &m, half](const std::string& n, auto func) {
65 if (s == n) {
66 // Sort of convoluted logic here: if we force higher precision
67 // we set "half" to be full precision. But if someone asks for
68 // full precision in a variable that must be stored that way
69 // anyway, we throw an exception.
71 throw std::logic_error(
72 "asked for a full precision version of a variable that can"
73 " not be stored at half precision: " + s);
74 }
76 m = true;
77 } else if (s == n + "GeV") {
78 add(func, half, 0.001);
79 m = true;
80 }
81 };
82
83 // match cases
84 matchGeV("pt", [](auto in) {return in->pt(); });
85 match("eta", [](auto in) {return in->eta(); });
86 match("phi", [](auto in) {return in->phi(); });
87 matchGeV("px", [](auto in) {return in->p4().Px(); });
88 matchGeV("py", [](auto in) {return in->p4().Py(); });
89 matchGeV("pz", [](auto in) {return in->p4().Pz(); });
90 matchGeV("mass", [](auto in) {return in->m(); });
91
92 if (m) return;
93
94 if (s == "valid") {
95 c.add(s, [](I) {return true; }, false);
96 return;
97 }
98 throw std::logic_error("unknow known custom primitive: " + s);
99 }
100
101
102 // stuff to retrieve associated links
103 template <typename T, typename R=SG::AuxElement>
104 class LinkGetter
105 {
106 public:
107 LinkGetter(std::string name);
108 const R* operator()(In_t in) const;
109 private:
111 const LinkAccessor m_accessor;
112 const std::string m_linkName;
113 };
114 template <typename T, typename R>
115 LinkGetter<T,R>::LinkGetter(std::string name):
116 m_accessor(name),
117 m_linkName(name)
118 {}
119 template <typename T, typename R>
120 const R* LinkGetter<T,R>::operator()(In_t in) const {
121 auto elink = m_accessor(*in);
122 // isDefault should generally indicate that the element link has
123 // been created but not set to point to any specific place. This
124 // is distinct from being created and pointing to a particle that
125 // has been thinned or slimmed away. So it's safe to use as a "not
126 // set" code.
127 if (elink.isDefault()) {
128 return nullptr;
129 }
130 // If the link is _not_ default, but is also invalid, then we're
131 // trying to access something that has been removed. We treat this
132 // as an error, because the behavior becomes dependent on the
133 // format.
134 if (!elink.isValid()) {
135 throw std::runtime_error("invalid link " + m_linkName);
136 }
137 return *elink;
138 }
139
140 CountWriter_t::consumer_type getOffsetConsumer()
141 {
142 using CountIn_t = CountWriter_t::input_type;
143 CountWriter_t::consumer_type c;
144 c.add("count", [](CountIn_t i) { return i; });
145 return c;
146 }
147
148}
149
150
151// IParticleWriter implementaitons
152namespace details {
154 {
155 public:
156 virtual ~IParticleWriterBase() = default;
157 virtual void fill(const std::vector<const xAOD::IParticle*>& info) = 0;
158 virtual void flush() = 0;
159 };
160}
161namespace {
162 // implementation for writer for 2d arrays
163 class IParticle2dWriter: public details::IParticleWriterBase
164 {
165 private:
166 Writer_t<1> m_writer;
167 public:
168 IParticle2dWriter(H5::Group& group,
169 const std::string& n,
170 Consumer_t c,
171 long long unsigned size):
172 m_writer(group, n, c, {{size}}) {}
173 ~IParticle2dWriter() = default;
174
175 void fill(const std::vector<In_t>& v) override {
176 m_writer.fill(v);
177 }
178 void flush() override {
179 m_writer.flush();
180 }
181 };
182
183 // implementation for writer for awkward arrays
184 class IParticleAwkwardWriter: public details::IParticleWriterBase
185 {
186 private:
187 H5::Group m_group;
188 Writer_t<0> m_writer;
189 CountWriter_t m_counts;
190 public:
191 IParticleAwkwardWriter(H5::Group& parent,
192 const std::string& n,
193 Consumer_t c):
194 m_group(parent.createGroup(n)),
195 m_writer(m_group, "raw", c),
196 m_counts(m_group, "counts", getOffsetConsumer())
197 {}
198 ~IParticleAwkwardWriter() = default;
199
200 void fill(const std::vector<In_t>& v) override {
201 using Count_t = CountWriter_t::input_type;
202 constexpr auto max = std::numeric_limits<Count_t>::max();
203 auto n_entries = v.size();
204 if (n_entries > max) {
205 throw std::overflow_error(
206 "number of entries exceeds maximum for this datatype "
207 "[" + std::to_string(n_entries) + " > " + std::to_string(max) + "]"
208 );
209 }
210 for (const auto& e: v) m_writer.fill(e);
211 m_counts.fill(n_entries);
212 }
213 void flush() override {
214 m_writer.flush();
215 m_counts.flush();
216 }
217 };
218
219
220 // implementation for writer for awkward arrays
221 class IParticleFlatWriter: public details::IParticleWriterBase
222 {
223 private:
224 H5::Group m_group;
225 Writer_t<0> m_writer;
226 public:
227 IParticleFlatWriter(H5::Group& parent,
228 const std::string& n,
229 Consumer_t c):
230 m_writer(parent, n, c)
231 {}
232 ~IParticleFlatWriter() = default;
233
234 void fill(const std::vector<In_t>& v) override {
235 for (const auto& e: v) m_writer.fill(e);
236 }
237 void flush() override {
238 m_writer.flush();
239 }
240 };
241
242 std::unique_ptr<details::IParticleWriterBase> getWriter(
243 H5::Group& g,
244 const IParticleWriterConfig& cfg,
245 Consumer_t c) {
247 switch (cfg.format) {
248 case af::AWKWARD: {
249 if (cfg.maximum_size != 0) {
250 throw std::domain_error(
251 "no maximum_size should be specified for awkward arrays");
252 }
253 return std::make_unique<IParticleAwkwardWriter>(g, cfg.name, c);
254 }
255 case af::FLAT: {
256 if (cfg.maximum_size != 0) {
257 throw std::domain_error(
258 "no maximum_size should be specified for flat arrays");
259 }
260 return std::make_unique<IParticleFlatWriter>(g, cfg.name, c);
261 }
262 case af::PADDED: {
263 if (cfg.maximum_size == 0) {
264 throw std::domain_error(
265 "maximum_size should be specified for padded 2d arrays");
266 }
267 return std::make_unique<IParticle2dWriter>(
268 g, cfg.name, c, cfg.maximum_size);
269 }
270 default: {
271 throw std::domain_error("unknown array format");
272 }
273 }
274 }
275}
276
277
279 H5::Group& group,
280 const IParticleWriterConfig& cfg)
281{
282 using IPC = xAOD::IParticleContainer;
283 using IP = xAOD::IParticle;
284 Consumer_t c;
285 for (const auto& input: cfg.inputs) {
286 if (input.link_name.empty()) {
287 const auto& primitive = input.input;
288 if (detail::isCustom(primitive.type)) {
289 addCustomType(c, primitive);
290 } else {
291 detail::addInput(c, primitive);
292 }
293 } else {
294 // else we have some association to follow
295 std::string n = input.link_name;
296 // unfortunately we need a special case for b-tagging
297 if (n == "btaggingLink") {
298 LinkGetter<xAOD::BTaggingContainer> getter(n);
299 detail::addInput(c, input.input, getter);
300 } else {
301 // everything else is an IParticle
302 LinkGetter<IPC,IP> getter(n);
303 if (detail::isCustom(input.input.type)) {
304 addCustomType(c, input.input, getter);
305 } else {
306 detail::addInput(c, input.input, getter);
307 }
308 }
309 }
310 }
311 m_writer = getWriter(group, cfg, c);
312}
313
314
316
317
318void IParticleWriter::fill(const std::vector<In_t>& info) {
319 m_writer->fill(info);
320}
321
322
324 m_writer->flush();
325}
static Double_t a
#define I(x, y, z)
Definition MD5.cxx:116
#define max(a, b)
Definition cfImp.cxx:41
Writer.
Definition Writer.h:350
void fill(const std::vector< const xAOD::IParticle * > &)
std::unique_ptr< details::IParticleWriterBase > m_writer
IParticleWriter(H5::Group &output_group, const IParticleWriterConfig &)
SG::ConstAccessor< T, ALLOC > ConstAccessor
Definition AuxElement.h:569
virtual ~IParticleWriterBase()=default
virtual void fill(const std::vector< const xAOD::IParticle * > &info)=0
Class providing the definition of the 4-vector interface.
bool add(const std::string &hname, TKey *tobj)
Definition fastadd.cxx:55
bool match(std::string s1, std::string s2)
match the individual directories of two strings
Definition hcg.cxx:357
double R(const INavigable4Momentum *p1, const double v_eta, const double v_phi)
void addInput(T &c, const Primitive &input, A a=defaultAccessor< T >)
bool isCustom(const Primitive &p)
auto defaultAccessor
DataVector< IParticle > IParticleContainer
Simple convenience declaration of IParticleContainer.
hold the test vectors and ease the comparison
void fill(H5::Group &out_file, size_t iterations)