5#ifndef TRACKOVERLAYREC_TRACKOVERLAYDECISIONALG_H
6#define TRACKOVERLAYREC_TRACKOVERLAYDECISIONALG_H
12#include "GaudiKernel/ToolHandle.h"
18#include "GaudiKernel/ToolHandle.h"
24#include <onnxruntime_cxx_api.h>
53 virtual StatusCode
initialize() override final;
56 virtual StatusCode
finalize() override final;
63 BooleanProperty
m_useTrackSelection {
this,
"useTrackSelection",
false,
"plot only tracks accepted by selection tool"};
64 StringProperty
m_pileupSwitch {
this,
"PileupSwitch",
"HardScatter",
"Pileup truth strategy to use. May be \"All\", \"HardScatter\", or \"PileUp\""};
65 FloatProperty
m_lowProb{
this,
"LowProb",0.5,
"Truth match prob. cutoff for efficiency (lower bound) and fake (upper bound) classification."};
73 const std::vector<const xAOD::TruthParticle *> getTruthParticles()
const;
75 Gaudi::Property<bool>
m_invertfilter{
this,
"InvertFilter",
false,
"Invert filter decision."};
76 Gaudi::Property<float>
m_MLthreshold{
this,
"MLThreshold", 0.74201,
"ML threshold for bad/good tracks decision. ML scores larger than this threshold are considered as bad tracks."};
79 std::tuple<std::vector<int64_t>, std::vector<char*>>
m_inputInfo;
82 inline std::tuple<std::vector<int64_t>, std::vector<char*> >
GetInputNodeInfo(
const std::unique_ptr< Ort::Session >& session) {
83 std::vector<int64_t> input_node_dims;
84 size_t num_input_nodes = session->GetInputCount();
85 std::vector<char*> input_node_names(num_input_nodes);
86 Ort::AllocatorWithDefaultOptions allocator;
87 for( std::size_t i = 0; i < num_input_nodes; i++ ) {
88 char* input_name = session->GetInputNameAllocated(i, allocator).release();
89 input_node_names[i] = input_name;
90 Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
91 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
92 input_node_dims = tensor_info.GetShape();
94 return std::make_tuple(input_node_dims, input_node_names);
97 inline std::tuple<std::vector<int64_t>, std::vector<char*> >
GetOutputNodeInfo(
const std::unique_ptr< Ort::Session >& session){
98 std::vector<int64_t> output_node_dims;
99 size_t num_output_nodes = session->GetOutputCount();
100 std::vector<char*> output_node_names(num_output_nodes);
101 Ort::AllocatorWithDefaultOptions allocator;
102 for( std::size_t i = 0; i < num_output_nodes; i++ ) {
103 char* output_name = session->GetOutputNameAllocated(i, allocator).release();
104 output_node_names[i] = output_name;
105 Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
106 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
107 output_node_dims = tensor_info.GetShape();
109 return std::make_tuple(output_node_dims, output_node_names);
Handle class for reading from StoreGate.
An algorithm that can be simultaneously executed in multiple threads.
a handle for applying algorithm filter decisions
SG::Decorator< T, ALLOC > Decorator
Property holding a SG store/key/clid from which a ReadHandle is made.
Gaudi::Property< float > m_MLthreshold
FilterReporterParams m_filterParams
SG::AuxElement::Decorator< bool > m_dec_selectedByPileupSwitch
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfoContainerName
std::tuple< std::vector< int64_t >, std::vector< char * > > GetInputNodeInfo(const std::unique_ptr< Ort::Session > &session)
Gaudi::Property< bool > m_invertfilter
invert filter decision at the end
std::unique_ptr< Ort::Session > m_session
virtual StatusCode execute(const EventContext &ctx) const override final
Athena algorithm's interface method execute()
virtual ~TrackOverlayDecisionAlg()=default
Destructor.
std::tuple< std::vector< int64_t >, std::vector< char * > > m_outputInfo
SG::ReadHandleKey< xAOD::TruthEventContainer > m_truthEventName
std::tuple< std::vector< int64_t >, std::vector< char * > > GetOutputNodeInfo(const std::unique_ptr< Ort::Session > &session)
std::tuple< std::vector< int64_t >, std::vector< char * > > m_inputInfo
virtual StatusCode finalize() override final
Athena algorithm's interface method finalize()
bool m_usingSpecialPileupSwitch
SG::ReadHandleKey< xAOD::TruthParticleContainer > m_truthParticleName
SG::ReadHandleKey< xAOD::TruthPileupEventContainer > m_truthPileUpEventName
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
ToolHandle< IAthSelectionTool > m_truthSelectionTool
void markSelectedByPileupSwitch(const std::vector< const xAOD::TruthParticle * > &truthParticles) const
StringProperty m_pileupSwitch
BooleanProperty m_useTrackSelection