ATLAS Offline Software
LArNNRawChannelBuilder.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2024-2025 CERN for the benefit of the ATLAS collaboration
3  */
4 
7 
8 #include "GaudiKernel/SystemOfUnits.h"
11 
17 
18 #include <onnxruntime_cxx_api.h>
22 //
23 #include <cmath>
24 #include <map>
25 #include <vector>
26 #include <fstream>
27 #include <sstream>
28 #include <typeinfo>
29 
30 using namespace cool;
31 
33  ATH_CHECK(m_digitKey.initialize());
34  ATH_CHECK(m_rawChannelKey.initialize());
35  ATH_CHECK(m_pedestalKey.initialize());
36  ATH_CHECK(m_adc2MeVKey.initialize());
37  ATH_CHECK(m_cablingKey.initialize() );
38  ATH_CHECK(m_ofcKey.initialize());
39  ATH_CHECK(m_shapeKey.initialize());
40  ATH_CHECK(m_run1DSPThresholdsKey.initialize(SG::AllowEmpty) );
41  ATH_CHECK(m_run2DSPThresholdsKey.initialize(SG::AllowEmpty) );
42  if (m_useDBFortQ) {
43  if (m_run1DSPThresholdsKey.empty() && m_run2DSPThresholdsKey.empty()) {
44  ATH_MSG_ERROR ("useDB requested but neither Run1DSPThresholdsKey nor Run2DSPThresholdsKey initialized.");
45  return StatusCode::FAILURE;
46  }
47  }
48  ATH_CHECK(m_nnClustersDb.initialize());
49 
50  ATH_CHECK(detStore()->retrieve(m_onlineId,"LArOnlineID"));
51  ATH_CHECK(detStore()->retrieve(m_calocellID,"CaloCell_ID"));
52 
53  ATH_CHECK(m_onnxRuntimeSvc.retrieve());
54 
55  return StatusCode::SUCCESS;
56 }
57 
58 
59 StatusCode LArNNRawChannelBuilder::execute(const EventContext& ctx) const {
60  Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61  Ort::SessionOptions session_options;
62  session_options.SetIntraOpNumThreads(1);
63 
64  std::vector<int> hashIdToCluster;
65  std::vector<std::shared_ptr<Ort::Session>> clusterToOnnx;
66 
67  const CondAttrListCollection *catr{nullptr};
68  ATH_CHECK(SG::get(catr, m_nnClustersDb, ctx));
69 
70  if(!catr){
71  ATH_MSG_ERROR("CondAttrListCollection can't be opened");
72  return StatusCode::FAILURE;
73  }
74 
75  CondAttrListCollection::const_iterator chanIt=catr->begin();
76  const coral::Blob& bls = chanIt->second["clusters"].data<coral::Blob>();
77  const unsigned char* blobData = static_cast<const unsigned char*>(bls.startingAddress());
78 
79  // Reading BLOB part by part
80  int blob_ctr = 0;
81  // Nb of IDs encoded as 3 bytes
82  int nHash = static_cast<unsigned char>(blobData[blob_ctr]) << 16 |
83  static_cast<unsigned char>(blobData[blob_ctr+1]) << 8 |
84  static_cast<unsigned char>(blobData[blob_ctr+2]);
85  blob_ctr += 3;
86  hashIdToCluster.resize(nHash,-1);
87 
88  // Nb of clusters encoded as 2 bytes
89  int nCluster = static_cast<unsigned char>(blobData[blob_ctr]) << 8 |
90  static_cast<unsigned char>(blobData[blob_ctr+1]);
91  blob_ctr += 2;
92  clusterToOnnx.resize(nCluster,nullptr);
93 
94  // Reading clusters for each ID
95  for(int i=0; i<nHash;i++){
96  int cluster;
97  cluster = static_cast<unsigned char>(blobData[blob_ctr]) << 8 |
98  static_cast<unsigned char>(blobData[blob_ctr+1]);
99  blob_ctr += 2;
100  hashIdToCluster[i] = cluster;
101  }
102 
103  // Creating ONNX model instances
104  for(int i=0; i<nCluster; i++){
105  // Size of the instance written in the BLOB
106  int nnInstanceSize = static_cast<unsigned char>(blobData[blob_ctr]) << 16 |
107  static_cast<unsigned char>(blobData[blob_ctr+1]) << 8 |
108  static_cast<unsigned char>(blobData[blob_ctr+2]);
109  blob_ctr += 3;
110  std::vector<char> nnInstanceContent(blobData + blob_ctr, blobData + blob_ctr + nnInstanceSize);
111  blob_ctr += nnInstanceSize;
112  // One session per model
113  clusterToOnnx[i] = std::make_shared<Ort::Session>(m_onnxRuntimeSvc->env(), nnInstanceContent.data(), nnInstanceContent.size(), session_options);
114  }
115  //Get event inputs from read handles:
116  SG::ReadHandle<LArDigitContainer>inputContainer(m_digitKey, ctx);
117  ATH_CHECK(inputContainer.isValid());
118  //Write output via write handle
119  auto outputContainerLRPtr = std::make_unique<LArRawChannelContainer>();
120  //Get Conditions input
121  SG::ReadCondHandle<ILArPedestal>pedHdl(m_pedestalKey, ctx);
122  ATH_CHECK(pedHdl.isValid());
123  const ILArPedestal* peds = *pedHdl;
124  const LArADC2MeV* adc2MeVs{nullptr};
125  ATH_CHECK(SG::get(adc2MeVs, m_adc2MeVKey, ctx));
127  ATH_CHECK(cabling.isValid());
128 
129  // Same instance of input tensors are used
130  std::vector<Ort::Value> input_tensors;
131  // inputSamples variable memory is being used for the input tensors --> modify inputSamples to modify what's inside the inout tensors
132  std::vector<std::vector<float>> inputSamples(24, std::vector<float>(1, 0.0f));
133  // Same shapes should be provided for every neural networks
134  std::vector<std::vector<int64_t>> inputShape;
135  // Same input and output names sould be provided for every neural networks
136  std::vector<char*> input_names;
137  std::vector<const char*> output_names;
138  // Indices are sorted differently with the ORT, so it's needed to keep in memory to go faster than reading for each cell
139  std::vector<int> indicesOrder(24,-1);
140  // Boolean for the first iteration
141  int firstIter = 1;
142  //Loop over digits:
143  for (const LArDigit* digit : *inputContainer) {
144  const HWIdentifier id = digit->hardwareID();
145  Identifier idCell;
146  try {
147  idCell = (*cabling)->cnvToIdentifier(id);
148  } catch ( LArID_Exception & except ) {
149  ATH_MSG_DEBUG( "A Cabling exception was caught for channel 0x!"
150  << MSG::hex << id.get_compact() << MSG::dec );
151  continue ;
152  }
153  const IdentifierHash oflHash=m_calocellID->calo_cell_hash(idCell);
154  const bool connected = (*cabling)->isOnlineConnected(id);
155 
156  ATH_MSG_VERBOSE("Working on channel " << m_onlineId->channel_name(id));
157  const std::vector<short>& samples = digit->samples();
158  const int gain = digit->gain();
159  const float pedestal_value = peds->pedestal(id, gain);
160  const int clusterFromHash = hashIdToCluster[oflHash];
161  if (clusterFromHash<0){
162  ATH_MSG_ERROR("LArNNRawChannelBuilder::execute: clusterFromHash returned"<<clusterFromHash);
163  return StatusCode::FAILURE;
164 
165  }
166  unsigned nnNumInputs = clusterToOnnx[clusterFromHash]->GetInputCount();
167  unsigned nnNumOutputs = clusterToOnnx[clusterFromHash]->GetOutputCount();
168 
169  if(firstIter==1){
170  inputShape.resize(nnNumInputs);
171  indicesOrder.resize(nnNumInputs);
172  input_names.resize(nnNumInputs);
173  for(unsigned int i = 0; i < nnNumInputs; i++){
174  auto type_info = clusterToOnnx[clusterFromHash]->GetInputTypeInfo(i);
175  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
176  Ort::AllocatedStringPtr nnVariableNameStrPtr = clusterToOnnx[clusterFromHash]->GetInputNameAllocated(i, Ort::AllocatorWithDefaultOptions());
177  // Retrieving the ownership of the unique_ptr that was keeping the variable name
178  char * nnVariableName = nnVariableNameStrPtr.release();
179  input_names[i] = nnVariableName;
180  inputSamples.resize(static_cast<int>(nnNumInputs));
181  if(std::strlen(nnVariableName) <= 7){
182  ATH_MSG_ERROR("Input name must starts with \"sample_\", then \"m\" (< 0) or \"p\" (>= 0) and end with an index (example : sample_m2)");
183  return StatusCode::FAILURE;
184  }
185  if(!(std::strncmp(nnVariableName, "sample_", 7) == 0)){
186  ATH_MSG_ERROR("Input name must starts with \"sample_\", then \"m\" (< 0) or \"p\" (>= 0) and end with an index (example : sample_m2)");
187  return StatusCode::FAILURE;
188  }
189  char index_sign = nnVariableName[7];
190  int index = std::atoi(nnVariableName + 8);
191  if(index_sign == 'm'){
192  index*=-1;
193  }
194  else if(index_sign != 'p'){
195  ATH_MSG_ERROR("Wrong sign used, you have to use \"m\" (< 0) or \"p\" (>= 0)");
196  return StatusCode::FAILURE;
197  }
198  indicesOrder[i] = index;
199  for(auto el : tensor_info.GetShape()){
200  inputShape[i].push_back((int) abs(el));
201  }
202  input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, inputSamples[i].data(), inputSamples[i].size(), inputShape[i].data(), inputShape[i].size()));
203  }
204  for(unsigned int i = 0; i < nnNumOutputs; i++){
205  auto outputName = clusterToOnnx[clusterFromHash]->GetOutputNameAllocated(i, Ort::AllocatorWithDefaultOptions());
206  output_names.push_back(outputName.release());
207  }
208  }
209 
210  firstIter = 0;
211  for(unsigned int i = 0; i < nnNumInputs; i++){
212  char index_sign = input_names[i][7];
213  int index = std::atoi(input_names[i] + 8);
214  if(index_sign == 'm'){
215  index*=-1;
216  }
217  else if(index_sign != 'p'){
218  ATH_MSG_ERROR("Wrong sign used, you have to use \"m\" (< 0) or \"p\" (>= 0)");
219  return StatusCode::FAILURE;
220  }
221  inputSamples[i][0] = ((samples[index+m_firstSample]-pedestal_value)/(4096.0-pedestal_value));
222  }
223  //The following autos will resolve either into vectors or vector-proxies
224  const auto& adc2mev = adc2MeVs->ADC2MEV(id, gain);
225 
226  if (ATH_UNLIKELY(pedestal_value == ILArPedestal::ERRORCODE)) {
227  if (!connected) continue; //No conditions for disconencted channel, who cares?
228  ATH_MSG_ERROR("No valid pedestal for connected channel " << m_onlineId->channel_name(id)
229  << " gain " << gain);
230  return StatusCode::FAILURE;
231  }
232 
233  if (ATH_UNLIKELY(adc2mev.size() < 2)) {
234  if (!connected) continue; //No conditions for disconencted channel, who cares?
235  ATH_MSG_ERROR("No valid ADC2MeV for connected channel " << m_onlineId->channel_name(id)
236  << " gain " << gain);
237  return StatusCode::FAILURE;
238  }
239 
240  // Compute amplitude
241  float An = 0;
242  float A = 0;
243  bool saturated = false;
244  // Check saturation AND discount pedestal on samples used by the NN
245  std::vector<float>samp_no_ped(nnNumInputs, 0.0);
246  for (unsigned int i = 0; i < nnNumInputs; i++) {
247  int index = indicesOrder[i]+m_firstSample;
248  if (samples[index] == 4096 || samples[index] == 0) saturated = true;
249  samp_no_ped[i] = samples[index]-pedestal_value;
250  }
251 
252  std::vector<Ort::Value> outputs;
253 
254  outputs = clusterToOnnx[clusterFromHash]->Run(Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_names.size());
255 
256  //normalised output
257  An = outputs.front().GetTensorMutableData<float>()[0];
258 
259  //taking the normalisation into account
260  A = An*(4096.0-pedestal_value);
261 
262  //Apply Ramp
263  const float E = adc2mev[0]+A*adc2mev[1];
264 
265  uint16_t iquaShort = 0;
266  float tau = 0;
267 
268 
270  if (saturated) prov |= LArProv::SATURATED;
271 
272 
273  outputContainerLRPtr->emplace_back(id, static_cast<int>(std::floor(E+0.5)),
274  static_cast<int>(std::floor(tau+0.5)),
275  iquaShort, prov, (CaloGain::CaloGain)gain);
276 
277  }
278 
279  SG::WriteHandle<LArRawChannelContainer>outputContainer(m_rawChannelKey, ctx);
280 
281  for(auto el : input_names){
282  delete el;
283  }
284 
285  for(auto el : output_names){
286  delete el;
287  }
288  ATH_CHECK(outputContainer.record(std::move(outputContainerLRPtr) ) );
289 
290  return StatusCode::SUCCESS;
291 }
python.PyKernel.retrieve
def retrieve(aClass, aKey=None)
Definition: PyKernel.py:110
ILArPedestal::pedestal
virtual float pedestal(const HWIdentifier &id, int gain) const =0
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
LArNNRawChannelBuilder::initialize
StatusCode initialize() override
Definition: LArNNRawChannelBuilder.cxx:32
SG::ReadCondHandle
Definition: ReadCondHandle.h:44
CoraCoolDatabaseSvc.h
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:67
index
Definition: index.py:1
LArRawChannelBuilderAlg.h
ILArPedestal
Definition: ILArPedestal.h:12
LArNNRawChannelBuilder.h
CaloCondBlobAlgs_fillNoiseFromASCII.gain
gain
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:109
ReadCellNoiseFromCool.cabling
cabling
Definition: ReadCellNoiseFromCool.py:154
SG::ReadCondHandle::isValid
bool isValid()
Definition: ReadCondHandle.h:210
python.subdetectors.tile.Blob
Blob
Definition: tile.py:17
ATH_UNLIKELY
#define ATH_UNLIKELY(x)
Definition: AthUnlikelyMacros.h:17
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
checkRpcDigits.digit
digit
Definition: checkRpcDigits.py:186
CondAttrListCollection
This class is a collection of AttributeLists where each one is associated with a channel number....
Definition: CondAttrListCollection.h:52
cool
Definition: CoolTagInfo.h:12
CaloCell_ID.h
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
xAOD::saturated
setScaleOne setStatusOne saturated
Definition: gFexGlobalRoI_v1.cxx:51
A
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
xAOD::uint16_t
setWord1 uint16_t
Definition: eFexEMRoI_v1.cxx:93
LArDigit
Liquid Argon digit base class.
Definition: LArDigit.h:25
lumiFormat.i
int i
Definition: lumiFormat.py:85
SG::get
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
Definition: ReadCondHandle.h:287
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
plotIsoValidation.el
el
Definition: plotIsoValidation.py:197
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
hist_file_dump.f
f
Definition: hist_file_dump.py:140
SG::ReadHandle::isValid
virtual bool isValid() override final
Can the handle be successfully dereferenced?
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
LArDSPThresholdsFlat.h
python.PyKernel.detStore
detStore
Definition: PyKernel.py:41
ILArPedestal::ERRORCODE
@ ERRORCODE
Definition: ILArPedestal.h:47
LArProv::PEAKNN
@ PEAKNN
Definition: LArProvenance.h:20
PathResolver.h
LArProv::PEDDB
@ PEDDB
Definition: LArProvenance.h:26
VP1PartSpect::E
@ E
Definition: VP1PartSpectFlags.h:21
LArNNRawChannelBuilder::execute
StatusCode execute(const EventContext &ctx) const override
Definition: LArNNRawChannelBuilder.cxx:59
LArDigitContainer.h
LArProv::RAMPDB
@ RAMPDB
Definition: LArProvenance.h:23
lumiFormat.outputName
string outputName
Definition: lumiFormat.py:65
CaloGain::CaloGain
CaloGain
Definition: CaloGain.h:11
LArADC2MeV
Definition: LArADC2MeV.h:21
SG::WriteHandle
Definition: StoreGate/StoreGate/WriteHandle.h:73
DeMoScan.index
string index
Definition: DeMoScan.py:362
SG::WriteHandle::record
StatusCode record(std::unique_ptr< T > data)
Record a const object to the store.
CondAttrListCollection::const_iterator
ChanAttrListMap::const_iterator const_iterator
Definition: CondAttrListCollection.h:63
LArProvenance.h
CoraCoolDatabase.h
CxxUtils::atoi
int atoi(std::string_view str)
Helper functions to unpack numbers decoded in string into integers and doubles The strings are requir...
Definition: Control/CxxUtils/Root/StringUtils.cxx:85
CoraCoolDatabaseSvcFactory.h
SG::AllowEmpty
@ AllowEmpty
Definition: StoreGate/StoreGate/VarHandleKey.h:30
LArID_Exception
Exception class for LAr Identifiers.
Definition: LArID_Exception.h:20
LArRawChannelContainer.h
LArOnlineID.h
LArProv::SATURATED
@ SATURATED
Definition: LArProvenance.h:31
Identifier
Definition: IdentifierFieldParser.cxx:14