ATLAS Offline Software
Loading...
Searching...
No Matches
PunchThroughG4Classifier.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6// PunchThroughG4Classifier.cxx, (c) ATLAS Detector software
9
10#include <fstream>
11
12// PathResolver
14
15// Geant4
16#include "G4FastTrack.hh"
17#include "G4FastStep.hh"
18
19//LWTNN
20#include "lwtnn/parse_json.hh"
21
22//libXML
23#include <libxml/xmlmemory.h>
24#include <libxml/parser.h>
25#include <libxml/tree.h>
26#include <libxml/xmlreader.h>
27#include <libxml/xpath.h>
28#include <libxml/xpathInternals.h>
29
31
32using namespace CxxUtils;
33
34PunchThroughG4Classifier::PunchThroughG4Classifier(const std::string& type, const std::string& name, const IInterface* parent)
35 : base_class(type, name, parent) {
36}
37
39
40 ATH_MSG_DEBUG( "[ punchthroughclassifier ] Initializing PunchThroughG4Classifier" );
41
42 std::string resolvedScalerFileName = PathResolverFindCalibFile (m_scalerConfigFileName);
43 ATH_CHECK ( initializeScaler(resolvedScalerFileName) );
44
45 std::string resolvedNetworkFileName = PathResolverFindCalibFile (m_networkConfigFileName);
46 ATH_CHECK ( initializeNetwork(resolvedNetworkFileName) );
47
48 std::string resolvedCalibratorFileName = PathResolverFindCalibFile (m_calibratorConfigFileName);
49 ATH_CHECK ( initializeCalibrator(resolvedCalibratorFileName) );
50
51 return StatusCode::SUCCESS;
52}
53
55
56 ATH_MSG_DEBUG( "[punchthroughclassifier] finalize() successful" );
57
58 return StatusCode::SUCCESS;
59}
60
61StatusCode PunchThroughG4Classifier::initializeScaler(const std::string & scalerConfigFile){
62 // Initialize pointers
63 xmlDocPtr doc;
64
65 // Parse xml that contains config for MinMaxScaler for each of the network inputs
66 doc = xmlParseFile( scalerConfigFile.c_str() );
67
68 ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
69
70 for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
71
72 if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
73 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
74
75 //Get min and max values that we normalise values to
76 if (xmlStrEqual( nodeTransform->name, BAD_CAST "ScalerValues" )) {
77 GetXmlAttrIfThere(nodeTransform, "min", m_scalerMin);
78 GetXmlAttrIfThere(nodeTransform, "max", m_scalerMax);
79 }
80
81 //Get values necessary to normalise each input variable
82 if (xmlStrEqual( nodeTransform->name, BAD_CAST "VarScales" )) {
83 std::string name = "";
84 double min=-1, max=-1;
85
86 GetXmlAttrIfThere(nodeTransform, "name", name);
87 GetXmlAttrIfThere(nodeTransform, "min", min);
88 GetXmlAttrIfThere(nodeTransform, "max", max);
89
90 // Insert into maps
91 m_scalerMinMap.emplace ( name, min );
92 m_scalerMaxMap.emplace ( name, max );
93 }
94 }
95 }
96 }
97
98 // free memory when done
99 xmlFreeDoc(doc);
100
101 return StatusCode::SUCCESS;
102}
103
104StatusCode PunchThroughG4Classifier::initializeNetwork(const std::string & networkConfigFile){
105
106 ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
107
108 std::ifstream input(networkConfigFile);
109 if(!input){
110 ATH_MSG_ERROR("Could not find json file " << networkConfigFile );
111 return StatusCode::FAILURE;
112 }
113
114 m_graph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
115 if(!m_graph){
116 ATH_MSG_ERROR("Could not parse graph json file " << networkConfigFile );
117 return StatusCode::FAILURE;
118 }
119
120
121 return StatusCode::SUCCESS;
122}
123
124StatusCode PunchThroughG4Classifier::initializeCalibrator(const std::string & calibratorConfigFile){
125 // Initialize pointers
126 xmlDocPtr doc;
127
128 //parse xml that contains config for isotonic regressor used to calibrate the network output
129 ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
130
131 doc = xmlParseFile( calibratorConfigFile.c_str() );
132
133 for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
134
135 if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
136 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
137
138 //get lower and upper bounds of isotonic regressor
139 if (xmlStrEqual( nodeTransform->name, BAD_CAST "LimitValues" )) {
140 GetXmlAttrIfThere(nodeTransform, "min", m_calibrationMin);
141 GetXmlAttrIfThere(nodeTransform, "max", m_calibrationMax);
142 }
143
144 //get defined points where isotonic regressor knows transform
145 if (xmlStrEqual( nodeTransform->name, BAD_CAST "LinearNorm" )) {
146 double orig = -1;
147 double norm = -1;
148 GetXmlAttrIfThere(nodeTransform, "orig", orig);
149 GetXmlAttrIfThere(nodeTransform, "norm", norm);
150
151 // Insert into maps
152 m_calibrationMap.emplace ( orig, norm );
153 }
154 }
155 }
156 }
157
158 // free memory when done
159 xmlFreeDoc(doc);
160
161 return StatusCode::SUCCESS;
162}
163
164double PunchThroughG4Classifier::computePunchThroughProbability(const G4FastTrack& fastTrack, const double simE, const std::vector<double> & simEfrac) const {
165
166 std::map<std::string, std::map<std::string, double> > networkInputs = computeInputs(fastTrack, simE, simEfrac); //compute inputs
167
168 networkInputs = scaleInputs(networkInputs); //scale inputs
169
170 std::map<std::string, double> networkOutputs = m_graph->compute(networkInputs); //call neural network on inputs
171
172 double calibratedOutput = calibrateOutput(networkOutputs["out_0"]); //calibrate neural network output
173
174 return calibratedOutput;
175}
176
177std::map<std::string, std::map<std::string, double> > PunchThroughG4Classifier::computeInputs(const G4FastTrack& fastTrack, const double simE, const std::vector<double> & simEfrac) {
178
179 //calculate inputs for NN
180
181 std::map<std::string, std::map<std::string, double> > networkInputs;
182
183 //add initial particle and total energy variables
184 networkInputs["node_0"] = {
185 {"variable_0", fastTrack.GetPrimaryTrack()->GetMomentum().mag() },
186 {"variable_1", std::abs(fastTrack.GetPrimaryTrack()->GetPosition().eta()) },
187 {"variable_2", fastTrack.GetPrimaryTrack()->GetPosition().phi() },
188 {"variable_3", simE},
189 };
190
191 //add energy fraction variables
192 for (unsigned int i = 0; i < simEfrac.size(); i++) { //from 0 to 23, 24 layers
193 networkInputs["node_0"].insert({"variable_" + std::to_string(i + 4), simEfrac[i]});
194 }
195
196 return networkInputs;
197}
198
199std::map<std::string, std::map<std::string, double> > PunchThroughG4Classifier::scaleInputs(std::map<std::string, std::map<std::string, double> >& inputs) const{
200
201 //apply MinMaxScaler to network inputs
202
203 for (auto& var : inputs["node_0"]) {
204
205 double x_std;
206 if(m_scalerMaxMap.at(var.first) != m_scalerMinMap.at(var.first)){
207 x_std = (var.second - m_scalerMinMap.at(var.first)) / (m_scalerMaxMap.at(var.first) - m_scalerMinMap.at(var.first));
208 }
209 else{
210 x_std = (var.second - m_scalerMinMap.at(var.first));
211 }
212 var.second = x_std * (m_scalerMax - m_scalerMin) + m_scalerMin;
213 }
214
215 return inputs;
216}
217
218double PunchThroughG4Classifier::calibrateOutput(double& networkOutput) const {
219
220 //calibrate output of network using isotonic regressor model
221
222 //if network output is outside of the range of isotonic regressor then return min and max values
223 if (networkOutput < m_calibrationMin){
224 return m_calibrationMin;
225 }
226 else if (networkOutput > m_calibrationMax){
227 return m_calibrationMax;
228 }
229
230 //otherwise find neighbouring points in isotonic regressor
231 auto upper = m_calibrationMap.upper_bound(networkOutput);
232 auto lower = upper--;
233
234 //Perform linear interpolation between points
235 double m = (upper->second - lower->second)/(upper->first - lower->first);
236 double c = lower->second - m * lower->first;
237 double calibrated = m * networkOutput + c;
238
239 return calibrated;
240}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_DEBUG(x)
int upper(int c)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
#define min(a, b)
Definition cfImp.cxx:40
#define max(a, b)
Definition cfImp.cxx:41
std::unique_ptr< lwt::LightweightGraph > m_graph
NN graph.
std::map< std::string, double > m_scalerMaxMap
virtual StatusCode initialize() override
AlgTool initialize method.
double m_scalerMin
input variable MinMaxScaler members
std::map< std::string, std::map< std::string, double > > scaleInputs(std::map< std::string, std::map< std::string, double > > &inputs) const
scale NN inputs using MinMaxScaler
std::map< double, double > m_calibrationMap
PunchThroughG4Classifier(const std::string &, const std::string &, const IInterface *)
Constructor.
std::map< std::string, double > m_scalerMinMap
StatusCode initializeCalibrator(const std::string &calibratorConfigFile)
isotonic regressor calibrator initialize method
virtual StatusCode finalize() override
AlgTool finalize method.
StatusCode initializeScaler(const std::string &scalerConfigFile)
input variable MinMaxScaler initialize method
StatusCode initializeNetwork(const std::string &networkConfigFile)
neural network initialize method
double calibrateOutput(double &networkOutput) const
calibrate NN output using isotonic regressor
static std::map< std::string, std::map< std::string, double > > computeInputs(const G4FastTrack &fastTrack, const double simE, const std::vector< double > &simEfrac)
calcalate NN inputs based on G4FastTrack and simulstate
virtual double computePunchThroughProbability(const G4FastTrack &fastTrack, const double simE, const std::vector< double > &simEfrac) const override
interface method to return probability prediction of punch through
void GetXmlAttrIfThere(xmlNodePtr node, const char *attrName, T &value)