59 {
60 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61 Ort::SessionOptions session_options;
62 session_options.SetIntraOpNumThreads(1);
63
64 std::vector<int> hashIdToCluster;
65 std::vector<std::shared_ptr<Ort::Session>> clusterToOnnx;
66
67 const CondAttrListCollection *catr{nullptr};
69
70 if(!catr){
72 return StatusCode::FAILURE;
73 }
74
76 const coral::Blob& bls = chanIt->second["clusters"].data<coral::Blob>();
77 const unsigned char* blobData = static_cast<const unsigned char*>(bls.startingAddress());
78
79
80 int blob_ctr = 0;
81
82 int nHash = static_cast<unsigned char>(blobData[blob_ctr]) << 16 |
83 static_cast<unsigned char>(blobData[blob_ctr+1]) << 8 |
84 static_cast<unsigned char>(blobData[blob_ctr+2]);
85 blob_ctr += 3;
86 hashIdToCluster.resize(nHash,-1);
87
88
89 int nCluster = static_cast<unsigned char>(blobData[blob_ctr]) << 8 |
90 static_cast<unsigned char>(blobData[blob_ctr+1]);
91 blob_ctr += 2;
92 clusterToOnnx.resize(nCluster,nullptr);
93
94
95 for(
int i=0;
i<nHash;
i++){
96 int cluster;
97 cluster = static_cast<unsigned char>(blobData[blob_ctr]) << 8 |
98 static_cast<unsigned char>(blobData[blob_ctr+1]);
99 blob_ctr += 2;
100 hashIdToCluster[
i] = cluster;
101 }
102
103
104 for(
int i=0;
i<nCluster;
i++){
105
106 int nnInstanceSize = static_cast<unsigned char>(blobData[blob_ctr]) << 16 |
107 static_cast<unsigned char>(blobData[blob_ctr+1]) << 8 |
108 static_cast<unsigned char>(blobData[blob_ctr+2]);
109 blob_ctr += 3;
110 std::vector<char> nnInstanceContent(blobData + blob_ctr, blobData + blob_ctr + nnInstanceSize);
111 blob_ctr += nnInstanceSize;
112
113 clusterToOnnx[
i] = std::make_shared<Ort::Session>(
m_onnxRuntimeSvc->env(), nnInstanceContent.data(), nnInstanceContent.size(), session_options);
114 }
115
116 SG::ReadHandle<LArDigitContainer>inputContainer(
m_digitKey, ctx);
118
119 auto outputContainerLRPtr = std::make_unique<LArRawChannelContainer>();
120
123 const ILArPedestal* peds = *pedHdl;
124 const LArADC2MeV* adc2MeVs{nullptr};
128
129
130 std::vector<Ort::Value> input_tensors;
131
132 std::vector<std::vector<float>> inputSamples(24, std::vector<float>(1, 0.0f));
133
134 std::vector<std::vector<int64_t>> inputShape;
135
136 std::vector<char*> input_names;
137 std::vector<const char*> output_names;
138
139 std::vector<int> indicesOrder(24,-1);
140
141 int firstIter = 1;
142
143 for (const LArDigit* digit : *inputContainer) {
144 const HWIdentifier
id =
digit->hardwareID();
145 Identifier idCell;
146 try {
147 idCell = (*cabling)->cnvToIdentifier(id);
148 } catch ( LArID_Exception & except ) {
149 ATH_MSG_DEBUG(
"A Cabling exception was caught for channel 0x!"
150 << MSG::hex << id.get_compact() << MSG::dec );
151 continue ;
152 }
153 const IdentifierHash oflHash=
m_calocellID->calo_cell_hash(idCell);
154 const bool connected = (*cabling)->isOnlineConnected(id);
155
157 const std::vector<short>& samples =
digit->samples();
159 const float pedestal_value = peds->
pedestal(
id, gain);
160 const int clusterFromHash = hashIdToCluster[oflHash];
161 if (clusterFromHash<0){
162 ATH_MSG_ERROR(
"LArNNRawChannelBuilder::execute: clusterFromHash returned"<<clusterFromHash);
163 return StatusCode::FAILURE;
164
165 }
166 unsigned nnNumInputs = clusterToOnnx[clusterFromHash]->GetInputCount();
167 unsigned nnNumOutputs = clusterToOnnx[clusterFromHash]->GetOutputCount();
168
169 if(firstIter==1){
170 inputShape.resize(nnNumInputs);
171 indicesOrder.resize(nnNumInputs);
172 input_names.resize(nnNumInputs);
173 for(
unsigned int i = 0;
i < nnNumInputs;
i++){
174 auto type_info = clusterToOnnx[clusterFromHash]->GetInputTypeInfo(i);
175 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
176 Ort::AllocatedStringPtr nnVariableNameStrPtr = clusterToOnnx[clusterFromHash]->GetInputNameAllocated(i, Ort::AllocatorWithDefaultOptions());
177
178 char * nnVariableName = nnVariableNameStrPtr.release();
179 input_names[
i] = nnVariableName;
180 inputSamples.resize(static_cast<int>(nnNumInputs));
181 if(std::strlen(nnVariableName) <= 7){
182 ATH_MSG_ERROR(
"Input name must starts with \"sample_\", then \"m\" (< 0) or \"p\" (>= 0) and end with an index (example : sample_m2)");
183 return StatusCode::FAILURE;
184 }
185 if(!(std::strncmp(nnVariableName, "sample_", 7) == 0)){
186 ATH_MSG_ERROR(
"Input name must starts with \"sample_\", then \"m\" (< 0) or \"p\" (>= 0) and end with an index (example : sample_m2)");
187 return StatusCode::FAILURE;
188 }
189 char index_sign = nnVariableName[7];
190 int index = std::atoi(nnVariableName + 8);
191 if(index_sign == 'm'){
193 }
194 else if(index_sign != 'p'){
195 ATH_MSG_ERROR(
"Wrong sign used, you have to use \"m\" (< 0) or \"p\" (>= 0)");
196 return StatusCode::FAILURE;
197 }
199 for(auto el : tensor_info.GetShape()){
200 inputShape[
i].push_back((
int) abs(el));
201 }
202 input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, inputSamples[i].
data(), inputSamples[i].size(), inputShape[i].
data(), inputShape[i].size()));
203 }
204 for(
unsigned int i = 0;
i < nnNumOutputs;
i++){
205 auto outputName = clusterToOnnx[clusterFromHash]->GetOutputNameAllocated(i, Ort::AllocatorWithDefaultOptions());
207 }
208 }
209
210 firstIter = 0;
211 for(
unsigned int i = 0;
i < nnNumInputs;
i++){
212 char index_sign = input_names[
i][7];
213 int index = std::atoi(input_names[i] + 8);
214 if(index_sign == 'm'){
216 }
217 else if(index_sign != 'p'){
218 ATH_MSG_ERROR(
"Wrong sign used, you have to use \"m\" (< 0) or \"p\" (>= 0)");
219 return StatusCode::FAILURE;
220 }
221 inputSamples[
i][0] = ((samples[
index+
m_firstSample]-pedestal_value)/(4096.0-pedestal_value));
222 }
223
224 const auto& adc2mev = adc2MeVs->
ADC2MEV(
id, gain);
225
227 if (!connected) continue;
229 << " gain " << gain);
230 return StatusCode::FAILURE;
231 }
232
234 if (!connected) continue;
236 << " gain " << gain);
237 return StatusCode::FAILURE;
238 }
239
240
241 float An = 0;
244
245 std::vector<float>samp_no_ped(nnNumInputs, 0.0);
246 for (
unsigned int i = 0;
i < nnNumInputs;
i++) {
248 if (samples[index] == 4096 || samples[index] == 0)
saturated =
true;
249 samp_no_ped[
i] = samples[
index]-pedestal_value;
250 }
251
252 std::vector<Ort::Value>
outputs;
253
254 outputs = clusterToOnnx[clusterFromHash]->Run(Ort::RunOptions{
nullptr}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_names.size());
255
256
257 An =
outputs.front().GetTensorMutableData<
float>()[0];
258
259
260 A = An*(4096.0-pedestal_value);
261
262
263 const float E = adc2mev[0]+
A*adc2mev[1];
264
266 float tau = 0;
267
268
271
272
273 outputContainerLRPtr->emplace_back(id, static_cast<int>(std::floor(E+0.5)),
274 static_cast<int>(std::floor(tau+0.5)),
276
277 }
278
279 SG::WriteHandle<LArRawChannelContainer>outputContainer(
m_rawChannelKey, ctx);
280
281 for(auto el : input_names){
283 }
284
285 for(auto el : output_names){
287 }
288 ATH_CHECK(outputContainer.record(std::move(outputContainerLRPtr) ) );
289
290 return StatusCode::SUCCESS;
291}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_VERBOSE(x)
char data[hepevt_bytes_allocation_ATLAS]
const_iterator begin() const
Access to Chan/AttributeList pairs via iterators.
ChanAttrListMap::const_iterator const_iterator
virtual float pedestal(const HWIdentifier &id, int gain) const =0
const LArVectorProxy ADC2MEV(const HWIdentifier &id, int gain) const
SG::ReadCondHandleKey< CondAttrListCollection > m_nnClustersDb
SG::WriteHandleKey< LArRawChannelContainer > m_rawChannelKey
SG::ReadCondHandleKey< ILArPedestal > m_pedestalKey
SG::ReadCondHandleKey< LArADC2MeV > m_adc2MeVKey
const CaloCell_ID * m_calocellID
SG::ReadHandleKey< LArDigitContainer > m_digitKey
SG::ReadCondHandleKey< LArOnOffIdMapping > m_cablingKey
Gaudi::Property< int > m_firstSample
const LArOnlineID * m_onlineId
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_onnxRuntimeSvc
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
setScaleOne setStatusOne saturated