ATLAS Offline Software
Loading...
Searching...
No Matches
LWTNNCondAlg.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4/*
5 * */
6
7#include "LWTNNCondAlg.h"
9
12#include "CoolKernel/IObject.h"
14
15// NN includes
16#include "lwtnn/parse_json.hh"
17#include "lwtnn/Exceptions.hh"
18#include "lwtnn/lightweight_nn_streamers.hh"
19#include "lwtnn/NanReplacer.hh"
20
21// JSON parsers
22#define BOOST_BIND_GLOBAL_PLACEHOLDERS // Needed to silence Boost pragma message
23#include <boost/property_tree/ptree.hpp>
24#include <boost/property_tree/json_parser.hpp>
25#include "boost/property_tree/exceptions.hpp"
26
27
28// for error messages
29#include <memory>
30
31#include <typeinfo>
32
33namespace InDet {
34
35 LWTNNCondAlg::LWTNNCondAlg (const std::string& name, ISvcLocator* pSvcLocator)
36 : ::AthCondAlgorithm( name, pSvcLocator )
37 {}
38
40
41 // Condition Handles
42 ATH_CHECK( m_readKey.initialize() );
43 ATH_CHECK( m_writeKey.initialize() );
44
45 return StatusCode::SUCCESS;
46 }
47
49 {
50 return StatusCode::SUCCESS;
51 }
52
53 StatusCode LWTNNCondAlg::configureLwtnn(std::unique_ptr<lwt::atlas::FastGraph> & thisNN,
54 const std::string& thisJson) const {
55
56 // Read DNN weights from input json config
57 lwt::GraphConfig config;
58 try {
59 std::istringstream input_cfg( thisJson );
60 config = lwt::parse_json_graph(input_cfg);
61 } catch (boost::property_tree::ptree_error& err) {
62 ATH_MSG_ERROR("NN file unreadable!");
63 return StatusCode::FAILURE;
64 }
65
66 // pass the input order for the FastGraph
68 order.scalar.emplace_back("NNinputs", m_variableOrder );
69 // sequence not needed for NN (more for RNN, but set anyway)
70 order.sequence.emplace_back("NNinputs", m_variableOrder );
71
72 // Build the network
73 try {
74 thisNN = std::make_unique<lwt::atlas::FastGraph>(config, order, "merge_1");
75 } catch (lwt::NNConfigurationException& exc) {
76 ATH_MSG_ERROR("NN configuration problem: " << exc.what());
77 return StatusCode::FAILURE;
78 }
79
80 return StatusCode::SUCCESS;
81
82 }
83
84 StatusCode LWTNNCondAlg::execute(const EventContext& ctx) const {
85
87 if (NnWriteHandle.isValid()) {
88 ATH_MSG_DEBUG("Write CondHandle "<< NnWriteHandle.fullKey() << " is already valid");
89 return StatusCode::SUCCESS;
90 }
91
93 if(!readHandle.isValid()) {
94 ATH_MSG_ERROR("Invalid read handle " << m_readKey.key());
95 return StatusCode::FAILURE;
96 }
97 const CondAttrListCollection* atrcol{*readHandle};
98 assert( atrcol != nullptr);
99
100 // So now we have the string containing the json. Access it.
101 // Retrieve channel 0 (only channel there is)
102 const coral::AttributeList& attrList=atrcol->attributeList(0);
103
104 // Check that it is filled as expected
105 if ((attrList["NNConfigurations"]).isNull()) {
106 ATH_MSG_ERROR( "NNConfigurations is NULL !" );
107 return StatusCode::FAILURE;
108 }
109
110 // Retrieve the string
111 // This is for a single LOB when it is all a giant block
112 const std::string megajson = attrList["NNConfigurations"].data<cool::String16M>();
113
114 // Parse the large json to extract the individual configurations for the NNs
115 std::istringstream initializerStream(megajson);
116 namespace pt = boost::property_tree;
117 pt::ptree parentTree;
118 pt::read_json(initializerStream, parentTree);
119 std::ostringstream configStream;
120
121 // This is for handling IOVs
122 EventIDRange cdo_iov;
123 if(!readHandle.range(cdo_iov)) {
124 ATH_MSG_ERROR("Failed to get valid validity range from " << readHandle.key());
125 return StatusCode::FAILURE;
126 }
127
128 // Here I create a pointer to the object I want to write
129 // And what I want to write is the map with my lwtnn networks.
130 std::unique_ptr<LWTNNCollection> writeCdo{std::make_unique<LWTNNCollection>()};
131
132 // First, extract configuration for the number network.
133 pt::ptree subtreeNumberNetwork = parentTree.get_child("NumberNetwork");
134 writeCdo->insert(std::make_pair(0,std::unique_ptr<lwt::atlas::FastGraph>(nullptr)));
135 // If this json is empty, just fill a null pointer.
136 if(subtreeNumberNetwork.empty()) {
137 ATH_MSG_ERROR("You are trying to use lwtnn for the number network but have an empty configuration file; this should never happen!");
138 return StatusCode::FAILURE;
139 }
140 // Otherwise, set up lwtnn.
141 else {
142 ATH_MSG_DEBUG("Setting up lwtnn for number network...");
143 pt::write_json(configStream, subtreeNumberNetwork);
144 std::string numberNetworkConfig = configStream.str();
145 if ((configureLwtnn(writeCdo->at(0), numberNetworkConfig)).isFailure())
146 return StatusCode::FAILURE;
147 }
148
149 // Now extract configuration for each position network.
150 // For simplicity, we'll require all three configurations
151 // in order to use lwtnn for positions.
152 for (int i=1; i<4; i++) {
153 const std::string key = "PositionNetwork_N"+std::to_string(i);
154 configStream.str("");
155 pt::ptree subtreePosNetwork = parentTree.get_child(key);
156 pt::write_json(configStream, subtreePosNetwork);
157 std::string posNetworkConfig = configStream.str();
158
159 // Put a lwt network into the map
160 writeCdo->insert(std::make_pair(i,std::unique_ptr<lwt::atlas::FastGraph>(nullptr)));
161
162 // Now do empty check: if any one of these is empty we won't use lwtnn
163 if(subtreePosNetwork.empty()) {
164 ATH_MSG_ERROR("You are trying to use lwtnn for the position networks but have an empty configuration file; this should never happen!");
165 return StatusCode::FAILURE;
166 } else {
167 // Otherwise, set up lwtnn
168 ATH_MSG_DEBUG("Setting up lwtnn for n = " << i << " position network...");
169 if ((configureLwtnn(writeCdo->at(i), posNetworkConfig)).isFailure())
170 return StatusCode::FAILURE;
171 }
172
173 }
174
175 // Write the networks to the store
176
177 if(NnWriteHandle.record(cdo_iov,std::move(writeCdo)).isFailure()) {
178 ATH_MSG_ERROR("Failed to record Trained network collection to "
179 << NnWriteHandle.key()
180 << " with IOV " << cdo_iov );
181 return StatusCode::FAILURE;
182 }
183
184 return StatusCode::SUCCESS;
185 }
186
187}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_DEBUG(x)
This file defines the class for a collection of AttributeLists where each one is associated with a ch...
Base class for conditions algorithms.
This class is a collection of AttributeLists where each one is associated with a channel number.
const AttributeList & attributeList(ChanNum chanNum) const
attribute list for a given channel number
LWTNNCondAlg(const std::string &name, ISvcLocator *pSvcLocator)
StatusCode initialize() override final
StatusCode execute(const EventContext &ctx) const override final
Gaudi::Property< std::vector< std::string > > m_variableOrder
SG::ReadCondHandleKey< CondAttrListCollection > m_readKey
StatusCode finalize() override final
StatusCode configureLwtnn(std::unique_ptr< lwt::atlas::FastGraph > &thisNN, const std::string &thisJson) const
SG::WriteCondHandleKey< LWTNNCollection > m_writeKey
bool range(EventIDRange &r)
const std::string & key() const
const std::string & key() const
StatusCode record(const EventIDRange &range, T *t)
record handle, with explicit range DEPRECATED
const DataObjID & fullKey() const
Primary Vertex Finder.