11 :
m_model_name(config.value(
"model_name",
"PassThrough")) {
15 const char* jet_key = config.contains(
"jet_variables")
16 ?
"jet_variables" :
"variables";
18 if (config.contains(jet_key) && config[jet_key].is_array()) {
19 SaltModelGraphConfig::InputNodeConfig input_node;
20 input_node.name =
"jets";
22 for (const auto& var : config[jet_key]) {
23 std::string input_name = var.at(
"input").get<std::string>();
24 std::string output_name = var.at(
"output").get<std::string>();
26 m_jet_input_names.push_back(input_name);
27 m_jet_output_names.push_back(output_name);
30 input_node.variables.emplace_back(input_name,0.0,1.0);
33 m_output_config.emplace_back(
35 ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
40 if (!input_node.variables.empty()) {
41 m_graph_config.inputs.push_back(std::move(input_node));
46 if (config.contains(
"constituents") && config[
"constituents"].is_array()) {
47 for (const auto& cnode : config[
"constituents"]) {
48 std::string node_name = cnode.at(
"node_name").get<std::string>();
50 if (!cnode.contains(
"variables") || !cnode[
"variables"].is_array()) {
51 throw std::runtime_error(
52 "PassThroughSaltModel: constituent node '" + node_name
53 +
"' must have a 'variables' array");
57 cn.node_name = node_name;
60 SaltModelGraphConfig::InputNodeConfig seq_node;
61 seq_node.name = node_name;
63 for (const auto& var : cnode[
"variables"]) {
64 std::string input_name = var.at(
"input").get<std::string>();
65 std::string output_name = var.at(
"output").get<std::string>();
66 std::string var_type = var.value(
"type",
"float");
67 float fp16_scale = var.value(
"scale", 1.0f);
69 cn.output_names.push_back(output_name);
70 cn.var_types.push_back(var_type);
73 seq_node.variables.emplace_back(input_name, 0.0, 1.0);
80 if (var.contains(
"cast")) {
81 const auto& cast_val = var.at(
"cast");
82 if (!cast_val.is_object()) {
83 throw std::runtime_error(
84 "PassThroughSaltModel: 'cast' for variable '" + input_name
85 +
"' must be an object {exp, man}");
87 int exp_bits = cast_val.value(
"exp", 8);
88 int man_bits = cast_val.value(
"man", 7);
89 if (exp_bits < 2 || exp_bits > 8 || man_bits < 0 || man_bits > 23) {
90 throw std::runtime_error(
91 "PassThroughSaltModel: cast {exp,man} requires 2<=exp<=8 and 0<=man<=23, got exp="
92 + std::to_string(exp_bits) +
" man=" + std::to_string(man_bits));
94 m_output_config.push_back(
95 SaltModelOutput(output_name, SaltModelOutput::OutputType::VECTRUNCFLOAT, fp16_scale, exp_bits, man_bits));
97 ONNXTensorElementDataType onnx_type;
98 if (var_type ==
"int") {
99 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
100 } else if (var_type ==
"char") {
101 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
103 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
105 m_output_config.emplace_back(output_name, onnx_type, 1);
109 cn.num_vars = cn.output_names.size();
110 if (!cnode.contains(
"input_key")) {
111 throw std::runtime_error(
112 "PassThroughSaltModel: constituent node '" + node_name
113 +
"' must define 'input_key' (e.g. \"tracks\", \"flows\", "
114 "\"hits\", \"electrons\", \"muons\", \"clusters\", \"towers\")");
116 cn.input_key = cnode.at(
"input_key").get<std::string>();
117 m_constituent_nodes.push_back(std::move(cn));
118 m_graph_config.input_sequences.push_back(std::move(seq_node));
122 if (m_jet_input_names.empty() && m_constituent_nodes.empty()) {
123 throw std::runtime_error(
124 "PassThroughSaltModel: JSON config must contain "
125 "'jet_variables'/'variables' and/or 'constituents'");
130 std::map<std::string, Inputs>& gnn_inputs)
const
136 auto it = gnn_inputs.find(
"jets");
137 if (it == gnn_inputs.end()) {
138 throw std::runtime_error(
139 "PassThroughSaltModel: expected 'jets' in gnn_inputs");
142 const auto&
data = it->second.first;
143 const auto& shape = it->second.second;
145 if (shape.size() != 2 || shape[0] != 1) {
146 throw std::runtime_error(
147 "PassThroughSaltModel: unexpected jets shape");
150 size_t num_vars =
static_cast<size_t>(shape[1]);
152 throw std::runtime_error(
153 "PassThroughSaltModel: jets size mismatch: got "
154 + std::to_string(num_vars) +
" expected "
158 for (
size_t i = 0; i < num_vars; ++i) {
169 for (
size_t v = 0; v <
node.output_names.size(); ++v) {
170 if (
node.var_types[v] ==
"char") {
171 output.vecChar[
node.output_names[v]] = {};
173 output.vecFloat[
node.output_names[v]] = {};
178 auto it = gnn_inputs.find(cn.input_key);
179 if (it == gnn_inputs.end()) {
184 std::string available;
185 for (
const auto& kv : gnn_inputs) {
186 if (!available.empty()) available +=
", ";
187 available +=
"'" + kv.first +
"'";
189 throw std::runtime_error(
190 "PassThroughSaltModel: constituent node '" + cn.node_name
191 +
"' has input_key='" + cn.input_key
192 +
"' but no loader registered under that key. Available keys: ["
193 + available +
"]. Check ConstituentsInputConfig.output_name in "
194 +
"ConstituentsLoader.cxx.");
197 const auto&
data = it->second.first;
198 const auto& shape = it->second.second;
201 int64_t n_constituents = (shape.size() >= 1) ? shape[0] : 0;
202 int64_t n_vars = (shape.size() >= 2) ? shape[1] : 0;
204 if (n_constituents == 0 || n_vars == 0) {
209 if (
static_cast<size_t>(n_vars) != cn.num_vars) {
210 throw std::runtime_error(
211 "PassThroughSaltModel: constituent '" + cn.node_name
212 +
"' vars mismatch: got " + std::to_string(n_vars)
213 +
" expected " + std::to_string(cn.num_vars));
220 for (
size_t v = 0; v < cn.num_vars; ++v) {
221 if (cn.var_types[v] ==
"char") {
222 std::vector<char> col(n_constituents);
223 for (int64_t c = 0; c < n_constituents; ++c) {
224 col[c] =
static_cast<char>(
data[c * n_vars + v]);
226 output.vecChar[cn.output_names[v]] = std::move(col);
228 std::vector<float> col(n_constituents);
229 for (int64_t c = 0; c < n_constituents; ++c) {
230 col[c] =
data[c * n_vars + v];
232 output.vecFloat[cn.output_names[v]] = std::move(col);