Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
PunchThroughG4Classifier.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 // PunchThroughG4Classifier.cxx, (c) ATLAS Detector software
9 
10 #include <fstream>
11 
12 // PathResolver
14 
15 // Geant4
16 #include "G4FastTrack.hh"
17 #include "G4FastStep.hh"
18 
19 //LWTNN
20 #include "lwtnn/parse_json.hh"
21 
22 //libXML
23 #include <libxml/xmlmemory.h>
24 #include <libxml/parser.h>
25 #include <libxml/tree.h>
26 #include <libxml/xmlreader.h>
27 #include <libxml/xpath.h>
28 #include <libxml/xpathInternals.h>
29 
30 PunchThroughG4Classifier::PunchThroughG4Classifier(const std::string& type, const std::string& name, const IInterface* parent)
31  : base_class(type, name, parent) {
32 }
33 
35 
36  ATH_MSG_DEBUG( "[ punchthroughclassifier ] Initializing PunchThroughG4Classifier" );
37 
38  std::string resolvedScalerFileName = PathResolverFindCalibFile (m_scalerConfigFileName);
39  ATH_CHECK ( initializeScaler(resolvedScalerFileName) );
40 
41  std::string resolvedNetworkFileName = PathResolverFindCalibFile (m_networkConfigFileName);
42  ATH_CHECK ( initializeNetwork(resolvedNetworkFileName) );
43 
44  std::string resolvedCalibratorFileName = PathResolverFindCalibFile (m_calibratorConfigFileName);
45  ATH_CHECK ( initializeCalibrator(resolvedCalibratorFileName) );
46 
47  return StatusCode::SUCCESS;
48 }
49 
51 
52  ATH_MSG_DEBUG( "[punchthroughclassifier] finalize() successful" );
53 
54  return StatusCode::SUCCESS;
55 }
56 
57 StatusCode PunchThroughG4Classifier::initializeScaler(const std::string & scalerConfigFile){
58  // Initialize pointers
59  xmlDocPtr doc;
60  xmlChar* xmlBuff = nullptr;
61 
62  // Parse xml that contains config for MinMaxScaler for each of the network inputs
63  doc = xmlParseFile( scalerConfigFile.c_str() );
64 
65  ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
66 
67  for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
68 
69  if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
70  for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
71 
72  //Get min and max values that we normalise values to
73  if (xmlStrEqual( nodeTransform->name, BAD_CAST "ScalerValues" )) {
74  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "min")) != nullptr) {
75  m_scalerMin = atof((const char*)xmlBuff);
76  }
77  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "max")) != nullptr) {
78  m_scalerMax = atof((const char*)xmlBuff);
79  }
80  }
81 
82  //Get values necessary to normalise each input variable
83  if (xmlStrEqual( nodeTransform->name, BAD_CAST "VarScales" )) {
84  std::string name = "";
85  double min=-1, max=-1;
86 
87  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "name")) != nullptr) {
88  name = (const char*)xmlBuff;
89  }
90  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "min")) != nullptr) {
91  min = atof((const char*)xmlBuff);
92  }
93  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "max")) != nullptr) {
94  max = atof((const char*)xmlBuff);
95  }
96 
97  // Insert into maps
98  m_scalerMinMap.insert ( std::pair<std::string, double>(name, min) );
99  m_scalerMaxMap.insert ( std::pair<std::string, double>(name, max) );
100  }
101  }
102  }
103  }
104 
105  // free memory when done
106  xmlFreeDoc(doc);
107 
108  return StatusCode::SUCCESS;
109 }
110 
111 StatusCode PunchThroughG4Classifier::initializeNetwork(const std::string & networkConfigFile){
112 
113  ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
114 
115  std::ifstream input(networkConfigFile);
116  if(!input){
117  ATH_MSG_ERROR("Could not find json file " << networkConfigFile );
118  return StatusCode::FAILURE;
119  }
120 
121  m_graph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
122  if(!m_graph){
123  ATH_MSG_ERROR("Could not parse graph json file " << networkConfigFile );
124  return StatusCode::FAILURE;
125  }
126 
127 
128  return StatusCode::SUCCESS;
129 }
130 
131 StatusCode PunchThroughG4Classifier::initializeCalibrator(const std::string & calibratorConfigFile){
132  // Initialize pointers
133  xmlDocPtr doc;
134  xmlChar* xmlBuff = nullptr;
135 
136  //parse xml that contains config for isotonic regressor used to calibrate the network output
137  ATH_MSG_DEBUG( "[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
138 
139  doc = xmlParseFile( calibratorConfigFile.c_str() );
140 
141  for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
142 
143  if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
144  for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
145 
146  //get lower and upper bounds of isotonic regressor
147  if (xmlStrEqual( nodeTransform->name, BAD_CAST "LimitValues" )) {
148  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "min")) != nullptr) {
149  m_calibrationMin = atof((const char*)xmlBuff);
150  }
151  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "max")) != nullptr) {
152  m_calibrationMax = atof((const char*)xmlBuff);
153  }
154  }
155 
156  //get defined points where isotonic regressor knows transform
157  if (xmlStrEqual( nodeTransform->name, BAD_CAST "LinearNorm" )) {
158  double orig = -1;
159  double norm = -1;
160  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "orig")) != nullptr) {
161  orig = atof((const char*)xmlBuff);
162  }
163  if ((xmlBuff = xmlGetProp(nodeTransform, BAD_CAST "norm")) != nullptr) {
164  norm = atof((const char*)xmlBuff);
165  }
166 
167  // Insert into maps
168  m_calibrationMap.insert ( std::pair<double,double>(orig, norm) );
169  }
170  }
171  }
172  }
173 
174  // free memory when done
175  xmlFreeDoc(doc);
176 
177  return StatusCode::SUCCESS;
178 }
179 
180 double PunchThroughG4Classifier::computePunchThroughProbability(const G4FastTrack& fastTrack, const double simE, const std::vector<double> & simEfrac) const {
181 
182  std::map<std::string, std::map<std::string, double> > networkInputs = computeInputs(fastTrack, simE, simEfrac); //compute inputs
183 
184  networkInputs = scaleInputs(networkInputs); //scale inputs
185 
186  std::map<std::string, double> networkOutputs = m_graph->compute(networkInputs); //call neural network on inputs
187 
188  double calibratedOutput = calibrateOutput(networkOutputs["out_0"]); //calibrate neural network output
189 
190  return calibratedOutput;
191 }
192 
193 std::map<std::string, std::map<std::string, double> > PunchThroughG4Classifier::computeInputs(const G4FastTrack& fastTrack, const double simE, const std::vector<double> & simEfrac) {
194 
195  //calculate inputs for NN
196 
197  std::map<std::string, std::map<std::string, double> > networkInputs;
198 
199  //add initial particle and total energy variables
200  networkInputs["node_0"] = {
201  {"variable_0", fastTrack.GetPrimaryTrack()->GetMomentum().mag() },
202  {"variable_1", std::abs(fastTrack.GetPrimaryTrack()->GetPosition().eta()) },
203  {"variable_2", fastTrack.GetPrimaryTrack()->GetPosition().phi() },
204  {"variable_3", simE},
205  };
206 
207  //add energy fraction variables
208  for (unsigned int i = 0; i < simEfrac.size(); i++) { //from 0 to 23, 24 layers
209  networkInputs["node_0"].insert({"variable_" + std::to_string(i + 4), simEfrac[i]});
210  }
211 
212  return networkInputs;
213 }
214 
215 std::map<std::string, std::map<std::string, double> > PunchThroughG4Classifier::scaleInputs(std::map<std::string, std::map<std::string, double> >& inputs) const{
216 
217  //apply MinMaxScaler to network inputs
218 
219  for (auto& var : inputs["node_0"]) {
220 
221  double x_std;
222  if(m_scalerMaxMap.at(var.first) != m_scalerMinMap.at(var.first)){
223  x_std = (var.second - m_scalerMinMap.at(var.first)) / (m_scalerMaxMap.at(var.first) - m_scalerMinMap.at(var.first));
224  }
225  else{
226  x_std = (var.second - m_scalerMinMap.at(var.first));
227  }
228  var.second = x_std * (m_scalerMax - m_scalerMin) + m_scalerMin;
229  }
230 
231  return inputs;
232 }
233 
234 double PunchThroughG4Classifier::calibrateOutput(double& networkOutput) const {
235 
236  //calibrate output of network using isotonic regressor model
237 
238  //if network output is outside of the range of isotonic regressor then return min and max values
239  if (networkOutput < m_calibrationMin){
240  return m_calibrationMin;
241  }
242  else if (networkOutput > m_calibrationMax){
243  return m_calibrationMax;
244  }
245 
246  //otherwise find neighbouring points in isotonic regressor
247  auto upper = m_calibrationMap.upper_bound(networkOutput);
248  auto lower = upper--;
249 
250  //Perform linear interpolation between points
251  double m = (upper->second - lower->second)/(upper->first - lower->first);
252  double c = lower->second - m * lower->first;
253  double calibrated = m * networkOutput + c;
254 
255  return calibrated;
256 }
xmlChar
unsigned char xmlChar
Definition: TGoodRunsListWriter.h:28
PunchThroughG4Classifier::m_calibrationMin
double m_calibrationMin
Definition: PunchThroughG4Classifier.h:69
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
PunchThroughG4Classifier::computePunchThroughProbability
virtual double computePunchThroughProbability(const G4FastTrack &fastTrack, const double simE, const std::vector< double > &simEfrac) const override
interface method to return probability prediction of punch through
Definition: PunchThroughG4Classifier.cxx:180
PlotCalibFromCool.norm
norm
Definition: PlotCalibFromCool.py:100
python.SystemOfUnits.m
int m
Definition: SystemOfUnits.py:91
max
constexpr double max()
Definition: ap_fixedTest.cxx:33
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
PunchThroughG4Classifier::initialize
virtual StatusCode initialize() override
AlgTool initialize method.
Definition: PunchThroughG4Classifier.cxx:34
upper
int upper(int c)
Definition: LArBadChannelParser.cxx:49
PunchThroughG4Classifier::m_graph
std::unique_ptr< lwt::LightweightGraph > m_graph
NN graph.
Definition: PunchThroughG4Classifier.h:59
PunchThroughG4Classifier::calibrateOutput
double calibrateOutput(double &networkOutput) const
calibrate NN output using isotonic regressor
Definition: PunchThroughG4Classifier.cxx:234
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
PunchThroughG4Classifier::finalize
virtual StatusCode finalize() override
AlgTool finalize method.
Definition: PunchThroughG4Classifier.cxx:50
PunchThroughG4Classifier::m_scalerConfigFileName
StringProperty m_scalerConfigFileName
Definition: PunchThroughG4Classifier.h:76
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
lumiFormat.i
int i
Definition: lumiFormat.py:85
PunchThroughG4Classifier::computeInputs
static std::map< std::string, std::map< std::string, double > > computeInputs(const G4FastTrack &fastTrack, const double simE, const std::vector< double > &simEfrac)
calcalate NN inputs based on G4FastTrack and simulstate
Definition: PunchThroughG4Classifier.cxx:193
PunchThroughG4Classifier::m_scalerMax
double m_scalerMax
Definition: PunchThroughG4Classifier.h:63
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
PunchThroughG4Classifier.h
test_pyathena.parent
parent
Definition: test_pyathena.py:15
PunchThroughG4Classifier::initializeCalibrator
StatusCode initializeCalibrator(const std::string &calibratorConfigFile)
isotonic regressor calibrator initialize method
Definition: PunchThroughG4Classifier.cxx:131
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
PunchThroughG4Classifier::PunchThroughG4Classifier
PunchThroughG4Classifier(const std::string &, const std::string &, const IInterface *)
Constructor.
Definition: PunchThroughG4Classifier.cxx:30
PunchThroughG4Classifier::m_calibrationMap
std::map< double, double > m_calibrationMap
Definition: PunchThroughG4Classifier.h:71
CxxUtils::atof
double atof(std::string_view str)
Converts a string into a double / float.
Definition: Control/CxxUtils/Root/StringUtils.cxx:91
PunchThroughG4Classifier::m_networkConfigFileName
StringProperty m_networkConfigFileName
Definition: PunchThroughG4Classifier.h:77
PunchThroughG4Classifier::initializeScaler
StatusCode initializeScaler(const std::string &scalerConfigFile)
input variable MinMaxScaler initialize method
Definition: PunchThroughG4Classifier.cxx:57
PathResolver.h
merge_scale_histograms.doc
string doc
Definition: merge_scale_histograms.py:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
PunchThroughG4Classifier::m_calibratorConfigFileName
StringProperty m_calibratorConfigFileName
Definition: PunchThroughG4Classifier.h:78
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
PunchThroughG4Classifier::m_scalerMin
double m_scalerMin
input variable MinMaxScaler members
Definition: PunchThroughG4Classifier.h:62
PunchThroughG4Classifier::m_scalerMinMap
std::map< std::string, double > m_scalerMinMap
Definition: PunchThroughG4Classifier.h:64
PunchThroughG4Classifier::m_calibrationMax
double m_calibrationMax
Definition: PunchThroughG4Classifier.h:70
PunchThroughG4Classifier::m_scalerMaxMap
std::map< std::string, double > m_scalerMaxMap
Definition: PunchThroughG4Classifier.h:65
PunchThroughG4Classifier::initializeNetwork
StatusCode initializeNetwork(const std::string &networkConfigFile)
neural network initialize method
Definition: PunchThroughG4Classifier.cxx:111
python.compressB64.c
def c
Definition: compressB64.py:93
PunchThroughG4Classifier::scaleInputs
std::map< std::string, std::map< std::string, double > > scaleInputs(std::map< std::string, std::map< std::string, double > > &inputs) const
scale NN inputs using MinMaxScaler
Definition: PunchThroughG4Classifier.cxx:215