ATLAS Offline Software
EvaluateModel.cxx
Go to the documentation of this file.
1 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 
3 // Local include(s).
4 #include "EvaluateModel.h"
5 
6 // Framework include(s).
8 #include "EvaluateUtils.h"
10 
11 namespace AthOnnx {
12 
14  // Fetch tools
15  ATH_CHECK( m_onnxTool.retrieve() );
16  m_onnxTool->printModelInfo();
17 
18  /*****
19  The combination of no. of batches and batch size shouldn't cross
20  the total smple size which is 10000 for this example
21  *****/
22  if(m_batchSize > 10000){
23  ATH_MSG_INFO("The total no. of sample crossed the no. of available sample ....");
24  return StatusCode::FAILURE;
25  }
26  // read input file, and the target file for comparison.
27  std::string pixelFilePath = PathResolver::find_file(m_pixelFileName.value(), "CALIBPATH", PathResolver::RecursiveSearch);
28  ATH_MSG_INFO( "Using pixel file: " << pixelFilePath );
29 
31  ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
32 
33  return StatusCode::SUCCESS;
34 }
35 
36  StatusCode EvaluateModel::execute( const EventContext& /*ctx*/ ) const {
37 
38  // prepare inputs
39  std::vector<float> inputData;
40  for (int ibatch = 0; ibatch < m_batchSize; ibatch++){
41  const std::vector<std::vector<float> >& imageData = m_input_tensor_values_notFlat[ibatch];
42  std::vector<float> flatten = AthOnnxUtils::flattenNestedVectors(imageData);
43  inputData.insert(inputData.end(), flatten.begin(), flatten.end());
44  }
45 
46  int64_t batchSize = m_onnxTool->getBatchSize(inputData.size());
47  ATH_MSG_INFO("Batch size is " << batchSize << ".");
48  assert(batchSize == m_batchSize);
49 
50  // bind the input data to the input tensor
51  std::vector<Ort::Value> inputTensors;
52  ATH_CHECK( m_onnxTool->addInput(inputTensors, inputData, 0, batchSize) );
53 
54  // reserve space for output data and bind it to the output tensor
55  std::vector<float> outputScores;
56  std::vector<Ort::Value> outputTensors;
57  ATH_CHECK( m_onnxTool->addOutput(outputTensors, outputScores, 0, batchSize) );
58 
59  // run the inference
60  // the output will be filled to the outputScores.
61  ATH_CHECK( m_onnxTool->inference(inputTensors, outputTensors) );
62 
63  ATH_MSG_INFO("Label for the input test data: ");
64  for(int ibatch = 0; ibatch < m_batchSize; ibatch++){
65  float max = -999;
66  int max_index = 0;
67  for (int i = 0; i < 10; i++){
68  ATH_MSG_DEBUG("Score for class "<< i <<" = "<<outputScores[i] << " in batch " << ibatch);
69  int index = i + ibatch * 10;
70  if (max < outputScores[index]){
71  max = outputScores[index];
72  max_index = index;
73  }
74  }
75  ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<outputScores[max_index] << " in batch " << ibatch);
76  }
77 
78  return StatusCode::SUCCESS;
79  }
80 
81 } // namespace AthOnnx
AthOnnx::EvaluateModel::initialize
virtual StatusCode initialize() override
Function initialising the algorithm.
Definition: EvaluateModel.cxx:13
PathResolver::RecursiveSearch
@ RecursiveSearch
Definition: PathResolver.h:28
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
PathResolver::find_file
static std::string find_file(const std::string &logical_file_name, const std::string &search_path, SearchType search_type=LocalSearch)
Definition: PathResolver.cxx:251
index
Definition: index.py:1
max
constexpr double max()
Definition: ap_fixedTest.cxx:33
EvaluateUtils.h
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
AthOnnx::EvaluateModel::execute
virtual StatusCode execute(const EventContext &ctx) const override
Function executing the algorithm for a single event.
Definition: EvaluateModel.cxx:36
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
PathResolver.h
AthOnnx::EvaluateModel::m_onnxTool
ToolHandle< IOnnxRuntimeInferenceTool > m_onnxTool
Tool handler for onnx inference session.
Definition: EvaluateModel.h:59
EvaluateUtils::read_mnist_pixel_notFlat
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat(const std::string &full_path)
Definition: EvaluateUtils.cxx:11
DeMoScan.index
string index
Definition: DeMoScan.py:364
AthOnnx::EvaluateModel::m_batchSize
Gaudi::Property< int > m_batchSize
Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
Definition: EvaluateModel.h:56
AthOnnxUtils::flattenNestedVectors
std::vector< T > flattenNestedVectors(const std::vector< std::vector< T >> &features)
Definition: OnnxUtils.h:20
OnnxUtils.h
AthOnnx::EvaluateModel::m_pixelFileName
Gaudi::Property< std::string > m_pixelFileName
Name of the model file to load.
Definition: EvaluateModel.h:51
EvaluateModel.h
AthOnnx::EvaluateModel::m_input_tensor_values_notFlat
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
Definition: EvaluateModel.h:63
AthOnnx
Namespace holding all of the Onnx Runtime example code.
Definition: EvaluateModel.cxx:11