23 return StatusCode::SUCCESS;
28 Ort::SessionOptions session_options;
29 Ort::AllocatorWithDefaultOptions allocator;
30 session_options.SetIntraOpNumThreads(1);
31 session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
32 m_session = std::make_unique<Ort::Session>(
m_svc->env(), path.c_str(), session_options);
34 ATH_MSG_INFO(
"Created ONNX runtime session with model " << path);
36 size_t num_input_nodes = m_session->GetInputCount();
39 for (std::size_t i = 0; i < num_input_nodes; i++) {
41 char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
43 <<
" name= " << input_name);
46 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
47 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
48 ONNXTensorElementDataType
type = tensor_info.GetElementType();
50 <<
" type= " <<
type);
63 std::vector<int64_t> output_node_dims;
64 size_t num_output_nodes = m_session->GetOutputCount();
68 for (std::size_t i = 0; i < num_output_nodes; i++) {
70 char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
72 <<
" name= " << output_name);
75 Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
76 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
77 ONNXTensorElementDataType
type = tensor_info.GetElementType();
79 <<
" type= " <<
type);
82 output_node_dims = tensor_info.GetShape();
83 ATH_MSG_INFO(
"Output " << i <<
" : num_dims= " << output_node_dims.size());
84 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
85 if (output_node_dims[j] < 0) output_node_dims[j] = 1;
86 ATH_MSG_INFO(
"Output" << i <<
" : dim " << j <<
"= " << output_node_dims[j]);
90 return StatusCode::SUCCESS;
134 constexpr std::array<int,19> calo_numbers{1,2,3,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20};
135 constexpr std::array<int,12> fixed_r_numbers = {1,2,3,12,13,14,15,16,17,18,19,20};
136 constexpr std::array<double,12> fixed_r_vals = {1532.18, 1723.89, 1923.02, 2450.00, 2995.00, 3630.00, 3215.00,
137 3630.00, 2246.50, 2450.00, 2870.00, 3480.00
139 constexpr std::array<int, 7> fixed_z_numbers = {5,6,7,8,9,10,11};
140 constexpr std::array<double, 7> fixed_z_vals = {3790.03, 3983.68, 4195.84, 4461.25, 4869.50, 5424.50, 5905.00};
141 std::unordered_map<int, double> r_calo_dict;
142 std::unordered_map<int, double> z_calo_dict;
143 for(
size_t i=0; i<fixed_r_vals.size(); i++) r_calo_dict[fixed_r_numbers[i]] = fixed_r_vals[i];
144 for(
size_t i=0; i<fixed_z_numbers.size(); i++) z_calo_dict[fixed_z_numbers[i]] = fixed_z_vals[i];
146 std::vector<float> inputnn;
147 inputnn.assign(5430, 0.0);
148 std::vector<eflowRecCluster*> matchedClusters;
149 const std::vector<eflowTrackClusterLink*>& links = ptr->getClusterMatches();
151 std::array<double, 19> etatotal =
getEtaTrackCalo(ptr->getTrackCaloPoints());
152 std::array<double, 19> phitotal =
getPhiTrackCalo(ptr->getTrackCaloPoints());
154 const std::array<double, 2> track{ptr->getTrack()->eta(), ptr->getTrack()->phi()};
156 for(
auto *clink : links){
157 auto *cell = clink->getCluster()->getCluster();
158 float clusterE = cell->e()*1e-3;
159 float clusterEta = cell->eta();
161 if (clusterE < 0.0 || clusterE > 1e4f || std::abs(clusterEta) > 2.5)
continue;
163 constexpr bool cutOnR =
false;
165 std::array<double, 2> p{clink->getCluster()->getCluster()->eta(), clink->getCluster()->getCluster()->phi()};
166 double part1 = p[0] - track[0];
167 double part2 = p[1] - track[1];
173 if(R >= 1.2)
continue;
176 matchedClusters.push_back(clink->getCluster());
179 std::vector<std::array<double, 5>> cells;
182 bool trk_bool_em[2] = {
false,
false};
187 for(
int i =0; i<2; i++) {
188 trk_bool_em[i] = std::abs(trk_em_eta[i]) < 2.5 && std::abs(trk_em_phi[i]) <=
M_PI;
190 int nProj_em = (int)trk_bool_em[0] + (
int)trk_bool_em[1];
193 eta_ctr = trk_bool_em[0] ? trk_em_eta[0] : trk_em_eta[1];
194 phi_ctr = trk_bool_em[0] ? trk_em_phi[0] : trk_em_phi[1];
195 }
else if(nProj_em==2) {
196 eta_ctr = (trk_em_eta[0] + trk_em_eta[1]) / 2.0;
197 phi_ctr = (trk_em_phi[0] + trk_em_phi[1]) / 2.0;
199 eta_ctr = ptr->getTrack()->eta();
200 phi_ctr = ptr->getTrack()->phi();
205 for(
auto *cptr : matchedClusters){
206 auto *clustlink = cptr->getCluster();
208 for(
auto it_cell = clustlink->cell_begin(); it_cell != clustlink->cell_end(); it_cell++){
210 float cellE = cell->e()*(it_cell.weight())*1e-3f;
211 if(cellE < 0.005)
continue;
212 const auto *theDDE=it_cell->caloDDE();
213 double cx=theDDE->x();
214 double cy=theDDE->y();
216 cells.emplace_back( std::array<double, 5> { cellE,
217 theDDE->eta() - eta_ctr,
218 theDDE->phi() - phi_ctr,
225 std::vector<bool> trk_bool(calo_numbers.size(),
false);
226 std::vector<std::array<double,4>> trk_full(calo_numbers.size());
227 for(
size_t j=0; j<phitotal.size(); j++) {
228 int cnum = calo_numbers[j];
229 double eta = etatotal[j];
230 double phi = phitotal[j];
231 if(std::abs(
eta) < 2.5 && std::abs(
phi) <=
M_PI) {
233 trk_full[j][0] =
eta;
234 trk_full[j][1] =
phi;
235 trk_full[j][3] = cnum;
236 double rPerp =-99999;
237 if(
auto itr = r_calo_dict.find(cnum); itr != r_calo_dict.end()) rPerp = itr->second;
238 else if(
auto itr = z_calo_dict.find(cnum); itr != z_calo_dict.end())
240 double z = itr->second;
242 double aeta = std::abs(
eta);
243 rPerp =
z*2.*std::exp(aeta)/(std::exp(2.0*aeta)-1.0);
246 throw std::runtime_error(
"Calo sample num not found in dicts..");
248 trk_full[j][2] = rPerp;
250 trk_full[j].fill(0.0);
253 double trackP = std::abs(1. / ptr->getTrack()->qOverP()) * 1e-3;
254 int trk_proj_num = std::accumulate(trk_bool.begin(), trk_bool.end(), 0);
255 if(trk_proj_num ==0) {
257 std::array<double,5> trk_arr{};
260 trk_arr[1] = ptr->getTrack()->eta() - eta_ctr;
261 trk_arr[2] = ptr->getTrack()->phi() - phi_ctr;
262 trk_arr[3] = 1532.18;
265 cells.emplace_back(trk_arr);
267 for(
size_t i =0; i<calo_numbers.size(); i++) {
268 if(!trk_bool[i])
continue;
269 std::array<double,5> trk_arr{};
270 trk_arr[0]= trackP/double(trk_proj_num);
271 trk_arr[1]= trk_full[i][0] - eta_ctr;
272 trk_arr[2]= trk_full[i][1] - phi_ctr;
273 trk_arr[3]= trk_full[i][2];
276 cells.emplace_back(trk_arr);
281 for(
auto &in : cells){
282 std::copy(in.begin(), in.end(), inputnn.begin() +
index);
284 if(
index >=
static_cast<int>(inputnn.size()-4)) {
295 return predictedEnergy;