ATLAS Offline Software
Loading...
Searching...
No Matches
EvaluateModel.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3// Local include(s).
4#include "EvaluateModel.h"
5
6// Framework include(s).
8#include "EvaluateUtils.h"
10
11namespace AthOnnx {
12
14 // Fetch tools
15 ATH_CHECK( m_onnxTool.retrieve() );
16 m_onnxTool->printModelInfo();
17
18 /*****
19 The combination of no. of batches and batch size shouldn't cross
20 the total smple size which is 10000 for this example
21 *****/
22 if(m_batchSize > 10000){
23 ATH_MSG_INFO("The total no. of sample crossed the no. of available sample ....");
24 return StatusCode::FAILURE;
25 }
26 // read input file, and the target file for comparison.
27 std::string pixelFilePath = PathResolver::find_calib_file(m_pixelFileName.value());
28 ATH_MSG_INFO( "Using pixel file: " << pixelFilePath );
29
31 ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
32
33 return StatusCode::SUCCESS;
34}
35
36 StatusCode EvaluateModel::execute( const EventContext& /*ctx*/ ) const {
37
38 // prepare inputs
39 std::vector<float> inputData;
40 for (int ibatch = 0; ibatch < m_batchSize; ibatch++){
41 const std::vector<std::vector<float> >& imageData = m_input_tensor_values_notFlat[ibatch];
42 std::vector<float> flatten = AthOnnxUtils::flattenNestedVectors(imageData);
43 inputData.insert(inputData.end(), flatten.begin(), flatten.end());
44 }
45
46 int64_t batchSize = m_onnxTool->getBatchSize(inputData.size());
47 ATH_MSG_INFO("Batch size is " << batchSize << ".");
48 assert(batchSize == m_batchSize);
49
50 // bind the input data to the input tensor
51 std::vector<Ort::Value> inputTensors;
52 ATH_CHECK( m_onnxTool->addInput(inputTensors, inputData, 0, batchSize) );
53
54 // reserve space for output data and bind it to the output tensor
55 std::vector<float> outputScores;
56 std::vector<Ort::Value> outputTensors;
57 ATH_CHECK( m_onnxTool->addOutput(outputTensors, outputScores, 0, batchSize) );
58
59 // run the inference
60 // the output will be filled to the outputScores.
61 ATH_CHECK( m_onnxTool->inference(inputTensors, outputTensors) );
62
63 ATH_MSG_INFO("Label for the input test data: ");
64 for(int ibatch = 0; ibatch < m_batchSize; ibatch++){
65 float max = -999;
66 int max_index = 0;
67 for (int i = 0; i < 10; i++){
68 ATH_MSG_DEBUG("Score for class "<< i <<" = "<<outputScores[i] << " in batch " << ibatch);
69 int index = i + ibatch * 10;
70 if (max < outputScores[index]){
71 max = outputScores[index];
72 max_index = index;
73 }
74 }
75 ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<outputScores[max_index] << " in batch " << ibatch);
76 }
77
78 return StatusCode::SUCCESS;
79 }
80
81} // namespace AthOnnx
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
#define max(a, b)
Definition cfImp.cxx:41
virtual StatusCode execute(const EventContext &ctx) const override
Function executing the algorithm for a single event.
Gaudi::Property< std::string > m_pixelFileName
Name of the model file to load.
virtual StatusCode initialize() override
Function initialising the algorithm.
ToolHandle< IOnnxRuntimeInferenceTool > m_onnxTool
Tool handler for onnx inference session.
Gaudi::Property< int > m_batchSize
Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
static std::string find_calib_file(const std::string &logical_file_name)
std::vector< T > flattenNestedVectors(const std::vector< std::vector< T > > &features)
Definition OnnxUtils.h:20
Namespace holding all of the Onnx Runtime example code.
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat(const std::string &full_path)
Definition index.py:1