ATLAS Offline Software
Loading...
Searching...
No Matches
PunchThroughClassifier.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7#include <fstream>
8#include <memory>
9
10// PathResolver
12
13//LWTNN
14#include "lwtnn/parse_json.hh"
15
16//libXML
17#include <libxml/xmlmemory.h>
18#include <libxml/parser.h>
19#include <libxml/tree.h>
20#include <libxml/xmlreader.h>
21#include <libxml/xpath.h>
22#include <libxml/xpathInternals.h>
23
25
26using namespace CxxUtils;
27
28ISF::PunchThroughClassifier::PunchThroughClassifier(const std::string& type, const std::string& name, const IInterface* parent)
29 : base_class(type, name, parent) {
30
31 declareProperty( "ScalerConfigFileName", m_scalerConfigFileName );
32 declareProperty( "NetworkConfigFileName", m_networkConfigFileName );
33 declareProperty( "CalibratorConfigFileName", m_calibratorConfigFileName );
34}
35
37
38 std::map<std::string, std::map<std::string, double> > networkInputs = computeInputs(isfp, simulstate); //compute inputs
39
40 networkInputs = scaleInputs(networkInputs); //scale inputs
41
42 std::map<std::string, double> networkOutputs = m_graph->compute(networkInputs); //call neural network on inputs
43
44 double calibratedOutput = calibrateOutput(networkOutputs["out_0"]); //calibrate neural network output
45
46 return calibratedOutput;
47}
48
49
51
52 ATH_MSG_VERBOSE( "[ punchthroughclassifier ] initialize()" );
53
54 std::string resolvedScalerFileName = PathResolverFindCalibFile (m_scalerConfigFileName);
55 if ( initializeScaler(resolvedScalerFileName) != StatusCode::SUCCESS)
56 {
57 ATH_MSG_ERROR("[ punchthroughclassifier ] unable to load punchthroughclassifier input scaler");
58 }
59
60 std::string resolvedNetworkFileName = PathResolverFindCalibFile (m_networkConfigFileName);
61 if ( initializeNetwork(resolvedNetworkFileName) != StatusCode::SUCCESS)
62 {
63 ATH_MSG_ERROR("[ punchthroughclassifier ] unable to load punchthroughclassifier network");
64 }
65
66 std::string resolvedCalibratorFileName = PathResolverFindCalibFile (m_calibratorConfigFileName);
67 if ( initializeCalibrator(resolvedCalibratorFileName) != StatusCode::SUCCESS)
68 {
69 ATH_MSG_ERROR("[ punchthroughclassifier ] unable to load punchthroughclassifier calibrator");
70 }
71
72 return StatusCode::SUCCESS;
73}
74
75StatusCode ISF::PunchThroughClassifier::initializeScaler(const std::string & scalerConfigFile){
76
77 //parse xml that contains config for MinMaxScaler for each of the network inputs
78
79 xmlDocPtr doc = xmlParseFile( scalerConfigFile.c_str() );
80
81 ATH_MSG_INFO( "[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
82
83 for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
84
85 if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
86 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
87
88 //Get min and max values that we normalise values to
89 if (xmlStrEqual( nodeTransform->name, BAD_CAST "ScalerValues" )) {
90 m_scalerMin = GetXmlAttr<double>( nodeTransform, "min" );
91 m_scalerMax = GetXmlAttr<double>( nodeTransform, "max" );
92 }
93
94 //Get values necessary to normalise each input variable
95 if (xmlStrEqual( nodeTransform->name, BAD_CAST "VarScales" )) {
96 std::string name = GetXmlAttr<std::string>( nodeTransform, "name" );
97 double min = GetXmlAttr<double>( nodeTransform, "min" );
98 double max = GetXmlAttr<double>( nodeTransform, "max" );
99 m_scalerMinMap.emplace ( name, min );
100 m_scalerMaxMap.emplace ( name, max );
101 }
102 }
103 }
104 }
105
106 return StatusCode::SUCCESS;
107}
108
109StatusCode ISF::PunchThroughClassifier::initializeNetwork(const std::string & networkConfigFile){
110
111 ATH_MSG_INFO( "[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
112
113 std::ifstream input(networkConfigFile);
114 if(!input){
115 ATH_MSG_ERROR("Could not find json file " << networkConfigFile );
116 return StatusCode::FAILURE;
117 }
118
119 m_graph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
120 if(!m_graph){
121 ATH_MSG_ERROR("Could not parse graph json file " << networkConfigFile );
122 return StatusCode::FAILURE;
123 }
124
125
126 return StatusCode::SUCCESS;
127}
128
129
130StatusCode ISF::PunchThroughClassifier::initializeCalibrator(const std::string & calibratorConfigFile){
131
132 //parse xml that contains config for isotonic regressor used to calibrate the network output
133 ATH_MSG_INFO( "[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
134
135 xmlDocPtr doc = xmlParseFile( calibratorConfigFile.c_str() );
136
137 for( xmlNodePtr nodeRoot = doc->children; nodeRoot != nullptr; nodeRoot = nodeRoot->next) {
138
139 if (xmlStrEqual( nodeRoot->name, BAD_CAST "Transformations" )) {
140 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform != nullptr; nodeTransform = nodeTransform->next ) {
141
142 //get lower and upper bounds of isotonic regressor
143 if (xmlStrEqual( nodeTransform->name, BAD_CAST "LimitValues" )) {
144 m_calibrationMin = GetXmlAttr<double>( nodeTransform, "min" );
145 m_calibrationMax = GetXmlAttr<double>( nodeTransform, "max" );
146 }
147
148 //get defined points where isotonic regressor knows transform
149 if (xmlStrEqual( nodeTransform->name, BAD_CAST "LinearNorm" )) {
150 double orig = GetXmlAttr<double>( nodeTransform, "orig" );
151 double norm = GetXmlAttr<double>( nodeTransform, "norm" );
152 m_calibrationMap.emplace ( orig, norm );
153 }
154 }
155 }
156 }
157
158 return StatusCode::SUCCESS;
159}
160
161std::map<std::string, std::map<std::string, double> > ISF::PunchThroughClassifier::computeInputs(const ISF::ISFParticle &isfp, const TFCSSimulationState& simulstate) {
162
163 //calculate inputs for NN
164
165 std::map<std::string, std::map<std::string, double> > networkInputs;
166
167 //add initial particle and total energy variables
168 networkInputs["node_0"] = {
169 {"variable_0", isfp.momentum().mag() },
170 {"variable_1", std::abs(isfp.position().eta()) },
171 {"variable_2", isfp.position().phi() },
172 {"variable_3", simulstate.E()},
173 };
174
175 //add energy fraction variables
176 for (unsigned int i = 0; i < 24; i++){
177 networkInputs["node_0"].insert({"variable_" + std::to_string(i + 4), simulstate.Efrac(i)});
178 }
179
180 return networkInputs;
181}
182
183std::map<std::string, std::map<std::string, double> > ISF::PunchThroughClassifier::scaleInputs(std::map<std::string, std::map<std::string, double> >& inputs) const{
184
185 //apply MinMaxScaler to network inputs
186
187 for (auto& var : inputs["node_0"]) {
188
189 double x_std;
190 if(m_scalerMaxMap.at(var.first) != m_scalerMinMap.at(var.first)){
191 x_std = (var.second - m_scalerMinMap.at(var.first)) / (m_scalerMaxMap.at(var.first) - m_scalerMinMap.at(var.first));
192 }
193 else{
194 x_std = (var.second - m_scalerMinMap.at(var.first));
195 }
196 var.second = x_std * (m_scalerMax - m_scalerMin) + m_scalerMin;
197 }
198
199 return inputs;
200}
201
202double ISF::PunchThroughClassifier::calibrateOutput(double& networkOutput) const {
203
204 //calibrate output of network using isotonic regressor model
205
206 //if network output is outside of the range of isotonic regressor then return min and max values
207 if (networkOutput < m_calibrationMin){
208 return m_calibrationMin;
209 }
210 else if (networkOutput > m_calibrationMax){
211 return m_calibrationMax;
212 }
213
214 //otherwise find neighbouring points in isotonic regressor
215 auto upper = m_calibrationMap.upper_bound(networkOutput);
216 auto lower = upper--;
217
218 //Perform linear interpolation between points
219 double m = (upper->second - lower->second)/(upper->first - lower->first);
220 double c = lower->second - m * lower->first;
221 double calibrated = m * networkOutput + c;
222
223 return calibrated;
224}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
int upper(int c)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
#define min(a, b)
Definition cfImp.cxx:40
#define max(a, b)
Definition cfImp.cxx:41
The generic ISF particle definition,.
Definition ISFParticle.h:42
const Amg::Vector3D & momentum() const
The current momentum vector of the ISFParticle.
const Amg::Vector3D & position() const
The current position of the ISFParticle.
PunchThroughClassifier(const std::string &, const std::string &, const IInterface *)
Constructor.
std::map< std::string, std::map< std::string, double > > scaleInputs(std::map< std::string, std::map< std::string, double > > &inputs) const
scale NN inputs using MinMaxScaler
virtual double computePunchThroughProbability(const ISF::ISFParticle &isfp, const TFCSSimulationState &simulstate) const override
interface method to return probability prediction of punch through
StatusCode initializeScaler(const std::string &scalerConfigFile)
input variable MinMaxScaler initialize method
std::map< std::string, double > m_scalerMinMap
double m_scalerMin
input variable MinMaxScaler members
std::map< double, double > m_calibrationMap
virtual StatusCode initialize() override final
AlgTool initialize method.
StatusCode initializeNetwork(const std::string &networkConfigFile)
neural network initialize method
double calibrateOutput(double &networkOutput) const
calibrate NN output using isotonic regressor
static std::map< std::string, std::map< std::string, double > > computeInputs(const ISF::ISFParticle &isfp, const TFCSSimulationState &simulstate)
calcalate NN inputs based on isfp and simulstate
std::map< std::string, double > m_scalerMaxMap
StatusCode initializeCalibrator(const std::string &calibratorConfigFile)
isotonic regressor calibrator initialize method
std::unique_ptr< lwt::LightweightGraph > m_graph
NN graph.
double Efrac(int sample) const
T GetXmlAttr(xmlNodePtr node, const char *attrName, const T &defaultValue=T{}) noexcept(std::is_arithmetic_v< T >)