ATLAS Offline Software
TTrainedNetworkCondAlg.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 /*
5  * */
6 
8 
11 
12 #include "TFile.h"
13 #include "TH1.h"
14 #include "TH2.h"
15 #include "TObject.h"
16 
17 // for error messages
18 #include <typeinfo>
19 
20 namespace InDet {
21 
22  TTrainedNetworkCondAlg::TTrainedNetworkCondAlg (const std::string& name, ISvcLocator* pSvcLocator)
23  : ::AthReentrantAlgorithm( name, pSvcLocator )
24  {}
25 
27  ATH_CHECK( m_poolsvc.retrieve() );
28 
29  // Condition Handles
32 
33  return StatusCode::SUCCESS;
34  }
35 
37  {
38  return StatusCode::SUCCESS;
39  }
40 
41  namespace {
42  template<class T>
43  std::unique_ptr<T>
44  getObject(TFile& a_file, const std::string& path)
45  {
46  std::unique_ptr<T> obj(dynamic_cast<T*>(a_file.Get(path.c_str())));
47  if (!obj) {
48  std::stringstream msg;
49  msg << "Failed object " << path << " from File " << a_file.GetName()
50  << " with type " << typeid(T).name();
51  throw std::runtime_error(msg.str());
52  }
53  obj->SetDirectory(nullptr);
54  return obj;
55  }
56  }
57 
59  {
60  //The following means that Histos we own get deleted at the end of this method
61  //when we are done with them.
62  std::vector<std::unique_ptr<TH1>> ownedRetrievedHistos;
63  const unsigned int layer_info = ownedRetrievedHistos.size();
64  // the information about the layers
65  ownedRetrievedHistos.push_back( getObject<TH1>(input_file, folder+m_layerInfoHistogram.value()) );
66 
67  if(!m_getInputsInfo){
68  // the info about the input nodes
69  ownedRetrievedHistos.push_back( getObject<TH2>(input_file, folder+"InputsInfo") );
70  }
71 
72  // retrieve the number of hidden layers from the LayerInfo histogram
73  unsigned int n_hidden = ownedRetrievedHistos.at(layer_info)->GetNbinsX()-2;
74  ATH_MSG_VERBOSE(" Retrieving calibration: " << folder << " for NN with: " << n_hidden << " hidden layers.");
75 
76  ownedRetrievedHistos.reserve( ownedRetrievedHistos.size() + n_hidden*2 );
77  for (unsigned int i=0; i<=n_hidden; ++i) {
78  std::stringstream folder_name;
79  folder_name << folder << m_layerPrefix.value() << i;
80  ownedRetrievedHistos.push_back( getObject<TH2>(input_file, folder_name.str()+m_weightIndicator.value() ) );
81  ownedRetrievedHistos.push_back( getObject<TH1>(input_file, folder_name.str()+m_thresholdIndicator.value() ) );
82  }
83 
84  //We need this in order to keep compatibility with legacy code
85  std::vector<const TH1*> retrievedHistos;
86  retrievedHistos.reserve(ownedRetrievedHistos.size());
87 
88 for(const auto & h: ownedRetrievedHistos){
89  retrievedHistos.push_back(h.get());
90  }
91 
92  std::unique_ptr<TTrainedNetwork> a_nn(m_networkToHistoTool->fromHistoToTrainedNetwork(retrievedHistos));
93  if (!a_nn) {
94  ATH_MSG_ERROR( "Failed to create NN from " << retrievedHistos.size() << " histograms read from " << folder);
95  }
96  else {
97  ATH_MSG_VERBOSE( folder << " " << a_nn->getnInput() );
98  }
99 
100  return a_nn.release();
101  }
102 
103 
104  StatusCode TTrainedNetworkCondAlg::execute(const EventContext& ctx) const {
105 
107  if (NnWriteHandle.isValid()) {
108  ATH_MSG_DEBUG("Write CondHandle "<< NnWriteHandle.fullKey() << " is already valid");
109  return StatusCode::SUCCESS;
110  }
111 
113  if(!readHandle.isValid()) {
114  ATH_MSG_ERROR("Invalid read handle " << m_readKey.key());
115  return StatusCode::FAILURE;
116  }
117  const CondAttrListCollection* atrcol{*readHandle};
118  assert( atrcol != nullptr);
119 
120  EventIDRange cdo_iov;
121  if(!readHandle.range(cdo_iov)) {
122  ATH_MSG_ERROR("Failed to get valid validaty range from " << readHandle.key());
123  return StatusCode::FAILURE;
124  }
125 
126  unsigned int channel=1; //Always 1 in old version with CoolHistSvc
127  CondAttrListCollection::const_iterator channel_iter = atrcol->chanAttrListPair(channel);
128  if (channel_iter==atrcol->end()) {
129  ATH_MSG_ERROR("Conditions data " << readHandle.key() << " misses channel " << channel);
130  return StatusCode::FAILURE;
131  }
132 
133  // @TODO store NN parameters in a different way than as a set of histograms in a root file.
134  const std::string coolguid=channel_iter->second["fileGUID"].data<std::string>();
135 
136  std::unique_ptr<TTrainedNetworkCollection> writeCdo{std::make_unique<TTrainedNetworkCollection>()};
137  {
138  std::string pfname;
139  std::string tech;
140  m_poolsvc->catalog()->getFirstPFN(coolguid, pfname, tech );
141  ATH_MSG_VERBOSE("Get NNs from file " << pfname.c_str() << " [" << coolguid << " <- " << readHandle.key() << "]." );
142  std::unique_ptr<TFile> a_file( TFile::Open(pfname.c_str(),"READ") );
143  if (!a_file || !a_file->IsOpen()) {
144  ATH_MSG_ERROR("Failed to open file " << pfname << " referenced by " << readHandle.key() << " GUID " << coolguid);
145  return StatusCode::FAILURE;
146  }
147 
148  writeCdo->reserve(m_nnOrder.size());
149  for (const std::string &folder: m_nnOrder) {
150  ATH_MSG_VERBOSE( "Retrieve NN " << writeCdo->size() << ": " << folder );
151  writeCdo->push_back( std::unique_ptr<TTrainedNetwork>( retrieveNetwork(*a_file, folder) ) );
152  }
153  writeCdo->setNames(m_nnOrder);
154  }
155 
156  if(NnWriteHandle.record(cdo_iov,std::move(writeCdo)).isFailure()) {
157  ATH_MSG_ERROR("Failed to record Trained network collection to "
158  << NnWriteHandle.key()
159  << " with IOV " << cdo_iov );
160  return StatusCode::FAILURE;
161  }
162  return StatusCode::SUCCESS;
163  }
164 
165 }
TTrainedNetwork::getnInput
Int_t getnInput() const
Definition: InnerDetector/InDetCalibAlgs/PixelCalibAlgs/NNClusteringCalibration_RunI/TTrainedNetwork.h:46
InDet::TTrainedNetworkCondAlg::m_weightIndicator
Gaudi::Property< std::string > m_weightIndicator
Definition: TTrainedNetworkCondAlg.h:76
athena.path
path
python interpreter configuration --------------------------------------—
Definition: athena.py:126
plotting.yearwise_efficiency.channel
channel
Definition: yearwise_efficiency.py:28
CondAttrListCollection.h
This file defines the class for a collection of AttributeLists where each one is associated with a ch...
SG::ReadCondHandle
Definition: ReadCondHandle.h:44
InDet::TTrainedNetworkCondAlg::m_layerPrefix
Gaudi::Property< std::string > m_layerPrefix
Definition: TTrainedNetworkCondAlg.h:73
InDet::TTrainedNetworkCondAlg::TTrainedNetworkCondAlg
TTrainedNetworkCondAlg(const std::string &name, ISvcLocator *pSvcLocator)
Definition: TTrainedNetworkCondAlg.cxx:22
InDet::TTrainedNetworkCondAlg::retrieveNetwork
TTrainedNetwork * retrieveNetwork(TFile &input_file, const std::string &folder) const
Definition: TTrainedNetworkCondAlg.cxx:58
InDet
DUMMY Primary Vertex Finder.
Definition: VP1ErrorUtils.h:36
InDet::TTrainedNetworkCondAlg::m_poolsvc
ServiceHandle< IPoolSvc > m_poolsvc
Definition: TTrainedNetworkCondAlg.h:45
IFileCatalog.h
python.resample_meson.input_file
input_file
Definition: resample_meson.py:164
InDet::TTrainedNetworkCondAlg::m_networkToHistoTool
ToolHandle< Trk::NeuralNetworkToHistoTool > m_networkToHistoTool
Definition: TTrainedNetworkCondAlg.h:47
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
SG::VarHandleKey::key
const std::string & key() const
Return the StoreGate ID for the referenced object.
Definition: AthToolSupport/AsgDataHandles/Root/VarHandleKey.cxx:141
CondAttrListCollection
This class is a collection of AttributeLists where each one is associated with a channel number....
Definition: CondAttrListCollection.h:52
AthReentrantAlgorithm
An algorithm that can be simultaneously executed in multiple threads.
Definition: AthReentrantAlgorithm.h:83
InDet::TTrainedNetworkCondAlg::m_writeKey
SG::WriteCondHandleKey< TTrainedNetworkCollection > m_writeKey
Definition: TTrainedNetworkCondAlg.h:53
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
InDet::TTrainedNetworkCondAlg::m_thresholdIndicator
Gaudi::Property< std::string > m_thresholdIndicator
Definition: TTrainedNetworkCondAlg.h:79
lumiFormat.i
int i
Definition: lumiFormat.py:92
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
InDet::TTrainedNetworkCondAlg::initialize
StatusCode initialize() override final
Definition: TTrainedNetworkCondAlg.cxx:26
InDet::TTrainedNetworkCondAlg::m_getInputsInfo
Gaudi::Property< bool > m_getInputsInfo
Definition: TTrainedNetworkCondAlg.h:82
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
InDet::TTrainedNetworkCondAlg::m_readKey
SG::ReadCondHandleKey< CondAttrListCollection > m_readKey
Definition: TTrainedNetworkCondAlg.h:50
TTrainedNetwork
Definition: InnerDetector/InDetCalibAlgs/PixelCalibAlgs/NNClusteringCalibration_RunI/TTrainedNetwork.h:21
TTrainedNetworkCondAlg.h
InDet::TTrainedNetworkCondAlg::finalize
StatusCode finalize() override final
Definition: TTrainedNetworkCondAlg.cxx:36
InDet::TTrainedNetworkCondAlg::m_layerInfoHistogram
Gaudi::Property< std::string > m_layerInfoHistogram
Definition: TTrainedNetworkCondAlg.h:70
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
InDet::TTrainedNetworkCondAlg::m_nnOrder
Gaudi::Property< std::vector< std::string > > m_nnOrder
Definition: TTrainedNetworkCondAlg.h:56
InDet::TTrainedNetworkCondAlg::execute
StatusCode execute(const EventContext &ctx) const override final
Definition: TTrainedNetworkCondAlg.cxx:104
SG::CondHandleKey::initialize
StatusCode initialize(bool used=true)
h
CondAttrListCollection::const_iterator
ChanAttrListMap::const_iterator const_iterator
Definition: CondAttrListCollection.h:63
CaloCondBlobAlgs_fillNoiseFromASCII.folder
folder
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:56
python.PyAthena.obj
obj
Definition: PyAthena.py:135
SG::WriteCondHandle
Definition: WriteCondHandle.h:26
python.AutoConfigFlags.msg
msg
Definition: AutoConfigFlags.py:7