ATLAS Offline Software
TFCSMLCalorimeterSimulator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 #include "CLHEP/Random/RandGauss.h"
8 #include <chrono>
9 
10 
11 
13 
15 
17  // Load the simulator
18  try {
20  } catch (std::exception &e) {
21  ATH_MSG_ERROR("Failed to load simulator from file " << filename << " with error " << e.what());
22  return false;
23  }
24 
25  if (m_onnx_model == nullptr) {
26  ATH_MSG_ERROR("Failed to load simulator from file " << filename);
27  return false;
28  }
29 
30  return true;
31 }
32 
34 
35  // Bring eta into the needed range
36  eta = std::abs(eta) * 10;
37 
38  // Initialize the energy and eta input vectors
39  std::vector<float> eta_vector(m_nEvents, eta);
40  std::vector<float> energy_vector(m_nEvents, energy);
41 
42  // sample the z vectors according to a standard normal distribution
43  std::vector<float> z_shape_vector(m_nEvents * m_nVoxels, 0.0);
44  std::vector<float> z_energy_vector(m_nEvents * m_nLayers, 0.0);
45  for (auto& z_shape : z_shape_vector) {
46  z_shape = CLHEP::RandGauss::shoot(simulstate.randomEngine(), 0.0, 1.0);
47  }
48  for (auto& z_energy : z_energy_vector) {
49  z_energy = CLHEP::RandGauss::shoot(simulstate.randomEngine(), 0.0, 1.0);
50  }
51 
52  // Prepare the inputs for the network
54 
55  int i = 0;
56  for (float thisEta : eta_vector) {
57  inputs["inn_eta_in"].insert(std::pair<std::string, double>("variable_" + std::to_string(i), thisEta));
58  i++;
59  }
60 
61  i = 0;
62  for (float thisEnergy : energy_vector) {
63  inputs["inn_einc_in"].insert(std::pair<std::string, double>("variable_" + std::to_string(i), thisEnergy));
64  i++;
65  }
66 
67  i = 0;
68  for (float z_shape : z_shape_vector) {
69  inputs["cfm_z_shape"].insert(std::pair<std::string, double>("variable_" + std::to_string(i), z_shape));
70  i++;
71  }
72 
73  i = 0;
74  for (float z_energy : z_energy_vector) {
75  inputs["inn_z_energy"].insert(std::pair<std::string, double>("variable_" + std::to_string(i), z_energy));
76  i++;
77  }
78 
79  // Compute the network outputs
81 
82  return outputs;
83 }
84 
86 
87  // Get the voxel energies
89 
90  // check if the output contains a nan
91  // If yes: retry up to 5 times
92  float first_output = outputs.begin()->second;
93  bool contains_nan = std::isnan(first_output);
94  if (contains_nan) {
95  int retry = 0;
96  while (contains_nan) {
97 
98  if (retry > 5) {
99  ATH_MSG_WARNING("Network output contains NaN. Giving up.");
100  break;
101  }
102 
103  ATH_MSG_WARNING("Network output contains NaN. Retrying.");
104  outputs = predictVoxels(simulstate, eta, energy);
105  first_output = outputs.begin()->second;
106  contains_nan = std::isnan(first_output);
107 
108  retry++;
109  }
110  }
111 
112  // Fill the event structure with the voxel energies
113  std::vector<unsigned int> bin_index_vector;
114  std::vector<float> E_vector;
115 
116  event_t event;
117 
118  long unsigned int layer_index = 0;
119  long unsigned int layer = m_used_layers.at(layer_index);
120 
121  for (long unsigned int voxel_index = 0; voxel_index < m_nVoxels; ++voxel_index) {
122 
123 
124  if (voxel_index == m_layer_boundaries[layer_index+1]) {
125  layer_index = layer_index + 1;
126  layer = m_used_layers.at(layer_index);
127  }
128 
129  float voxel_energy = outputs[std::to_string(voxel_index)];
130 
131  if (voxel_energy > 0) {
132  if (event.event_data.size() <= layer) {
133  event.event_data.resize(layer+1);
134  }
135  event.event_data.at(layer).bin_index_vector.push_back(voxel_index - m_layer_boundaries[layer_index]);
136 
137  // We need energy fractions, not MeV values
138  event.event_data.at(layer).E_vector.push_back(voxel_energy/energy);
139  }
140 
141  }
142 
143  return event;
144 
145 }
146 
148 
149  // For testing...
150  // This function sets the input dimensionality and the number of predicted layers
151  // to work with the currently best photon ML simulation model. 382 voxels spanned over the
152  // presampler, the three EMB layers and the first HCAL layer.
153  // This allows for easier testing calls.
154 
155  int nEvents = 1;
156  int nVoxels = 382;
157  int nLayers= 5;
158 
159  std::vector<float> eta_vector(nEvents, 2.0);
160  std::vector<float> energy_vector(nEvents, 65536.0);
161  std::vector<float> z_shape_vector(nEvents*nVoxels, 0.5);
162  std::vector<float> z_energy_vector(nEvents*nLayers, 0.5);
163 
165 
166  int i = 0;
167  for (float eta : eta_vector) {
168  inputs["inn_eta_in"].insert(std::pair<std::string, double>("variable_" + std::to_string(i), eta));
169  i++;
170  }
171 
172  i = 0;
173  for (float energy : energy_vector) {
174  inputs["inn_einc_in"].insert(std::pair<std::string, double>("variable_" + std::to_string(i), energy));
175  i++;
176  }
177 
178  i = 0;
179  for (float z_shape : z_shape_vector) {
180  inputs["cfm_z_shape"].insert(std::pair<std::string, double>("variable_" + std::to_string(i), z_shape));
181  i++;
182  }
183 
184  i = 0;
185  for (float z_energy : z_energy_vector) {
186  inputs["inn_z_energy"].insert(std::pair<std::string, double>("variable_" + std::to_string(i), z_energy));
187  i++;
188  }
189 
191 
193 
195 
196  return outputs;
197 
198 
199 }
200 
201 
203  ATH_MSG_INFO("ONNX AICalorimeterSimulator");
204 }
nEvents
const int nEvents
Definition: fbtTestBasics.cxx:78
AllowedVariables::e
e
Definition: AsgElectronSelectorTool.cxx:37
VNetworkBase::NetworkOutputs
std::map< std::string, double > NetworkOutputs
Format for network outputs.
Definition: VNetworkBase.h:100
TFCSMLCalorimeterSimulator::m_onnx_model
std::unique_ptr< VNetworkBase > m_onnx_model
Definition: TFCSMLCalorimeterSimulator.h:54
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
VNetworkBase::representNetworkOutputs
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
Definition: VNetworkBase.cxx:57
eta
Scalar eta() const
pseudorapidity method
Definition: AmgMatrixBasePlugin.h:83
TFCSMLCalorimeterSimulator::TFCSMLCalorimeterSimulator
TFCSMLCalorimeterSimulator()
Definition: TFCSMLCalorimeterSimulator.cxx:12
TRT::Track::event
@ event
Definition: InnerDetector/InDetCalibEvent/TRT_CalibData/TRT_CalibData/TrackInfo.h:74
VNetworkBase::NetworkInputs
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
Definition: VNetworkBase.h:90
TFCSMLCalorimeterSimulator::getEvent
event_t getEvent(TFCSSimulationState &simulstate, float eta, float energy) const
Definition: TFCSMLCalorimeterSimulator.cxx:85
TFCSSimulationState::randomEngine
CLHEP::HepRandomEngine * randomEngine()
Definition: TFCSSimulationState.h:36
MuonR4::to_string
std::string to_string(const SectorProjector proj)
Definition: MsTrackSeeder.cxx:66
TFCSNetworkFactory::create
static std::unique_ptr< VNetworkBase > create(std::string input)
Given a string, make a network.
Definition: TFCSNetworkFactory.cxx:66
TFCSMLCalorimeterSimulator::loadSimulator
bool loadSimulator(std::string &filename)
Definition: TFCSMLCalorimeterSimulator.cxx:16
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
TFCSMLCalorimeterSimulator.h
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
TFCSMLCalorimeterSimulator::m_nVoxels
long unsigned int m_nVoxels
Definition: TFCSMLCalorimeterSimulator.h:62
ParticleGun_FastCalo_ChargeFlip_Config.energy
energy
Definition: ParticleGun_FastCalo_ChargeFlip_Config.py:78
lumiFormat.i
int i
Definition: lumiFormat.py:85
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
TRT::Hit::layer
@ layer
Definition: HitInfo.h:79
TFCSMLCalorimeterSimulator::event_t
Definition: TFCSMLCalorimeterSimulator.h:33
calibdata.exception
exception
Definition: calibdata.py:495
TFCSMLCalorimeterSimulator::~TFCSMLCalorimeterSimulator
virtual ~TFCSMLCalorimeterSimulator()
Definition: TFCSMLCalorimeterSimulator.cxx:14
TFCSMLCalorimeterSimulator::m_nLayers
long unsigned int m_nLayers
Definition: TFCSMLCalorimeterSimulator.h:63
TFCSMLCalorimeterSimulator::m_used_layers
std::vector< long unsigned int > m_used_layers
Definition: TFCSMLCalorimeterSimulator.h:61
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
TFCSMLCalorimeterSimulator::m_nEvents
int m_nEvents
Definition: TFCSMLCalorimeterSimulator.h:56
VNetworkBase::representNetworkInputs
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
Definition: VNetworkBase.cxx:37
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:23
TFCSMLCalorimeterSimulator::Print
void Print() const
Definition: TFCSMLCalorimeterSimulator.cxx:202
TFCSMLCalorimeterSimulator::predictVoxels
VNetworkBase::NetworkOutputs predictVoxels() const
Definition: TFCSMLCalorimeterSimulator.cxx:147
TFCSNetworkFactory.h
TFCSMLCalorimeterSimulator::m_layer_boundaries
std::vector< long unsigned int > m_layer_boundaries
Definition: TFCSMLCalorimeterSimulator.h:60
TFCSSimulationState
Definition: TFCSSimulationState.h:32