11 :
m_model_name(config.value(
"model_name",
"PassThrough")) {
15 const char* jet_key = config.contains(
"jet_variables")
16 ?
"jet_variables" :
"variables";
18 if (config.contains(jet_key) && config[jet_key].is_array()) {
19 SaltModelGraphConfig::InputNodeConfig input_node;
20 input_node.name =
"jets";
22 for (const auto& var : config[jet_key]) {
23 std::string input_name = var.at(
"input").get<std::string>();
24 std::string output_name = var.at(
"output").get<std::string>();
26 m_jet_input_names.push_back(input_name);
27 m_jet_output_names.push_back(output_name);
30 input_node.variables.emplace_back(input_name,0.0,1.0);
33 m_output_config.emplace_back(
35 ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
40 if (!input_node.variables.empty()) {
41 m_graph_config.inputs.push_back(std::move(input_node));
46 if (config.contains(
"constituents") && config[
"constituents"].is_array()) {
47 for (const auto& cnode : config[
"constituents"]) {
48 std::string node_name = cnode.at(
"node_name").get<std::string>();
50 if (!cnode.contains(
"variables") || !cnode[
"variables"].is_array()) {
51 throw std::runtime_error(
52 "PassThroughSaltModel: constituent node '" + node_name
53 +
"' must have a 'variables' array");
57 cn.node_name = node_name;
60 SaltModelGraphConfig::InputNodeConfig seq_node;
61 seq_node.name = node_name;
63 for (const auto& var : cnode[
"variables"]) {
64 std::string input_name = var.at(
"input").get<std::string>();
65 std::string output_name = var.at(
"output").get<std::string>();
66 std::string var_type = var.value(
"type",
"float");
67 float fp16_scale = var.value(
"scale", 1.0f);
69 cn.output_names.push_back(output_name);
70 cn.var_types.push_back(var_type);
73 seq_node.variables.emplace_back(input_name, 0.0, 1.0);
80 if (var.contains(
"cast")) {
81 const auto& cast_val = var.at(
"cast");
82 if (!cast_val.is_object()) {
83 throw std::runtime_error(
84 "PassThroughSaltModel: 'cast' for variable '" + input_name
85 +
"' must be an object {exp, man}");
87 int exp_bits = cast_val.value(
"exp", 8);
88 int man_bits = cast_val.value(
"man", 7);
89 if (exp_bits < 2 || exp_bits > 8 || man_bits < 0 || man_bits > 23) {
90 throw std::runtime_error(
91 "PassThroughSaltModel: cast {exp,man} requires 2<=exp<=8 and 0<=man<=23, got exp="
92 + std::to_string(exp_bits) +
" man=" + std::to_string(man_bits));
94 m_output_config.push_back(
95 SaltModelOutput(output_name, SaltModelOutput::OutputType::VECTRUNCFLOAT, fp16_scale, exp_bits, man_bits));
97 ONNXTensorElementDataType onnx_type;
98 if (var_type ==
"int") {
99 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
100 } else if (var_type ==
"char") {
101 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
103 onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
105 m_output_config.emplace_back(output_name, onnx_type, 1);
109 cn.num_vars = cn.output_names.size();
110 if (!cnode.contains(
"input_key")) {
111 throw std::runtime_error(
112 "PassThroughSaltModel: constituent node '" + node_name
113 +
"' must define 'input_key' (e.g. \"tracks\", \"flows\", "
114 "\"hits\", \"electrons\", \"muons\", \"clusters\", \"towers\")");
116 cn.input_key = cnode.at(
"input_key").get<std::string>();
117 m_constituent_nodes.push_back(std::move(cn));
118 m_graph_config.input_sequences.push_back(std::move(seq_node));
122 if (m_jet_input_names.empty() && m_constituent_nodes.empty()) {
123 throw std::runtime_error(
124 "PassThroughSaltModel: JSON config must contain "
125 "'jet_variables'/'variables' and/or 'constituents'");
135 auto it = gnn_inputs.find(
"jets");
136 if (it == gnn_inputs.end()) {
137 throw std::runtime_error(
138 "PassThroughSaltModel: expected 'jets' in gnn_inputs");
141 const auto&
data = it->second.first;
142 const auto& shape = it->second.second;
144 if (shape.size() != 2 || shape[0] != 1) {
145 throw std::runtime_error(
146 "PassThroughSaltModel: unexpected jets shape");
149 size_t num_vars =
static_cast<size_t>(shape[1]);
151 throw std::runtime_error(
152 "PassThroughSaltModel: jets size mismatch: got "
153 + std::to_string(num_vars) +
" expected "
157 for (
size_t i = 0; i < num_vars; ++i) {
168 for (
size_t v = 0; v <
node.output_names.size(); ++v) {
169 if (
node.var_types[v] ==
"char") {
170 output.vecChar[
node.output_names[v]] = {};
172 output.vecFloat[
node.output_names[v]] = {};
177 auto it = gnn_inputs.find(cn.input_key);
178 if (it == gnn_inputs.end()) {
183 std::string available;
184 for (
const auto& kv : gnn_inputs) {
185 if (!available.empty()) available +=
", ";
186 available +=
"'" + kv.first +
"'";
188 throw std::runtime_error(
189 "PassThroughSaltModel: constituent node '" + cn.node_name
190 +
"' has input_key='" + cn.input_key
191 +
"' but no loader registered under that key. Available keys: ["
192 + available +
"]. Check ConstituentsInputConfig.output_name in "
193 +
"ConstituentsLoader.cxx.");
196 const auto&
data = it->second.first;
197 const auto& shape = it->second.second;
200 int64_t n_constituents = (shape.size() >= 1) ? shape[0] : 0;
201 int64_t n_vars = (shape.size() >= 2) ? shape[1] : 0;
203 if (n_constituents == 0 || n_vars == 0) {
208 if (
static_cast<size_t>(n_vars) != cn.num_vars) {
209 throw std::runtime_error(
210 "PassThroughSaltModel: constituent '" + cn.node_name
211 +
"' vars mismatch: got " + std::to_string(n_vars)
212 +
" expected " + std::to_string(cn.num_vars));
219 for (
size_t v = 0; v < cn.num_vars; ++v) {
220 if (cn.var_types[v] ==
"char") {
221 std::vector<char> col(n_constituents);
222 for (int64_t c = 0; c < n_constituents; ++c) {
223 col[c] =
static_cast<char>(
data[c * n_vars + v]);
225 output.vecChar[cn.output_names[v]] = std::move(col);
227 std::vector<float> col(n_constituents);
228 for (int64_t c = 0; c < n_constituents; ++c) {
229 col[c] =
data[c * n_vars + v];
231 output.vecFloat[cn.output_names[v]] = std::move(col);