14#include "lwtnn/parse_json.hh"
17#include <libxml/xmlmemory.h>
18#include <libxml/parser.h>
19#include <libxml/tree.h>
20#include <libxml/xmlreader.h>
21#include <libxml/xpath.h>
22#include <libxml/xpathInternals.h>
25 : base_class(
type, name, parent) {
34 std::map<std::string, std::map<std::string, double> > networkInputs =
computeInputs(isfp, simulstate);
38 std::map<std::string, double> networkOutputs =
m_graph->compute(networkInputs);
42 return calibratedOutput;
53 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier input scaler");
59 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier network");
65 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier calibrator");
68 return StatusCode::SUCCESS;
75 xmlDocPtr doc = xmlParseFile( scalerConfigFile.c_str() );
77 ATH_MSG_INFO(
"[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
79 for( xmlNodePtr nodeRoot = doc->children; nodeRoot !=
nullptr; nodeRoot = nodeRoot->next) {
81 if (xmlStrEqual( nodeRoot->name, BAD_CAST
"Transformations" )) {
82 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform !=
nullptr; nodeTransform = nodeTransform->next ) {
85 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"ScalerValues" )) {
86 m_scalerMin = atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"min" ) );
87 m_scalerMax = atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"max" ) );
91 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"VarScales" )) {
92 std::string name = (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"name" );
93 double min = atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"min" ) );
94 double max = atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"max" ) );
102 return StatusCode::SUCCESS;
107 ATH_MSG_INFO(
"[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
109 std::ifstream input(networkConfigFile);
111 ATH_MSG_ERROR(
"Could not find json file " << networkConfigFile );
112 return StatusCode::FAILURE;
115 m_graph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
117 ATH_MSG_ERROR(
"Could not parse graph json file " << networkConfigFile );
118 return StatusCode::FAILURE;
122 return StatusCode::SUCCESS;
129 ATH_MSG_INFO(
"[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
131 xmlDocPtr doc = xmlParseFile( calibratorConfigFile.c_str() );
133 for( xmlNodePtr nodeRoot = doc->children; nodeRoot !=
nullptr; nodeRoot = nodeRoot->next) {
135 if (xmlStrEqual( nodeRoot->name, BAD_CAST
"Transformations" )) {
136 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform !=
nullptr; nodeTransform = nodeTransform->next ) {
139 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"LimitValues" )) {
140 m_calibrationMin = atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"min" ) );
141 m_calibrationMax = atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"max" ) );
145 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"LinearNorm" )) {
146 double orig = atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"orig" ) );
147 double norm = atof( (
const char*) xmlGetProp( nodeTransform, BAD_CAST
"norm" ) );
154 return StatusCode::SUCCESS;
161 std::map<std::string, std::map<std::string, double> > networkInputs;
164 networkInputs[
"node_0"] = {
165 {
"variable_0", isfp.
momentum().mag() },
166 {
"variable_1", std::abs(isfp.
position().eta()) },
167 {
"variable_2", isfp.
position().phi() },
168 {
"variable_3", simulstate.
E()},
172 for (
unsigned int i = 0; i < 24; i++){
173 networkInputs[
"node_0"].insert({
"variable_" + std::to_string(i + 4), simulstate.
Efrac(i)});
176 return networkInputs;
183 for (
auto& var : inputs[
"node_0"]) {
212 auto lower =
upper--;
215 double m = (
upper->second - lower->second)/(
upper->first - lower->first);
216 double c = lower->second - m * lower->first;
217 double calibrated = m * networkOutput + c;
#define ATH_MSG_VERBOSE(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
The generic ISF particle definition,.
const Amg::Vector3D & momentum() const
The current momentum vector of the ISFParticle.
const Amg::Vector3D & position() const
The current position of the ISFParticle.
PunchThroughClassifier(const std::string &, const std::string &, const IInterface *)
Constructor.
std::map< std::string, std::map< std::string, double > > scaleInputs(std::map< std::string, std::map< std::string, double > > &inputs) const
scale NN inputs using MinMaxScaler
std::string m_calibratorConfigFileName
virtual double computePunchThroughProbability(const ISF::ISFParticle &isfp, const TFCSSimulationState &simulstate) const override
interface method to return probability prediction of punch through
StatusCode initializeScaler(const std::string &scalerConfigFile)
input variable MinMaxScaler initialize method
std::map< std::string, double > m_scalerMinMap
double m_scalerMin
input variable MinMaxScaler members
std::map< double, double > m_calibrationMap
virtual StatusCode initialize() override final
AlgTool initialize method.
StatusCode initializeNetwork(const std::string &networkConfigFile)
neural network initialize method
double calibrateOutput(double &networkOutput) const
calibrate NN output using isotonic regressor
std::string m_scalerConfigFileName
static std::map< std::string, std::map< std::string, double > > computeInputs(const ISF::ISFParticle &isfp, const TFCSSimulationState &simulstate)
calcalate NN inputs based on isfp and simulstate
std::string m_networkConfigFileName
std::map< std::string, double > m_scalerMaxMap
StatusCode initializeCalibrator(const std::string &calibratorConfigFile)
isotonic regressor calibrator initialize method
std::unique_ptr< lwt::LightweightGraph > m_graph
NN graph.
double Efrac(int sample) const