ATLAS Offline Software
Loading...
Searching...
No Matches
LArNNRawChannelBuilder.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2024-2025 CERN for the benefit of the ATLAS collaboration
3 */
4
7
8#include "GaudiKernel/SystemOfUnits.h"
11
17
18#include <onnxruntime_cxx_api.h>
22//
23#include <cmath>
24#include <map>
25#include <vector>
26#include <fstream>
27#include <sstream>
28#include <typeinfo>
29
30using namespace cool;
31
33 ATH_CHECK(m_digitKey.initialize());
34 ATH_CHECK(m_rawChannelKey.initialize());
35 ATH_CHECK(m_pedestalKey.initialize());
36 ATH_CHECK(m_adc2MeVKey.initialize());
37 ATH_CHECK(m_cablingKey.initialize() );
38 ATH_CHECK(m_ofcKey.initialize());
39 ATH_CHECK(m_shapeKey.initialize());
42 if (m_useDBFortQ) {
43 if (m_run1DSPThresholdsKey.empty() && m_run2DSPThresholdsKey.empty()) {
44 ATH_MSG_ERROR ("useDB requested but neither Run1DSPThresholdsKey nor Run2DSPThresholdsKey initialized.");
45 return StatusCode::FAILURE;
46 }
47 }
48 ATH_CHECK(m_nnClustersDb.initialize());
49
50 ATH_CHECK(detStore()->retrieve(m_onlineId,"LArOnlineID"));
51 ATH_CHECK(detStore()->retrieve(m_calocellID,"CaloCell_ID"));
52
53 ATH_CHECK(m_onnxRuntimeSvc.retrieve());
54
55 return StatusCode::SUCCESS;
56}
57
58
59StatusCode LArNNRawChannelBuilder::execute(const EventContext& ctx) const {
60 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61 Ort::SessionOptions session_options;
62 session_options.SetIntraOpNumThreads(1);
63
64 std::vector<int> hashIdToCluster;
65 std::vector<std::shared_ptr<Ort::Session>> clusterToOnnx;
66
67 const CondAttrListCollection *catr{nullptr};
68 ATH_CHECK(SG::get(catr, m_nnClustersDb, ctx));
69
70 if(!catr){
71 ATH_MSG_ERROR("CondAttrListCollection can't be opened");
72 return StatusCode::FAILURE;
73 }
74
76 const coral::Blob& bls = chanIt->second["clusters"].data<coral::Blob>();
77 const unsigned char* blobData = static_cast<const unsigned char*>(bls.startingAddress());
78
79 // Reading BLOB part by part
80 int blob_ctr = 0;
81 // Nb of IDs encoded as 3 bytes
82 int nHash = static_cast<unsigned char>(blobData[blob_ctr]) << 16 |
83 static_cast<unsigned char>(blobData[blob_ctr+1]) << 8 |
84 static_cast<unsigned char>(blobData[blob_ctr+2]);
85 blob_ctr += 3;
86 hashIdToCluster.resize(nHash,-1);
87
88 // Nb of clusters encoded as 2 bytes
89 int nCluster = static_cast<unsigned char>(blobData[blob_ctr]) << 8 |
90 static_cast<unsigned char>(blobData[blob_ctr+1]);
91 blob_ctr += 2;
92 clusterToOnnx.resize(nCluster,nullptr);
93
94 // Reading clusters for each ID
95 for(int i=0; i<nHash;i++){
96 int cluster;
97 cluster = static_cast<unsigned char>(blobData[blob_ctr]) << 8 |
98 static_cast<unsigned char>(blobData[blob_ctr+1]);
99 blob_ctr += 2;
100 hashIdToCluster[i] = cluster;
101 }
102
103 // Creating ONNX model instances
104 for(int i=0; i<nCluster; i++){
105 // Size of the instance written in the BLOB
106 int nnInstanceSize = static_cast<unsigned char>(blobData[blob_ctr]) << 16 |
107 static_cast<unsigned char>(blobData[blob_ctr+1]) << 8 |
108 static_cast<unsigned char>(blobData[blob_ctr+2]);
109 blob_ctr += 3;
110 std::vector<char> nnInstanceContent(blobData + blob_ctr, blobData + blob_ctr + nnInstanceSize);
111 blob_ctr += nnInstanceSize;
112 // One session per model
113 clusterToOnnx[i] = std::make_shared<Ort::Session>(m_onnxRuntimeSvc->env(), nnInstanceContent.data(), nnInstanceContent.size(), session_options);
114 }
115 //Get event inputs from read handles:
117 ATH_CHECK(inputContainer.isValid());
118 //Write output via write handle
119 auto outputContainerLRPtr = std::make_unique<LArRawChannelContainer>();
120 //Get Conditions input
122 ATH_CHECK(pedHdl.isValid());
123 const ILArPedestal* peds = *pedHdl;
124 const LArADC2MeV* adc2MeVs{nullptr};
125 ATH_CHECK(SG::get(adc2MeVs, m_adc2MeVKey, ctx));
127 ATH_CHECK(cabling.isValid());
128
129 // Same instance of input tensors are used
130 std::vector<Ort::Value> input_tensors;
131 // inputSamples variable memory is being used for the input tensors --> modify inputSamples to modify what's inside the inout tensors
132 std::vector<std::vector<float>> inputSamples(24, std::vector<float>(1, 0.0f));
133 // Same shapes should be provided for every neural networks
134 std::vector<std::vector<int64_t>> inputShape;
135 // Same input and output names sould be provided for every neural networks
136 std::vector<char*> input_names;
137 std::vector<const char*> output_names;
138 // Indices are sorted differently with the ORT, so it's needed to keep in memory to go faster than reading for each cell
139 std::vector<int> indicesOrder(24,-1);
140 // Boolean for the first iteration
141 int firstIter = 1;
142 //Loop over digits:
143 for (const LArDigit* digit : *inputContainer) {
144 const HWIdentifier id = digit->hardwareID();
145 Identifier idCell;
146 try {
147 idCell = (*cabling)->cnvToIdentifier(id);
148 } catch ( LArID_Exception & except ) {
149 ATH_MSG_DEBUG( "A Cabling exception was caught for channel 0x!"
150 << MSG::hex << id.get_compact() << MSG::dec );
151 continue ;
152 }
153 const IdentifierHash oflHash=m_calocellID->calo_cell_hash(idCell);
154 const bool connected = (*cabling)->isOnlineConnected(id);
155
156 ATH_MSG_VERBOSE("Working on channel " << m_onlineId->channel_name(id));
157 const std::vector<short>& samples = digit->samples();
158 const int gain = digit->gain();
159 const float pedestal_value = peds->pedestal(id, gain);
160 const int clusterFromHash = hashIdToCluster[oflHash];
161 if (clusterFromHash<0){
162 ATH_MSG_ERROR("LArNNRawChannelBuilder::execute: clusterFromHash returned"<<clusterFromHash);
163 return StatusCode::FAILURE;
164
165 }
166 unsigned nnNumInputs = clusterToOnnx[clusterFromHash]->GetInputCount();
167 unsigned nnNumOutputs = clusterToOnnx[clusterFromHash]->GetOutputCount();
168
169 if(firstIter==1){
170 inputShape.resize(nnNumInputs);
171 indicesOrder.resize(nnNumInputs);
172 input_names.resize(nnNumInputs);
173 for(unsigned int i = 0; i < nnNumInputs; i++){
174 auto type_info = clusterToOnnx[clusterFromHash]->GetInputTypeInfo(i);
175 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
176 Ort::AllocatedStringPtr nnVariableNameStrPtr = clusterToOnnx[clusterFromHash]->GetInputNameAllocated(i, Ort::AllocatorWithDefaultOptions());
177 // Retrieving the ownership of the unique_ptr that was keeping the variable name
178 char * nnVariableName = nnVariableNameStrPtr.release();
179 input_names[i] = nnVariableName;
180 inputSamples.resize(static_cast<int>(nnNumInputs));
181 if(std::strlen(nnVariableName) <= 7){
182 ATH_MSG_ERROR("Input name must starts with \"sample_\", then \"m\" (< 0) or \"p\" (>= 0) and end with an index (example : sample_m2)");
183 return StatusCode::FAILURE;
184 }
185 if(!(std::strncmp(nnVariableName, "sample_", 7) == 0)){
186 ATH_MSG_ERROR("Input name must starts with \"sample_\", then \"m\" (< 0) or \"p\" (>= 0) and end with an index (example : sample_m2)");
187 return StatusCode::FAILURE;
188 }
189 char index_sign = nnVariableName[7];
190 int index = std::atoi(nnVariableName + 8);
191 if(index_sign == 'm'){
192 index*=-1;
193 }
194 else if(index_sign != 'p'){
195 ATH_MSG_ERROR("Wrong sign used, you have to use \"m\" (< 0) or \"p\" (>= 0)");
196 return StatusCode::FAILURE;
197 }
198 indicesOrder[i] = index;
199 for(auto el : tensor_info.GetShape()){
200 inputShape[i].push_back((int) abs(el));
201 }
202 input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, inputSamples[i].data(), inputSamples[i].size(), inputShape[i].data(), inputShape[i].size()));
203 }
204 for(unsigned int i = 0; i < nnNumOutputs; i++){
205 auto outputName = clusterToOnnx[clusterFromHash]->GetOutputNameAllocated(i, Ort::AllocatorWithDefaultOptions());
206 output_names.push_back(outputName.release());
207 }
208 }
209
210 firstIter = 0;
211 for(unsigned int i = 0; i < nnNumInputs; i++){
212 char index_sign = input_names[i][7];
213 int index = std::atoi(input_names[i] + 8);
214 if(index_sign == 'm'){
215 index*=-1;
216 }
217 else if(index_sign != 'p'){
218 ATH_MSG_ERROR("Wrong sign used, you have to use \"m\" (< 0) or \"p\" (>= 0)");
219 return StatusCode::FAILURE;
220 }
221 inputSamples[i][0] = ((samples[index+m_firstSample]-pedestal_value)/(4096.0-pedestal_value));
222 }
223 //The following autos will resolve either into vectors or vector-proxies
224 const auto& adc2mev = adc2MeVs->ADC2MEV(id, gain);
225
226 if (ATH_UNLIKELY(pedestal_value == ILArPedestal::ERRORCODE)) {
227 if (!connected) continue; //No conditions for disconencted channel, who cares?
228 ATH_MSG_ERROR("No valid pedestal for connected channel " << m_onlineId->channel_name(id)
229 << " gain " << gain);
230 return StatusCode::FAILURE;
231 }
232
233 if (ATH_UNLIKELY(adc2mev.size() < 2)) {
234 if (!connected) continue; //No conditions for disconencted channel, who cares?
235 ATH_MSG_ERROR("No valid ADC2MeV for connected channel " << m_onlineId->channel_name(id)
236 << " gain " << gain);
237 return StatusCode::FAILURE;
238 }
239
240 // Compute amplitude
241 float An = 0;
242 float A = 0;
243 bool saturated = false;
244 // Check saturation AND discount pedestal on samples used by the NN
245 std::vector<float>samp_no_ped(nnNumInputs, 0.0);
246 for (unsigned int i = 0; i < nnNumInputs; i++) {
247 int index = indicesOrder[i]+m_firstSample;
248 if (samples[index] == 4096 || samples[index] == 0) saturated = true;
249 samp_no_ped[i] = samples[index]-pedestal_value;
250 }
251
252 std::vector<Ort::Value> outputs;
253
254 outputs = clusterToOnnx[clusterFromHash]->Run(Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_names.size());
255
256 //normalised output
257 An = outputs.front().GetTensorMutableData<float>()[0];
258
259 //taking the normalisation into account
260 A = An*(4096.0-pedestal_value);
261
262 //Apply Ramp
263 const float E = adc2mev[0]+A*adc2mev[1];
264
265 uint16_t iquaShort = 0;
266 float tau = 0;
267
268
270 if (saturated) prov |= LArProv::SATURATED;
271
272
273 outputContainerLRPtr->emplace_back(id, static_cast<int>(std::floor(E+0.5)),
274 static_cast<int>(std::floor(tau+0.5)),
275 iquaShort, prov, (CaloGain::CaloGain)gain);
276
277 }
278
280
281 for(auto el : input_names){
282 delete el;
283 }
284
285 for(auto el : output_names){
286 delete el;
287 }
288 ATH_CHECK(outputContainer.record(std::move(outputContainerLRPtr) ) );
289
290 return StatusCode::SUCCESS;
291}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_DEBUG(x)
#define ATH_UNLIKELY(x)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
const ServiceHandle< StoreGateSvc > & detStore() const
This class is a collection of AttributeLists where each one is associated with a channel number.
const_iterator begin() const
Access to Chan/AttributeList pairs via iterators.
ChanAttrListMap::const_iterator const_iterator
virtual float pedestal(const HWIdentifier &id, int gain) const =0
This is a "hash" representation of an Identifier.
const LArVectorProxy ADC2MEV(const HWIdentifier &id, int gain) const
Definition LArADC2MeV.h:32
Liquid Argon digit base class.
Definition LArDigit.h:25
Exception class for LAr Identifiers.
SG::ReadCondHandleKey< CondAttrListCollection > m_nnClustersDb
SG::ReadCondHandleKey< LArDSPThresholdsComplete > m_run1DSPThresholdsKey
SG::WriteHandleKey< LArRawChannelContainer > m_rawChannelKey
SG::ReadCondHandleKey< AthenaAttributeList > m_run2DSPThresholdsKey
StatusCode execute(const EventContext &ctx) const override
Gaudi::Property< bool > m_useDBFortQ
SG::ReadCondHandleKey< ILArPedestal > m_pedestalKey
SG::ReadCondHandleKey< LArADC2MeV > m_adc2MeVKey
SG::ReadCondHandleKey< ILArOFC > m_ofcKey
const CaloCell_ID * m_calocellID
SG::ReadHandleKey< LArDigitContainer > m_digitKey
SG::ReadCondHandleKey< LArOnOffIdMapping > m_cablingKey
StatusCode initialize() override
Gaudi::Property< int > m_firstSample
SG::ReadCondHandleKey< ILArShape > m_shapeKey
const LArOnlineID * m_onlineId
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_onnxRuntimeSvc
virtual bool isValid() override final
Can the handle be successfully dereferenced?
StatusCode record(std::unique_ptr< T > data)
Record a const object to the store.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
Definition index.py:1
hold the test vectors and ease the comparison