ATLAS Offline Software
Loading...
Searching...
No Matches
PunchThroughClassifier.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7#include <fstream>
8#include <memory>
9
10// PathResolver
12
13//LWTNN
14#include "lwtnn/parse_json.hh"
15
16//libXML
17#include <libxml/xmlmemory.h>
18#include <libxml/parser.h>
19#include <libxml/tree.h>
20#include <libxml/xmlreader.h>
21#include <libxml/xpath.h>
22#include <libxml/xpathInternals.h>
23
24ISF::PunchThroughClassifier::PunchThroughClassifier(const std::string& type, const std::string& name, const IInterface* parent)
25 : base_class(type, name, parent) {
26
27 declareProperty( "ScalerConfigFileName", m_scalerConfigFileName );
28 declareProperty( "NetworkConfigFileName", m_networkConfigFileName );
29 declareProperty( "CalibratorConfigFileName", m_calibratorConfigFileName );
30}
31
33
34 std::map<std::string, std::map<std::string, double> > networkInputs = computeInputs(isfp, simulstate); //compute inputs
35
36 networkInputs = scaleInputs(networkInputs); //scale inputs
37
38 std::map<std::string, double> networkOutputs = m_graph->compute(networkInputs); //call neural network on inputs
39
40 double calibratedOutput = calibrateOutput(networkOutputs["out_0"]); //calibrate neural network output
41
42 return calibratedOutput;
43}
44
45
47
48 ATH_MSG_VERBOSE( "[ punchthroughclassifier ] initialize()" );
49
50 std::string resolvedScalerFileName = PathResolverFindCalibFile (m_scalerConfigFileName);
51 if ( initializeScaler(resolvedScalerFileName) != StatusCode::SUCCESS)
52 {
53 ATH_MSG_ERROR("[ punchthroughclassifier ] unable to load punchthroughclassifier input scaler");
54 }
55
56 std::string resolvedNetworkFileName = PathResolverFindCalibFile (m_networkConfigFileName);
57 if ( initializeNetwork(resolvedNetworkFileName) != StatusCode::SUCCESS)
58 {
59 ATH_MSG_ERROR("[ punchthroughclassifier ] unable to load punchthroughclassifier network");
60 }
61
62 std::string resolvedCalibratorFileName = PathResolverFindCalibFile (m_calibratorConfigFileName);
63 if ( initializeCalibrator(resolvedCalibratorFileName) != StatusCode::SUCCESS)
64 {
65 ATH_MSG_ERROR("[ punchthroughclassifier ] unable to load punchthroughclassifier calibrator");
66 }
67
68 return StatusCode::SUCCESS;
69}
70
71StatusCode ISF::PunchThroughClassifier::initializeScaler(const std::string & scalerConfigFile){
72
73 //parse xml that contains config for MinMaxScaler for each of the network inputs
74
75 xmlDocPtr doc = xmlParseFile( scalerConfigFile.c_str() );
76
77 ATH_MSG_INFO( "[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
78
79 for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
80
81 if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
82 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
83
84 //Get min and max values that we normalise values to
85 if (xmlStrEqual( nodeTransform->name, BAD_CAST "ScalerValues" )) {
86 m_scalerMin = atof( (const char*) xmlGetProp( nodeTransform, BAD_CAST "min" ) );
87 m_scalerMax = atof( (const char*) xmlGetProp( nodeTransform, BAD_CAST "max" ) );
88 }
89
90 //Get values necessary to normalise each input variable
91 if (xmlStrEqual( nodeTransform->name, BAD_CAST "VarScales" )) {
92 std::string name = (const char*) xmlGetProp( nodeTransform, BAD_CAST "name" );
93 double min = atof( (const char*) xmlGetProp( nodeTransform, BAD_CAST "min" ) );
94 double max = atof( (const char*) xmlGetProp( nodeTransform, BAD_CAST "max" ) );
95 m_scalerMinMap.insert ( std::pair<std::string, double>(name, min) );
96 m_scalerMaxMap.insert ( std::pair<std::string, double>(name, max) );
97 }
98 }
99 }
100 }
101
102 return StatusCode::SUCCESS;
103}
104
105StatusCode ISF::PunchThroughClassifier::initializeNetwork(const std::string & networkConfigFile){
106
107 ATH_MSG_INFO( "[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
108
109 std::ifstream input(networkConfigFile);
110 if(!input){
111 ATH_MSG_ERROR("Could not find json file " << networkConfigFile );
112 return StatusCode::FAILURE;
113 }
114
115 m_graph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
116 if(!m_graph){
117 ATH_MSG_ERROR("Could not parse graph json file " << networkConfigFile );
118 return StatusCode::FAILURE;
119 }
120
121
122 return StatusCode::SUCCESS;
123}
124
125
126StatusCode ISF::PunchThroughClassifier::initializeCalibrator(const std::string & calibratorConfigFile){
127
128 //parse xml that contains config for isotonic regressor used to calibrate the network output
129 ATH_MSG_INFO( "[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
130
131 xmlDocPtr doc = xmlParseFile( calibratorConfigFile.c_str() );
132
133 for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
134
135 if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
136 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
137
138 //get lower and upper bounds of isotonic regressor
139 if (xmlStrEqual( nodeTransform->name, BAD_CAST "LimitValues" )) {
140 m_calibrationMin = atof( (const char*) xmlGetProp( nodeTransform, BAD_CAST "min" ) );
141 m_calibrationMax = atof( (const char*) xmlGetProp( nodeTransform, BAD_CAST "max" ) );
142 }
143
144 //get defined points where isotonic regressor knows transform
145 if (xmlStrEqual( nodeTransform->name, BAD_CAST "LinearNorm" )) {
146 double orig = atof( (const char*) xmlGetProp( nodeTransform, BAD_CAST "orig" ) );
147 double norm = atof( (const char*) xmlGetProp( nodeTransform, BAD_CAST "norm" ) );
148 m_calibrationMap.insert ( std::pair<double,double>(orig, norm) );
149 }
150 }
151 }
152 }
153
154 return StatusCode::SUCCESS;
155}
156
157std::map<std::string, std::map<std::string, double> > ISF::PunchThroughClassifier::computeInputs(const ISF::ISFParticle &isfp, const TFCSSimulationState& simulstate) {
158
159 //calculate inputs for NN
160
161 std::map<std::string, std::map<std::string, double> > networkInputs;
162
163 //add initial particle and total energy variables
164 networkInputs["node_0"] = {
165 {"variable_0", isfp.momentum().mag() },
166 {"variable_1", std::abs(isfp.position().eta()) },
167 {"variable_2", isfp.position().phi() },
168 {"variable_3", simulstate.E()},
169 };
170
171 //add energy fraction variables
172 for (unsigned int i = 0; i < 24; i++){
173 networkInputs["node_0"].insert({"variable_" + std::to_string(i + 4), simulstate.Efrac(i)});
174 }
175
176 return networkInputs;
177}
178
179std::map<std::string, std::map<std::string, double> > ISF::PunchThroughClassifier::scaleInputs(std::map<std::string, std::map<std::string, double> >& inputs) const{
180
181 //apply MinMaxScaler to network inputs
182
183 for (auto& var : inputs["node_0"]) {
184
185 double x_std;
186 if(m_scalerMaxMap.at(var.first) != m_scalerMinMap.at(var.first)){
187 x_std = (var.second - m_scalerMinMap.at(var.first)) / (m_scalerMaxMap.at(var.first) - m_scalerMinMap.at(var.first));
188 }
189 else{
190 x_std = (var.second - m_scalerMinMap.at(var.first));
191 }
192 var.second = x_std * (m_scalerMax - m_scalerMin) + m_scalerMin;
193 }
194
195 return inputs;
196}
197
198double ISF::PunchThroughClassifier::calibrateOutput(double& networkOutput) const {
199
200 //calibrate output of network using isotonic regressor model
201
202 //if network output is outside of the range of isotonic regressor then return min and max values
203 if (networkOutput < m_calibrationMin){
204 return m_calibrationMin;
205 }
206 else if (networkOutput > m_calibrationMax){
207 return m_calibrationMax;
208 }
209
210 //otherwise find neighbouring points in isotonic regressor
211 auto upper = m_calibrationMap.upper_bound(networkOutput);
212 auto lower = upper--;
213
214 //Perform linear interpolation between points
215 double m = (upper->second - lower->second)/(upper->first - lower->first);
216 double c = lower->second - m * lower->first;
217 double calibrated = m * networkOutput + c;
218
219 return calibrated;
220}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
int upper(int c)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
#define min(a, b)
Definition cfImp.cxx:40
#define max(a, b)
Definition cfImp.cxx:41
The generic ISF particle definition,.
Definition ISFParticle.h:42
const Amg::Vector3D & momentum() const
The current momentum vector of the ISFParticle.
const Amg::Vector3D & position() const
The current position of the ISFParticle.
PunchThroughClassifier(const std::string &, const std::string &, const IInterface *)
Constructor.
std::map< std::string, std::map< std::string, double > > scaleInputs(std::map< std::string, std::map< std::string, double > > &inputs) const
scale NN inputs using MinMaxScaler
virtual double computePunchThroughProbability(const ISF::ISFParticle &isfp, const TFCSSimulationState &simulstate) const override
interface method to return probability prediction of punch through
StatusCode initializeScaler(const std::string &scalerConfigFile)
input variable MinMaxScaler initialize method
std::map< std::string, double > m_scalerMinMap
double m_scalerMin
input variable MinMaxScaler members
std::map< double, double > m_calibrationMap
virtual StatusCode initialize() override final
AlgTool initialize method.
StatusCode initializeNetwork(const std::string &networkConfigFile)
neural network initialize method
double calibrateOutput(double &networkOutput) const
calibrate NN output using isotonic regressor
static std::map< std::string, std::map< std::string, double > > computeInputs(const ISF::ISFParticle &isfp, const TFCSSimulationState &simulstate)
calcalate NN inputs based on isfp and simulstate
std::map< std::string, double > m_scalerMaxMap
StatusCode initializeCalibrator(const std::string &calibratorConfigFile)
isotonic regressor calibrator initialize method
std::unique_ptr< lwt::LightweightGraph > m_graph
NN graph.
double Efrac(int sample) const