Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
JSSMLTool.h
Go to the documentation of this file.
1 // Dear emacs, this is -*- c++ -*-
2 // Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 #ifndef BOOSTEDJETTAGGERS_JSSMLTOOL_H
4 #define BOOSTEDJETTAGGERS_JSSMLTOOL_H
5 
6 #include "IJSSMLTool.h"
7 #include "AsgTools/AsgTool.h"
8 
9 // ONNX Runtime include(s).
10 #include <onnxruntime_cxx_api.h>
11 
12 // xAOD
13 #include "xAODJet/JetContainer.h"
15 // System include(s).
16 #include <memory> //unique_ptr
17 #include <string>
18 #include <map>
19 #include <vector>
20 #include <cstdint>
21 
22 class TH2D;
23 
24 
25 namespace AthONNX {
26 
41 
45 
46 class JSSMLTool
47  : public asg::AsgTool,
48  virtual public IJSSMLTool {
50 
51  public:
52  JSSMLTool (const std::string& name);
53 
55  virtual StatusCode initialize() override;
57  virtual double retrieveConstituentsScore(std::vector<TH2D> Images) const override;
58  virtual double retrieveConstituentsScore(std::vector<std::vector<float>> constituents) const override;
59  virtual double retrieveConstituentsScore(std::vector<std::vector<float>> constituents, std::vector<std::vector<std::vector<float>>> interactions) const override;
60  virtual double retrieveHighLevelScore(std::map<std::string, double> JSSVars) const override;
61 
62  // basic tool functions
63  std::vector<float> ReadJetImagePixels( std::vector<TH2D> Images ) const;
64  std::vector<float> ReadJSSInputs(std::map<std::string, double> JSSVars) const;
65  std::vector<int> ReadOutputLabels() const;
66 
67  // extra methods
68  StatusCode SetScaler(std::map<std::string, std::vector<double>> scaler) override;
69 
71  std::unique_ptr< Ort::Session > m_session;
72  std::unique_ptr< Ort::Env > m_env;
73 
74  std::map<std::string, std::vector<double>> m_scaler;
75  std::map<int, std::string> m_JSSInputMap;
76 
77  private:
78 
80  std::string m_modelFileName;
81  std::string m_pixelFileName;
82  std::string m_labelFileName;
83 
84  // input node info
85  std::vector<int64_t> m_input_node_dims;
87  std::vector<const char*> m_input_node_names;
88 
89  // output node info
90  std::vector<int64_t> m_output_node_dims;
92  std::vector<const char*> m_output_node_names;
93 
94  // some configs
96 
97  int m_nvars{};
98 
99  }; // class JSSMLTool
100 
101 } // namespace AthONNX
102 
103 #endif // BOOSTEDJETTAGGERS_JSSMLTOOL_H
AthONNX::JSSMLTool::m_nPixelsX
int m_nPixelsX
Definition: JSSMLTool.h:95
AthONNX::JSSMLTool::SetScaler
StatusCode SetScaler(std::map< std::string, std::vector< double >> scaler) override
Definition: JSSMLTool.cxx:410
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
AthONNX::JSSMLTool::m_nPixelsY
int m_nPixelsY
Definition: JSSMLTool.h:95
AthONNX::JSSMLTool::m_num_output_nodes
size_t m_num_output_nodes
Definition: JSSMLTool.h:91
AthONNX::JSSMLTool::m_modelFileName
std::string m_modelFileName
Name of the model file to load.
Definition: JSSMLTool.h:80
xAOD::scaler
setOverV setNumU setNumY setODFibSel setYDetCS setYLhcCS setXRPotCS setXStatCS setXBeamCS scaler
Definition: ALFAData_v1.cxx:66
AthONNX::JSSMLTool::JSSMLTool
JSSMLTool(const std::string &name)
Definition: JSSMLTool.cxx:73
AthONNX::JSSMLTool::m_env
std::unique_ptr< Ort::Env > m_env
Definition: JSSMLTool.h:72
TrackCaloClusterContainer.h
AthONNX::JSSMLTool::m_labelFileName
std::string m_labelFileName
Definition: JSSMLTool.h:82
IJSSMLTool.h
AthONNX::JSSMLTool::m_input_node_names
std::vector< const char * > m_input_node_names
Definition: JSSMLTool.h:87
AthONNX::JSSMLTool::ReadOutputLabels
std::vector< int > ReadOutputLabels() const
Definition: JSSMLTool.cxx:63
AthONNX::JSSMLTool::m_pixelFileName
std::string m_pixelFileName
Definition: JSSMLTool.h:81
AthONNX::JSSMLTool::m_output_node_names
std::vector< const char * > m_output_node_names
Definition: JSSMLTool.h:92
AthONNX::JSSMLTool::m_scaler
std::map< std::string, std::vector< double > > m_scaler
Definition: JSSMLTool.h:74
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthONNX::JSSMLTool::m_nvars
int m_nvars
Definition: JSSMLTool.h:97
AthONNX::JSSMLTool::m_input_node_dims
std::vector< int64_t > m_input_node_dims
Definition: JSSMLTool.h:85
AthONNX::JSSMLTool::m_nPixelsZ
int m_nPixelsZ
Definition: JSSMLTool.h:95
AthONNX::JSSMLTool::m_JSSInputMap
std::map< int, std::string > m_JSSInputMap
Definition: JSSMLTool.h:75
AthONNX::JSSMLTool::initialize
virtual StatusCode initialize() override
Function initialising the tool.
Definition: JSSMLTool.cxx:83
AthONNX::JSSMLTool::ReadJSSInputs
std::vector< float > ReadJSSInputs(std::map< std::string, double > JSSVars) const
Definition: JSSMLTool.cxx:39
AthONNX::JSSMLTool::m_output_node_dims
std::vector< int64_t > m_output_node_dims
Definition: JSSMLTool.h:90
AthONNX
Definition: IJSSMLTool.h:23
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
AthONNX::IJSSMLTool
Definition: IJSSMLTool.h:25
JetContainer.h
AthONNX::JSSMLTool::m_session
std::unique_ptr< Ort::Session > m_session
Definition: JSSMLTool.h:71
ASG_TOOL_CLASS
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Definition: AsgToolMacros.h:68
AthONNX::JSSMLTool
Tool using the ONNX Runtime C++ API to retrive constituents based model for boson jet tagging.
Definition: JSSMLTool.h:48
AthONNX::JSSMLTool::m_num_input_nodes
size_t m_num_input_nodes
Definition: JSSMLTool.h:86
AthONNX::JSSMLTool::retrieveConstituentsScore
virtual double retrieveConstituentsScore(std::vector< TH2D > Images) const override
Function executing the tool for a single event.
Definition: JSSMLTool.cxx:154
AsgTool.h
AthONNX::JSSMLTool::retrieveHighLevelScore
virtual double retrieveHighLevelScore(std::map< std::string, double > JSSVars) const override
Definition: JSSMLTool.cxx:344
AthONNX::JSSMLTool::ReadJetImagePixels
std::vector< float > ReadJetImagePixels(std::vector< TH2D > Images) const
Definition: JSSMLTool.cxx:17