12 #include "CoolKernel/IObject.h"
16 #include "lwtnn/parse_json.hh"
17 #include "lwtnn/Exceptions.hh"
18 #include "lwtnn/lightweight_nn_streamers.hh"
19 #include "lwtnn/NanReplacer.hh"
22 #define BOOST_BIND_GLOBAL_PLACEHOLDERS // Needed to silence Boost pragma message
23 #include <boost/property_tree/ptree.hpp>
24 #include <boost/property_tree/json_parser.hpp>
25 #include "boost/property_tree/exceptions.hpp"
45 return StatusCode::SUCCESS;
50 return StatusCode::SUCCESS;
54 const std::string& thisJson)
const {
59 std::istringstream input_cfg( thisJson );
61 }
catch (boost::property_tree::ptree_error&
err) {
63 return StatusCode::FAILURE;
74 thisNN = std::make_unique<lwt::atlas::FastGraph>(
config,
order,
"merge_1");
75 }
catch (lwt::NNConfigurationException& exc) {
77 return StatusCode::FAILURE;
80 return StatusCode::SUCCESS;
87 if (NnWriteHandle.isValid()) {
88 ATH_MSG_DEBUG(
"Write CondHandle "<< NnWriteHandle.fullKey() <<
" is already valid");
89 return StatusCode::SUCCESS;
93 if(!readHandle.isValid()) {
95 return StatusCode::FAILURE;
98 assert( atrcol !=
nullptr);
105 if ((attrList[
"NNConfigurations"]).isNull()) {
107 return StatusCode::FAILURE;
112 const std::string megajson = attrList[
"NNConfigurations"].data<cool::String16M>();
115 std::istringstream initializerStream(megajson);
116 namespace pt = boost::property_tree;
118 pt::read_json(initializerStream, parentTree);
119 std::ostringstream configStream;
122 EventIDRange cdo_iov;
123 if(!readHandle.range(cdo_iov)) {
124 ATH_MSG_ERROR(
"Failed to get valid validity range from " << readHandle.key());
125 return StatusCode::FAILURE;
130 std::unique_ptr<LWTNNCollection> writeCdo{std::make_unique<LWTNNCollection>()};
133 pt::ptree subtreeNumberNetwork = parentTree.get_child(
"NumberNetwork");
134 writeCdo->insert(std::make_pair(0,std::unique_ptr<lwt::atlas::FastGraph>(
nullptr)));
136 if(subtreeNumberNetwork.empty()) {
137 ATH_MSG_ERROR(
"You are trying to use lwtnn for the number network but have an empty configuration file; this should never happen!");
138 return StatusCode::FAILURE;
143 pt::write_json(configStream, subtreeNumberNetwork);
144 std::string numberNetworkConfig = configStream.str();
145 if ((
configureLwtnn(writeCdo->at(0), numberNetworkConfig)).isFailure())
146 return StatusCode::FAILURE;
152 for (
int i=1;
i<4;
i++) {
154 configStream.str(
"");
155 pt::ptree subtreePosNetwork = parentTree.get_child(
key);
156 pt::write_json(configStream, subtreePosNetwork);
157 std::string posNetworkConfig = configStream.str();
160 writeCdo->insert(std::make_pair(
i,std::unique_ptr<lwt::atlas::FastGraph>(
nullptr)));
163 if(subtreePosNetwork.empty()) {
164 ATH_MSG_ERROR(
"You are trying to use lwtnn for the position networks but have an empty configuration file; this should never happen!");
165 return StatusCode::FAILURE;
168 ATH_MSG_DEBUG(
"Setting up lwtnn for n = " <<
i <<
" position network...");
170 return StatusCode::FAILURE;
177 if(NnWriteHandle.record(cdo_iov,std::move(writeCdo)).isFailure()) {
178 ATH_MSG_ERROR(
"Failed to record Trained network collection to "
179 << NnWriteHandle.key()
180 <<
" with IOV " << cdo_iov );
181 return StatusCode::FAILURE;
184 return StatusCode::SUCCESS;