ATLAS Offline Software
Loading...
Searching...
No Matches
LWTNNCondAlg.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4/*
5 * */
6
7#include "LWTNNCondAlg.h"
9
12#include "CoolKernel/IObject.h"
13
14// NN includes
15#include "lwtnn/parse_json.hh"
16#include "lwtnn/Exceptions.hh"
17#include "lwtnn/lightweight_nn_streamers.hh"
18#include "lwtnn/NanReplacer.hh"
19
20// JSON parsers
21#define BOOST_BIND_GLOBAL_PLACEHOLDERS // Needed to silence Boost pragma message
22#include <boost/property_tree/ptree.hpp>
23#include <boost/property_tree/json_parser.hpp>
24#include "boost/property_tree/exceptions.hpp"
25
26
27// for error messages
28#include <memory>
29
30#include <typeinfo>
31
32namespace InDet {
33
34 LWTNNCondAlg::LWTNNCondAlg (const std::string& name, ISvcLocator* pSvcLocator)
35 : ::AthCondAlgorithm( name, pSvcLocator )
36 {}
37
39
40 // Condition Handles
41 ATH_CHECK( m_readKey.initialize() );
42 ATH_CHECK( m_writeKey.initialize() );
43
44 return StatusCode::SUCCESS;
45 }
46
48 {
49 return StatusCode::SUCCESS;
50 }
51
52 StatusCode LWTNNCondAlg::configureLwtnn(std::unique_ptr<lwt::atlas::FastGraph> & thisNN,
53 const std::string& thisJson) const {
54
55 // Read DNN weights from input json config
56 lwt::GraphConfig config;
57 try {
58 std::istringstream input_cfg( thisJson );
59 config = lwt::parse_json_graph(input_cfg);
60 } catch (boost::property_tree::ptree_error& err) {
61 ATH_MSG_ERROR("NN file unreadable!");
62 return StatusCode::FAILURE;
63 }
64
65 // pass the input order for the FastGraph
67 order.scalar.emplace_back("NNinputs", m_variableOrder );
68 // sequence not needed for NN (more for RNN, but set anyway)
69 order.sequence.emplace_back("NNinputs", m_variableOrder );
70
71 // Build the network
72 try {
73 thisNN = std::make_unique<lwt::atlas::FastGraph>(config, order, "merge_1");
74 } catch (lwt::NNConfigurationException& exc) {
75 ATH_MSG_ERROR("NN configuration problem: " << exc.what());
76 return StatusCode::FAILURE;
77 }
78
79 return StatusCode::SUCCESS;
80
81 }
82
83 StatusCode LWTNNCondAlg::execute(const EventContext& ctx) const {
84
86 if (NnWriteHandle.isValid()) {
87 ATH_MSG_DEBUG("Write CondHandle "<< NnWriteHandle.fullKey() << " is already valid");
88 return StatusCode::SUCCESS;
89 }
90
92 if(!readHandle.isValid()) {
93 ATH_MSG_ERROR("Invalid read handle " << m_readKey.key());
94 return StatusCode::FAILURE;
95 }
96 const CondAttrListCollection* atrcol{*readHandle};
97 assert( atrcol != nullptr);
98
99 // So now we have the string containing the json. Access it.
100 // Retrieve channel 0 (only channel there is)
101 const coral::AttributeList& attrList=atrcol->attributeList(0);
102
103 // Check that it is filled as expected
104 if ((attrList["NNConfigurations"]).isNull()) {
105 ATH_MSG_ERROR( "NNConfigurations is NULL !" );
106 return StatusCode::FAILURE;
107 }
108
109 // Retrieve the string
110 // This is for a single LOB when it is all a giant block
111 const std::string megajson = attrList["NNConfigurations"].data<cool::String16M>();
112
113 // Parse the large json to extract the individual configurations for the NNs
114 std::istringstream initializerStream(megajson);
115 namespace pt = boost::property_tree;
116 pt::ptree parentTree;
117 pt::read_json(initializerStream, parentTree);
118 std::ostringstream configStream;
119
120 // This is for handling IOVs
121 EventIDRange cdo_iov;
122 if(!readHandle.range(cdo_iov)) {
123 ATH_MSG_ERROR("Failed to get valid validity range from " << readHandle.key());
124 return StatusCode::FAILURE;
125 }
126
127 // Here I create a pointer to the object I want to write
128 // And what I want to write is the map with my lwtnn networks.
129 std::unique_ptr<LWTNNCollection> writeCdo{std::make_unique<LWTNNCollection>()};
130
131 // First, extract configuration for the number network.
132 pt::ptree subtreeNumberNetwork = parentTree.get_child("NumberNetwork");
133 writeCdo->insert(std::make_pair(0,std::unique_ptr<lwt::atlas::FastGraph>(nullptr)));
134 // If this json is empty, just fill a null pointer.
135 if(subtreeNumberNetwork.empty()) {
136 ATH_MSG_ERROR("You are trying to use lwtnn for the number network but have an empty configuration file; this should never happen!");
137 return StatusCode::FAILURE;
138 }
139 // Otherwise, set up lwtnn.
140 else {
141 ATH_MSG_DEBUG("Setting up lwtnn for number network...");
142 pt::write_json(configStream, subtreeNumberNetwork);
143 std::string numberNetworkConfig = configStream.str();
144 if ((configureLwtnn(writeCdo->at(0), numberNetworkConfig)).isFailure())
145 return StatusCode::FAILURE;
146 }
147
148 // Now extract configuration for each position network.
149 // For simplicity, we'll require all three configurations
150 // in order to use lwtnn for positions.
151 for (int i=1; i<4; i++) {
152 const std::string key = "PositionNetwork_N"+std::to_string(i);
153 configStream.str("");
154 pt::ptree subtreePosNetwork = parentTree.get_child(key);
155 pt::write_json(configStream, subtreePosNetwork);
156 std::string posNetworkConfig = configStream.str();
157
158 // Put a lwt network into the map
159 writeCdo->insert(std::make_pair(i,std::unique_ptr<lwt::atlas::FastGraph>(nullptr)));
160
161 // Now do empty check: if any one of these is empty we won't use lwtnn
162 if(subtreePosNetwork.empty()) {
163 ATH_MSG_ERROR("You are trying to use lwtnn for the position networks but have an empty configuration file; this should never happen!");
164 return StatusCode::FAILURE;
165 } else {
166 // Otherwise, set up lwtnn
167 ATH_MSG_DEBUG("Setting up lwtnn for n = " << i << " position network...");
168 if ((configureLwtnn(writeCdo->at(i), posNetworkConfig)).isFailure())
169 return StatusCode::FAILURE;
170 }
171
172 }
173
174 // Write the networks to the store
175
176 if(NnWriteHandle.record(cdo_iov,std::move(writeCdo)).isFailure()) {
177 ATH_MSG_ERROR("Failed to record Trained network collection to "
178 << NnWriteHandle.key()
179 << " with IOV " << cdo_iov );
180 return StatusCode::FAILURE;
181 }
182
183 return StatusCode::SUCCESS;
184 }
185
186}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_DEBUG(x)
This file defines the class for a collection of AttributeLists where each one is associated with a ch...
Base class for conditions algorithms.
This class is a collection of AttributeLists where each one is associated with a channel number.
const AttributeList & attributeList(ChanNum chanNum) const
attribute list for a given channel number
LWTNNCondAlg(const std::string &name, ISvcLocator *pSvcLocator)
StatusCode initialize() override final
StatusCode execute(const EventContext &ctx) const override final
Gaudi::Property< std::vector< std::string > > m_variableOrder
SG::ReadCondHandleKey< CondAttrListCollection > m_readKey
StatusCode finalize() override final
StatusCode configureLwtnn(std::unique_ptr< lwt::atlas::FastGraph > &thisNN, const std::string &thisJson) const
SG::WriteCondHandleKey< LWTNNCollection > m_writeKey
bool range(EventIDRange &r)
const std::string & key() const
const std::string & key() const
StatusCode record(const EventIDRange &range, T *t)
record handle, with explicit range DEPRECATED
const DataObjID & fullKey() const
Primary Vertex Finder.