ATLAS Offline Software
LWTNNCondAlg.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 /*
5  * */
6 
7 #include "LWTNNCondAlg.h"
9 
12 #include "CoolKernel/IObject.h"
14 
15 // NN includes
16 #include "lwtnn/parse_json.hh"
17 #include "lwtnn/Exceptions.hh"
18 #include "lwtnn/lightweight_nn_streamers.hh"
19 #include "lwtnn/NanReplacer.hh"
20 
21 // JSON parsers
22 #define BOOST_BIND_GLOBAL_PLACEHOLDERS // Needed to silence Boost pragma message
23 #include <boost/property_tree/ptree.hpp>
24 #include <boost/property_tree/json_parser.hpp>
25 #include "boost/property_tree/exceptions.hpp"
26 
27 
28 // for error messages
29 #include <memory>
30 
31 #include <typeinfo>
32 
33 namespace InDet {
34 
35  LWTNNCondAlg::LWTNNCondAlg (const std::string& name, ISvcLocator* pSvcLocator)
36  : ::AthReentrantAlgorithm( name, pSvcLocator )
37  {}
38 
40 
41  // Condition Handles
44 
45  return StatusCode::SUCCESS;
46  }
47 
49  {
50  return StatusCode::SUCCESS;
51  }
52 
53  StatusCode LWTNNCondAlg::configureLwtnn(std::unique_ptr<lwt::atlas::FastGraph> & thisNN,
54  const std::string& thisJson) const {
55 
56  // Read DNN weights from input json config
57  lwt::GraphConfig config;
58  try {
59  std::istringstream input_cfg( thisJson );
60  config = lwt::parse_json_graph(input_cfg);
61  } catch (boost::property_tree::ptree_error& err) {
62  ATH_MSG_ERROR("NN file unreadable!");
63  return StatusCode::FAILURE;
64  }
65 
66  // pass the input order for the FastGraph
68  order.scalar.emplace_back("NNinputs", m_variableOrder );
69  // sequence not needed for NN (more for RNN, but set anyway)
70  order.sequence.emplace_back("NNinputs", m_variableOrder );
71 
72  // Build the network
73  try {
74  thisNN = std::make_unique<lwt::atlas::FastGraph>(config, order, "merge_1");
75  } catch (lwt::NNConfigurationException& exc) {
76  ATH_MSG_ERROR("NN configuration problem: " << exc.what());
77  return StatusCode::FAILURE;
78  }
79 
80  return StatusCode::SUCCESS;
81 
82  }
83 
84  StatusCode LWTNNCondAlg::execute(const EventContext& ctx) const {
85 
87  if (NnWriteHandle.isValid()) {
88  ATH_MSG_DEBUG("Write CondHandle "<< NnWriteHandle.fullKey() << " is already valid");
89  return StatusCode::SUCCESS;
90  }
91 
93  if(!readHandle.isValid()) {
94  ATH_MSG_ERROR("Invalid read handle " << m_readKey.key());
95  return StatusCode::FAILURE;
96  }
97  const CondAttrListCollection* atrcol{*readHandle};
98  assert( atrcol != nullptr);
99 
100  // So now we have the string containing the json. Access it.
101  // Retrieve channel 0 (only channel there is)
102  const coral::AttributeList& attrList=atrcol->attributeList(0);
103 
104  // Check that it is filled as expected
105  if ((attrList["NNConfigurations"]).isNull()) {
106  ATH_MSG_ERROR( "NNConfigurations is NULL !" );
107  return StatusCode::FAILURE;
108  }
109 
110  // Retrieve the string
111  // This is for a single LOB when it is all a giant block
112  const std::string megajson = attrList["NNConfigurations"].data<cool::String16M>();
113 
114  // Parse the large json to extract the individual configurations for the NNs
115  std::istringstream initializerStream(megajson);
116  namespace pt = boost::property_tree;
117  pt::ptree parentTree;
118  pt::read_json(initializerStream, parentTree);
119  std::ostringstream configStream;
120 
121  // This is for handling IOVs
122  EventIDRange cdo_iov;
123  if(!readHandle.range(cdo_iov)) {
124  ATH_MSG_ERROR("Failed to get valid validity range from " << readHandle.key());
125  return StatusCode::FAILURE;
126  }
127 
128  // Here I create a pointer to the object I want to write
129  // And what I want to write is the map with my lwtnn networks.
130  std::unique_ptr<LWTNNCollection> writeCdo{std::make_unique<LWTNNCollection>()};
131 
132  // First, extract configuration for the number network.
133  pt::ptree subtreeNumberNetwork = parentTree.get_child("NumberNetwork");
134  writeCdo->insert(std::make_pair(0,std::unique_ptr<lwt::atlas::FastGraph>(nullptr)));
135  // If this json is empty, just fill a null pointer.
136  if(subtreeNumberNetwork.empty()) {
137  ATH_MSG_ERROR("You are trying to use lwtnn for the number network but have an empty configuration file; this should never happen!");
138  return StatusCode::FAILURE;
139  }
140  // Otherwise, set up lwtnn.
141  else {
142  ATH_MSG_DEBUG("Setting up lwtnn for number network...");
143  pt::write_json(configStream, subtreeNumberNetwork);
144  std::string numberNetworkConfig = configStream.str();
145  if ((configureLwtnn(writeCdo->at(0), numberNetworkConfig)).isFailure())
146  return StatusCode::FAILURE;
147  }
148 
149  // Now extract configuration for each position network.
150  // For simplicity, we'll require all three configurations
151  // in order to use lwtnn for positions.
152  for (int i=1; i<4; i++) {
153  const std::string key = "PositionNetwork_N"+std::to_string(i);
154  configStream.str("");
155  pt::ptree subtreePosNetwork = parentTree.get_child(key);
156  pt::write_json(configStream, subtreePosNetwork);
157  std::string posNetworkConfig = configStream.str();
158 
159  // Put a lwt network into the map
160  writeCdo->insert(std::make_pair(i,std::unique_ptr<lwt::atlas::FastGraph>(nullptr)));
161 
162  // Now do empty check: if any one of these is empty we won't use lwtnn
163  if(subtreePosNetwork.empty()) {
164  ATH_MSG_ERROR("You are trying to use lwtnn for the position networks but have an empty configuration file; this should never happen!");
165  return StatusCode::FAILURE;
166  } else {
167  // Otherwise, set up lwtnn
168  ATH_MSG_DEBUG("Setting up lwtnn for n = " << i << " position network...");
169  if ((configureLwtnn(writeCdo->at(i), posNetworkConfig)).isFailure())
170  return StatusCode::FAILURE;
171  }
172 
173  }
174 
175  // Write the networks to the store
176 
177  if(NnWriteHandle.record(cdo_iov,std::move(writeCdo)).isFailure()) {
178  ATH_MSG_ERROR("Failed to record Trained network collection to "
179  << NnWriteHandle.key()
180  << " with IOV " << cdo_iov );
181  return StatusCode::FAILURE;
182  }
183 
184  return StatusCode::SUCCESS;
185  }
186 
187 }
CondAttrListCollection.h
This file defines the class for a collection of AttributeLists where each one is associated with a ch...
SG::ReadCondHandle
Definition: ReadCondHandle.h:44
InDet::LWTNNCondAlg::m_readKey
SG::ReadCondHandleKey< CondAttrListCollection > m_readKey
Definition: LWTNNCondAlg.h:50
InDet
Primary Vertex Finder.
Definition: VP1ErrorUtils.h:36
python.base_data.config
config
Definition: base_data.py:21
IFileCatalog.h
test_pyathena.pt
pt
Definition: test_pyathena.py:11
python.PyKernel.AttributeList
AttributeList
Definition: PyKernel.py:36
SG::VarHandleKey::key
const std::string & key() const
Return the StoreGate ID for the referenced object.
Definition: AthToolSupport/AsgDataHandles/Root/VarHandleKey.cxx:141
CondAttrListCollection
This class is a collection of AttributeLists where each one is associated with a channel number....
Definition: CondAttrListCollection.h:52
AthenaAttributeList.h
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
AthReentrantAlgorithm
An algorithm that can be simultaneously executed in multiple threads.
Definition: AthReentrantAlgorithm.h:83
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
dqt_zlumi_pandas.err
err
Definition: dqt_zlumi_pandas.py:182
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
mc.order
order
Configure Herwig7.
Definition: mc.Herwig7_Dijet.py:12
lwt::atlas::InputOrder
Definition: InputOrder.h:28
InDet::LWTNNCondAlg::execute
StatusCode execute(const EventContext &ctx) const override final
Definition: LWTNNCondAlg.cxx:84
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
InDet::LWTNNCondAlg::LWTNNCondAlg
LWTNNCondAlg(const std::string &name, ISvcLocator *pSvcLocator)
Definition: LWTNNCondAlg.cxx:35
InDet::LWTNNCondAlg::configureLwtnn
StatusCode configureLwtnn(std::unique_ptr< lwt::atlas::FastGraph > &thisNN, const std::string &thisJson) const
Definition: LWTNNCondAlg.cxx:53
ptree
boost::property_tree::ptree ptree
Definition: JsonFileLoader.cxx:16
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
InDet::LWTNNCondAlg::finalize
StatusCode finalize() override final
Definition: LWTNNCondAlg.cxx:48
SG::CondHandleKey::initialize
StatusCode initialize(bool used=true)
InDet::LWTNNCondAlg::m_variableOrder
Gaudi::Property< std::vector< std::string > > m_variableOrder
Definition: LWTNNCondAlg.h:58
InputOrder.h
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
SG::WriteCondHandle
Definition: WriteCondHandle.h:26
InDet::LWTNNCondAlg::m_writeKey
SG::WriteCondHandleKey< LWTNNCollection > m_writeKey
Definition: LWTNNCondAlg.h:53
InDet::LWTNNCondAlg::initialize
StatusCode initialize() override final
Definition: LWTNNCondAlg.cxx:39
LWTNNCondAlg.h
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37