ATLAS Offline Software
Loading...
Searching...
No Matches
JSSMLTool.h
Go to the documentation of this file.
1// Dear emacs, this is -*- c++ -*-
2// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3#ifndef BOOSTEDJETTAGGERS_JSSMLTOOL_H
4#define BOOSTEDJETTAGGERS_JSSMLTOOL_H
5
6#include "IJSSMLTool.h"
7#include "AsgTools/AsgTool.h"
8
9// ONNX Runtime include(s).
10#include <onnxruntime_cxx_api.h>
11
12// xAOD
15// System include(s).
16#include <memory> //unique_ptr
17#include <string>
18#include <map>
19#include <vector>
20#include <cstdint>
21
22class TH2D;
23
24
25namespace AthONNX {
26
41
45
47 : public asg::AsgTool,
48 virtual public IJSSMLTool {
50
51 public:
52 JSSMLTool (const std::string& name);
53
55 virtual StatusCode initialize() override;
57 virtual double retrieveConstituentsScore(std::vector<TH2D> Images) const override;
58 virtual double retrieveConstituentsScore(std::vector<std::vector<float>> constituents) const override;
59 virtual double retrieveConstituentsScore(std::vector<std::vector<float>> constituents, std::vector<std::vector<std::vector<float>>> interactions) const override;
60 virtual double retrieveConstituentsScore(std::vector<std::vector<float>> constituents, std::vector<std::vector<std::vector<float>>> interactions, std::vector<std::vector<float>> mask) const override;
61 virtual double retrieveHighLevelScore(std::map<std::string, double> JSSVars) const override;
62
63 virtual std::vector<float> retrieveConstituentsScoreMultiClass(const std::vector<std::vector<float>>& constituents, const std::vector<std::vector<std::vector<float>>>& interactions, const std::vector<std::vector<float>>& mask) const override;
64
65 // basic tool functions
66 std::vector<float> ReadJetImagePixels( std::vector<TH2D> Images ) const;
67 std::vector<float> ReadJSSInputs(std::map<std::string, double> JSSVars) const;
68 std::vector<int> ReadOutputLabels() const;
69
70 // extra methods
71 StatusCode SetScaler(std::map<std::string, std::vector<double>> scaler) override;
72
74 std::unique_ptr< Ort::Session > m_session;
75 std::unique_ptr< Ort::Env > m_env;
76
77 std::map<std::string, std::vector<double>> m_scaler;
78 std::map<int, std::string> m_JSSInputMap;
79
80 private:
81
83 std::string m_modelFileName;
84 std::string m_pixelFileName;
85 std::string m_labelFileName;
86
87 // input node info
88 std::vector<int64_t> m_input_node_dims;
90 std::vector<const char*> m_input_node_names;
91
92 // output node info
93 std::vector<int64_t> m_output_node_dims;
95 std::vector<const char*> m_output_node_names;
96
97 // some configs
99
100 int m_nvars{};
101
102 }; // class JSSMLTool
103
104} // namespace AthONNX
105
106#endif // BOOSTEDJETTAGGERS_JSSMLTOOL_H
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
std::string m_modelFileName
Name of the model file to load.
Definition JSSMLTool.h:83
std::map< std::string, std::vector< double > > m_scaler
Definition JSSMLTool.h:77
StatusCode SetScaler(std::map< std::string, std::vector< double > > scaler) override
virtual double retrieveHighLevelScore(std::map< std::string, double > JSSVars) const override
virtual StatusCode initialize() override
Function initialising the tool.
Definition JSSMLTool.cxx:83
std::vector< float > ReadJetImagePixels(std::vector< TH2D > Images) const
Definition JSSMLTool.cxx:17
std::string m_pixelFileName
Definition JSSMLTool.h:84
JSSMLTool(const std::string &name)
Definition JSSMLTool.cxx:73
virtual std::vector< float > retrieveConstituentsScoreMultiClass(const std::vector< std::vector< float > > &constituents, const std::vector< std::vector< std::vector< float > > > &interactions, const std::vector< std::vector< float > > &mask) const override
std::vector< int > ReadOutputLabels() const
Definition JSSMLTool.cxx:63
std::unique_ptr< Ort::Env > m_env
Definition JSSMLTool.h:75
std::vector< int64_t > m_output_node_dims
Definition JSSMLTool.h:93
virtual double retrieveConstituentsScore(std::vector< TH2D > Images) const override
Function executing the tool for a single event.
size_t m_num_output_nodes
Definition JSSMLTool.h:94
size_t m_num_input_nodes
Definition JSSMLTool.h:89
std::string m_labelFileName
Definition JSSMLTool.h:85
std::vector< float > ReadJSSInputs(std::map< std::string, double > JSSVars) const
Definition JSSMLTool.cxx:39
std::vector< const char * > m_output_node_names
Definition JSSMLTool.h:95
std::vector< int64_t > m_input_node_dims
Definition JSSMLTool.h:88
std::map< int, std::string > m_JSSInputMap
Definition JSSMLTool.h:78
std::vector< const char * > m_input_node_names
Definition JSSMLTool.h:90
std::unique_ptr< Ort::Session > m_session
Definition JSSMLTool.h:74
Base class for the dual-use tool implementation classes.
Definition AsgTool.h:47