ATLAS Offline Software
JSSMLTool.h
Go to the documentation of this file.
1 // Dear emacs, this is -*- c++ -*-
2 // Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 #ifndef BOOSTEDJETTAGGERS_JSSMLTOOL_H
4 #define BOOSTEDJETTAGGERS_JSSMLTOOL_H
5 
6 #include "IJSSMLTool.h"
7 #include "AsgTools/AsgTool.h"
8 
9 // ONNX Runtime include(s).
10 #include <onnxruntime_cxx_api.h>
11 
12 // xAOD
13 #include "xAODJet/JetContainer.h"
15 // System include(s).
16 #include <memory> //unique_ptr
17 #include <string>
18 #include <map>
19 #include <vector>
20 #include <cstdint>
21 
22 class TH2D;
23 
24 
25 namespace AthONNX {
26 
41 
45 
46 class JSSMLTool
47  : public asg::AsgTool,
48  virtual public IJSSMLTool {
50 
51  public:
52  JSSMLTool (const std::string& name);
53 
55  virtual StatusCode initialize() override;
57  virtual double retrieveConstituentsScore(std::vector<TH2D> Images) const override;
58  virtual double retrieveConstituentsScore(std::vector<std::vector<float>> constituents) const override;
59  virtual double retrieveConstituentsScore(std::vector<std::vector<float>> constituents, std::vector<std::vector<std::vector<float>>> interactions) const override;
60  virtual double retrieveConstituentsScore(std::vector<std::vector<float>> constituents, std::vector<std::vector<std::vector<float>>> interactions, std::vector<std::vector<float>> mask) const override;
61  virtual double retrieveHighLevelScore(std::map<std::string, double> JSSVars) const override;
62 
63  // basic tool functions
64  std::vector<float> ReadJetImagePixels( std::vector<TH2D> Images ) const;
65  std::vector<float> ReadJSSInputs(std::map<std::string, double> JSSVars) const;
66  std::vector<int> ReadOutputLabels() const;
67 
68  // extra methods
69  StatusCode SetScaler(std::map<std::string, std::vector<double>> scaler) override;
70 
72  std::unique_ptr< Ort::Session > m_session;
73  std::unique_ptr< Ort::Env > m_env;
74 
75  std::map<std::string, std::vector<double>> m_scaler;
76  std::map<int, std::string> m_JSSInputMap;
77 
78  private:
79 
81  std::string m_modelFileName;
82  std::string m_pixelFileName;
83  std::string m_labelFileName;
84 
85  // input node info
86  std::vector<int64_t> m_input_node_dims;
88  std::vector<const char*> m_input_node_names;
89 
90  // output node info
91  std::vector<int64_t> m_output_node_dims;
93  std::vector<const char*> m_output_node_names;
94 
95  // some configs
97 
98  int m_nvars{};
99 
100  }; // class JSSMLTool
101 
102 } // namespace AthONNX
103 
104 #endif // BOOSTEDJETTAGGERS_JSSMLTOOL_H
AthONNX::JSSMLTool::m_nPixelsX
int m_nPixelsX
Definition: JSSMLTool.h:96
AthONNX::JSSMLTool::SetScaler
StatusCode SetScaler(std::map< std::string, std::vector< double >> scaler) override
Definition: JSSMLTool.cxx:513
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
AthONNX::JSSMLTool::m_nPixelsY
int m_nPixelsY
Definition: JSSMLTool.h:96
AthONNX::JSSMLTool::m_num_output_nodes
size_t m_num_output_nodes
Definition: JSSMLTool.h:92
AthONNX::JSSMLTool::m_modelFileName
std::string m_modelFileName
Name of the model file to load.
Definition: JSSMLTool.h:81
xAOD::scaler
setOverV setNumU setNumY setODFibSel setYDetCS setYLhcCS setXRPotCS setXStatCS setXBeamCS scaler
Definition: ALFAData_v1.cxx:66
AthONNX::JSSMLTool::JSSMLTool
JSSMLTool(const std::string &name)
Definition: JSSMLTool.cxx:73
AthONNX::JSSMLTool::m_env
std::unique_ptr< Ort::Env > m_env
Definition: JSSMLTool.h:73
TrackCaloClusterContainer.h
AthONNX::JSSMLTool::m_labelFileName
std::string m_labelFileName
Definition: JSSMLTool.h:83
IJSSMLTool.h
AthONNX::JSSMLTool::m_input_node_names
std::vector< const char * > m_input_node_names
Definition: JSSMLTool.h:88
AthONNX::JSSMLTool::ReadOutputLabels
std::vector< int > ReadOutputLabels() const
Definition: JSSMLTool.cxx:63
python.utils.AtlRunQueryLookup.mask
string mask
Definition: AtlRunQueryLookup.py:459
AthONNX::JSSMLTool::m_pixelFileName
std::string m_pixelFileName
Definition: JSSMLTool.h:82
AthONNX::JSSMLTool::m_output_node_names
std::vector< const char * > m_output_node_names
Definition: JSSMLTool.h:93
AthONNX::JSSMLTool::m_scaler
std::map< std::string, std::vector< double > > m_scaler
Definition: JSSMLTool.h:75
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthONNX::JSSMLTool::m_nvars
int m_nvars
Definition: JSSMLTool.h:98
AthONNX::JSSMLTool::m_input_node_dims
std::vector< int64_t > m_input_node_dims
Definition: JSSMLTool.h:86
AthONNX::JSSMLTool::m_nPixelsZ
int m_nPixelsZ
Definition: JSSMLTool.h:96
AthONNX::JSSMLTool::m_JSSInputMap
std::map< int, std::string > m_JSSInputMap
Definition: JSSMLTool.h:76
AthONNX::JSSMLTool::initialize
virtual StatusCode initialize() override
Function initialising the tool.
Definition: JSSMLTool.cxx:83
AthONNX::JSSMLTool::ReadJSSInputs
std::vector< float > ReadJSSInputs(std::map< std::string, double > JSSVars) const
Definition: JSSMLTool.cxx:39
AthONNX::JSSMLTool::m_output_node_dims
std::vector< int64_t > m_output_node_dims
Definition: JSSMLTool.h:91
AthONNX
Definition: IJSSMLTool.h:23
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
AthONNX::IJSSMLTool
Definition: IJSSMLTool.h:25
JetContainer.h
AthONNX::JSSMLTool::m_session
std::unique_ptr< Ort::Session > m_session
Definition: JSSMLTool.h:72
ASG_TOOL_CLASS
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Definition: AsgToolMacros.h:68
AthONNX::JSSMLTool
Tool using the ONNX Runtime C++ API to retrive constituents based model for boson jet tagging.
Definition: JSSMLTool.h:48
AthONNX::JSSMLTool::m_num_input_nodes
size_t m_num_input_nodes
Definition: JSSMLTool.h:87
AthONNX::JSSMLTool::retrieveConstituentsScore
virtual double retrieveConstituentsScore(std::vector< TH2D > Images) const override
Function executing the tool for a single event.
Definition: JSSMLTool.cxx:162
AsgTool.h
AthONNX::JSSMLTool::retrieveHighLevelScore
virtual double retrieveHighLevelScore(std::map< std::string, double > JSSVars) const override
Definition: JSSMLTool.cxx:447
AthONNX::JSSMLTool::ReadJetImagePixels
std::vector< float > ReadJetImagePixels(std::vector< TH2D > Images) const
Definition: JSSMLTool.cxx:17