ATLAS Offline Software
Loading...
Searching...
No Matches
TTrainedNetworkCondAlg.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4/*
5 * */
6
8
10
11#include "TFile.h"
12#include "TH1.h"
13#include "TH2.h"
14#include "TObject.h"
15
16// for error messages
17#include <typeinfo>
18
19namespace InDet {
20
21 TTrainedNetworkCondAlg::TTrainedNetworkCondAlg (const std::string& name, ISvcLocator* pSvcLocator)
22 : ::AthCondAlgorithm( name, pSvcLocator )
23 {}
24
26 ATH_CHECK( m_poolsvc.retrieve() );
27
28 // Condition Handles
29 ATH_CHECK( m_readKey.initialize() );
30 ATH_CHECK( m_writeKey.initialize() );
31
32 return StatusCode::SUCCESS;
33 }
34
36 {
37 return StatusCode::SUCCESS;
38 }
39
40 namespace {
41 template<class T>
42 std::unique_ptr<T>
43 getObject(TFile& a_file, const std::string& path)
44 {
45 std::unique_ptr<T> obj(dynamic_cast<T*>(a_file.Get(path.c_str())));
46 if (!obj) {
47 std::stringstream msg;
48 msg << "Failed object " << path << " from File " << a_file.GetName()
49 << " with type " << typeid(T).name();
50 throw std::runtime_error(msg.str());
51 }
52 obj->SetDirectory(nullptr);
53 return obj;
54 }
55 }
56
57 TTrainedNetwork* TTrainedNetworkCondAlg::retrieveNetwork(TFile &input_file, const std::string& folder) const
58 {
59 //The following means that Histos we own get deleted at the end of this method
60 //when we are done with them.
61 std::vector<std::unique_ptr<TH1>> ownedRetrievedHistos;
62 const unsigned int layer_info = ownedRetrievedHistos.size();
63 // the information about the layers
64 ownedRetrievedHistos.push_back( getObject<TH1>(input_file, folder+m_layerInfoHistogram.value()) );
65
66 if(!m_getInputsInfo){
67 // the info about the input nodes
68 ownedRetrievedHistos.push_back( getObject<TH2>(input_file, folder+"InputsInfo") );
69 }
70
71 // retrieve the number of hidden layers from the LayerInfo histogram
72 unsigned int n_hidden = ownedRetrievedHistos.at(layer_info)->GetNbinsX()-2;
73 ATH_MSG_VERBOSE(" Retrieving calibration: " << folder << " for NN with: " << n_hidden << " hidden layers.");
74 //coverity[INEFFICIENT_RESERVE:FALSE]
75 ownedRetrievedHistos.reserve( ownedRetrievedHistos.size() + n_hidden*2 );
76 for (unsigned int i=0; i<=n_hidden; ++i) {
77 std::stringstream folder_name;
78 folder_name << folder << m_layerPrefix.value() << i;
79 ownedRetrievedHistos.push_back( getObject<TH2>(input_file, folder_name.str()+m_weightIndicator.value() ) );
80 ownedRetrievedHistos.push_back( getObject<TH1>(input_file, folder_name.str()+m_thresholdIndicator.value() ) );
81 }
82
83 //We need this in order to keep compatibility with legacy code
84 std::vector<const TH1*> retrievedHistos;
85 retrievedHistos.reserve(ownedRetrievedHistos.size());
86
87for(const auto & h: ownedRetrievedHistos){
88 retrievedHistos.push_back(h.get());
89 }
90
91 std::unique_ptr<TTrainedNetwork> a_nn(Trk::NeuralNetworkToHistoTool::fromHistoToTrainedNetwork(retrievedHistos));
92 if (!a_nn) {
93 ATH_MSG_ERROR( "Failed to create NN from " << retrievedHistos.size() << " histograms read from " << folder);
94 }
95 else {
96 ATH_MSG_VERBOSE( folder << " " << a_nn->getnInput() );
97 }
98
99 return a_nn.release();
100 }
101
102
103 StatusCode TTrainedNetworkCondAlg::execute(const EventContext& ctx) const {
104
106 if (NnWriteHandle.isValid()) {
107 ATH_MSG_DEBUG("Write CondHandle "<< NnWriteHandle.fullKey() << " is already valid");
108 return StatusCode::SUCCESS;
109 }
110
112 if(!readHandle.isValid()) {
113 ATH_MSG_ERROR("Invalid read handle " << m_readKey.key());
114 return StatusCode::FAILURE;
115 }
116 const CondAttrListCollection* atrcol{*readHandle};
117 assert( atrcol != nullptr);
118
119 EventIDRange cdo_iov;
120 if(!readHandle.range(cdo_iov)) {
121 ATH_MSG_ERROR("Failed to get valid validaty range from " << readHandle.key());
122 return StatusCode::FAILURE;
123 }
124
125 unsigned int channel=1; //Always 1 in old version with CoolHistSvc
126 CondAttrListCollection::const_iterator channel_iter = atrcol->chanAttrListPair(channel);
127 if (channel_iter==atrcol->end()) {
128 ATH_MSG_ERROR("Conditions data " << readHandle.key() << " misses channel " << channel);
129 return StatusCode::FAILURE;
130 }
131
132 // @TODO store NN parameters in a different way than as a set of histograms in a root file.
133 const std::string coolguid=channel_iter->second["fileGUID"].data<std::string>();
134
135 std::unique_ptr<TTrainedNetworkCollection> writeCdo{std::make_unique<TTrainedNetworkCollection>()};
136 {
137 std::string pfname;
138 std::string tech;
139 m_poolsvc->lookupBestPfn(coolguid, pfname, tech );
140 ATH_MSG_VERBOSE("Get NNs from file " << pfname.c_str() << " [" << coolguid << " <- " << readHandle.key() << "]." );
141 std::unique_ptr<TFile> a_file( TFile::Open(pfname.c_str(),"READ") );
142 if (!a_file || !a_file->IsOpen()) {
143 ATH_MSG_ERROR("Failed to open file " << pfname << " referenced by " << readHandle.key() << " GUID " << coolguid);
144 return StatusCode::FAILURE;
145 }
146
147 writeCdo->reserve(m_nnOrder.size());
148 for (const std::string &folder: m_nnOrder) {
149 ATH_MSG_VERBOSE( "Retrieve NN " << writeCdo->size() << ": " << folder );
150 writeCdo->push_back( std::unique_ptr<TTrainedNetwork>( retrieveNetwork(*a_file, folder) ) );
151 }
152 writeCdo->setNames(m_nnOrder);
153 }
154
155 if(NnWriteHandle.record(cdo_iov,std::move(writeCdo)).isFailure()) {
156 ATH_MSG_ERROR("Failed to record Trained network collection to "
157 << NnWriteHandle.key()
158 << " with IOV " << cdo_iov );
159 return StatusCode::FAILURE;
160 }
161 return StatusCode::SUCCESS;
162 }
163
164}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_DEBUG(x)
This file defines the class for a collection of AttributeLists where each one is associated with a ch...
Base class for conditions algorithms.
Header file for AthHistogramAlgorithm.
This class is a collection of AttributeLists where each one is associated with a channel number.
const_iterator end() const
ChanAttrListMap::const_iterator const_iterator
const_iterator chanAttrListPair(ChanNum chanNum) const
Access to Chan/AttributeList pairs via channel number: returns map iterator.
Gaudi::Property< std::string > m_thresholdIndicator
StatusCode finalize() override final
StatusCode execute(const EventContext &ctx) const override final
ServiceHandle< IPoolSvc > m_poolsvc
Gaudi::Property< std::string > m_weightIndicator
StatusCode initialize() override final
Gaudi::Property< std::string > m_layerInfoHistogram
Gaudi::Property< std::string > m_layerPrefix
TTrainedNetwork * retrieveNetwork(TFile &input_file, const std::string &folder) const
TTrainedNetworkCondAlg(const std::string &name, ISvcLocator *pSvcLocator)
Gaudi::Property< std::vector< std::string > > m_nnOrder
SG::ReadCondHandleKey< CondAttrListCollection > m_readKey
Gaudi::Property< bool > m_getInputsInfo
SG::WriteCondHandleKey< TTrainedNetworkCollection > m_writeKey
bool range(EventIDRange &r)
const std::string & key() const
const std::string & key() const
StatusCode record(const EventIDRange &range, T *t)
record handle, with explicit range DEPRECATED
const DataObjID & fullKey() const
static TTrainedNetwork * fromHistoToTrainedNetwork(const std::vector< TH1 * > &)
Primary Vertex Finder.
MsgStream & msg
Definition testRead.cxx:32