16#include "G4FastTrack.hh"
17#include "G4FastStep.hh"
20#include "lwtnn/parse_json.hh"
23#include <libxml/xmlmemory.h>
24#include <libxml/parser.h>
25#include <libxml/tree.h>
26#include <libxml/xmlreader.h>
27#include <libxml/xpath.h>
28#include <libxml/xpathInternals.h>
35 : base_class(
type, name, parent) {
40 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Initializing PunchThroughG4Classifier" );
51 return StatusCode::SUCCESS;
56 ATH_MSG_DEBUG(
"[punchthroughclassifier] finalize() successful" );
58 return StatusCode::SUCCESS;
66 doc = xmlParseFile( scalerConfigFile.c_str() );
68 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
70 for( xmlNodePtr nodeRoot = doc->children; nodeRoot !=
nullptr; nodeRoot = nodeRoot->next) {
72 if (xmlStrEqual( nodeRoot->name, BAD_CAST
"Transformations" )) {
73 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform !=
nullptr; nodeTransform = nodeTransform->next ) {
76 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"ScalerValues" )) {
82 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"VarScales" )) {
83 std::string name =
"";
101 return StatusCode::SUCCESS;
106 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
108 std::ifstream input(networkConfigFile);
110 ATH_MSG_ERROR(
"Could not find json file " << networkConfigFile );
111 return StatusCode::FAILURE;
114 m_graph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
116 ATH_MSG_ERROR(
"Could not parse graph json file " << networkConfigFile );
117 return StatusCode::FAILURE;
121 return StatusCode::SUCCESS;
129 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
131 doc = xmlParseFile( calibratorConfigFile.c_str() );
133 for( xmlNodePtr nodeRoot = doc->children; nodeRoot !=
nullptr; nodeRoot = nodeRoot->next) {
135 if (xmlStrEqual( nodeRoot->name, BAD_CAST
"Transformations" )) {
136 for( xmlNodePtr nodeTransform = nodeRoot->children; nodeTransform !=
nullptr; nodeTransform = nodeTransform->next ) {
139 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"LimitValues" )) {
145 if (xmlStrEqual( nodeTransform->name, BAD_CAST
"LinearNorm" )) {
161 return StatusCode::SUCCESS;
166 std::map<std::string, std::map<std::string, double> > networkInputs =
computeInputs(fastTrack, simE, simEfrac);
170 std::map<std::string, double> networkOutputs =
m_graph->compute(networkInputs);
174 return calibratedOutput;
181 std::map<std::string, std::map<std::string, double> > networkInputs;
184 networkInputs[
"node_0"] = {
185 {
"variable_0", fastTrack.GetPrimaryTrack()->GetMomentum().mag() },
186 {
"variable_1", std::abs(fastTrack.GetPrimaryTrack()->GetPosition().eta()) },
187 {
"variable_2", fastTrack.GetPrimaryTrack()->GetPosition().phi() },
188 {
"variable_3", simE},
192 for (
unsigned int i = 0; i < simEfrac.size(); i++) {
193 networkInputs[
"node_0"].insert({
"variable_" + std::to_string(i + 4), simEfrac[i]});
196 return networkInputs;
203 for (
auto& var : inputs[
"node_0"]) {
232 auto lower =
upper--;
235 double m = (
upper->second - lower->second)/(
upper->first - lower->first);
236 double c = lower->second - m * lower->first;
237 double calibrated = m * networkOutput + c;
#define ATH_CHECK
Evaluate an expression and check for errors.
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
std::unique_ptr< lwt::LightweightGraph > m_graph
NN graph.
std::map< std::string, double > m_scalerMaxMap
virtual StatusCode initialize() override
AlgTool initialize method.
double m_scalerMin
input variable MinMaxScaler members
StringProperty m_scalerConfigFileName
std::map< std::string, std::map< std::string, double > > scaleInputs(std::map< std::string, std::map< std::string, double > > &inputs) const
scale NN inputs using MinMaxScaler
StringProperty m_networkConfigFileName
std::map< double, double > m_calibrationMap
PunchThroughG4Classifier(const std::string &, const std::string &, const IInterface *)
Constructor.
std::map< std::string, double > m_scalerMinMap
StatusCode initializeCalibrator(const std::string &calibratorConfigFile)
isotonic regressor calibrator initialize method
virtual StatusCode finalize() override
AlgTool finalize method.
StatusCode initializeScaler(const std::string &scalerConfigFile)
input variable MinMaxScaler initialize method
StatusCode initializeNetwork(const std::string &networkConfigFile)
neural network initialize method
double calibrateOutput(double &networkOutput) const
calibrate NN output using isotonic regressor
StringProperty m_calibratorConfigFileName
static std::map< std::string, std::map< std::string, double > > computeInputs(const G4FastTrack &fastTrack, const double simE, const std::vector< double > &simEfrac)
calcalate NN inputs based on G4FastTrack and simulstate
virtual double computePunchThroughProbability(const G4FastTrack &fastTrack, const double simE, const std::vector< double > &simEfrac) const override
interface method to return probability prediction of punch through
void GetXmlAttrIfThere(xmlNodePtr node, const char *attrName, T &value)