16 #include "lwtnn/LightweightGraph.hh"
17 #include "lwtnn/parse_json.hh"
31 BTagPairGetter(
const std::string&
key);
32 std::pair<std::string, double> operator()(
const xAOD::Jet&
jet);
38 using Pg = std::function<std::pair<std::string, double>(
const xAOD::Jet&)>;
39 Pg makePairGetter(
const std::string&
key);
41 void requireOverwrite(std::map<std::string, double>&
target,
42 const std::pair<std::string, double>&
value);
50 m_parent_link(
"Parent"),
51 m_subjet_link_getter(
config.subjet_link_name),
53 m_min_subjet_pt(
config.min_subjet_pt)
55 namespace fs = std::filesystem;
60 if (nn_path.empty()) {
61 throw std::runtime_error(
62 "no file found at '" +
config.input_file_path.string() +
"'");
65 std::ifstream input_stream(nn_path.string());
67 m_graph.reset(
new lwt::LightweightGraph(graph_cfg));
72 for (
const std::string&
key:
keys.fatjet) {
75 for (
const std::string&
key:
keys.subjet) {
82 for (
const auto&
output: graph_cfg.outputs) {
83 const std::string& node_name =
output.first;
84 const lwt::OutputNodeConfig&
node =
output.second;
87 std::string write_name = node_name +
"_" +
varname;
88 node_writer.emplace_back(
varname, write_name);
90 m_outputs.emplace_back(node_name, node_writer);
97 namespace hk = hbb_key;
103 std::vector<const xAOD::IParticle*> subjets;
105 if (!
parent)
throw std::runtime_error(
"can't resolve parent jet");
108 if (!
subjet)
throw std::runtime_error(
"can't resolve subjet link");
110 subjets.push_back(
subjet);
113 std::sort(subjets.begin(), subjets.end(),
114 [](
auto*
a,
auto*
b) { return a->pt() > b->pt(); });
117 for (
size_t jet_n = 0; jet_n < n_jets; jet_n++) {
118 const auto*
subjet =
dynamic_cast<const xAOD::Jet*
>(subjets.at(jet_n));
119 if (!
subjet)
throw std::runtime_error(
"IParticle is not a Jet");
122 requireOverwrite(
inputs.at(subjet_name),getter(*
subjet));
129 for (
const auto& var_writer:
node.second) {
130 var_writer.second(
jet) =
result.at(var_writer.first);
141 template <
typename T>
142 BTagPairGetter<T>::BTagPairGetter(
const std::string&
key):
146 template <
typename T>
147 std::pair<std::string, double>
150 if (!btag)
throw std::runtime_error(
"can't find btagging object");
151 return {m_key, m_getter(*btag)};
155 Pg makePairGetter(
const std::string&
key) {
158 return [](
const xAOD::Jet& j) -> std::pair<std::string, double> {
162 return [](
const xAOD::Jet& j) -> std::pair<std::string, double> {
168 return BTagPairGetter<float>(
key);
172 void requireOverwrite(std::map<std::string, double>&
target,
173 const std::pair<std::string, double>&
value) {
175 if (itr ==
target.end()) {
176 throw std::logic_error(
"can't fine a default value for " +
value.first);
178 itr->second =
value.second;