16 #include "G4FastTrack.hh"
17 #include "G4FastStep.hh"
20 #include "lwtnn/parse_json.hh"
23 #include <libxml/xmlmemory.h>
24 #include <libxml/parser.h>
25 #include <libxml/tree.h>
26 #include <libxml/xmlreader.h>
27 #include <libxml/xpath.h>
28 #include <libxml/xpathInternals.h>
36 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Initializing PunchThroughG4Classifier" );
47 return StatusCode::SUCCESS;
52 ATH_MSG_DEBUG(
"[punchthroughclassifier] finalize() successful" );
54 return StatusCode::SUCCESS;
63 doc = xmlParseFile( scalerConfigFile.c_str() );
65 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
67 for( xmlNodePtr nodeRoot =
doc->children; nodeRoot !=
nullptr; nodeRoot = nodeRoot->next) {
69 if (xmlStrEqual( nodeRoot->name, BAD_CAST
"Transformations" )) {
70 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform !=
nullptr; nodeTransform = nodeTransform->next ) {
73 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"ScalerValues" )) {
74 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"min")) !=
nullptr) {
77 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"max")) !=
nullptr) {
83 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"VarScales" )) {
84 std::string
name =
"";
87 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"name")) !=
nullptr) {
88 name = (
const char*)xmlBuff;
90 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"min")) !=
nullptr) {
91 min =
atof((
const char*)xmlBuff);
93 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"max")) !=
nullptr) {
94 max =
atof((
const char*)xmlBuff);
108 return StatusCode::SUCCESS;
113 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
115 std::ifstream
input(networkConfigFile);
117 ATH_MSG_ERROR(
"Could not find json file " << networkConfigFile );
118 return StatusCode::FAILURE;
123 ATH_MSG_ERROR(
"Could not parse graph json file " << networkConfigFile );
124 return StatusCode::FAILURE;
128 return StatusCode::SUCCESS;
137 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
139 doc = xmlParseFile( calibratorConfigFile.c_str() );
141 for( xmlNodePtr nodeRoot =
doc->children; nodeRoot !=
nullptr; nodeRoot = nodeRoot->next) {
143 if (xmlStrEqual( nodeRoot->name, BAD_CAST
"Transformations" )) {
144 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform !=
nullptr; nodeTransform = nodeTransform->next ) {
147 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"LimitValues" )) {
148 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"min")) !=
nullptr) {
151 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"max")) !=
nullptr) {
157 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"LinearNorm" )) {
160 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"orig")) !=
nullptr) {
161 orig =
atof((
const char*)xmlBuff);
163 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST
"norm")) !=
nullptr) {
177 return StatusCode::SUCCESS;
182 std::map<std::string, std::map<std::string, double> > networkInputs =
computeInputs(fastTrack, simE, simEfrac);
186 std::map<std::string, double> networkOutputs =
m_graph->compute(networkInputs);
190 return calibratedOutput;
197 std::map<std::string, std::map<std::string, double> > networkInputs;
200 networkInputs[
"node_0"] = {
201 {
"variable_0", fastTrack.GetPrimaryTrack()->GetMomentum().mag() },
202 {
"variable_1", std::abs(fastTrack.GetPrimaryTrack()->GetPosition().eta()) },
203 {
"variable_2", fastTrack.GetPrimaryTrack()->GetPosition().phi() },
204 {
"variable_3", simE},
208 for (
unsigned int i = 0;
i < simEfrac.size();
i++) {
209 networkInputs[
"node_0"].insert({
"variable_" +
std::to_string(
i + 4), simEfrac[
i]});
212 return networkInputs;
248 auto lower =
upper--;
251 double m = (
upper->second - lower->second)/(
upper->first - lower->first);
252 double c = lower->second -
m * lower->first;
253 double calibrated =
m * networkOutput +
c;