ATLAS Offline Software
Loading...
Searching...
No Matches
PunchThroughG4Classifier.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6// PunchThroughG4Classifier.cxx, (c) ATLAS Detector software
9
10#include <fstream>
11
12// PathResolver
14
15// Geant4
16#include "G4FastTrack.hh"
17#include "G4FastStep.hh"
18
19//LWTNN
20#include "lwtnn/parse_json.hh"
21
22//libXML
23#include <libxml/xmlmemory.h>
24#include <libxml/parser.h>
25#include <libxml/tree.h>
26#include <libxml/xmlreader.h>
27#include <libxml/xpath.h>
28#include <libxml/xpathInternals.h>
29
30PunchThroughG4Classifier::PunchThroughG4Classifier(const std::string& type, const std::string& name, const IInterface* parent)
31 : base_class(type, name, parent) {
32}
33
35
36 ATH_MSG_DEBUG( "[ punchthroughclassifier ] Initializing PunchThroughG4Classifier" );
37
38 std::string resolvedScalerFileName = PathResolverFindCalibFile (m_scalerConfigFileName);
39 ATH_CHECK ( initializeScaler(resolvedScalerFileName) );
40
41 std::string resolvedNetworkFileName = PathResolverFindCalibFile (m_networkConfigFileName);
42 ATH_CHECK ( initializeNetwork(resolvedNetworkFileName) );
43
44 std::string resolvedCalibratorFileName = PathResolverFindCalibFile (m_calibratorConfigFileName);
45 ATH_CHECK ( initializeCalibrator(resolvedCalibratorFileName) );
46
47 return StatusCode::SUCCESS;
48}
49
51
52 ATH_MSG_DEBUG( "[punchthroughclassifier] finalize() successful" );
53
54 return StatusCode::SUCCESS;
55}
56
57StatusCode PunchThroughG4Classifier::initializeScaler(const std::string & scalerConfigFile){
58 // Initialize pointers
59 xmlDocPtr doc;
60 xmlChar* xmlBuff = nullptr;
61
62 // Parse xml that contains config for MinMaxScaler for each of the network inputs
63 doc = xmlParseFile( scalerConfigFile.c_str() );
64
65 ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
66
67 for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
68
69 if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
70 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
71
72 //Get min and max values that we normalise values to
73 if (xmlStrEqual( nodeTransform->name, BAD_CAST "ScalerValues" )) {
74 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "min")) != nullptr) {
75 m_scalerMin = atof(reinterpret_cast<const char*>(xmlBuff));
76 }
77 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "max")) != nullptr) {
78 m_scalerMax = atof(reinterpret_cast<const char*>(xmlBuff));
79 }
80 }
81
82 //Get values necessary to normalise each input variable
83 if (xmlStrEqual( nodeTransform->name, BAD_CAST "VarScales" )) {
84 std::string name = "";
85 double min=-1, max=-1;
86
87 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "name")) != nullptr) {
88 name = reinterpret_cast<const char*>(xmlBuff);
89 }
90 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "min")) != nullptr) {
91 min = atof(reinterpret_cast<const char*>(xmlBuff));
92 }
93 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "max")) != nullptr) {
94 max = atof(reinterpret_cast<const char*>(xmlBuff));
95 }
96
97 // Insert into maps
98 m_scalerMinMap.insert ( std::pair<std::string, double>(name, min) );
99 m_scalerMaxMap.insert ( std::pair<std::string, double>(name, max) );
100 }
101 }
102 }
103 }
104
105 // free memory when done
106 xmlFreeDoc(doc);
107
108 return StatusCode::SUCCESS;
109}
110
111StatusCode PunchThroughG4Classifier::initializeNetwork(const std::string & networkConfigFile){
112
113 ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
114
115 std::ifstream input(networkConfigFile);
116 if(!input){
117 ATH_MSG_ERROR("Could not find json file " << networkConfigFile );
118 return StatusCode::FAILURE;
119 }
120
121 m_graph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
122 if(!m_graph){
123 ATH_MSG_ERROR("Could not parse graph json file " << networkConfigFile );
124 return StatusCode::FAILURE;
125 }
126
127
128 return StatusCode::SUCCESS;
129}
130
131StatusCode PunchThroughG4Classifier::initializeCalibrator(const std::string & calibratorConfigFile){
132 // Initialize pointers
133 xmlDocPtr doc;
134 xmlChar* xmlBuff = nullptr;
135
136 //parse xml that contains config for isotonic regressor used to calibrate the network output
137 ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
138
139 doc = xmlParseFile( calibratorConfigFile.c_str() );
140
141 for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
142
143 if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
144 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
145
146 //get lower and upper bounds of isotonic regressor
147 if (xmlStrEqual( nodeTransform->name, BAD_CAST "LimitValues" )) {
148 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "min")) != nullptr) {
149 m_calibrationMin = atof(reinterpret_cast<const char*>(xmlBuff));
150 }
151 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "max")) != nullptr) {
152 m_calibrationMax = atof(reinterpret_cast<const char*>(xmlBuff));
153 }
154 }
155
156 //get defined points where isotonic regressor knows transform
157 if (xmlStrEqual( nodeTransform->name, BAD_CAST "LinearNorm" )) {
158 double orig = -1;
159 double norm = -1;
160 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "orig")) != nullptr) {
161 orig = atof(reinterpret_cast<const char*>(xmlBuff));
162 }
163 if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "norm")) != nullptr) {
164 norm = atof(reinterpret_cast<const char*>(xmlBuff));
165 }
166
167 // Insert into maps
168 m_calibrationMap.insert ( std::pair<double,double>(orig, norm) );
169 }
170 }
171 }
172 }
173
174 // free memory when done
175 xmlFreeDoc(doc);
176
177 return StatusCode::SUCCESS;
178}
179
180double PunchThroughG4Classifier::computePunchThroughProbability(const G4FastTrack& fastTrack, const double simE, const std::vector<double> & simEfrac) const {
181
182 std::map<std::string, std::map<std::string, double> > networkInputs = computeInputs(fastTrack, simE, simEfrac); //compute inputs
183
184 networkInputs = scaleInputs(networkInputs); //scale inputs
185
186 std::map<std::string, double> networkOutputs = m_graph->compute(networkInputs); //call neural network on inputs
187
188 double calibratedOutput = calibrateOutput(networkOutputs["out_0"]); //calibrate neural network output
189
190 return calibratedOutput;
191}
192
193std::map<std::string, std::map<std::string, double> > PunchThroughG4Classifier::computeInputs(const G4FastTrack& fastTrack, const double simE, const std::vector<double> & simEfrac) {
194
195 //calculate inputs for NN
196
197 std::map<std::string, std::map<std::string, double> > networkInputs;
198
199 //add initial particle and total energy variables
200 networkInputs["node_0"] = {
201 {"variable_0", fastTrack.GetPrimaryTrack()->GetMomentum().mag() },
202 {"variable_1", std::abs(fastTrack.GetPrimaryTrack()->GetPosition().eta()) },
203 {"variable_2", fastTrack.GetPrimaryTrack()->GetPosition().phi() },
204 {"variable_3", simE},
205 };
206
207 //add energy fraction variables
208 for (unsigned int i = 0; i < simEfrac.size(); i++) { //from 0 to 23, 24 layers
209 networkInputs["node_0"].insert({"variable_" + std::to_string(i + 4), simEfrac[i]});
210 }
211
212 return networkInputs;
213}
214
215std::map<std::string, std::map<std::string, double> > PunchThroughG4Classifier::scaleInputs(std::map<std::string, std::map<std::string, double> >& inputs) const{
216
217 //apply MinMaxScaler to network inputs
218
219 for (auto& var : inputs["node_0"]) {
220
221 double x_std;
222 if(m_scalerMaxMap.at(var.first) != m_scalerMinMap.at(var.first)){
223 x_std = (var.second - m_scalerMinMap.at(var.first)) / (m_scalerMaxMap.at(var.first) - m_scalerMinMap.at(var.first));
224 }
225 else{
226 x_std = (var.second - m_scalerMinMap.at(var.first));
227 }
228 var.second = x_std * (m_scalerMax - m_scalerMin) + m_scalerMin;
229 }
230
231 return inputs;
232}
233
234double PunchThroughG4Classifier::calibrateOutput(double& networkOutput) const {
235
236 //calibrate output of network using isotonic regressor model
237
238 //if network output is outside of the range of isotonic regressor then return min and max values
239 if (networkOutput < m_calibrationMin){
240 return m_calibrationMin;
241 }
242 else if (networkOutput > m_calibrationMax){
243 return m_calibrationMax;
244 }
245
246 //otherwise find neighbouring points in isotonic regressor
247 auto upper = m_calibrationMap.upper_bound(networkOutput);
248 auto lower = upper--;
249
250 //Perform linear interpolation between points
251 double m = (upper->second - lower->second)/(upper->first - lower->first);
252 double c = lower->second - m * lower->first;
253 double calibrated = m * networkOutput + c;
254
255 return calibrated;
256}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_DEBUG(x)
int upper(int c)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
unsigned char xmlChar
#define min(a, b)
Definition cfImp.cxx:40
#define max(a, b)
Definition cfImp.cxx:41
std::unique_ptr< lwt::LightweightGraph > m_graph
NN graph.
std::map< std::string, double > m_scalerMaxMap
virtual StatusCode initialize() override
AlgTool initialize method.
double m_scalerMin
input variable MinMaxScaler members
std::map< std::string, std::map< std::string, double > > scaleInputs(std::map< std::string, std::map< std::string, double > > &inputs) const
scale NN inputs using MinMaxScaler
std::map< double, double > m_calibrationMap
PunchThroughG4Classifier(const std::string &, const std::string &, const IInterface *)
Constructor.
std::map< std::string, double > m_scalerMinMap
StatusCode initializeCalibrator(const std::string &calibratorConfigFile)
isotonic regressor calibrator initialize method
virtual StatusCode finalize() override
AlgTool finalize method.
StatusCode initializeScaler(const std::string &scalerConfigFile)
input variable MinMaxScaler initialize method
StatusCode initializeNetwork(const std::string &networkConfigFile)
neural network initialize method
double calibrateOutput(double &networkOutput) const
calibrate NN output using isotonic regressor
static std::map< std::string, std::map< std::string, double > > computeInputs(const G4FastTrack &fastTrack, const double simE, const std::vector< double > &simEfrac)
calcalate NN inputs based on G4FastTrack and simulstate
virtual double computePunchThroughProbability(const G4FastTrack &fastTrack, const double simE, const std::vector< double > &simEfrac) const override
interface method to return probability prediction of punch through