ATLAS Offline Software
Loading...
Searching...
No Matches
PFEnergyPredictorTool.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef PFENERFYPREDICTORTOOL_H
6#define PFENERFYPREDICTORTOOL_H
7
9#include "GaudiKernel/ServiceHandle.h"
11#include <fstream> // std::fstream
12
13static const InterfaceID IID_PFEnergyPredictorTool("PFEnergyPredictorTool", 1, 0);
14class eflowRecTrack;
15
17{
18public:
19 PFEnergyPredictorTool(const std::string& type, const std::string& name, const IInterface* parent);
20 virtual StatusCode initialize() override;
21 virtual StatusCode finalize() override;
22
23 float runOnnxInference(std::vector<float> &tensor) const;
24 static const InterfaceID& interfaceID();
25
26 float nnEnergyPrediction(const eflowRecTrack *ptr) const;
27 void NormalizeTensor(std::vector<float> &tensor, size_t limit) const;
28
29private:
30 //mark as thread safe because we need to call the run function of Session, which is not const
31 //the onnx documentation states that this is thread safe
32 std::unique_ptr<Ort::Session> m_session ATLAS_THREAD_SAFE;
33
34 std::vector<const char *> m_input_node_names;
35
36 std::vector<const char *> m_output_node_names;
37
38 std::vector<int64_t> m_input_node_dims;
39 ServiceHandle<AthOnnx::IOnnxRuntimeSvc> m_svc{this, "ONNXRuntimeSvc", "AthOnnx::OnnxRuntimeSvc", "CaloMuonScoreTool ONNXRuntimeSvc"};
40 Gaudi::Property<std::string> m_model_filepath{this, "ModelPath", "////"};
41
43 Gaudi::Property<float> m_cellE_mean{this,"cellE_mean",-2.2852574689444385};
44 Gaudi::Property<float> m_cellE_std{this,"cellE_std",2.0100506557174946};
45 Gaudi::Property<float> m_cellPhi_std{this,"cellPhi_std",0.6916977411859621};
46
47};
48
50
51
52#endif
53
static const InterfaceID IID_PFEnergyPredictorTool("PFEnergyPredictorTool", 1, 0)
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
std::unique_ptr< Ort::Session > m_session ATLAS_THREAD_SAFE
void NormalizeTensor(std::vector< float > &tensor, size_t limit) const
float runOnnxInference(std::vector< float > &tensor) const
std::vector< const char * > m_input_node_names
Gaudi::Property< float > m_cellE_std
Gaudi::Property< std::string > m_model_filepath
virtual StatusCode finalize() override
std::vector< const char * > m_output_node_names
PFEnergyPredictorTool(const std::string &type, const std::string &name, const IInterface *parent)
float nnEnergyPrediction(const eflowRecTrack *ptr) const
virtual StatusCode initialize() override
Gaudi::Property< float > m_cellPhi_std
Gaudi::Property< float > m_cellE_mean
Normalization constants for the inputs to the onnx model.
static const InterfaceID & interfaceID()
std::vector< int64_t > m_input_node_dims
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
This class extends the information about a xAOD::Track.