8 #include "GaudiKernel/SystemOfUnits.h"
18 #include <onnxruntime_cxx_api.h>
43 if (m_run1DSPThresholdsKey.empty() && m_run2DSPThresholdsKey.empty()) {
44 ATH_MSG_ERROR (
"useDB requested but neither Run1DSPThresholdsKey nor Run2DSPThresholdsKey initialized.");
45 return StatusCode::FAILURE;
55 return StatusCode::SUCCESS;
60 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61 Ort::SessionOptions session_options;
62 session_options.SetIntraOpNumThreads(1);
64 std::vector<int> hashIdToCluster;
65 std::vector<std::shared_ptr<Ort::Session>> clusterToOnnx;
72 return StatusCode::FAILURE;
77 const unsigned char* blobData =
static_cast<const unsigned char*
>(bls.startingAddress());
82 int nHash =
static_cast<unsigned char>(blobData[blob_ctr]) << 16 |
83 static_cast<unsigned char>(blobData[blob_ctr+1]) << 8 |
84 static_cast<unsigned char>(blobData[blob_ctr+2]);
86 hashIdToCluster.resize(nHash,-1);
89 int nCluster =
static_cast<unsigned char>(blobData[blob_ctr]) << 8 |
90 static_cast<unsigned char>(blobData[blob_ctr+1]);
92 clusterToOnnx.resize(nCluster,
nullptr);
95 for(
int i=0;
i<nHash;
i++){
97 cluster =
static_cast<unsigned char>(blobData[blob_ctr]) << 8 |
98 static_cast<unsigned char>(blobData[blob_ctr+1]);
100 hashIdToCluster[
i] = cluster;
104 for(
int i=0;
i<nCluster;
i++){
106 int nnInstanceSize =
static_cast<unsigned char>(blobData[blob_ctr]) << 16 |
107 static_cast<unsigned char>(blobData[blob_ctr+1]) << 8 |
108 static_cast<unsigned char>(blobData[blob_ctr+2]);
110 std::vector<char> nnInstanceContent(blobData + blob_ctr, blobData + blob_ctr + nnInstanceSize);
111 blob_ctr += nnInstanceSize;
113 clusterToOnnx[
i] = std::make_shared<Ort::Session>(m_onnxRuntimeSvc->env(), nnInstanceContent.data(), nnInstanceContent.size(), session_options);
119 auto outputContainerLRPtr = std::make_unique<LArRawChannelContainer>();
130 std::vector<Ort::Value> input_tensors;
132 std::vector<std::vector<float>> inputSamples(24, std::vector<float>(1, 0.0
f));
134 std::vector<std::vector<int64_t>> inputShape;
136 std::vector<char*> input_names;
137 std::vector<const char*> output_names;
139 std::vector<int> indicesOrder(24,-1);
144 const HWIdentifier
id =
digit->hardwareID();
147 idCell = (*cabling)->cnvToIdentifier(
id);
149 ATH_MSG_DEBUG(
"A Cabling exception was caught for channel 0x!"
150 << MSG::hex <<
id.get_compact() << MSG::dec );
153 const IdentifierHash oflHash=m_calocellID->calo_cell_hash(idCell);
154 const bool connected = (*cabling)->isOnlineConnected(
id);
156 ATH_MSG_VERBOSE(
"Working on channel " << m_onlineId->channel_name(
id));
157 const std::vector<short>& samples =
digit->samples();
160 const int clusterFromHash = hashIdToCluster[oflHash];
161 if (clusterFromHash<0){
162 ATH_MSG_ERROR(
"LArNNRawChannelBuilder::execute: clusterFromHash returned"<<clusterFromHash);
163 return StatusCode::FAILURE;
166 unsigned nnNumInputs = clusterToOnnx[clusterFromHash]->GetInputCount();
167 unsigned nnNumOutputs = clusterToOnnx[clusterFromHash]->GetOutputCount();
170 inputShape.resize(nnNumInputs);
171 indicesOrder.resize(nnNumInputs);
172 input_names.resize(nnNumInputs);
173 for(
unsigned int i = 0;
i < nnNumInputs;
i++){
174 auto type_info = clusterToOnnx[clusterFromHash]->GetInputTypeInfo(
i);
175 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
176 Ort::AllocatedStringPtr nnVariableNameStrPtr = clusterToOnnx[clusterFromHash]->GetInputNameAllocated(
i, Ort::AllocatorWithDefaultOptions());
178 char * nnVariableName = nnVariableNameStrPtr.release();
179 input_names[
i] = nnVariableName;
180 inputSamples.resize(
static_cast<int>(nnNumInputs));
181 if(std::strlen(nnVariableName) <= 7){
182 ATH_MSG_ERROR(
"Input name must starts with \"sample_\", then \"m\" (< 0) or \"p\" (>= 0) and end with an index (example : sample_m2)");
183 return StatusCode::FAILURE;
185 if(!(std::strncmp(nnVariableName,
"sample_", 7) == 0)){
186 ATH_MSG_ERROR(
"Input name must starts with \"sample_\", then \"m\" (< 0) or \"p\" (>= 0) and end with an index (example : sample_m2)");
187 return StatusCode::FAILURE;
189 char index_sign = nnVariableName[7];
191 if(index_sign ==
'm'){
194 else if(index_sign !=
'p'){
195 ATH_MSG_ERROR(
"Wrong sign used, you have to use \"m\" (< 0) or \"p\" (>= 0)");
196 return StatusCode::FAILURE;
199 for(
auto el : tensor_info.GetShape()){
200 inputShape[
i].push_back((
int) abs(
el));
202 input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, inputSamples[
i].
data(), inputSamples[
i].
size(), inputShape[
i].
data(), inputShape[
i].
size()));
204 for(
unsigned int i = 0;
i < nnNumOutputs;
i++){
205 auto outputName = clusterToOnnx[clusterFromHash]->GetOutputNameAllocated(
i, Ort::AllocatorWithDefaultOptions());
211 for(
unsigned int i = 0;
i < nnNumInputs;
i++){
212 char index_sign = input_names[
i][7];
214 if(index_sign ==
'm'){
217 else if(index_sign !=
'p'){
218 ATH_MSG_ERROR(
"Wrong sign used, you have to use \"m\" (< 0) or \"p\" (>= 0)");
219 return StatusCode::FAILURE;
221 inputSamples[
i][0] = ((samples[
index+m_firstSample]-pedestal_value)/(4096.0-pedestal_value));
224 const auto& adc2mev = adc2MeVs->ADC2MEV(
id,
gain);
227 if (!connected)
continue;
228 ATH_MSG_ERROR(
"No valid pedestal for connected channel " << m_onlineId->channel_name(
id)
229 <<
" gain " <<
gain);
230 return StatusCode::FAILURE;
234 if (!connected)
continue;
235 ATH_MSG_ERROR(
"No valid ADC2MeV for connected channel " << m_onlineId->channel_name(
id)
236 <<
" gain " <<
gain);
237 return StatusCode::FAILURE;
245 std::vector<float>samp_no_ped(nnNumInputs, 0.0);
246 for (
unsigned int i = 0;
i < nnNumInputs;
i++) {
247 int index = indicesOrder[
i]+m_firstSample;
249 samp_no_ped[
i] = samples[
index]-pedestal_value;
252 std::vector<Ort::Value>
outputs;
254 outputs = clusterToOnnx[clusterFromHash]->Run(Ort::RunOptions{
nullptr}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_names.size());
257 An =
outputs.front().GetTensorMutableData<
float>()[0];
260 A = An*(4096.0-pedestal_value);
263 const float E = adc2mev[0]+
A*adc2mev[1];
273 outputContainerLRPtr->emplace_back(
id,
static_cast<int>(std::floor(
E+0.5)),
274 static_cast<int>(std::floor(tau+0.5)),
281 for(
auto el : input_names){
285 for(
auto el : output_names){
290 return StatusCode::SUCCESS;