ATLAS Offline Software
Loading...
Searching...
No Matches
PassThroughSaltModel.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include <stdexcept>
7
8namespace FlavorTagInference {
9
10 PassThroughSaltModel::PassThroughSaltModel(const nlohmann::json& config)
11 : m_model_name(config.value("model_name", "PassThrough")) {
12
13 // ── Scalar jet variables ──
14 // Support both "variables" (legacy) and "jet_variables" keys
15 const char* jet_key = config.contains("jet_variables")
16 ? "jet_variables" : "variables";
17
18 if (config.contains(jet_key) && config[jet_key].is_array()) {
19 SaltModelGraphConfig::InputNodeConfig input_node;
20 input_node.name = "jets";
21
22 for (const auto& var : config[jet_key]) {
23 std::string input_name = var.at("input").get<std::string>();
24 std::string output_name = var.at("output").get<std::string>();
25
26 m_jet_input_names.push_back(input_name);
27 m_jet_output_names.push_back(output_name);
28
29 // Graph config: identity normalisation (offset=0, scale=1)
30 input_node.variables.emplace_back(input_name,0.0,1.0);
31
32 // Output config: scalar float (rank 0)
33 m_output_config.emplace_back(
34 output_name,
35 ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
36 0 // rank 0 = scalar float
37 );
38 }
39
40 if (!input_node.variables.empty()) {
41 m_graph_config.inputs.push_back(std::move(input_node));
42 }
43 }
44
45 // ── Constituent variables (tracks, electrons, muons, flows) ──
46 if (config.contains("constituents") && config["constituents"].is_array()) {
47 for (const auto& cnode : config["constituents"]) {
48 std::string node_name = cnode.at("node_name").get<std::string>();
49
50 if (!cnode.contains("variables") || !cnode["variables"].is_array()) {
51 throw std::runtime_error(
52 "PassThroughSaltModel: constituent node '" + node_name
53 + "' must have a 'variables' array");
54 }
55
56 ConstituentNode cn;
57 cn.node_name = node_name;
58
59 // Build graph config input_sequence for this constituent type
60 SaltModelGraphConfig::InputNodeConfig seq_node;
61 seq_node.name = node_name;
62
63 for (const auto& var : cnode["variables"]) {
64 std::string input_name = var.at("input").get<std::string>();
65 std::string output_name = var.at("output").get<std::string>();
66 std::string var_type = var.value("type", "float");
67 float fp16_scale = var.value("scale", 1.0f);
68
69 cn.output_names.push_back(output_name);
70 cn.var_types.push_back(var_type);
71
72 // Graph config: identity normalisation
73 seq_node.variables.emplace_back(input_name, 0.0, 1.0);
74
75 // Output config: vector type based on JSON "type"/"cast" fields.
76 // "cast" is an object {exp: E, man: M} specifying reduced-precision
77 // float32 with E/M mantissa/exponent bits (default E=8, M=7).
78 // Non-float variables use "type" instead ("int" → INT32, "char" → INT8)
79 // and never have "cast".
80 if (var.contains("cast")) {
81 const auto& cast_val = var.at("cast");
82 if (!cast_val.is_object()) {
83 throw std::runtime_error(
84 "PassThroughSaltModel: 'cast' for variable '" + input_name
85 + "' must be an object {exp, man}");
86 }
87 int exp_bits = cast_val.value("exp", 8);
88 int man_bits = cast_val.value("man", 7);
89 if (exp_bits < 2 || exp_bits > 8 || man_bits < 0 || man_bits > 23) {
90 throw std::runtime_error(
91 "PassThroughSaltModel: cast {exp,man} requires 2<=exp<=8 and 0<=man<=23, got exp="
92 + std::to_string(exp_bits) + " man=" + std::to_string(man_bits));
93 }
94 m_output_config.push_back(
95 SaltModelOutput(output_name, SaltModelOutput::OutputType::VECTRUNCFLOAT, fp16_scale, exp_bits, man_bits));
96 } else {
97 ONNXTensorElementDataType onnx_type;
98 if (var_type == "int") {
99 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
100 } else if (var_type == "char") {
101 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
102 } else {
103 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
104 }
105 m_output_config.emplace_back(output_name, onnx_type, 1);
106 }
107 }
108
109 cn.num_vars = cn.output_names.size();
110 if (!cnode.contains("input_key")) {
111 throw std::runtime_error(
112 "PassThroughSaltModel: constituent node '" + node_name
113 + "' must define 'input_key' (e.g. \"tracks\", \"flows\", "
114 "\"hits\", \"electrons\", \"muons\", \"clusters\", \"towers\")");
115 }
116 cn.input_key = cnode.at("input_key").get<std::string>();
117 m_constituent_nodes.push_back(std::move(cn));
118 m_graph_config.input_sequences.push_back(std::move(seq_node));
119 }
120 }
121
122 if (m_jet_input_names.empty() && m_constituent_nodes.empty()) {
123 throw std::runtime_error(
124 "PassThroughSaltModel: JSON config must contain "
125 "'jet_variables'/'variables' and/or 'constituents'");
126 }
127 }
128
130 {
131 InferenceOutput output;
132
133 // ── Scalar jet variables ──
134 if (!m_jet_input_names.empty()) {
135 auto it = gnn_inputs.find("jets");
136 if (it == gnn_inputs.end()) {
137 throw std::runtime_error(
138 "PassThroughSaltModel: expected 'jets' in gnn_inputs");
139 }
140
141 const auto& data = it->second.first; // vector<float>
142 const auto& shape = it->second.second; // vector<int64_t>
143
144 if (shape.size() != 2 || shape[0] != 1) {
145 throw std::runtime_error(
146 "PassThroughSaltModel: unexpected jets shape");
147 }
148
149 size_t num_vars = static_cast<size_t>(shape[1]);
150 if (num_vars != m_jet_input_names.size()) {
151 throw std::runtime_error(
152 "PassThroughSaltModel: jets size mismatch: got "
153 + std::to_string(num_vars) + " expected "
154 + std::to_string(m_jet_input_names.size()));
155 }
156
157 for (size_t i = 0; i < num_vars; ++i) {
158 output.singleFloat[m_jet_output_names[i]] = data[i];
159 }
160 }
161
162 // ── Constituent variables ──
163 // The loaders produce flat tensors with shape [N_constituents, N_vars]
164 // in row-major order: element [c, v] = data[c * N_vars + v].
165 // We transpose back to per-variable vectors for the output.
166 for (const auto& cn : m_constituent_nodes) {
167 auto writeEmpty = [&](const ConstituentNode& node) {
168 for (size_t v = 0; v < node.output_names.size(); ++v) {
169 if (node.var_types[v] == "char") {
170 output.vecChar[node.output_names[v]] = {};
171 } else {
172 output.vecFloat[node.output_names[v]] = {};
173 }
174 }
175 };
176
177 auto it = gnn_inputs.find(cn.input_key);
178 if (it == gnn_inputs.end()) {
179 // Config bug: JSON input_key does not match any registered loader's
180 // output_name (set in ConstituentsLoader.cxx, e.g. "tracks", "flows",
181 // "electrons", "muons", "clusters", "towers", "hits" — all plural).
182 // Fail loudly with the list of available keys so the typo is obvious.
183 std::string available;
184 for (const auto& kv : gnn_inputs) {
185 if (!available.empty()) available += ", ";
186 available += "'" + kv.first + "'";
187 }
188 throw std::runtime_error(
189 "PassThroughSaltModel: constituent node '" + cn.node_name
190 + "' has input_key='" + cn.input_key
191 + "' but no loader registered under that key. Available keys: ["
192 + available + "]. Check ConstituentsInputConfig.output_name in "
193 + "ConstituentsLoader.cxx.");
194 }
195
196 const auto& data = it->second.first;
197 const auto& shape = it->second.second;
198
199 // shape = [N_constituents, N_vars]
200 int64_t n_constituents = (shape.size() >= 1) ? shape[0] : 0;
201 int64_t n_vars = (shape.size() >= 2) ? shape[1] : 0;
202
203 if (n_constituents == 0 || n_vars == 0) {
204 writeEmpty(cn);
205 continue;
206 }
207
208 if (static_cast<size_t>(n_vars) != cn.num_vars) {
209 throw std::runtime_error(
210 "PassThroughSaltModel: constituent '" + cn.node_name
211 + "' vars mismatch: got " + std::to_string(n_vars)
212 + " expected " + std::to_string(cn.num_vars));
213 }
214
215 // Transpose: extract column v from the flat row-major tensor.
216 // Float and int variables go to vecFloat (GNN converts int at
217 // decoration time). Char variables go to vecChar for 1-byte
218 // storage (hit counts, quality flags).
219 for (size_t v = 0; v < cn.num_vars; ++v) {
220 if (cn.var_types[v] == "char") {
221 std::vector<char> col(n_constituents);
222 for (int64_t c = 0; c < n_constituents; ++c) {
223 col[c] = static_cast<char>(data[c * n_vars + v]);
224 }
225 output.vecChar[cn.output_names[v]] = std::move(col);
226 } else {
227 std::vector<float> col(n_constituents);
228 for (int64_t c = 0; c < n_constituents; ++c) {
229 col[c] = data[c * n_vars + v];
230 }
231 output.vecFloat[cn.output_names[v]] = std::move(col);
232 }
233 }
234 }
235
236 return output;
237 }
238
243
247
251
252 const std::string& PassThroughSaltModel::getModelName() const {
253 return m_model_name;
254 }
255
256} // namespace FlavorTagInference
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
if(pathvar)
const std::string & getModelName() const override
std::vector< ConstituentNode > m_constituent_nodes
SaltModelGraphConfig::GraphConfig m_graph_config
std::vector< std::string > m_jet_output_names
Scalar jet variables: output names (matching by index).
PassThroughSaltModel(const nlohmann::json &config)
Construct from a JSON config with the format: { "model_name": "PassThrough", "jet_variables": [ {"inp...
const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
InferenceOutput runInference(InputMap &gnn_inputs) const override
const OutputConfig & getOutputConfig() const override
SaltModelVersion getSaltModelVersion() const override
std::vector< std::string > m_jet_input_names
Scalar jet variables: input names (graph config order).
Definition node.h:24
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:38
std::map< std::string, Inputs, std::less<> > InputMap
Definition ISaltModel.h:37