ATLAS Offline Software
Loading...
Searching...
No Matches
PassThroughSaltModel.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include <stdexcept>
7
8namespace FlavorTagInference {
9
10 PassThroughSaltModel::PassThroughSaltModel(const nlohmann::json& config)
11 : m_model_name(config.value("model_name", "PassThrough")) {
12
13 // ── Scalar jet variables ──
14 // Support both "variables" (legacy) and "jet_variables" keys
15 const char* jet_key = config.contains("jet_variables")
16 ? "jet_variables" : "variables";
17
18 if (config.contains(jet_key) && config[jet_key].is_array()) {
19 SaltModelGraphConfig::InputNodeConfig input_node;
20 input_node.name = "jets";
21
22 for (const auto& var : config[jet_key]) {
23 std::string input_name = var.at("input").get<std::string>();
24 std::string output_name = var.at("output").get<std::string>();
25
26 m_jet_input_names.push_back(input_name);
27 m_jet_output_names.push_back(output_name);
28
29 // Graph config: identity normalisation (offset=0, scale=1)
30 input_node.variables.emplace_back(input_name,0.0,1.0);
31
32 // Output config: scalar float (rank 0)
33 m_output_config.emplace_back(
34 output_name,
35 ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
36 0 // rank 0 = scalar float
37 );
38 }
39
40 if (!input_node.variables.empty()) {
41 m_graph_config.inputs.push_back(std::move(input_node));
42 }
43 }
44
45 // ── Constituent variables (tracks, electrons, muons, flows) ──
46 if (config.contains("constituents") && config["constituents"].is_array()) {
47 for (const auto& cnode : config["constituents"]) {
48 std::string node_name = cnode.at("node_name").get<std::string>();
49
50 if (!cnode.contains("variables") || !cnode["variables"].is_array()) {
51 throw std::runtime_error(
52 "PassThroughSaltModel: constituent node '" + node_name
53 + "' must have a 'variables' array");
54 }
55
56 ConstituentNode cn;
57 cn.node_name = node_name;
58
59 // Build graph config input_sequence for this constituent type
60 SaltModelGraphConfig::InputNodeConfig seq_node;
61 seq_node.name = node_name;
62
63 for (const auto& var : cnode["variables"]) {
64 std::string input_name = var.at("input").get<std::string>();
65 std::string output_name = var.at("output").get<std::string>();
66 std::string var_type = var.value("type", "float");
67 float fp16_scale = var.value("scale", 1.0f);
68
69 cn.output_names.push_back(output_name);
70 cn.var_types.push_back(var_type);
71
72 // Graph config: identity normalisation
73 seq_node.variables.emplace_back(input_name, 0.0, 1.0);
74
75 // Output config: vector type based on JSON "type"/"cast" fields.
76 // "cast" is an object {exp: E, man: M} specifying reduced-precision
77 // float32 with E/M mantissa/exponent bits (default E=8, M=7).
78 // Non-float variables use "type" instead ("int" → INT32, "char" → INT8)
79 // and never have "cast".
80 if (var.contains("cast")) {
81 const auto& cast_val = var.at("cast");
82 if (!cast_val.is_object()) {
83 throw std::runtime_error(
84 "PassThroughSaltModel: 'cast' for variable '" + input_name
85 + "' must be an object {exp, man}");
86 }
87 int exp_bits = cast_val.value("exp", 8);
88 int man_bits = cast_val.value("man", 7);
89 if (exp_bits < 2 || exp_bits > 8 || man_bits < 0 || man_bits > 23) {
90 throw std::runtime_error(
91 "PassThroughSaltModel: cast {exp,man} requires 2<=exp<=8 and 0<=man<=23, got exp="
92 + std::to_string(exp_bits) + " man=" + std::to_string(man_bits));
93 }
94 m_output_config.push_back(
95 SaltModelOutput(output_name, SaltModelOutput::OutputType::VECTRUNCFLOAT, fp16_scale, exp_bits, man_bits));
96 } else {
97 ONNXTensorElementDataType onnx_type;
98 if (var_type == "int") {
99 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
100 } else if (var_type == "char") {
101 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
102 } else {
103 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
104 }
105 m_output_config.emplace_back(output_name, onnx_type, 1);
106 }
107 }
108
109 cn.num_vars = cn.output_names.size();
110 if (!cnode.contains("input_key")) {
111 throw std::runtime_error(
112 "PassThroughSaltModel: constituent node '" + node_name
113 + "' must define 'input_key' (e.g. \"tracks\", \"flows\", "
114 "\"hits\", \"electrons\", \"muons\", \"clusters\", \"towers\")");
115 }
116 cn.input_key = cnode.at("input_key").get<std::string>();
117 m_constituent_nodes.push_back(std::move(cn));
118 m_graph_config.input_sequences.push_back(std::move(seq_node));
119 }
120 }
121
122 if (m_jet_input_names.empty() && m_constituent_nodes.empty()) {
123 throw std::runtime_error(
124 "PassThroughSaltModel: JSON config must contain "
125 "'jet_variables'/'variables' and/or 'constituents'");
126 }
127 }
128
130 std::map<std::string, Inputs>& gnn_inputs) const
131 {
132 InferenceOutput output;
133
134 // ── Scalar jet variables ──
135 if (!m_jet_input_names.empty()) {
136 auto it = gnn_inputs.find("jets");
137 if (it == gnn_inputs.end()) {
138 throw std::runtime_error(
139 "PassThroughSaltModel: expected 'jets' in gnn_inputs");
140 }
141
142 const auto& data = it->second.first; // vector<float>
143 const auto& shape = it->second.second; // vector<int64_t>
144
145 if (shape.size() != 2 || shape[0] != 1) {
146 throw std::runtime_error(
147 "PassThroughSaltModel: unexpected jets shape");
148 }
149
150 size_t num_vars = static_cast<size_t>(shape[1]);
151 if (num_vars != m_jet_input_names.size()) {
152 throw std::runtime_error(
153 "PassThroughSaltModel: jets size mismatch: got "
154 + std::to_string(num_vars) + " expected "
155 + std::to_string(m_jet_input_names.size()));
156 }
157
158 for (size_t i = 0; i < num_vars; ++i) {
159 output.singleFloat[m_jet_output_names[i]] = data[i];
160 }
161 }
162
163 // ── Constituent variables ──
164 // The loaders produce flat tensors with shape [N_constituents, N_vars]
165 // in row-major order: element [c, v] = data[c * N_vars + v].
166 // We transpose back to per-variable vectors for the output.
167 for (const auto& cn : m_constituent_nodes) {
168 auto writeEmpty = [&](const ConstituentNode& node) {
169 for (size_t v = 0; v < node.output_names.size(); ++v) {
170 if (node.var_types[v] == "char") {
171 output.vecChar[node.output_names[v]] = {};
172 } else {
173 output.vecFloat[node.output_names[v]] = {};
174 }
175 }
176 };
177
178 auto it = gnn_inputs.find(cn.input_key);
179 if (it == gnn_inputs.end()) {
180 // Config bug: JSON input_key does not match any registered loader's
181 // output_name (set in ConstituentsLoader.cxx, e.g. "tracks", "flows",
182 // "electrons", "muons", "clusters", "towers", "hits" — all plural).
183 // Fail loudly with the list of available keys so the typo is obvious.
184 std::string available;
185 for (const auto& kv : gnn_inputs) {
186 if (!available.empty()) available += ", ";
187 available += "'" + kv.first + "'";
188 }
189 throw std::runtime_error(
190 "PassThroughSaltModel: constituent node '" + cn.node_name
191 + "' has input_key='" + cn.input_key
192 + "' but no loader registered under that key. Available keys: ["
193 + available + "]. Check ConstituentsInputConfig.output_name in "
194 + "ConstituentsLoader.cxx.");
195 }
196
197 const auto& data = it->second.first;
198 const auto& shape = it->second.second;
199
200 // shape = [N_constituents, N_vars]
201 int64_t n_constituents = (shape.size() >= 1) ? shape[0] : 0;
202 int64_t n_vars = (shape.size() >= 2) ? shape[1] : 0;
203
204 if (n_constituents == 0 || n_vars == 0) {
205 writeEmpty(cn);
206 continue;
207 }
208
209 if (static_cast<size_t>(n_vars) != cn.num_vars) {
210 throw std::runtime_error(
211 "PassThroughSaltModel: constituent '" + cn.node_name
212 + "' vars mismatch: got " + std::to_string(n_vars)
213 + " expected " + std::to_string(cn.num_vars));
214 }
215
216 // Transpose: extract column v from the flat row-major tensor.
217 // Float and int variables go to vecFloat (GNN converts int at
218 // decoration time). Char variables go to vecChar for 1-byte
219 // storage (hit counts, quality flags).
220 for (size_t v = 0; v < cn.num_vars; ++v) {
221 if (cn.var_types[v] == "char") {
222 std::vector<char> col(n_constituents);
223 for (int64_t c = 0; c < n_constituents; ++c) {
224 col[c] = static_cast<char>(data[c * n_vars + v]);
225 }
226 output.vecChar[cn.output_names[v]] = std::move(col);
227 } else {
228 std::vector<float> col(n_constituents);
229 for (int64_t c = 0; c < n_constituents; ++c) {
230 col[c] = data[c * n_vars + v];
231 }
232 output.vecFloat[cn.output_names[v]] = std::move(col);
233 }
234 }
235 }
236
237 return output;
238 }
239
244
248
252
253 const std::string& PassThroughSaltModel::getModelName() const {
254 return m_model_name;
255 }
256
257} // namespace FlavorTagInference
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
if(pathvar)
InferenceOutput runInference(std::map< std::string, Inputs > &gnn_inputs) const override
const std::string & getModelName() const override
std::vector< ConstituentNode > m_constituent_nodes
SaltModelGraphConfig::GraphConfig m_graph_config
std::vector< std::string > m_jet_output_names
Scalar jet variables: output names (matching by index).
PassThroughSaltModel(const nlohmann::json &config)
Construct from a JSON config with the format: { "model_name": "PassThrough", "jet_variables": [ {"inp...
const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
const OutputConfig & getOutputConfig() const override
SaltModelVersion getSaltModelVersion() const override
std::vector< std::string > m_jet_input_names
Scalar jet variables: input names (graph config order).
Definition node.h:24
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:36