ATLAS Offline Software
FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 
4  This class acts as the interface to an ONNX model. It handles loading model
5  the model, initializing the ORT session, and running inference. It is decoupled
6  from the ATLAS EDM as much as possible. The FlavorTagDiscriminants::GNN class
7  handles the interaction with the ATLAS EDM.
8 */
9 
10 #ifndef ONNXUTIL_H
11 #define ONNXUTIL_H
12 
13 #include <onnxruntime_cxx_api.h>
14 
15 #include "nlohmann/json.hpp"
16 #include "lwtnn/parse_json.hh"
17 
19 
20 #include <map> //also has std::pair
21 #include <vector>
22 #include <string>
23 #include <memory>
24 
25 namespace FlavorTagDiscriminants {
26 
27  // the first element is the input data, the second is the shape
28  using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>;
29 
30  enum class OnnxModelVersion{UNKNOWN, V0, V1, V2};
31 
34  { OnnxModelVersion::V0, "v0" },
35  { OnnxModelVersion::V1, "v1" },
36  { OnnxModelVersion::V2, "v2" },
37  })
38 
39  //
40  // Utility class that loads the onnx model from the given path
41  // and runs inference based on the user given inputs
42 
43  class OnnxUtil final{
44 
45  public:
46  using OutputConfig = std::vector<OnnxOutput>;
47 
48  OnnxUtil(const std::string& path_to_onnx);
49 
50  void initialize();
51 
52  struct InferenceOutput {
53  std::map<std::string, float> singleFloat;
54  std::map<std::string, std::vector<char>> vecChar;
55  std::map<std::string, std::vector<float>> vecFloat;
56  };
57 
58  InferenceOutput runInference(std::map<std::string, Inputs>& gnn_inputs) const;
59 
60  const lwt::GraphConfig getLwtConfig() const;
61  const nlohmann::json& getMetadata() const;
62  const OutputConfig& getOutputConfig() const;
63  OnnxModelVersion getOnnxModelVersion() const;
64  const std::string& getModelName() const;
65 
66  private:
67  const nlohmann::json loadMetadata(const std::string& key) const;
68  const std::string determineModelName() const;
69 
70  nlohmann::json m_metadata;
71  std::string m_path_to_onnx;
72 
73  std::unique_ptr< Ort::Session > m_session;
74  std::unique_ptr< Ort::Env > m_env;
75 
76  size_t m_num_inputs;
77  size_t m_num_outputs;
78  std::string m_model_name;
79  std::vector<std::string> m_input_node_names;
80  OutputConfig m_output_nodes;
81 
82  OnnxModelVersion m_onnx_model_version = OnnxModelVersion::UNKNOWN;
83 
84  }; // Class OnnxUtil
85 } // end of FlavorTagDiscriminants namespace
86 #endif //ONNXUTIL_H
FlavorTagDiscriminants::OnnxModelVersion::V2
@ V2
FlavorTagDiscriminants::NLOHMANN_JSON_SERIALIZE_ENUM
NLOHMANN_JSON_SERIALIZE_ENUM(OnnxModelVersion, { { OnnxModelVersion::UNKNOWN, "" }, { OnnxModelVersion::V0, "v0" }, { OnnxModelVersion::V1, "v1" }, { OnnxModelVersion::V2, "v2" }, }) class OnnxUtil final
Definition: FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h:32
FlavorTagDiscriminants
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
FlavorTagDiscriminants::Inputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
Definition: FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h:28
json
nlohmann::json json
Definition: HistogramDef.cxx:9
initialize
void initialize()
Definition: run_EoverP.cxx:894
python.HanMetadata.getMetadata
def getMetadata(f, key)
Definition: HanMetadata.py:12
OnnxUtil
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h:14
OnnxOutput.h
FlavorTagDiscriminants::OnnxModelVersion
OnnxModelVersion
Definition: FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h:30
FlavorTagDiscriminants::OnnxModelVersion::UNKNOWN
@ UNKNOWN
FlavorTagDiscriminants::OnnxModelVersion::V0
@ V0
FlavorTagDiscriminants::OnnxModelVersion::V1
@ V1
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37