ATLAS Offline Software
Loading...
Searching...
No Matches
TTrainedNetworkCondAlg.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3*/
4/*
5 * */
6
8
11
12#include "TFile.h"
13#include "TH1.h"
14#include "TH2.h"
15#include "TObject.h"
16
17// for error messages
18#include <typeinfo>
19
20namespace InDet {
21
22 TTrainedNetworkCondAlg::TTrainedNetworkCondAlg (const std::string& name, ISvcLocator* pSvcLocator)
23 : ::AthCondAlgorithm( name, pSvcLocator )
24 {}
25
27 ATH_CHECK( m_poolsvc.retrieve() );
28
29 // Condition Handles
30 ATH_CHECK( m_readKey.initialize() );
31 ATH_CHECK( m_writeKey.initialize() );
32
33 return StatusCode::SUCCESS;
34 }
35
37 {
38 return StatusCode::SUCCESS;
39 }
40
41 namespace {
42 template<class T>
43 std::unique_ptr<T>
44 getObject(TFile& a_file, const std::string& path)
45 {
46 std::unique_ptr<T> obj(dynamic_cast<T*>(a_file.Get(path.c_str())));
47 if (!obj) {
48 std::stringstream msg;
49 msg << "Failed object " << path << " from File " << a_file.GetName()
50 << " with type " << typeid(T).name();
51 throw std::runtime_error(msg.str());
52 }
53 obj->SetDirectory(nullptr);
54 return obj;
55 }
56 }
57
58 TTrainedNetwork* TTrainedNetworkCondAlg::retrieveNetwork(TFile &input_file, const std::string& folder) const
59 {
60 //The following means that Histos we own get deleted at the end of this method
61 //when we are done with them.
62 std::vector<std::unique_ptr<TH1>> ownedRetrievedHistos;
63 const unsigned int layer_info = ownedRetrievedHistos.size();
64 // the information about the layers
65 ownedRetrievedHistos.push_back( getObject<TH1>(input_file, folder+m_layerInfoHistogram.value()) );
66
67 if(!m_getInputsInfo){
68 // the info about the input nodes
69 ownedRetrievedHistos.push_back( getObject<TH2>(input_file, folder+"InputsInfo") );
70 }
71
72 // retrieve the number of hidden layers from the LayerInfo histogram
73 unsigned int n_hidden = ownedRetrievedHistos.at(layer_info)->GetNbinsX()-2;
74 ATH_MSG_VERBOSE(" Retrieving calibration: " << folder << " for NN with: " << n_hidden << " hidden layers.");
75
76 ownedRetrievedHistos.reserve( ownedRetrievedHistos.size() + n_hidden*2 );
77 for (unsigned int i=0; i<=n_hidden; ++i) {
78 std::stringstream folder_name;
79 folder_name << folder << m_layerPrefix.value() << i;
80 ownedRetrievedHistos.push_back( getObject<TH2>(input_file, folder_name.str()+m_weightIndicator.value() ) );
81 ownedRetrievedHistos.push_back( getObject<TH1>(input_file, folder_name.str()+m_thresholdIndicator.value() ) );
82 }
83
84 //We need this in order to keep compatibility with legacy code
85 std::vector<const TH1*> retrievedHistos;
86 retrievedHistos.reserve(ownedRetrievedHistos.size());
87
88for(const auto & h: ownedRetrievedHistos){
89 retrievedHistos.push_back(h.get());
90 }
91
92 std::unique_ptr<TTrainedNetwork> a_nn(Trk::NeuralNetworkToHistoTool::fromHistoToTrainedNetwork(retrievedHistos));
93 if (!a_nn) {
94 ATH_MSG_ERROR( "Failed to create NN from " << retrievedHistos.size() << " histograms read from " << folder);
95 }
96 else {
97 ATH_MSG_VERBOSE( folder << " " << a_nn->getnInput() );
98 }
99
100 return a_nn.release();
101 }
102
103
104 StatusCode TTrainedNetworkCondAlg::execute(const EventContext& ctx) const {
105
107 if (NnWriteHandle.isValid()) {
108 ATH_MSG_DEBUG("Write CondHandle "<< NnWriteHandle.fullKey() << " is already valid");
109 return StatusCode::SUCCESS;
110 }
111
113 if(!readHandle.isValid()) {
114 ATH_MSG_ERROR("Invalid read handle " << m_readKey.key());
115 return StatusCode::FAILURE;
116 }
117 const CondAttrListCollection* atrcol{*readHandle};
118 assert( atrcol != nullptr);
119
120 EventIDRange cdo_iov;
121 if(!readHandle.range(cdo_iov)) {
122 ATH_MSG_ERROR("Failed to get valid validaty range from " << readHandle.key());
123 return StatusCode::FAILURE;
124 }
125
126 unsigned int channel=1; //Always 1 in old version with CoolHistSvc
127 CondAttrListCollection::const_iterator channel_iter = atrcol->chanAttrListPair(channel);
128 if (channel_iter==atrcol->end()) {
129 ATH_MSG_ERROR("Conditions data " << readHandle.key() << " misses channel " << channel);
130 return StatusCode::FAILURE;
131 }
132
133 // @TODO store NN parameters in a different way than as a set of histograms in a root file.
134 const std::string coolguid=channel_iter->second["fileGUID"].data<std::string>();
135
136 std::unique_ptr<TTrainedNetworkCollection> writeCdo{std::make_unique<TTrainedNetworkCollection>()};
137 {
138 std::string pfname;
139 std::string tech;
140 m_poolsvc->catalog()->getFirstPFN(coolguid, pfname, tech );
141 ATH_MSG_VERBOSE("Get NNs from file " << pfname.c_str() << " [" << coolguid << " <- " << readHandle.key() << "]." );
142 std::unique_ptr<TFile> a_file( TFile::Open(pfname.c_str(),"READ") );
143 if (!a_file || !a_file->IsOpen()) {
144 ATH_MSG_ERROR("Failed to open file " << pfname << " referenced by " << readHandle.key() << " GUID " << coolguid);
145 return StatusCode::FAILURE;
146 }
147
148 writeCdo->reserve(m_nnOrder.size());
149 for (const std::string &folder: m_nnOrder) {
150 ATH_MSG_VERBOSE( "Retrieve NN " << writeCdo->size() << ": " << folder );
151 writeCdo->push_back( std::unique_ptr<TTrainedNetwork>( retrieveNetwork(*a_file, folder) ) );
152 }
153 writeCdo->setNames(m_nnOrder);
154 }
155
156 if(NnWriteHandle.record(cdo_iov,std::move(writeCdo)).isFailure()) {
157 ATH_MSG_ERROR("Failed to record Trained network collection to "
158 << NnWriteHandle.key()
159 << " with IOV " << cdo_iov );
160 return StatusCode::FAILURE;
161 }
162 return StatusCode::SUCCESS;
163 }
164
165}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_DEBUG(x)
This file defines the class for a collection of AttributeLists where each one is associated with a ch...
Base class for conditions algorithms.
Header file for AthHistogramAlgorithm.
This class is a collection of AttributeLists where each one is associated with a channel number.
const_iterator end() const
ChanAttrListMap::const_iterator const_iterator
const_iterator chanAttrListPair(ChanNum chanNum) const
Access to Chan/AttributeList pairs via channel number: returns map iterator.
Gaudi::Property< std::string > m_thresholdIndicator
StatusCode finalize() override final
StatusCode execute(const EventContext &ctx) const override final
ServiceHandle< IPoolSvc > m_poolsvc
Gaudi::Property< std::string > m_weightIndicator
StatusCode initialize() override final
Gaudi::Property< std::string > m_layerInfoHistogram
Gaudi::Property< std::string > m_layerPrefix
TTrainedNetwork * retrieveNetwork(TFile &input_file, const std::string &folder) const
TTrainedNetworkCondAlg(const std::string &name, ISvcLocator *pSvcLocator)
Gaudi::Property< std::vector< std::string > > m_nnOrder
SG::ReadCondHandleKey< CondAttrListCollection > m_readKey
Gaudi::Property< bool > m_getInputsInfo
SG::WriteCondHandleKey< TTrainedNetworkCollection > m_writeKey
bool range(EventIDRange &r)
const std::string & key() const
const std::string & key() const
StatusCode record(const EventIDRange &range, T *t)
record handle, with explicit range DEPRECATED
const DataObjID & fullKey() const
static TTrainedNetwork * fromHistoToTrainedNetwork(const std::vector< TH1 * > &)
Primary Vertex Finder.
MsgStream & msg
Definition testRead.cxx:32