14 #include "lwtnn/parse_json.hh"
17 #include <libxml/xmlmemory.h>
18 #include <libxml/parser.h>
19 #include <libxml/tree.h>
20 #include <libxml/xmlreader.h>
21 #include <libxml/xpath.h>
22 #include <libxml/xpathInternals.h>
34 std::map<std::string, std::map<std::string, double> > networkInputs = computeInputs(isfp, simulstate);
36 networkInputs = scaleInputs(networkInputs);
38 std::map<std::string, double> networkOutputs = m_graph->compute(networkInputs);
40 double calibratedOutput = calibrateOutput(networkOutputs[
"out_0"]);
42 return calibratedOutput;
51 if ( initializeScaler(resolvedScalerFileName) != StatusCode::SUCCESS)
53 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier input scaler");
57 if ( initializeNetwork(resolvedNetworkFileName) != StatusCode::SUCCESS)
59 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier network");
63 if ( initializeCalibrator(resolvedCalibratorFileName) != StatusCode::SUCCESS)
65 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier calibrator");
68 return StatusCode::SUCCESS;
75 xmlDocPtr
doc = xmlParseFile( scalerConfigFile.c_str() );
77 ATH_MSG_INFO(
"[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
79 for( xmlNodePtr nodeRoot =
doc->children; nodeRoot !=
nullptr; nodeRoot = nodeRoot->next) {
81 if (xmlStrEqual( nodeRoot->name, BAD_CAST
"Transformations" )) {
82 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform !=
nullptr; nodeTransform = nodeTransform->next ) {
85 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"ScalerValues" )) {
86 m_scalerMin =
atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"min" ) );
87 m_scalerMax =
atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"max" ) );
91 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"VarScales" )) {
92 std::string
name = (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"name" );
93 double min =
atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"min" ) );
94 double max =
atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"max" ) );
95 m_scalerMinMap.insert ( std::pair<std::string, double>(
name,
min) );
96 m_scalerMaxMap.insert ( std::pair<std::string, double>(
name,
max) );
102 return StatusCode::SUCCESS;
107 ATH_MSG_INFO(
"[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
109 std::ifstream
input(networkConfigFile);
111 ATH_MSG_ERROR(
"Could not find json file " << networkConfigFile );
112 return StatusCode::FAILURE;
117 ATH_MSG_ERROR(
"Could not parse graph json file " << networkConfigFile );
118 return StatusCode::FAILURE;
122 return StatusCode::SUCCESS;
129 ATH_MSG_INFO(
"[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
131 xmlDocPtr
doc = xmlParseFile( calibratorConfigFile.c_str() );
133 for( xmlNodePtr nodeRoot =
doc->children; nodeRoot !=
nullptr; nodeRoot = nodeRoot->next) {
135 if (xmlStrEqual( nodeRoot->name, BAD_CAST
"Transformations" )) {
136 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform !=
nullptr; nodeTransform = nodeTransform->next ) {
139 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"LimitValues" )) {
140 m_calibrationMin =
atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"min" ) );
141 m_calibrationMax =
atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"max" ) );
145 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"LinearNorm" )) {
146 double orig =
atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"orig" ) );
147 double norm =
atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"norm" ) );
148 m_calibrationMap.insert ( std::pair<double,double>(orig,
norm) );
154 return StatusCode::SUCCESS;
161 std::map<std::string, std::map<std::string, double> > networkInputs;
164 networkInputs[
"node_0"] = {
165 {
"variable_0", isfp.
momentum().mag() },
166 {
"variable_1", std::abs(isfp.
position().eta()) },
167 {
"variable_2", isfp.
position().phi() },
168 {
"variable_3", simulstate.
E()},
172 for (
unsigned int i = 0;
i < 24;
i++){
176 return networkInputs;
186 if(m_scalerMaxMap.at(
var.first) != m_scalerMinMap.at(
var.first)){
187 x_std = (
var.second - m_scalerMinMap.at(
var.first)) / (m_scalerMaxMap.at(
var.first) - m_scalerMinMap.at(
var.first));
190 x_std = (
var.second - m_scalerMinMap.at(
var.first));
192 var.second = x_std * (m_scalerMax - m_scalerMin) + m_scalerMin;
203 if (networkOutput < m_calibrationMin){
204 return m_calibrationMin;
206 else if (networkOutput > m_calibrationMax){
207 return m_calibrationMax;
211 auto upper = m_calibrationMap.upper_bound(networkOutput);
212 auto lower =
upper--;
215 double m = (
upper->second - lower->second)/(
upper->first - lower->first);
216 double c = lower->second -
m * lower->first;
217 double calibrated =
m * networkOutput +
c;