10 #include "CaloGeoHelpers/CaloSampling.h" 
   23       return StatusCode::SUCCESS;
 
   28    Ort::SessionOptions session_options;
 
   29    Ort::AllocatorWithDefaultOptions allocator;
 
   30    session_options.SetIntraOpNumThreads(1);
 
   31    session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
 
   32    m_session = std::make_unique<Ort::Session>(
m_svc->env(), 
path.c_str(), session_options);
 
   36     size_t num_input_nodes = m_session->GetInputCount();
 
   39     for (std::size_t 
i = 0; 
i < num_input_nodes; 
i++) {
 
   41         char* input_name = m_session->GetInputNameAllocated(
i, allocator).release();
 
   43                               << 
" name= " << input_name);
 
   46         Ort::TypeInfo type_info = m_session->GetInputTypeInfo(
i);
 
   47         auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
 
   48         ONNXTensorElementDataType 
type = tensor_info.GetElementType();
 
   50                               << 
" type= " << 
type);
 
   63     std::vector<int64_t> output_node_dims;
 
   64     size_t num_output_nodes = m_session->GetOutputCount();
 
   68     for (std::size_t 
i = 0; 
i < num_output_nodes; 
i++) {
 
   70         char* output_name = m_session->GetOutputNameAllocated(
i, allocator).release();
 
   72                                << 
" name= " << output_name);
 
   75         Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(
i);
 
   76         auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
 
   77         ONNXTensorElementDataType 
type = tensor_info.GetElementType();
 
   79                                << 
" type= " << 
type);
 
   82         output_node_dims = tensor_info.GetShape();
 
   83         ATH_MSG_INFO(
"Output " << 
i << 
" : num_dims= " << output_node_dims.size());
 
   84         for (std::size_t j = 0; j < output_node_dims.size(); j++) {
 
   85             if (output_node_dims[j] < 0) output_node_dims[j] = 1;
 
   86             ATH_MSG_INFO(
"Output" << 
i << 
" : dim " << j << 
"= " << output_node_dims[j]);
 
   90     return StatusCode::SUCCESS;
 
   96     auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
 
   97     auto input_tensor_size = tensor.size();
 
   99     Ort::Value input_tensor =
 
  100         Ort::Value::CreateTensor<float>(memory_info, tensor.data(), input_tensor_size,
 
  106     const float *output_score_array = output_tensors.front().GetTensorData<
float>();
 
  109     float output_score = output_score_array[0];
 
  136      constexpr std::array<int,19> calo_numbers{1,2,3,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20};
 
  137      constexpr std::array<int,12> fixed_r_numbers = {1,2,3,12,13,14,15,16,17,18,19,20};
 
  138      constexpr std::array<double,12> fixed_r_vals = {1532.18, 1723.89, 1923.02, 2450.00, 2995.00, 3630.00, 3215.00,
 
  139                                         3630.00, 2246.50, 2450.00, 2870.00, 3480.00
 
  141      constexpr std::array<int, 7> fixed_z_numbers = {5,6,7,8,9,10,11};
 
  142      constexpr std::array<double, 7> fixed_z_vals = {3790.03, 3983.68, 4195.84, 4461.25, 4869.50, 5424.50, 5905.00};
 
  143      std::unordered_map<int, double> r_calo_dict;
 
  144      std::unordered_map<int, double> z_calo_dict;
 
  145      for(
size_t i=0; 
i<fixed_r_vals.size(); 
i++) r_calo_dict[fixed_r_numbers[
i]] = fixed_r_vals[
i];
 
  146      for(
size_t i=0; 
i<fixed_z_numbers.size(); 
i++) z_calo_dict[fixed_z_numbers[
i]] = fixed_z_vals[
i];
 
  148      std::vector<float> inputnn;
 
  149      inputnn.assign(5430, 0.0);
 
  150      std::vector<eflowRecCluster*> matchedClusters;
 
  151      std::vector<eflowTrackClusterLink*> 
links = 
ptr->getClusterMatches();
 
  156      const std::array<double, 2> 
track{
ptr->getTrack()->eta(), 
ptr->getTrack()->phi()};
 
  158      for(
auto *clink : 
links){
 
  159         auto *
cell = clink->getCluster()->getCluster();
 
  160         float clusterE = 
cell->e()*1
e-3;
 
  161         float clusterEta = 
cell->eta();
 
  163         if (clusterE < 0.0 || clusterE > 1e4f || std::abs(clusterEta) > 2.5) 
continue;
 
  165         constexpr 
bool cutOnR = 
false;
 
  167             std::array<double, 2> 
p{clink->getCluster()->getCluster()->eta(), clink->getCluster()->getCluster()->phi()};
 
  175             if(
R >= 1.2) 
continue;
 
  178         matchedClusters.push_back(clink->getCluster());
 
  181      std::vector<std::array<double, 5>> 
cells;
 
  184     bool trk_bool_em[2] = {
false,
false};
 
  189     for(
int i =0; 
i<2; 
i++) {
 
  190         trk_bool_em[
i] = std::abs(trk_em_eta[
i]) < 2.5 && std::abs(trk_em_phi[
i]) <= 
M_PI;
 
  192     int nProj_em = (
int)trk_bool_em[0] + (
int)trk_bool_em[1];
 
  195         eta_ctr = trk_bool_em[0] ? trk_em_eta[0] : trk_em_eta[1];
 
  196         phi_ctr = trk_bool_em[0] ? trk_em_phi[0] : trk_em_phi[1];
 
  197     } 
else if(nProj_em==2) {
 
  198         eta_ctr = (trk_em_eta[0] + trk_em_eta[1]) / 2.0;
 
  199         phi_ctr = (trk_em_phi[0] + trk_em_phi[1]) / 2.0;
 
  201         eta_ctr = 
ptr->getTrack()->eta();
 
  202         phi_ctr = 
ptr->getTrack()->phi();
 
  207      for(
auto *cptr : matchedClusters){
 
  208         auto *clustlink = cptr->getCluster();
 
  210         for(
auto it_cell = clustlink->cell_begin(); it_cell != clustlink->cell_end(); it_cell++){
 
  212            float cellE = 
cell->e()*(it_cell.weight())*1
e-3
f;
 
  213            if(cellE < 0.005) 
continue;
 
  214            const auto *theDDE=it_cell->caloDDE();
 
  215            double cx=theDDE->x();
 
  216            double cy=theDDE->y();
 
  218            cells.emplace_back( std::array<double, 5> { cellE,
 
  219                     theDDE->eta() -  eta_ctr,
 
  220                     theDDE->phi() -  phi_ctr,
 
  227     std::vector<bool> trk_bool(calo_numbers.size(), 
false);
 
  228     std::vector<std::array<double,4>> trk_full(calo_numbers.size());
 
  229     for(
size_t j=0; j<phitotal.size(); j++) {
 
  230         int cnum = calo_numbers[j];
 
  231         double eta = etatotal[j];
 
  232         double phi = phitotal[j];
 
  233         if(std::abs(
eta) < 2.5 && std::abs(
phi) <= 
M_PI) {
 
  235             trk_full[j][0] = 
eta;
 
  236             trk_full[j][1] = 
phi;
 
  237             trk_full[j][3] = cnum;
 
  238             double rPerp =-99999;
 
  239             if(
auto itr = r_calo_dict.find(cnum); itr != r_calo_dict.end()) rPerp = itr->second;
 
  240             else if(
auto itr = z_calo_dict.find(cnum); itr != z_calo_dict.end())
 
  242                 double z = itr->second;
 
  244                    double aeta = std::abs(
eta);
 
  248                 throw std::runtime_error(
"Calo sample num not found in dicts..");
 
  250             trk_full[j][2] = rPerp;
 
  252             trk_full[j].fill(0.0);
 
  255     double trackP = std::abs(1. / 
ptr->getTrack()->qOverP()) * 1
e-3;
 
  256     int trk_proj_num = 
std::accumulate(trk_bool.begin(), trk_bool.end(), 0);
 
  257     if(trk_proj_num ==0) {
 
  259         std::array<double,5> trk_arr{};
 
  262         trk_arr[1] = 
ptr->getTrack()->eta() - eta_ctr;
 
  263         trk_arr[2] = 
ptr->getTrack()->phi() - phi_ctr;
 
  264         trk_arr[3] = 1532.18; 
 
  267         cells.emplace_back(trk_arr);
 
  269         for(
size_t i =0; 
i<calo_numbers.size(); 
i++) {
 
  270             if(!trk_bool[
i]) 
continue;
 
  271             std::array<double,5> trk_arr{};
 
  272             trk_arr[0]= trackP/
double(trk_proj_num);
 
  273             trk_arr[1]= trk_full[
i][0] - eta_ctr;
 
  274             trk_arr[2]= trk_full[
i][1] - phi_ctr;
 
  275             trk_arr[3]= trk_full[
i][2];
 
  278             cells.emplace_back(trk_arr);
 
  283     for(
auto &in : 
cells){
 
  286       if(
index >= 
static_cast<int>(inputnn.size()-4)) {
 
  297     return predictedEnergy;
 
  304       auto &
f = inputnn[
i+3];
 
  305       if(
f!= 0.0
f) 
f/= 3630.f;
 
  306       auto &
e = inputnn[
i+0];
 
  311       auto &
eta = inputnn[
i+1];
 
  312       if(
eta!= 0.0) 
eta /= 0.7f;
 
  313       auto &
phi = inputnn[
i+2];
 
  316     if(
i> inputnn.size()){
 
  325   return StatusCode::SUCCESS;