ATLAS Offline Software
Loading...
Searching...
No Matches
TFCSGANEtaSlice.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6// TFCSGANEtaSlice.cxx, (c) ATLAS Detector software //
8
9// class header include
11
13
14#include "CLHEP/Random/RandGauss.h"
15
16#include "TFitResult.h"
17#include "TFile.h"
18#include "TTree.h"
19#include "TH1D.h"
20
21#include <iostream>
22#include <fstream>
23#include <string>
24#include <sstream>
25#include <cmath>
26
27
29
30TFCSGANEtaSlice::TFCSGANEtaSlice(int pid, int etaMin, int etaMax,
31 const TFCSGANXMLParameters &param)
32 : m_pid(pid), m_etaMin(etaMin), m_etaMax(etaMax), m_param(param) {}
33
35 // Deleting a nullptr is a noop
36 delete m_gan_all;
37 delete m_gan_low;
38 delete m_gan_high;
39}
40
42 if (m_net_all != nullptr)
43 return m_net_all.get();
44 return m_gan_all;
45}
47 if (m_net_low != nullptr)
48 return m_net_low.get();
49 return m_gan_low;
50}
52 if (m_net_high != nullptr)
53 return m_net_high.get();
54 return m_gan_high;
55}
56
58 if (m_pid == 211 || m_pid == 2212) {
59 if (GetNetAll() == nullptr) {
60 return false;
61 }
62 } else {
63 if (GetNetHigh() == nullptr || GetNetLow() == nullptr) {
64 return false;
65 }
66 }
67 return true;
68}
69
71 // Now load new data
72 std::string inputFileName;
73
76
77 bool success = true;
78
79 if (m_pid == 211) {
80 inputFileName = m_param.GetInputFolder() + "/neural_net_" +
81 std::to_string(m_pid) + "_eta_" + std::to_string(m_etaMin) +
82 "_" + std::to_string(m_etaMax) + "_All.*";
83 ATH_MSG_DEBUG("Gan input file name " << inputFileName);
84 m_net_all = TFCSNetworkFactory::create(std::move(inputFileName));
85 if (m_net_all == nullptr)
86 success = false;
87 } else if (m_pid == 2212) {
88 inputFileName = m_param.GetInputFolder() + "/neural_net_" +
89 std::to_string(m_pid) + "_eta_" + std::to_string(m_etaMin) +
90 "_" + std::to_string(m_etaMax) + "_High10.*";
91 ATH_MSG_DEBUG("Gan input file name " << inputFileName);
92 m_net_all = TFCSNetworkFactory::create(std::move(inputFileName));
93 if (m_net_all == nullptr)
94 success = false;
95 } else {
96 inputFileName = m_param.GetInputFolder() + "/neural_net_" +
97 std::to_string(m_pid) + "_eta_" + std::to_string(m_etaMin) +
98 "_" + std::to_string(m_etaMax) + "_High12.*";
99 ATH_MSG_DEBUG("Gan input file name " << inputFileName);
100 m_net_high = TFCSNetworkFactory::create(inputFileName);
101 if (m_net_high == nullptr)
102 success = false;
103
104 inputFileName = m_param.GetInputFolder() + "/neural_net_" +
105 std::to_string(m_pid) + "_eta_" + std::to_string(m_etaMin) +
106 "_" + std::to_string(m_etaMax) + "_UltraLow12.*";
107 m_net_low = TFCSNetworkFactory::create(std::move(inputFileName));
108 if (m_net_low == nullptr)
109 success = false;
110 }
111 return success;
112}
113
115 std::string rootFileName = m_param.GetInputFolder() + "/rootFiles/pid" +
116 std::to_string(m_pid) + "_E1048576_eta_" +
117 std::to_string(m_etaMin) + "_" +
118 std::to_string(m_etaMin + 5) + ".root";
119 ATH_MSG_DEBUG("Opening file " << rootFileName);
120 TFile *file = TFile::Open(rootFileName.c_str(), "read");
121 for (int layer : m_param.GetRelevantLayers()) {
122 ATH_MSG_DEBUG("Layer " << layer);
123 TFCSGANXMLParameters::Binning binsInLayers = m_param.GetBinning();
124 TH2D *h2 = &binsInLayers[layer];
125
126 std::string histoName = "r" + std::to_string(layer) + "w";
127 TH1D *h1 = (TH1D *)file->Get(histoName.c_str());
128 if (std::isnan(h1->Integral())) {
129 histoName = "r" + std::to_string(layer);
130 h1 = (TH1D *)file->Get(histoName.c_str());
131 }
132
133 TAxis *x = (TAxis *)h2->GetXaxis();
134 for (int ix = 1; ix <= h2->GetNbinsX(); ++ix) {
135 ATH_MSG_DEBUG(ix);
136 h1->GetXaxis()->SetRangeUser(x->GetBinLowEdge(ix), x->GetBinUpEdge(ix));
137
138 double result = 0;
139 if (h1->Integral() > 0 && h1->GetNbinsX() > 2) {
140 TFitResultPtr res(0);
141
142 res = h1->Fit("expo", "SQ");
143 if (res >= 0 && !std::isnan(res->Parameter(0))) {
144 result = res->Parameter(1);
145 }
146 }
147 m_allFitResults[layer].push_back(result);
148 }
149 }
150 ATH_MSG_DEBUG("Done initialisaing fits");
151}
152
154 std::string rootFileName = m_param.GetInputFolder() + "/rootFiles/pid" +
155 std::to_string(m_pid) + "_E65536_eta_" +
156 std::to_string(m_etaMin) + "_" +
157 std::to_string(m_etaMin + 5) + "_validation.root";
158 ATH_MSG_DEBUG("Opening file " << rootFileName);
159 TFile *file = TFile::Open(rootFileName.c_str(), "read");
160 for (int layer : m_param.GetRelevantLayers()) {
161 std::string branchName = "extrapWeight_" + std::to_string(layer);
162 TH1D *h = new TH1D("h", "h", 100, 0.01, 1);
163 TTree *tree = (TTree *)file->Get("rootTree");
164 std::string command = branchName + ">>h";
165 tree->Draw(command.c_str());
166 m_extrapolatorWeights[layer] = h->GetMean();
167 ATH_MSG_DEBUG("Extrapolation: layer " << layer << " mean "
168 << m_extrapolatorWeights[layer]);
169 }
170}
171
174 const TFCSExtrapolationState *extrapol,
175 TFCSSimulationState simulstate) const {
176 double randUniformZ = 0.;
177 NetworkInputs inputs;
178
179 int maxExp = 0, minExp = 0;
180 if (m_pid == 22 || std::abs(m_pid) == 11) {
181 if (truth->P() >
182 4096) { // This is the momentum, not the energy, because the split is
183 // based on the samples which are produced with the momentum
184 maxExp = 22;
185 minExp = 12;
186 } else {
187 maxExp = 12;
188 minExp = 6;
189 }
190 } else if (std::abs(m_pid) == 211) {
191 maxExp = 22;
192 minExp = 8;
193 } else if (std::abs(m_pid) == 2212) {
194 maxExp = 22;
195 minExp = 10;
196 }
197
198 int p_min = std::pow(2, minExp);
199 int p_max = std::pow(2, maxExp);
200 // Keep min and max without mass offset as we do not train on antiparticles
201 double Ekin_min =
202 std::sqrt(std::pow(p_min, 2) + std::pow(truth->M(), 2)) - truth->M();
203 double Ekin_max =
204 std::sqrt(std::pow(p_max, 2) + std::pow(truth->M(), 2)) - truth->M();
205
206 for (int i = 0; i < m_param.GetLatentSpaceSize(); i++) {
207 randUniformZ = CLHEP::RandGauss::shoot(simulstate.randomEngine(), 0.5, 0.5);
208 inputs["Noise"].insert(std::pair<std::string, double>(
209 "variable_" + std::to_string(i), randUniformZ));
210 }
211
212 // double e = log(truth->Ekin()/Ekin_min)/log(Ekin_max/Ekin_min) ;
213 // Could be uncommented , but would need the line above too
214 // ATH_MSG_DEBUG( "Check label: " << e <<" Ekin:" << truth->Ekin() <<" p:" <<
215 // truth->P() <<" mass:" << truth->M() <<" Ekin_off:" <<
216 // truth->Ekin_off() << " Ekin_min:"<<Ekin_min<<"
217 // Ekin_max:"<<Ekin_max);
218 // inputs["mycond"].insert ( std::pair<std::string,double>("variable_0",
219 // truth->Ekin()/(std::pow(2,maxExp))) ); //Old conditioning using linear
220 // interpolation, now use logaritminc interpolation
221 inputs["mycond"].insert(std::pair<std::string, double>(
222 "variable_0", log(truth->Ekin() / Ekin_min) / log(Ekin_max / Ekin_min)));
223
224 if (m_param.GetGANVersion() >= 2) {
225 if (false) { // conditioning on eta, should only be needed in transition
226 // regions and added only to the GANs that use it, for now all
227 // GANs have 3 conditioning inputs so filling zeros
228 inputs["mycond"].insert(std::pair<std::string, double>(
229 "variable_1", std::abs(extrapol->IDCaloBoundary_eta())));
230 } else {
231 inputs["mycond"].insert(std::pair<std::string, double>("variable_1", 0));
232 }
233 }
234
236 if (m_param.GetGANVersion() == 1 || m_pid == 211 || m_pid == 2212) {
237 outputs = GetNetAll()->compute(inputs);
238 } else {
239 if (truth->P() >
240 4096) { // This is the momentum, not the energy, because the split is
241 // based on the samples which are produced with the momentum
242 ATH_MSG_DEBUG("Computing outputs given inputs for high");
243 outputs = GetNetHigh()->compute(inputs);
244 } else {
245 outputs = GetNetLow()->compute(inputs);
246 }
247 }
248 ATH_MSG_DEBUG("Start Network inputs ~~~~~~~~");
250 ATH_MSG_DEBUG("End Network inputs ~~~~~~~~");
251 ATH_MSG_DEBUG("Start Network outputs ~~~~~~~~");
253 ATH_MSG_DEBUG("End Network outputs ~~~~~~~~");
254 return outputs;
255}
256
258 ATH_MSG_INFO("LWTNN Handler parameters");
259 ATH_MSG_INFO(" pid: " << m_pid);
260 ATH_MSG_INFO(" etaMin:" << m_etaMin);
261 ATH_MSG_INFO(" etaMax: " << m_etaMax);
262 m_param.Print();
263}
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
std::pair< std::vector< unsigned int >, bool > res
#define x
Header file for AthHistogramAlgorithm.
VNetworkBase * GetNetAll() const
TFCSGANLWTNNHandler * m_gan_all
TFCSGANLWTNNHandler * m_gan_low
TFCSGANLWTNNHandler * m_gan_high
std::unique_ptr< VNetworkBase > m_net_high
std::unique_ptr< VNetworkBase > m_net_all
void CalculateMeanPointFromDistributionOfR()
virtual ~TFCSGANEtaSlice()
std::unique_ptr< VNetworkBase > m_net_low
ExtrapolatorWeights m_extrapolatorWeights
NetworkOutputs GetNetworkOutputs(const TFCSTruthState *truth, const TFCSExtrapolationState *extrapol, TFCSSimulationState simulstate) const
bool IsGanCorrectlyLoaded() const
VNetworkBase * GetNetLow() const
std::map< std::string, std::map< std::string, double > > NetworkInputs
VNetworkBase * GetNetHigh() const
void ExtractExtrapolatorMeansFromInputs()
FitResultsPerLayer m_allFitResults
TFCSGANXMLParameters m_param
std::map< int, TH2D > Binning
static std::unique_ptr< VNetworkBase > create(std::string input)
Given a string, make a network.
CLHEP::HepRandomEngine * randomEngine()
double Ekin() const
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
virtual NetworkOutputs compute(NetworkInputs const &inputs) const =0
Function to pass values to the network.
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
TChain * tree
TFile * file