ATLAS Offline Software
Loading...
Searching...
No Matches
NnClusterizationFactory.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
18
19
20
21#include "GaudiKernel/ITHistSvc.h"
24#include <onnxruntime_cxx_api.h>
26
27//for position estimate and clustering
35
36//get std::isnan()
37#include <cmath>
38#include <algorithm>
39#include <limits>
40
41namespace {
42 std::pair<int, bool>
43 coerceToIntRange(double v){
44 constexpr double minint = std::numeric_limits<int>::min();
45 constexpr double maxint = std::numeric_limits<int>::max();
46 auto d = std::clamp(v, minint, maxint);
47
48 return {static_cast<int>(d), d != v};
49 }
50}
51
52
53namespace InDet {
54 const std::array<std::regex, NnClusterizationFactory::kNNetworkTypes>
56 std::regex("^NumberParticles(|/|_.*)$"),
57 std::regex("^ImpactPoints([0-9])P(|/|_.*)$"),
58 std::regex("^ImpactPointErrorsX([0-9])(|/|_.*)$"),
59 std::regex("^ImpactPointErrorsY([0-9])(|/|_.*)$"),
60 };
61
63 const std::string& n, const IInterface* p)
64 : AthAlgTool(name, n, p){
65 declareInterface<NnClusterizationFactory>(this);
66 }
67
69 ATH_CHECK(m_chargeDataKey.initialize());
72 if (m_doRunI) {
74 } else {
76 }
77 // =0 means invalid in the following, but later on the values will be decremented by one and they indicate the index in the NN collection
79 m_NNId.clear();
80 m_NNId.resize( kNNetworkTypes -1 ) ;
81 // map networks to element in network collection
82 unsigned int nn_id=0;
83 std::smatch match_result;
84 for(const std::string &nn_name : m_nnOrder) {
85 ++nn_id;
86 for (unsigned int network_i=0; network_i<kNNetworkTypes; ++network_i) {
87 if (std::regex_match( nn_name, match_result, m_nnNames[network_i])) {
88 if (network_i == kNumberParticlesNN) {
89 m_nParticleNNId = nn_id;
90 } else {
91 if (m_nParticleGroup[network_i]>0) {
92 if (m_nParticleGroup[network_i]>=match_result.size()) {
93 ATH_MSG_ERROR("Regex and match group of particle multiplicity do not coincide (groups=" << match_result.size()
94 << " n particle group=" << m_nParticleGroup[network_i]
95 << "; type=" << network_i << ")");
96 }
97 int n_particles=std::stoi( match_result[m_nParticleGroup[network_i]].str());
98 if (n_particles<=0 or static_cast<unsigned int>(n_particles)>m_maxSubClusters) {
99 ATH_MSG_ERROR( "Failed to extract number of clusters the NN is meant for. Got " << match_result[m_nParticleGroup[network_i]].str()
100 << " But this is not in the valid range 1..." << m_maxSubClusters);
101 return StatusCode::FAILURE;
102 }
103 if (static_cast<unsigned int>(n_particles)>=m_NNId[network_i-1].size()) {
104 m_NNId[network_i-1].resize( n_particles );
105 }
106 m_NNId[network_i-1][n_particles-1] = nn_id;
107 } else {
108 if (m_NNId[network_i-1].empty()) {
109 m_NNId[network_i-1].resize(1);
110 }
111 m_NNId[network_i-1][0] = nn_id;
112 }
113 }
114 }
115 }
116 }
117 // check whether the NN IDs are all valid
118 // if valid decrease IDs by 1, because the ID is used as index in the NN collection.
119 if ((m_nParticleNNId==0) or (m_nParticleNNId>=m_nnOrder.size())) {
120 ATH_MSG_ERROR( "No NN specified to estimate the number of particles.");
121 return StatusCode::FAILURE;
122 }
124 ATH_MSG_VERBOSE("Expect NN " << s_nnTypeNames[0] << " at index " << m_nParticleNNId );
125 unsigned int type_i=0;
126 for (std::vector<unsigned int> &nn_id : m_NNId) {
127 ++type_i;
128 if (nn_id.empty()) {
129 ATH_MSG_ERROR( "No " << s_nnTypeNames[type_i] << " specified.");
130 return StatusCode::FAILURE;
131 }
132 if (m_nParticleGroup[type_i-1]>0 and nn_id.size() != m_maxSubClusters) {
133 ATH_MSG_ERROR( "Number of networks of type " << s_nnTypeNames[type_i] << " does match the maximum number of supported sub clusters " << m_maxSubClusters);
134 return StatusCode::FAILURE;
135 }
136 unsigned int n_particles=0;
137 for (unsigned int &a_nn_id : nn_id ) {
138 ++n_particles;
139 if ((a_nn_id==0) or (a_nn_id>m_nnOrder.size())) {
140 ATH_MSG_ERROR( "No " << s_nnTypeNames[type_i] << " specified for " << n_particles);
141 return StatusCode::FAILURE;
142 }
143 --a_nn_id;
144 ATH_MSG_VERBOSE("Expect NN " << s_nnTypeNames[type_i] << " for " << n_particles << " particle(s) at index " << a_nn_id );
145 }
146 }
147 ATH_CHECK( m_readKeyWithoutTrack.initialize( !m_readKeyWithoutTrack.key().empty() ) );
148 ATH_CHECK( m_readKeyWithTrack.initialize( !m_readKeyWithTrack.key().empty() ) );
149 ATH_CHECK( m_readKeyJSON.initialize( !m_readKeyJSON.key().empty() ) );
150 ATH_CHECK( m_readKeyONNX.initialize( !m_readKeyONNX.key().empty() ) );
151 return StatusCode::SUCCESS;
152 }
153
154
155 std::vector<double>
157 const auto vectorSize{calculateVectorDimension(input.useTrackInfo)};
158 const auto invalidValue{std::numeric_limits<double>::quiet_NaN()};
159 std::vector<double> inputData(vectorSize, invalidValue);
160 size_t vectorIndex{0};
161 for (unsigned int u=0;u<m_sizeX;u++){
162 for (unsigned int s=0;s<m_sizeY;s++){
163 inputData[vectorIndex++] = input.matrixOfToT[u][s];
164 }
165 }
166 for (unsigned int s=0;s<m_sizeY;s++){
167 inputData[vectorIndex++] = input.vectorOfPitchesY[s];
168 }
169 inputData[vectorIndex++] = input.ClusterPixLayer;
170 inputData[vectorIndex++] = input.ClusterPixBarrelEC;
171 inputData[vectorIndex++] = input.phi;
172 inputData[vectorIndex++] = input.theta;
173 if (not input.useTrackInfo) inputData[vectorIndex] = input.etaModule;
174 return inputData;
175 }
176
177 std::vector<double>
179 const auto vectorSize{calculateVectorDimension(input.useTrackInfo)};
180 const auto invalidValue{std::numeric_limits<double>::quiet_NaN()};
181 std::vector<double> inputData(vectorSize, invalidValue);
182 size_t vectorIndex{0};
183 for (unsigned int u=0;u<m_sizeX;u++){
184 for (unsigned int s=0;s<m_sizeY;s++){
185 if (m_useToT){
186 inputData[vectorIndex++] = norm_rawToT(input.matrixOfToT[u][s]);
187 } else {
188 inputData[vectorIndex++] = norm_ToT(input.matrixOfToT[u][s]);
189 }
190 }
191 }
192 for (unsigned int s=0;s<m_sizeY;s++){
193 const double rawPitch(input.vectorOfPitchesY[s]);
194 const double normPitch(norm_pitch(rawPitch,m_addIBL));
195 if (std::isnan(normPitch)){
196 ATH_MSG_ERROR("NaN returned from norm_pitch, rawPitch = "<<rawPitch<<" addIBL = "<<m_addIBL);
197 }
198 inputData[vectorIndex++] = normPitch;
199 }
200 inputData[vectorIndex++] = norm_layerNumber(input.ClusterPixLayer);
201 inputData[vectorIndex++] = norm_layerType(input.ClusterPixBarrelEC);
202 if (input.useTrackInfo){
203 inputData[vectorIndex++] = norm_phi(input.phi);
204 inputData[vectorIndex] = norm_theta(input.theta);
205 } else {
206 inputData[vectorIndex++] = norm_phiBS(input.phi);
207 inputData[vectorIndex++] = norm_thetaBS(input.theta);
208 inputData[vectorIndex] = norm_etaModule(input.etaModule);
209 }
210 return inputData;
211 }
212
215 // we know the size to be
216 // - m_sizeX x m_sizeY pixel ToT values
217 // - m_sizeY pitch sizes in y
218 // - 2 values: detector location
219 // - 2 values: track incidence angles
220 // - optional: eta module
221 const auto vecSize{calculateVectorDimension(input.useTrackInfo)};
222 Eigen::VectorXd valuesVector( vecSize );
223 // Fill it!
224 // Variable names here need to match the ones in the configuration...
225 // ...IN THE SAME ORDER!!!
226 // location in eigen matrix object where next element goes
227 int location(0);
228 for (const auto & xvec: input.matrixOfToT){
229 for (const auto & xyElement : xvec){
230 valuesVector[location++] = xyElement;
231 }
232 }
233 for (const auto & pitch : input.vectorOfPitchesY) {
234 valuesVector[location++] = pitch;
235 }
236 valuesVector[location] = input.ClusterPixLayer;
237 location++;
238 valuesVector[location] = input.ClusterPixBarrelEC;
239 location++;
240 valuesVector[location] = input.phi;
241 location++;
242 valuesVector[location] = input.theta;
243 location++;
244 if (!input.useTrackInfo) {
245 valuesVector[location] = input.etaModule;
246 location++;
247 }
248 // We have only one node for now, so we just store things there.
249 // Format for use with lwtnn
250 std::vector<Eigen::VectorXd> vectorOfEigen;
251 vectorOfEigen.push_back(valuesVector);
252 return vectorOfEigen;
253 }
254
255 std::vector<double>
257 Amg::Vector3D & beamSpotPosition) const{
258 double tanl=0;
259 NNinput input( createInput(pCluster,beamSpotPosition,tanl) );
260 if (!input) return {};
261 // If using old TTrainedNetworks, fetch correct ones for the
262 // without-track situation and call them now.
264 const std::vector<double> & inputData=(this->*m_assembleInput)(input);
266 if (!nn_collection.isValid()) {
267 ATH_MSG_FATAL( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() );
268 return {};
269 }
270 return estimateNumberOfParticlesTTN(**nn_collection, inputData);
271 }
272 // Otherwise, prepare input vector and use ONNX or LWTNN networks.
274 if (m_useONNX) {
275 return estimateNumberOfParticlesONNX(nnInputVector[0]);
276 }
277 return estimateNumberOfParticlesLWTNN(nnInputVector);
278 }
279
280 std::vector<double>
282 const Trk::Surface& pixelSurface,
283 const Trk::TrackParameters& trackParsAtSurface) const{
284 Amg::Vector3D dummyBS(0,0,0);
285 double tanl=0;
286 NNinput input( createInput(pCluster,dummyBS,tanl) );
287
288 if (!input) return {};
289 addTrackInfoToInput(input,pixelSurface,trackParsAtSurface,tanl);
290 std::vector<double> inputData=(this->*m_assembleInput)(input);
291 // If using old TTrainedNetworks, fetch correct ones for the
292 // with-track situation and call them now.
295 if (!nn_collection.isValid()) {
296 ATH_MSG_FATAL( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() );
297 return {};
298 }
299 return estimateNumberOfParticlesTTN(**nn_collection, inputData);
300 }
301 // Otherwise, prepare input vector and use ONNX or LWTNN networks.
303 if (m_useONNX) {
304 return estimateNumberOfParticlesONNX(nnInputVector[0]);
305 }
306 return estimateNumberOfParticlesLWTNN(nnInputVector);
307 }
308
309 std::vector<double>
311 const std::vector<double>& inputData) const{
312 ATH_MSG_DEBUG("Using TTN number network");
313 std::vector<double> resultNN_TTN{};
314 if (not (m_nParticleNNId < nn_collection.size())){ //note: m_nParticleNNId is unsigned
315 ATH_MSG_FATAL("NnClusterizationFactory::estimateNumberOfParticlesTTN: Index "<<m_nParticleNNId<< "is out of range.");
316 return resultNN_TTN;
317 }
318 auto *const pNetwork = nn_collection[m_nParticleNNId].get();
319 if (not pNetwork){
320 ATH_MSG_FATAL("NnClusterizationFactory::estimateNumberOfParticlesTTN: nullptr returned for TrainedNetwork");
321 return resultNN_TTN;
322 }
323 // dereference unique_ptr<TTrainedNetwork> then call calculateOutput :
324 resultNN_TTN = (*pNetwork.*m_calculateOutput)(inputData);
325 ATH_MSG_VERBOSE(" TTN Prob of n. particles (1): " << resultNN_TTN[0] <<
326 " (2): " << resultNN_TTN[1] <<
327 " (3): " << resultNN_TTN[2]);
328 return resultNN_TTN;
329 }
330
331
332 std::vector<double>
334 std::vector<double> result(3,0.0);//ok as invalid result?
336 if (!lwtnn_collection.isValid()) {
337 ATH_MSG_FATAL( "Failed to get LWTNN network collection with key " << m_readKeyJSON.key() );
338 return result;
339 }
340 if (lwtnn_collection->empty()){
341 ATH_MSG_FATAL( "LWTNN network collection with key " << m_readKeyJSON.key()<<" is empty." );
342 return result;
343 }
344 ATH_MSG_DEBUG("Using lwtnn number network");
345 // Order of output matches order in JSON config in "outputs"
346 // Only 1 node here, simple compute function
347 Eigen::VectorXd discriminant = lwtnn_collection->at(0)->compute(input);
348 const double & num0 = discriminant[0];
349 const double & num1 = discriminant[1];
350 const double & num2 = discriminant[2];
351 // Get normalized predictions
352 const auto inverseSum = 1./(num0+num1+num2);
353 result[0] = num0 * inverseSum;
354 result[1] = num1 * inverseSum;
355 result[2] = num2 * inverseSum;
356 ATH_MSG_VERBOSE(" LWTNN Prob of n. particles (1): " << result[0] <<
357 " (2): " << result[1] <<
358 " (3): " << result[2]);
359 return result;
360 }
361
362
363 std::vector<Amg::Vector2D>
365 Amg::Vector3D & beamSpotPosition,
366 std::vector<Amg::MatrixX> & errors,
367 int numberSubClusters) const{
368 ATH_MSG_VERBOSE(" Starting to estimate positions...");
369 double tanl=0;
370 NNinput input( createInput(pCluster,beamSpotPosition,tanl) );
371 if (!input){
372 return {};
373 }
374 // If using old TTrainedNetworks, fetch correct ones for the
375 // without-track situation and call them now.
377 const std::vector<double> & inputData=(this->*m_assembleInput)(input);
379 if (!nn_collection.isValid()) {
380 ATH_MSG_FATAL( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() );
381 return {};
382 }
383 // *(ReadCondHandle<>) returns a pointer rather than a reference ...
384 return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,numberSubClusters,errors);
385 }
386 // Otherwise, prepare input vector and use ONNX or LWTNN networks.
388 if (m_useONNX) {
389 return estimatePositionsONNX(nnInputVector[0],input,pCluster,numberSubClusters,errors);
390 }
391 return estimatePositionsLWTNN(nnInputVector,input,pCluster,numberSubClusters,errors);
392 }
393
394
395 std::vector<Amg::Vector2D>
397 const Trk::Surface& pixelSurface,
398 const Trk::TrackParameters& trackParsAtSurface,
399 std::vector<Amg::MatrixX> & errors,
400 int numberSubClusters) const{
401 ATH_MSG_VERBOSE(" Starting to estimate positions...");
402 Amg::Vector3D dummyBS(0,0,0);
403 double tanl=0;
404 NNinput input( createInput(pCluster, dummyBS, tanl) );
405 if (!input) return {};
406 addTrackInfoToInput(input,pixelSurface,trackParsAtSurface,tanl);
407 // If using old TTrainedNetworks, fetch correct ones for the
408 // without-track situation and call them now.
410 std::vector<double> inputData=(this->*m_assembleInput)(input);
412 if (!nn_collection.isValid()) {
413 ATH_MSG_FATAL( "Failed to get trained network collection with key " << m_readKeyWithTrack.key() );
414 return {};
415 }
416 return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,numberSubClusters,errors);
417 }
418 // Otherwise, prepare input vector and use ONNX or LWTNN networks.
420 if (m_useONNX) {
421 return estimatePositionsONNX(nnInputVector[0],input,pCluster,numberSubClusters,errors);
422 }
423 return estimatePositionsLWTNN(nnInputVector,input,pCluster,numberSubClusters,errors);
424 }
425
426 std::vector<Amg::Vector2D>
428 const std::vector<double>& inputData,
429 const NNinput& input,
430 const InDet::PixelCluster& pCluster,
431 int numberSubClusters,
432 std::vector<Amg::MatrixX> & errors) const{
433 bool applyRecentering=(!input.useTrackInfo and m_useRecenteringNNWithouTracks) or (input.useTrackInfo and m_useRecenteringNNWithTracks);
434 std::vector<Amg::Vector2D> allPositions{};
435 const auto endNnIdx = nn_collection.size();
436 if (numberSubClusters>0 and static_cast<unsigned int>(numberSubClusters) < m_maxSubClusters) {
437 const auto subClusterIndex = numberSubClusters-1;
438 // get position network id for the given cluster multiplicity then
439 // dereference unique_ptr<TTrainedNetwork> then call calculateOutput :
440 const auto networkIndex = m_NNId[kPositionNN-1].at(subClusterIndex);
441 //TTrainedNetworkCollection inherits from std::vector
442 if (not(networkIndex < endNnIdx)){
443 ATH_MSG_FATAL("estimatePositionsTTN: Requested collection index, "<< networkIndex << " is out of range.");
444 return allPositions;
445 }
446 auto *const pNetwork = nn_collection[networkIndex].get();
447 std::vector<double> position1P = (*pNetwork.*m_calculateOutput)(inputData);
448 std::vector<Amg::Vector2D> myPosition1=getPositionsFromOutput(position1P,input,pCluster);
449 assert( position1P.size() % 2 == 0);
450 for (unsigned int i=0; i<position1P.size()/2 ; ++i) {
451 ATH_MSG_DEBUG(" Original RAW Estimated positions (" << i << ") x: " << back_posX(position1P[0+i*2],applyRecentering) << " y: " << back_posY(position1P[1+i*2]));
452 ATH_MSG_DEBUG(" Original estimated myPositions (" << i << ") x: " << myPosition1[i][Trk::locX] << " y: " << myPosition1[i][Trk::locY]);
453 }
454 std::vector<double> inputDataNew=inputData;
455 inputDataNew.reserve( inputDataNew.size() + numberSubClusters*2);
456 assert( static_cast<unsigned int>(numberSubClusters*2) <= position1P.size() );
457 for (unsigned int i=0; i<static_cast<unsigned int>(numberSubClusters*2); ++i) {
458 inputDataNew.push_back(position1P[i]);
459 }
460 // get error network id for the given cluster multiplicity then
461 // dereference unique_ptr<TTrainedNetwork> then call calculateOutput :
462 const auto xNetworkIndex = m_NNId[kErrorXNN-1].at(subClusterIndex);
463 const auto yNetworkIndex = m_NNId[kErrorYNN-1].at(subClusterIndex);
464 if ((not (xNetworkIndex < endNnIdx)) or (not (yNetworkIndex < endNnIdx))){
465 ATH_MSG_FATAL("estimatePositionsTTN: A requested collection index, "<< xNetworkIndex << " or "<< yNetworkIndex << "is out of range.");
466 return allPositions;
467 }
468 auto *pxNetwork = nn_collection.at(xNetworkIndex).get();
469 auto *pyNetwork = nn_collection.at(yNetworkIndex).get();
470 //call the selected member function of the TTrainedNetwork
471 std::vector<double> errors1PX = (*pxNetwork.*m_calculateOutput)(inputDataNew);
472 std::vector<double> errors1PY = (*pyNetwork.*m_calculateOutput)(inputDataNew);
473 //
474 std::vector<Amg::MatrixX> errorMatrices1;
475 getErrorMatrixFromOutput(errors1PX,errors1PY,errorMatrices1,numberSubClusters);
476 allPositions.reserve( allPositions.size() + myPosition1.size());
477 errors.reserve( errors.size() + myPosition1.size());
478 for (unsigned int i=0;i<myPosition1.size();i++){
479 allPositions.push_back(myPosition1[i]);
480 errors.push_back(errorMatrices1[i]);
481 }
482 }
483 return allPositions;
484 }
485
486
487 std::vector<Amg::Vector2D>
489 NNinput& rawInput,
490 const InDet::PixelCluster& pCluster,
491 int numberSubClusters,
492 std::vector<Amg::MatrixX> & errors) const {
494 if (not lwtnn_collection.isValid()) {
495 ATH_MSG_FATAL( "Failed to get LWTNN network collection with key " << m_readKeyJSON.key() );
496 return {};
497 }
498 if (lwtnn_collection->empty()){
499 ATH_MSG_FATAL( "estimatePositionsLWTNN: LWTNN network collection with key " << m_readKeyJSON.key()<<" is empty." );
500 return {};
501 }
502 // Need to evaluate the correct network once per cluster we're interested in.
503 // Save the output
504 std::vector<double> positionValues{};
505 std::vector<Amg::MatrixX> errorMatrices;
506 errorMatrices.reserve(numberSubClusters);
507 positionValues.reserve(numberSubClusters * 2);
508 std::size_t outputNode(0);
509 for (int cluster = 1; cluster < numberSubClusters+1; cluster++) {
510 // Check that the network is defined.
511 // If not, we are outside an IOV and should fail
512 const auto pNetwork = lwtnn_collection->find(numberSubClusters);
513 const bool validGraph = (pNetwork != lwtnn_collection->end()) and (pNetwork->second != nullptr);
514 if (not validGraph) {
515 std::string infoMsg ="Acceptable numbers of subclusters for the lwtnn collection:\n ";
516 for (const auto & pair: **lwtnn_collection){
517 infoMsg += std::to_string(pair.first) + "\n ";
518 }
519 infoMsg += "\nNumber of subclusters requested : "+ std::to_string(numberSubClusters);
520 ATH_MSG_DEBUG(infoMsg);
521 ATH_MSG_FATAL( "estimatePositionsLWTNN: No lwtnn network found for the number of clusters.\n"
522 <<" If you are outside the valid range for an lwtnn-based configuration, please run with useNNTTrainedNetworks instead.\n Key = "
523 << m_readKeyJSON.key() );
524 return {};
525 }
526 if(numberSubClusters==1) {
527 outputNode = m_outputNodesPos1;
528 } else if(numberSubClusters==2) {
529 outputNode = m_outputNodesPos2[cluster-1];
530 } else if(numberSubClusters==3) {
531 outputNode = m_outputNodesPos3[cluster-1];
532 } else {
533 ATH_MSG_FATAL( "Cannot evaluate LWTNN networks with " << numberSubClusters << " numberSubClusters" );
534 return {};
535 }
536
537 // Order of output matches order in JSON config in "outputs"
538 // "alpha", "mean_x", "mean_y", "prec_x", "prec_y"
539 // Assume here that 1 particle network is in position 1, 2 at 2, and 3 at 3.
540 Eigen::VectorXd position = lwtnn_collection->at(numberSubClusters)->compute(input, {}, outputNode);
541 ATH_MSG_DEBUG("Testing for numberSubClusters " << numberSubClusters << " and cluster " << cluster);
542 for (int i=0; i<position.rows(); i++) {
543 ATH_MSG_DEBUG(" position " << position[i]);
544 }
545 positionValues.push_back(position[1]); //mean_x
546 positionValues.push_back(position[2]); //mean_y
547 // Fill errors.
548 // Values returned by NN are inverse of variance, and we want variances.
549 const float rawRmsX = std::sqrt(1.0/position[3]); //prec_x
550 const float rawRmsY = std::sqrt(1.0/position[4]); //prec_y
551 // Now convert to real space units
552 const double rmsX = correctedRMSX(rawRmsX);
553 const double rmsY = correctedRMSY(rawRmsY, rawInput.vectorOfPitchesY);
554 ATH_MSG_DEBUG(" Estimated RMS errors (1) x: " << rmsX << ", y: " << rmsY);
555 // Fill matrix
556 Amg::MatrixX erm(2,2);
557 erm.setZero();
558 erm(0,0)=rmsX*rmsX;
559 erm(1,1)=rmsY*rmsY;
560 errorMatrices.push_back(erm);
561 }
562 std::vector<Amg::Vector2D> myPositions = getPositionsFromOutput(positionValues,rawInput,pCluster);
563 ATH_MSG_DEBUG(" Estimated myPositions (1) x: " << myPositions[0][Trk::locX] << " y: " << myPositions[0][Trk::locY]);
564 errors=std::move(errorMatrices);
565 return myPositions;
566 }
567
568 double
570 // This gives location in pixels
571 constexpr double pitch = 0.05;
572 const double corrected = posPixels * pitch;
573 return corrected;
574 }
575
576 double
578 std::vector<float>& pitches) const{
579 double p = posPixels + (m_sizeY - 1) * 0.5;
580 double p_Y = -100;
581 double p_center = -100;
582 double p_actual = 0;
583 for (unsigned int i = 0; i < m_sizeY; i++) {
584 if (p >= i and p <= (i + 1)) p_Y = p_actual + (p - i + 0.5) * pitches.at(i);
585 if (i == (m_sizeY - 1) / 2) p_center = p_actual + 0.5 * pitches.at(i);
586 p_actual += pitches.at(i);
587 }
588 return std::abs(p_Y - p_center);
589 }
590
591 void
593 std::vector<double>& outputY,
594 std::vector<Amg::MatrixX>& errorMatrix,
595 int nParticles) const{
596 int sizeOutputX=outputX.size()/nParticles;
597 int sizeOutputY=outputY.size()/nParticles;
598 double minimumX=-errorHalfIntervalX(nParticles);
599 double maximumX=errorHalfIntervalX(nParticles);
600 double minimumY=-errorHalfIntervalY(nParticles);
601 double maximumY=errorHalfIntervalY(nParticles);
602 //X=0...sizeOutput-1
603 //Y=minimum+(maximum-minimum)/sizeOutput*(X+1./2.)
604 errorMatrix.reserve( errorMatrix.size() + nParticles);
605 for (int i=0;i<nParticles;i++){
606 double sumValuesX=0;
607 for (int u=0;u<sizeOutputX;u++){
608 sumValuesX+=outputX[i*sizeOutputX+u];
609 }
610 double sumValuesY=0;
611 for (int u=0;u<sizeOutputY;u++){
612 sumValuesY+=outputY[i*sizeOutputY+u];
613 }
614 ATH_MSG_VERBOSE(" minimumX: " << minimumX << " maximumX: " << maximumX << " sizeOutputX " << sizeOutputX);
615 ATH_MSG_VERBOSE(" minimumY: " << minimumY << " maximumY: " << maximumY << " sizeOutputY " << sizeOutputY);
616 double RMSx=0;
617 for (int u=0;u<sizeOutputX;u++){
618 RMSx+=outputX[i*sizeOutputX+u]/sumValuesX*std::pow(minimumX+(maximumX-minimumX)/(double)(sizeOutputX-2)*(u-1./2.),2);
619 }
620 RMSx=std::sqrt(RMSx);//computed error!
621 ATH_MSG_VERBOSE(" first Iter RMSx: " << RMSx);
622 double intervalErrorX=3*RMSx;
623 //now recompute between -3*RMSx and +3*RMSx
624 int minBinX=(int)(1+(-intervalErrorX-minimumX)/(maximumX-minimumX)*(double)(sizeOutputX-2));
625 int maxBinX=(int)(1+(intervalErrorX-minimumX)/(maximumX-minimumX)*(double)(sizeOutputX-2));
626 if (maxBinX>sizeOutputX-1) maxBinX=sizeOutputX-1;
627 if (minBinX<0) minBinX=0;
628 ATH_MSG_VERBOSE(" minBinX: " << minBinX << " maxBinX: " << maxBinX );
629 RMSx=0;
630 for (int u=minBinX;u<maxBinX+1;u++){
631 RMSx+=outputX[i*sizeOutputX+u]/sumValuesX*std::pow(minimumX+(maximumX-minimumX)/(double)(sizeOutputX-2)*(u-1./2.),2);
632 }
633 RMSx=std::sqrt(RMSx);//computed error!
634 double RMSy=0;
635 for (int u=0;u<sizeOutputY;u++){
636 RMSy+=outputY[i*sizeOutputY+u]/sumValuesY*std::pow(minimumY+(maximumY-minimumY)/(double)(sizeOutputY-2)*(u-1./2.),2);
637 }
638 RMSy=std::sqrt(RMSy);//computed error!
639 ATH_MSG_VERBOSE("first Iter RMSy: " << RMSy );
640 double intervalErrorY=3*RMSy;
641 //now recompute between -3*RMSy and +3*RMSy
642 int minBinY=(int)(1+(-intervalErrorY-minimumY)/(maximumY-minimumY)*(double)(sizeOutputY-2));
643 int maxBinY=(int)(1+(intervalErrorY-minimumY)/(maximumY-minimumY)*(double)(sizeOutputY-2));
644 if (maxBinY>sizeOutputY-1) maxBinY=sizeOutputY-1;
645 if (minBinY<0) minBinY=0;
646 ATH_MSG_VERBOSE("minBinY: " << minBinY << " maxBinY: " << maxBinY );
647 RMSy=0;
648 for (int u=minBinY;u<maxBinY+1;u++){
649 RMSy+=outputY[i*sizeOutputY+u]/sumValuesY*std::pow(minimumY+(maximumY-minimumY)/(double)(sizeOutputY-2)*(u-1./2.),2);
650 }
651 RMSy=std::sqrt(RMSy);//computed error!
652 ATH_MSG_VERBOSE("Computed error, sigma(X) " << RMSx << " sigma(Y) " << RMSy );
653 Amg::MatrixX erm(2,2);
654 erm.setZero();
655 erm(0,0)=RMSx*RMSx;
656 erm(1,1)=RMSy*RMSy;
657 errorMatrix.push_back(erm);
658 }//end nParticles
659 }//getErrorMatrixFromOutput
660
661
662 std::vector<Amg::Vector2D>
664 const NNinput & input,
665 const InDet::PixelCluster& pCluster) const{
666 ATH_MSG_VERBOSE(" Translating output back into a position " );
667 const InDetDD::SiDetectorElement* element=pCluster.detectorElement();//DEFINE
668 const InDetDD::PixelModuleDesign* design
669 (dynamic_cast<const InDetDD::PixelModuleDesign*>(&element->design()));
670 if (not design){
671 ATH_MSG_ERROR("Dynamic cast failed at line "<<__LINE__<<" of NnClusterizationFactory.cxx.");
672 return {};
673 }
674 int numParticles=output.size()/2;
675 int columnWeightedPosition=input.columnWeightedPosition;
676 int rowWeightedPosition=input.rowWeightedPosition;
677 ATH_MSG_VERBOSE(" REF POS columnWeightedPos: " << columnWeightedPosition << " rowWeightedPos: " << rowWeightedPosition );
678 bool applyRecentering=false;
679 if (m_useRecenteringNNWithouTracks and (not input.useTrackInfo)){
680 applyRecentering=true;
681 }
682 if (m_useRecenteringNNWithTracks and input.useTrackInfo){
683 applyRecentering=true;
684 }
685 std::vector<Amg::Vector2D> positions;
686 for (int u=0;u<numParticles;u++){
687 double posXid{};
688 double posYid{};
689 if(m_doRunI){
690 posXid=back_posX(output[2*u],applyRecentering)+rowWeightedPosition;
691 posYid=back_posY(output[2*u+1])+columnWeightedPosition;
692 }else{
693 posXid=output[2*u]+rowWeightedPosition;
694 posYid=output[2*u+1]+columnWeightedPosition;
695 }
696 ATH_MSG_VERBOSE(" N. particle: " << u << " idx posX " << posXid << " posY " << posYid );
697 //ATLASRECTS-7155 : Pixel Charge Calibration needs investigating
698 const auto & [posXid_int, coercedX]=coerceToIntRange(posXid+0.5);
699 const auto & [posYid_int, coercedY]=coerceToIntRange(posYid+0.5);
700 if (coercedX or coercedY){
701 ATH_MSG_WARNING("X or Y position value has been limited in range; original values are (" << posXid<<", "<<posYid<<")");
702 //we cannot skip these values, it seems client code relies on the size of input vector and output vector being the same
703 }
704 ATH_MSG_VERBOSE(" N. particle: " << u << " TO INTEGER idx posX " << posXid_int << " posY " << posYid_int );
705 InDetDD::SiLocalPosition siLocalPositionDiscrete(design->positionFromColumnRow(posYid_int,posXid_int));
706 InDetDD::SiCellId cellIdOfPositionDiscrete=design->cellIdOfPosition(siLocalPositionDiscrete);
707 if ( not cellIdOfPositionDiscrete.isValid()){
708 ATH_MSG_WARNING(" Cell is outside validity region with index Y: " << posYid_int << " and index X: " << posXid_int << ". Not foreseen... " );
709 }
710 InDetDD::SiDiodesParameters diodeParameters = design->parameters(cellIdOfPositionDiscrete);
711 double pitchY = diodeParameters.width().xEta();
712 double pitchX = diodeParameters.width().xPhi();
713 ATH_MSG_VERBOSE(" Translated weighted position : " << siLocalPositionDiscrete.xPhi()
714 << " Translated weighted position : " << siLocalPositionDiscrete.xEta() );
715 //FOR TEST
716 InDetDD::SiLocalPosition siLocalPositionDiscreteOneRowMoreOneColumnMore(design->positionFromColumnRow(posYid_int+1,posXid_int+1));
717 ATH_MSG_VERBOSE(" Translated weighted position +1col +1row phi: " << siLocalPositionDiscreteOneRowMoreOneColumnMore.xPhi()
718 << " Translated weighted position +1col +1row eta: " << siLocalPositionDiscreteOneRowMoreOneColumnMore.xEta() );
719 ATH_MSG_VERBOSE("PitchY: " << pitchY << " pitchX " << pitchX );
720 InDetDD::SiLocalPosition siLocalPositionAdd(pitchY*(posYid-(double)posYid_int),
721 pitchX*(posXid-(double)posXid_int));
722 double lorentzShift=m_pixelLorentzAngleTool->getLorentzShift(element->identifyHash(), Gaudi::Hive::currentContext());
723 if (input.ClusterPixBarrelEC == 0){
724 if (not input.useTrackInfo){
726 } else {
728 }
729 }
730
732 siLocalPosition(siLocalPositionDiscrete.xEta()+pitchY*(posYid-(double)posYid_int),
733 siLocalPositionDiscrete.xPhi()+pitchX*(posXid-(double)posXid_int)+lorentzShift);
734 ATH_MSG_VERBOSE(" Translated final position phi: " << siLocalPosition.xPhi() << " eta: " << siLocalPosition.xEta() );
735 const auto halfWidth{design->width()*0.5};
736 if (siLocalPositionDiscrete.xPhi() > halfWidth){
737 siLocalPosition=InDetDD::SiLocalPosition(siLocalPositionDiscrete.xEta()+pitchY*(posYid-(double)posYid_int),
738 halfWidth-1e-6);
739 ATH_MSG_WARNING(" Corrected out of boundary cluster from x(phi): " << siLocalPositionDiscrete.xPhi()+pitchX*(posXid-(double)posXid_int)
740 << " to: " << halfWidth-1e-6);
741 } else if (siLocalPositionDiscrete.xPhi() < -halfWidth) {
742 siLocalPosition=InDetDD::SiLocalPosition(siLocalPositionDiscrete.xEta()+pitchY*(posYid-(double)posYid_int),
743 -halfWidth+1e-6);
744 ATH_MSG_WARNING(" Corrected out of boundary cluster from x(phi): " << siLocalPositionDiscrete.xPhi()+pitchX*(posXid-(double)posXid_int)
745 << " to: " << -halfWidth+1e-6);
746 }
747 positions.emplace_back(siLocalPosition);
748 }//iterate over all particles
749 return positions;
750 }
751
752
753 void
755 const Trk::Surface& pixelSurface, // pixelSurface = pcot->associatedSurface();
756 const Trk::TrackParameters& trackParsAtSurface,
757 const double tanl) const {
758 input.useTrackInfo=true;
759 Amg::Vector3D particleDir = trackParsAtSurface.momentum().unit();
760 Amg::Vector3D localIntersection = pixelSurface.transform().inverse().linear() * particleDir;
761 localIntersection *= 0.250/cos(localIntersection.theta());
762 float trackDeltaX = (float)localIntersection.x();
763 float trackDeltaY = (float)localIntersection.y();
764 input.theta=std::atan2(trackDeltaY,0.250);
765 input.phi=std::atan2(trackDeltaX,0.250);
766 ATH_MSG_VERBOSE("Angle phi bef Lorentz corr: " << input.phi );
767 input.phi=std::atan(std::tan(input.phi)-tanl);
768 ATH_MSG_VERBOSE(" From track: angle phi: " << input.phi << " theta: " << input.theta );
769 }
770
771
772 NNinput
774 Amg::Vector3D & beamSpotPosition,
775 double & tanl) const{
776 NNinput input;
777 ATH_MSG_VERBOSE(" Starting creating input from cluster " );
778 const InDetDD::SiDetectorElement* element=pCluster.detectorElement();
779 if (not element) {
780 ATH_MSG_ERROR("Could not get detector element");
781 return input;
782 }
783 const AtlasDetectorID* aid = element->getIdHelper();
784 if (not aid){
785 ATH_MSG_ERROR("Could not get ATLASDetectorID");
786 return input;
787 }
788
790 ATH_MSG_ERROR("Could not get PixelID pointer");
791 return input;
792 }
793 const PixelID* pixelIDp=static_cast<const PixelID*>(aid);
794 const PixelID& pixelID = *pixelIDp;
795 const InDetDD::PixelModuleDesign* design
796 (dynamic_cast<const InDetDD::PixelModuleDesign*>(&element->design()));
797 if (not design){
798 ATH_MSG_ERROR("Dynamic cast failed at line "<<__LINE__<<" of NnClusterizationFactory.cxx.");
799 return input;
800 }
802 const PixelChargeCalibCondData *calibData = *calibDataHandle;
803 const std::vector<Identifier>& rdos = pCluster.rdoList();
804 const size_t rdoSize = rdos.size();
805 ATH_MSG_VERBOSE(" Number of RDOs: " << rdoSize );
806 const std::vector<float>& chList = pCluster.chargeList();
807 const std::vector<int>& totList = pCluster.totList();
808 std::vector<float> chListRecreated{};
809 chListRecreated.reserve(rdoSize);
810 ATH_MSG_VERBOSE(" Number of charges: " << chList.size() );
811 std::vector<int>::const_iterator tot = totList.begin();
812 std::vector<Identifier>::const_iterator rdosBegin = rdos.begin();
813 std::vector<Identifier>::const_iterator rdosEnd = rdos.end();
814 std::vector<int> totListRecreated{};
815 totListRecreated.reserve(rdoSize);
816 std::vector<int>::const_iterator totRecreated = totListRecreated.begin();
817 // Recreate both charge list and ToT list to correct for the IBL ToT overflow (and later for small hits):
818 ATH_MSG_VERBOSE("Charge list is not filled ... re-creating it.");
819 IdentifierHash moduleHash = element->identifyHash(); // wafer hash
820
821 for ( ; rdosBegin!= rdosEnd and tot != totList.end(); ++tot, ++rdosBegin, ++totRecreated ){
822 // recreate the charge: should be a method of the calibSvc
823 int tot0 = *tot;
824 Identifier pixid = *rdosBegin;
825 assert( element->identifyHash() == pixelID.wafer_hash(pixelID.wafer_id(pixid)));
826
827 std::array<InDetDD::PixelDiodeTree::CellIndexType,2> diode_idx
829 pixelID.eta_index(pixid));
830 InDetDD::PixelDiodeTree::DiodeProxy si_param ( design->diodeProxyFromIdx(diode_idx));
831 std::uint32_t feValue = design->getFE(si_param);
832 auto diode_type = design->getDiodeType(si_param);
834 && design->numberOfConnectedCells( design->readoutIdOfCell(InDetDD::SiCellId(diode_idx[0],diode_idx[1])))>1) {
836 }
837
838 float charge = calibData->getCharge(diode_type, moduleHash, feValue, tot0);
839 chListRecreated.push_back(charge);
840 totListRecreated.push_back(tot0);
841 }
842 // reset the rdo iterator
843 rdosBegin = rdos.begin();
844 rdosEnd = rdos.end();
845 // and the tot iterator
846 tot = totList.begin();
847 totRecreated = totListRecreated.begin();
848 // Always use recreated charge and ToT lists:
849 std::vector<float>::const_iterator charge = chListRecreated.begin();
850 std::vector<float>::const_iterator chargeEnd = chListRecreated.end();
851 tot = totListRecreated.begin();
852 std::vector<int>::const_iterator totEnd = totListRecreated.end();
853 InDetDD::SiLocalPosition sumOfWeightedPositions(0,0,0);
854 double sumOfTot=0;
855 int rowMin = 999;
856 int rowMax = 0;
857 int colMin = 999;
858 int colMax = 0;
859 for (; (rdosBegin!= rdosEnd) and (charge != chargeEnd) and (tot != totEnd); ++rdosBegin, ++charge, ++tot){
860 Identifier rId = *rdosBegin;
861 int row = pixelID.phi_index(rId);
862 int col = pixelID.eta_index(rId);
863 InDetDD::SiLocalPosition siLocalPosition (design->positionFromColumnRow(col,row));
864 if (not m_useToT){
865 sumOfWeightedPositions += (*charge)*siLocalPosition;
866 sumOfTot += (*charge);
867 } else {
868 sumOfWeightedPositions += ((double)(*tot))*siLocalPosition;
869 sumOfTot += (double)(*tot);
870 }
871 rowMin = std::min(row, rowMin);
872 rowMax = std::max(row, rowMax);
873 colMin = std::min(col, colMin);
874 colMax = std::max(col, colMax);
875
876 }
877 sumOfWeightedPositions /= sumOfTot;
878 //what you want to know is simple:
879 //just the row and column of this average position!
880 InDetDD::SiCellId cellIdWeightedPosition=design->cellIdOfPosition(sumOfWeightedPositions);
881
882 if (!cellIdWeightedPosition.isValid()){
883 ATH_MSG_WARNING(" Weighted position is on invalid CellID." );
884 }
885 int columnWeightedPosition=cellIdWeightedPosition.etaIndex();
886 int rowWeightedPosition=cellIdWeightedPosition.phiIndex();
887 ATH_MSG_VERBOSE(" weighted pos row: " << rowWeightedPosition << " col: " << columnWeightedPosition );
888 int centralIndexX=(m_sizeX-1)/2;
889 int centralIndexY=(m_sizeY-1)/2;
890 if (std::abs(rowWeightedPosition-rowMin)>centralIndexX or
891 std::abs(rowWeightedPosition-rowMax)>centralIndexX){
892 ATH_MSG_VERBOSE(" Cluster too large rowMin" << rowMin << " rowMax " << rowMax << " centralX " << centralIndexX);
893 return input;
894 }
895 if (std::abs(columnWeightedPosition-colMin)>centralIndexY or
896 std::abs(columnWeightedPosition-colMax)>centralIndexY){
897 ATH_MSG_VERBOSE(" Cluster too large colMin" << colMin << " colMax " << colMax << " centralY " << centralIndexY);
898 return input;
899 }
900 input.matrixOfToT.reserve(m_sizeX);
901 for (unsigned int a=0;a<m_sizeX;a++){
902 input.matrixOfToT.emplace_back(m_sizeY, 0.0);
903 }
904 input.vectorOfPitchesY.assign(m_sizeY, 0.4);
905 rdosBegin = rdos.begin();
906 charge = chListRecreated.begin();
907 chargeEnd = chListRecreated.end();
908 tot = totListRecreated.begin();
909 ATH_MSG_VERBOSE(" Putting together the n. " << rdos.size() << " rdos into a matrix." );
910 Identifier pixidentif=pCluster.identify();
911 input.etaModule=(int)pixelID.eta_module(pixidentif);
912 input.ClusterPixLayer=(int)pixelID.layer_disk(pixidentif);
913 input.ClusterPixBarrelEC=(int)pixelID.barrel_ec(pixidentif);
914 for (;( charge != chargeEnd) and (rdosBegin!= rdosEnd); ++rdosBegin, ++charge, ++tot){
915 Identifier rId = *rdosBegin;
916 unsigned int absrow = pixelID.phi_index(rId)-rowWeightedPosition+centralIndexX;
917 unsigned int abscol = pixelID.eta_index(rId)-columnWeightedPosition+centralIndexY;
918 if (absrow > m_sizeX){
919 ATH_MSG_WARNING(" problem with index: " << absrow << " min: " << 0 << " max: " << m_sizeX);
920 return input;
921 }
922 if (abscol > m_sizeY){
923 ATH_MSG_WARNING(" problem with index: " << abscol << " min: " << 0 << " max: " << m_sizeY);
924 return input;
925 }
926 InDetDD::SiCellId cellId = element->cellIdFromIdentifier(*rdosBegin);
927 InDetDD::SiDiodesParameters diodeParameters = design->parameters(cellId);
928 double pitchY = diodeParameters.width().xEta();
929 if (not m_useToT) {
930 input.matrixOfToT[absrow][abscol]=*charge;
931 } else {
932 input.matrixOfToT[absrow][abscol]=(double)(*tot);
933 // in case to RunI setup to make IBL studies
934 if(m_doRunI){
935 if (m_addIBL and (input.ClusterPixLayer==0) and (input.ClusterPixBarrelEC==0)){
936 input.matrixOfToT[absrow][abscol]*=3;
937 }
938 }else{
939 // for RunII IBL is always present
940 if ( (input.ClusterPixLayer==0) and (input.ClusterPixBarrelEC==0)){
941 input.matrixOfToT[absrow][abscol]*=3;
942 }
943 }
944
945 }
946 if (std::abs(pitchY-0.4)>1e-5){
947 input.vectorOfPitchesY[abscol]=pitchY;
948 }
949 }//end iteration on rdos
950 ATH_MSG_VERBOSE(" eta module: " << input.etaModule );
951 ATH_MSG_VERBOSE(" Layer number: " << input.ClusterPixLayer << " Barrel / endcap: " << input.ClusterPixBarrelEC );
952 input.useTrackInfo=false;
953 const Amg::Vector2D& prdLocPos = pCluster.localPosition();
954 InDetDD::SiLocalPosition centroid(prdLocPos);
955 Amg::Vector3D globalPos = element->globalPosition(centroid);
956 Amg::Vector3D my_track = globalPos-beamSpotPosition;
957 const Amg::Vector3D &my_normal = element->normal();
958 const Amg::Vector3D &my_phiax = element->phiAxis();
959 const Amg::Vector3D &my_etaax = element->etaAxis();
960 float trkphicomp = my_track.dot(my_phiax);
961 float trketacomp = my_track.dot(my_etaax);
962 float trknormcomp = my_track.dot(my_normal);
963 double bowphi = std::atan2(trkphicomp,trknormcomp);
964 double boweta = std::atan2(trketacomp,trknormcomp);
965 tanl = m_pixelLorentzAngleTool->getTanLorentzAngle(element->identifyHash(), Gaudi::Hive::currentContext());
966 if(bowphi > M_PI_2) bowphi -= M_PI;
967 if(bowphi < -M_PI_2) bowphi += M_PI;
968 int readoutside = design->readoutSide();
969 double angle = std::atan(std::tan(bowphi)-readoutside*tanl);
970 input.phi=angle;
971 ATH_MSG_VERBOSE(" Angle theta bef corr: " << boweta );
972 if (boweta>M_PI_2) boweta-=M_PI;
973 if (boweta<-M_PI_2) boweta+=M_PI;
974 input.theta=boweta;
975 ATH_MSG_VERBOSE(" Angle phi: " << angle << " theta: " << boweta );
976 input.rowWeightedPosition=rowWeightedPosition;
977 input.columnWeightedPosition=columnWeightedPosition;
978 ATH_MSG_VERBOSE(" RowWeightedPosition: " << rowWeightedPosition << " ColWeightedPosition: " << columnWeightedPosition );
979 return input;
980 }//end create NNinput function
981
982 size_t
984 return (m_sizeX * m_sizeY) + m_sizeY + (useTrackInfo ? 4 : 5);
985 }
986
987 // ======================================================================
988 // ONNX inference methods
989 // ======================================================================
990
991 std::vector<double>
993 const Eigen::VectorXd& input) const {
994
995 std::vector<double> result(3, 0.0);
997 if (!onnxCollection.isValid()) {
998 ATH_MSG_FATAL("Failed to get ONNX network collection with key " << m_readKeyONNX.key());
999 return result;
1000 }
1001 Ort::Session& session = *onnxCollection->numberNetwork;
1002
1003 // Get expected input dimension from the model
1004 auto inputTypeInfo = session.GetInputTypeInfo(0);
1005 auto tensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();
1006 const int64_t expectedDim = tensorInfo.GetShape()[1];
1007
1008 // Convert Eigen double vector to float
1009 if (static_cast<int64_t>(input.size()) != expectedDim) {
1010 ATH_MSG_FATAL("ONNX number network expects input dimension " << expectedDim
1011 << " but got " << input.size() << " — check model/configuration");
1012 return result;
1013 }
1014 std::vector<float> inputData(expectedDim);
1015 for (int i = 0; i < expectedDim; ++i) {
1016 inputData[i] = static_cast<float>(input[i]);
1017 }
1018
1019 // Create input tensor
1020 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
1021 std::vector<int64_t> inputShape = {1, expectedDim};
1022 Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
1023 memInfo, inputData.data(), inputData.size(),
1024 inputShape.data(), inputShape.size());
1025 Ort::AllocatorWithDefaultOptions allocator;
1026 auto inputName = session.GetInputNameAllocated(0, allocator);
1027 auto outputName = session.GetOutputNameAllocated(0, allocator);
1028 const char* inputNames[] = {inputName.get()};
1029 const char* outputNames[] = {outputName.get()};
1030
1031 // Run inference
1032 auto outputTensors = session.Run(
1033 Ort::RunOptions{nullptr},
1034 inputNames, &inputTensor, 1,
1035 outputNames, 1);
1036
1037 // Extract output
1038 const float* outputData = outputTensors[0].GetTensorData<float>();
1039 double num0 = outputData[0];
1040 double num1 = outputData[1];
1041 double num2 = outputData[2];
1042
1043 // Normalize
1044 const double sum = num0 + num1 + num2;
1045 if (sum <= 0.0) {
1046 ATH_MSG_WARNING("ONNX number network output sum is non-positive: " << sum);
1047 return result;
1048 }
1049 const double inverseSum = 1.0 / sum;
1050 result[0] = num0 * inverseSum;
1051 result[1] = num1 * inverseSum;
1052 result[2] = num2 * inverseSum;
1053
1054 ATH_MSG_VERBOSE("ONNX Prob of n. particles (1): " << result[0]
1055 << " (2): " << result[1]
1056 << " (3): " << result[2]);
1057 return result;
1058 }
1059
1060 std::vector<Amg::Vector2D>
1062 const Eigen::VectorXd& input,
1063 NNinput& rawInput,
1064 const InDet::PixelCluster& pCluster,
1065 int numberSubClusters,
1066 std::vector<Amg::MatrixX>& errors) const {
1067
1068 std::vector<Amg::Vector2D> allPositions;
1069 if (numberSubClusters < 1 || numberSubClusters > static_cast<int>(m_maxSubClusters)) {
1070 return allPositions;
1071 }
1072
1074 if (!onnxCollection.isValid()) {
1075 ATH_MSG_FATAL("Failed to get ONNX network collection with key " << m_readKeyONNX.key());
1076 return allPositions;
1077 }
1078 Ort::Session* posNet = nullptr;
1079 if (numberSubClusters == 1) posNet = onnxCollection->positionNetwork1.get();
1080 else if (numberSubClusters == 2) posNet = onnxCollection->positionNetwork2.get();
1081 else if (numberSubClusters == 3) posNet = onnxCollection->positionNetwork3.get();
1082
1083 if (!posNet) {
1084 ATH_MSG_FATAL("ONNX position network for " << numberSubClusters
1085 << " sub-clusters not found in collection");
1086 return allPositions;
1087 }
1088
1089 Ort::Session& session = *posNet;
1090
1091 // Get expected input dimension from the model
1092 auto inputTypeInfo = session.GetInputTypeInfo(0);
1093 auto tensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();
1094 const int64_t expectedDim = tensorInfo.GetShape()[1];
1095
1096 // Convert input to float
1097 if (static_cast<int64_t>(input.size()) != expectedDim) {
1098 ATH_MSG_FATAL("ONNX position network (" << numberSubClusters
1099 << " sub-clusters) expects input dimension " << expectedDim
1100 << " but got " << input.size() << " — check model/configuration");
1101 return allPositions;
1102 }
1103 std::vector<float> inputData(expectedDim);
1104 for (int i = 0; i < expectedDim; ++i) {
1105 inputData[i] = static_cast<float>(input[i]);
1106 }
1107
1108 // Create input tensor
1109 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
1110 std::vector<int64_t> inputShape = {1, expectedDim};
1111 Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
1112 memInfo, inputData.data(), inputData.size(),
1113 inputShape.data(), inputShape.size());
1114 Ort::AllocatorWithDefaultOptions allocator;
1115 auto inputName = session.GetInputNameAllocated(0, allocator);
1116 auto outputName = session.GetOutputNameAllocated(0, allocator);
1117 const char* inputNames[] = {inputName.get()};
1118 const char* outputNames[] = {outputName.get()};
1119
1120 // Run inference
1121 auto outputTensors = session.Run(
1122 Ort::RunOptions{nullptr},
1123 inputNames, &inputTensor, 1,
1124 outputNames, 1);
1125
1126 // Extract output: expect [1, 5*numberSubClusters]
1127 // Format per sub-cluster: [alpha, mean_x, mean_y, prec_x, prec_y]
1128 const float* outputData = outputTensors[0].GetTensorData<float>();
1129
1130 std::vector<double> positionValues;
1131 positionValues.reserve(numberSubClusters * 2);
1132
1133 for (int iSub = 0; iSub < numberSubClusters; ++iSub) {
1134 const int offset = iSub * 5;
1135 // outputData[offset+0] = alpha (unused)
1136 const double mean_x = outputData[offset + 1];
1137 const double mean_y = outputData[offset + 2];
1138 const double prec_x = outputData[offset + 3];
1139 const double prec_y = outputData[offset + 4];
1140
1141 positionValues.push_back(mean_x);
1142 positionValues.push_back(mean_y);
1143
1144 // Convert precision to RMS and build error matrix
1145 if (prec_x <= 0 || prec_y <= 0) {
1146 ATH_MSG_WARNING("ONNX position network returned non-positive precision for sub-cluster "
1147 << iSub << " (prec_x=" << prec_x << ", prec_y=" << prec_y
1148 << "); using fallback RMS of 0.01");
1149 }
1150 const float rawRmsX = (prec_x > 0) ? std::sqrt(1.0f / prec_x) : 0.01f;
1151 const float rawRmsY = (prec_y > 0) ? std::sqrt(1.0f / prec_y) : 0.01f;
1152 const double rmsX = correctedRMSX(rawRmsX);
1153 const double rmsY = correctedRMSY(rawRmsY, rawInput.vectorOfPitchesY);
1154
1155 Amg::MatrixX erm(2, 2);
1156 erm.setZero();
1157 erm(0, 0) = rmsX * rmsX;
1158 erm(1, 1) = rmsY * rmsY;
1159 errors.push_back(erm);
1160 }
1161
1162 // Convert raw position outputs to detector coordinates
1163 allPositions = getPositionsFromOutput(positionValues, rawInput, pCluster);
1164 return allPositions;
1165 }
1166
1167}//end InDet namespace
#define M_PI
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_FATAL(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
double charge(const T &p)
Definition AtlasPID.h:997
This file defines the class for a collection of AttributeLists where each one is associated with a ch...
static Double_t a
double norm_rawToT(const double input)
double norm_pitch(const double input, bool addIBL=false)
double errorHalfIntervalY(const int nParticles)
double norm_layerNumber(const double input)
double norm_thetaBS(const double input)
double norm_layerType(const double input)
double norm_ToT(const double input)
double back_posX(const double input, const bool recenter=false)
double back_posY(const double input)
double norm_phi(const double input)
double norm_phiBS(const double input)
double norm_theta(const double input)
double norm_etaModule(const double input)
double errorHalfIntervalX(const int nParticles)
This is an Identifier helper class for the Pixel subdetector.
size_t size() const
Number of registered mappings.
double angle(const GeoTrf::Vector2D &a, const GeoTrf::Vector2D &b)
static const Attributes_t empty
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
This class provides an interface to generate or decode an identifier for the upper levels of the dete...
virtual HelperType helper() const
Type of helper, defaulted to 'Unimplemented'.
This is a "hash" representation of an Identifier.
int readoutSide() const
ReadoutSide.
static constexpr std::array< PixelDiodeTree::CellIndexType, 2 > makeCellIndex(T local_x_idx, T local_y_idx)
Create a 2D cell index from the indices in local-x (phi, row) and local-y (eta, column) direction.
Class used to describe the design of a module (diode segmentation and readout scheme).
virtual SiDiodesParameters parameters(const SiCellId &cellId) const
readout or diode id -> position, size
virtual int numberOfConnectedCells(const SiReadoutCellId &readoutId) const
readout id -> id of connected diodes
PixelReadoutTechnology getReadoutTechnology() const
PixelDiodeTree::DiodeProxy diodeProxyFromIdx(const std::array< PixelDiodeTree::IndexType, 2 > &idx) const
SiLocalPosition positionFromColumnRow(const int column, const int row) const
Given row and column index of a diode, return position of diode center ALTERNATIVE/PREFERED way is to...
virtual SiReadoutCellId readoutIdOfCell(const SiCellId &cellId) const
diode id -> readout id
static InDetDD::PixelDiodeType getDiodeType(const PixelDiodeTree::DiodeProxy &diode_proxy)
virtual SiCellId cellIdOfPosition(const SiLocalPosition &localPos) const
position -> id
virtual double width() const
Method to calculate average width of a module.
static unsigned int getFE(const PixelDiodeTree::DiodeProxy &diode_proxy)
Identifier for the strip or pixel cell.
Definition SiCellId.h:29
int phiIndex() const
Get phi index. Equivalent to strip().
Definition SiCellId.h:122
bool isValid() const
Test if its in a valid state.
Definition SiCellId.h:136
int etaIndex() const
Get eta index.
Definition SiCellId.h:114
Class to hold geometrical description of a silicon detector element.
virtual SiCellId cellIdFromIdentifier(const Identifier &identifier) const override final
SiCellId from Identifier.
virtual const SiDetectorDesign & design() const override final
access to the local description (inline):
Class to handle the position of the centre and the width of a diode or a cluster of diodes Version 1....
const SiLocalPosition & width() const
width of the diodes:
Class to represent a position in the natural frame of a silicon sensor, for Pixel and SCT For Pixel: ...
double xPhi() const
position along phi direction:
double xEta() const
position along eta direction:
virtual const Amg::Vector3D & normal() const override final
Get reconstruction local normal axes in global frame.
virtual IdentifierHash identifyHash() const override final
identifier hash (inline)
HepGeom::Point3D< double > globalPosition(const HepGeom::Point3D< double > &localPos) const
transform a reconstruction local position into a global position (inline):
const AtlasDetectorID * getIdHelper() const
Returns the id helper (inline).
std::vector< double > assembleInputRunII(NNinput &input) const
void addTrackInfoToInput(NNinput &input, const Trk::Surface &pixelSurface, const Trk::TrackParameters &trackParsAtSurface, const double tanl) const
Gaudi::Property< unsigned int > m_maxSubClusters
SG::ReadCondHandleKey< PixelChargeCalibCondData > m_chargeDataKey
std::vector< double > estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputVector &input) const
std::vector< Amg::Vector2D > estimatePositionsONNX(const Eigen::VectorXd &input, NNinput &rawInput, const InDet::PixelCluster &pCluster, int numberSubClusters, std::vector< Amg::MatrixX > &errors) const
Gaudi::Property< unsigned int > m_sizeX
double correctedRMSY(double posPixels, std::vector< float > &pitches) const
ReturnType(::TTrainedNetwork::* m_calculateOutput)(const InputType &input) const
NNinput createInput(const InDet::PixelCluster &pCluster, Amg::Vector3D &beamSpotPosition, double &tanl) const
virtual StatusCode initialize() override
Gaudi::Property< double > m_correctLorShiftBarrelWithoutTracks
Gaudi::Property< std::size_t > m_outputNodesPos1
Gaudi::Property< std::vector< std::size_t > > m_outputNodesPos2
ToolHandle< ISiLorentzAngleTool > m_pixelLorentzAngleTool
Gaudi::Property< std::vector< std::size_t > > m_outputNodesPos3
std::vector< Amg::Vector2D > estimatePositionsLWTNN(NnClusterizationFactory::InputVector &input, NNinput &rawInput, const InDet::PixelCluster &pCluster, int numberSubClusters, std::vector< Amg::MatrixX > &errors) const
SG::ReadCondHandleKey< OnnxNNCollection > m_readKeyONNX
std::vector< double > assembleInputRunI(NNinput &input) const
SG::ReadCondHandleKey< LWTNNCollection > m_readKeyJSON
std::vector< Amg::Vector2D > estimatePositions(const InDet::PixelCluster &pCluster, Amg::Vector3D &beamSpotPosition, std::vector< Amg::MatrixX > &errors, int numberSubClusters) const
Gaudi::Property< std::vector< std::string > > m_nnOrder
Gaudi::Property< double > m_correctLorShiftBarrelWithTracks
Gaudi::Property< bool > m_useTTrainedNetworks
std::vector< double > estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection, const std::vector< double > &inputData) const
static constexpr std::array< unsigned int, kNNetworkTypes > m_nParticleGroup
SG::ReadCondHandleKey< TTrainedNetworkCollection > m_readKeyWithoutTrack
std::vector< double >(InDet::NnClusterizationFactory::* m_assembleInput)(NNinput &input) const
std::vector< Eigen::VectorXd > InputVector
NnClusterizationFactory(const std::string &name, const std::string &n, const IInterface *p)
std::vector< Amg::Vector2D > estimatePositionsTTN(const TTrainedNetworkCollection &nn_collection, const std::vector< double > &inputData, const NNinput &input, const InDet::PixelCluster &pCluster, int numberSubClusters, std::vector< Amg::MatrixX > &errors) const
static const std::array< std::regex, kNNetworkTypes > m_nnNames
Gaudi::Property< bool > m_useRecenteringNNWithouTracks
static constexpr std::array< std::string_view, kNNetworkTypes > s_nnTypeNames
InputVector eigenInput(NNinput &input) const
void getErrorMatrixFromOutput(std::vector< double > &outputX, std::vector< double > &outputY, std::vector< Amg::MatrixX > &errorMatrix, int nParticles) const
Gaudi::Property< bool > m_useRecenteringNNWithTracks
std::vector< std::vector< unsigned int > > m_NNId
std::vector< Amg::Vector2D > getPositionsFromOutput(std::vector< double > &output, const NNinput &input, const InDet::PixelCluster &pCluster) const
static double correctedRMSX(double posPixels)
size_t calculateVectorDimension(const bool useTrackInfo) const
Gaudi::Property< unsigned int > m_sizeY
SG::ReadCondHandleKey< TTrainedNetworkCollection > m_readKeyWithTrack
std::vector< double > estimateNumberOfParticlesONNX(const Eigen::VectorXd &input) const
std::vector< double > estimateNumberOfParticles(const InDet::PixelCluster &pCluster, Amg::Vector3D &beamSpotPosition) const
virtual const InDetDD::SiDetectorElement * detectorElement() const override final
return the detector element corresponding to this PRD The pointer will be zero if the det el is not d...
float getCharge(InDetDD::PixelDiodeType type, unsigned int moduleHash, unsigned int FE, float ToT) const
This is an Identifier helper class for the Pixel subdetector.
Definition PixelID.h:69
int eta_index(const Identifier &id) const
Definition PixelID.h:640
int layer_disk(const Identifier &id) const
Definition PixelID.h:602
Identifier wafer_id(int barrel_ec, int layer_disk, int phi_module, int eta_module) const
For a single crystal.
Definition PixelID.h:355
int barrel_ec(const Identifier &id) const
Values of different levels (failure returns 0).
Definition PixelID.h:595
IdentifierHash wafer_hash(Identifier wafer_id) const
wafer hash from id
Definition PixelID.h:378
int eta_module(const Identifier &id) const
Definition PixelID.h:627
int phi_index(const Identifier &id) const
Definition PixelID.h:634
std::vector< Double_t > calculateOutputValues(std::vector< Double_t > &input) const
const Amg::Vector3D & momentum() const
Access method for the momentum.
const Amg::Vector2D & localPosition() const
return the local position reference
Identifier identify() const
return the identifier
const std::vector< Identifier > & rdoList() const
return the List of rdo identifiers (pointers)
Abstract Base Class for tracking surfaces.
Definition Surface.h:79
const Amg::Transform3D & transform() const
Returns HepGeom::Transform3D by reference.
STL class.
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > MatrixX
Dynamic Matrix - dynamic allocation.
Eigen::Matrix< double, 2, 1 > Vector2D
Eigen::Matrix< double, 3, 1 > Vector3D
Primary Vertex Finder.
@ locY
local cartesian
Definition ParamDefs.h:38
@ locX
Definition ParamDefs.h:37
ParametersBase< TrackParametersDim, Charged > TrackParameters
Helper class to access parameters of a diode.
std::vector< float > vectorOfPitchesY