14#include "lwtnn/parse_json.hh"
21 : base_class(
type, name, parent) {
30 std::map<std::string, std::map<std::string, double> > networkInputs =
computeInputs(isfp, simulstate);
34 std::map<std::string, double> networkOutputs =
m_graph->compute(networkInputs);
38 return calibratedOutput;
49 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier input scaler");
55 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier network");
61 ATH_MSG_ERROR(
"[ punchthroughclassifier ] unable to load punchthroughclassifier calibrator");
64 return StatusCode::SUCCESS;
71 std::unique_ptr<XMLCoreNode> doc = p.parse (scalerConfigFile);
73 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Loading scaler: " << scalerConfigFile);
77 if (
node->get_name() ==
"ScalerValues") {
81 else if (
node->get_name() ==
"VarScales") {
82 std::string name =
node->get_attrib (
"name");
88 return StatusCode::SUCCESS;
93 ATH_MSG_INFO(
"[ punchthroughclassifier ] Loading classifier: " << networkConfigFile);
95 std::ifstream input(networkConfigFile);
97 ATH_MSG_ERROR(
"Could not find json file " << networkConfigFile );
98 return StatusCode::FAILURE;
101 m_graph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
103 ATH_MSG_ERROR(
"Could not parse graph json file " << networkConfigFile );
104 return StatusCode::FAILURE;
108 return StatusCode::SUCCESS;
115 std::unique_ptr<XMLCoreNode> doc = p.parse (calibratorConfigFile);
118 ATH_MSG_DEBUG(
"[ punchthroughclassifier ] Loading calibrator: " << calibratorConfigFile);
120 for (
const XMLCoreNode*
node : doc->get_children (
"Transformations/*"))
122 if (
node->get_name() ==
"LimitValues") {
126 else if (
node->get_name() ==
"LinearNorm") {
127 double orig =
node->get_double_attrib (
"orig");
128 double norm =
node->get_double_attrib (
"norm");
133 return StatusCode::SUCCESS;
140 std::map<std::string, std::map<std::string, double> > networkInputs;
143 networkInputs[
"node_0"] = {
144 {
"variable_0", isfp.
momentum().mag() },
145 {
"variable_1", std::abs(isfp.
position().eta()) },
146 {
"variable_2", isfp.
position().phi() },
147 {
"variable_3", simulstate.
E()},
151 for (
unsigned int i = 0; i < 24; i++){
152 networkInputs[
"node_0"].insert({
"variable_" + std::to_string(i + 4), simulstate.
Efrac(i)});
155 return networkInputs;
162 for (
auto& var : inputs[
"node_0"]) {
191 auto lower =
upper--;
194 double m = (
upper->second - lower->second)/(
upper->first - lower->first);
195 double c = lower->second - m * lower->first;
196 double calibrated = m * networkOutput + c;
#define ATH_MSG_VERBOSE(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Simple DOM-like node structure to hold the result of XML parsing.
The generic ISF particle definition,.
const Amg::Vector3D & momentum() const
The current momentum vector of the ISFParticle.
const Amg::Vector3D & position() const
The current position of the ISFParticle.
PunchThroughClassifier(const std::string &, const std::string &, const IInterface *)
Constructor.
std::map< std::string, std::map< std::string, double > > scaleInputs(std::map< std::string, std::map< std::string, double > > &inputs) const
scale NN inputs using MinMaxScaler
std::string m_calibratorConfigFileName
virtual double computePunchThroughProbability(const ISF::ISFParticle &isfp, const TFCSSimulationState &simulstate) const override
interface method to return probability prediction of punch through
StatusCode initializeScaler(const std::string &scalerConfigFile)
input variable MinMaxScaler initialize method
std::map< std::string, double > m_scalerMinMap
double m_scalerMin
input variable MinMaxScaler members
std::map< double, double > m_calibrationMap
virtual StatusCode initialize() override final
AlgTool initialize method.
StatusCode initializeNetwork(const std::string &networkConfigFile)
neural network initialize method
double calibrateOutput(double &networkOutput) const
calibrate NN output using isotonic regressor
std::string m_scalerConfigFileName
static std::map< std::string, std::map< std::string, double > > computeInputs(const ISF::ISFParticle &isfp, const TFCSSimulationState &simulstate)
calcalate NN inputs based on isfp and simulstate
std::string m_networkConfigFileName
std::map< std::string, double > m_scalerMaxMap
StatusCode initializeCalibrator(const std::string &calibratorConfigFile)
isotonic regressor calibrator initialize method
std::unique_ptr< lwt::LightweightGraph > m_graph
NN graph.
double Efrac(int sample) const
Simple DOM-like node structure to hold the result of XML parsing.