ATLAS Offline Software
Loading...
Searching...
No Matches
TFCSGANEtaSlice.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6// TFCSGANEtaSlice.cxx, (c) ATLAS Detector software //
8
9// class header include
11
13
14#include "CLHEP/Random/RandGauss.h"
15
16#include "TFitResult.h"
17#include "TFile.h"
18#include "TTree.h"
19#include "TH1D.h"
20
21#include <iostream>
22#include <string>
23#include <cmath>
24#include <format>
25#include <stdexcept>
26
28
29TFCSGANEtaSlice::TFCSGANEtaSlice(int pid, int etaMin, int etaMax,
30 const TFCSGANXMLParameters &param)
31 : m_pid(pid), m_etaMin(etaMin), m_etaMax(etaMax), m_param(param) {}
32
34 // Deleting a nullptr is a noop
35 delete m_gan_all;
36 delete m_gan_low;
37 delete m_gan_high;
38}
39
41 if (m_net_all != nullptr)
42 return m_net_all.get();
43 return m_gan_all;
44}
46 if (m_net_low != nullptr)
47 return m_net_low.get();
48 return m_gan_low;
49}
51 if (m_net_high != nullptr)
52 return m_net_high.get();
53 return m_gan_high;
54}
55
57 if (m_pid == 211 || m_pid == 2212) {
58 if (GetNetAll() == nullptr) {
59 return false;
60 }
61 } else {
62 if (GetNetHigh() == nullptr || GetNetLow() == nullptr) {
63 return false;
64 }
65 }
66 return true;
67}
68
70 // Now load new data
71 std::string inputFileName;
72
75
76 bool success = true;
77
78 if (m_pid == 211) {
79 inputFileName = std::format("{}/neural_net_{}_eta_{}_{}_All.*",
80 m_param.GetInputFolder(),
81 m_pid,
83 m_etaMax);
84 ATH_MSG_DEBUG("Gan input file name " << inputFileName);
85 m_net_all = TFCSNetworkFactory::create(std::move(inputFileName));
86 if (m_net_all == nullptr)
87 success = false;
88 } else if (m_pid == 2212) {
89 inputFileName = std::format("{}/neural_net_{}_eta_{}_{}_High10.*",
90 m_param.GetInputFolder(),
91 m_pid,
93 m_etaMax);
94 ATH_MSG_DEBUG("Gan input file name " << inputFileName);
95 m_net_all = TFCSNetworkFactory::create(std::move(inputFileName));
96 if (m_net_all == nullptr)
97 success = false;
98 } else {
99 inputFileName = std::format("{}/neural_net_{}_eta_{}_{}_High12.*",
100 m_param.GetInputFolder(),
101 m_pid,
102 m_etaMin,
103 m_etaMax);
104 ATH_MSG_DEBUG("Gan input file name " << inputFileName);
105 m_net_high = TFCSNetworkFactory::create(inputFileName);
106 if (m_net_high == nullptr)
107 success = false;
108
109 inputFileName = std::format("{}/neural_net_{}_eta_{}_{}_UltraLow12.*",
110 m_param.GetInputFolder(),
111 m_pid,
112 m_etaMin,
113 m_etaMax);
114 m_net_low = TFCSNetworkFactory::create(std::move(inputFileName));
115 if (m_net_low == nullptr)
116 success = false;
117 }
118 return success;
119}
120
122 std::string rootFileName = std::format("{}/rootFiles/pid{}_E1048576_eta_{}_{}.root",
123 m_param.GetInputFolder(),
124 m_pid,
125 m_etaMin,
126 m_etaMin + 5);
127 ATH_MSG_DEBUG("Opening file " << rootFileName);
128 std::unique_ptr<TFile> file (TFile::Open(rootFileName.c_str(), "read"));
129 if (!file || file->IsZombie()) {
130 throw std::runtime_error(std::format("Failed to open or initialize ROOT file: {}", rootFileName));
131 }
132 for (int layer : m_param.GetRelevantLayers()) {
133 ATH_MSG_DEBUG("Layer " << layer);
134 TFCSGANXMLParameters::Binning binsInLayers = m_param.GetBinning();
135 TH2D *h2 = &binsInLayers[layer];
136
137 std::string histoName = std::format("r{}w", layer);
138 TH1D *h1 = (TH1D *)file->Get(histoName.c_str());
139 if (std::isnan(h1->Integral())) {
140 histoName = std::format("r{}", layer);
141 h1 = (TH1D *)file->Get(histoName.c_str());
142 }
143
144 TAxis *x = (TAxis *)h2->GetXaxis();
145 for (int ix = 1; ix <= h2->GetNbinsX(); ++ix) {
146 ATH_MSG_DEBUG(ix);
147 h1->GetXaxis()->SetRangeUser(x->GetBinLowEdge(ix), x->GetBinUpEdge(ix));
148
149 double result = 0;
150 if (h1->Integral() > 0 && h1->GetNbinsX() > 2) {
151 TFitResultPtr res(0);
152
153 res = h1->Fit("expo", "SQ");
154 if (res >= 0 && !std::isnan(res->Parameter(0))) {
155 result = res->Parameter(1);
156 }
157 }
158 m_allFitResults[layer].push_back(result);
159 }
160 }
161 ATH_MSG_DEBUG("Done initialisaing fits");
162}
163
165 std::string rootFileName = std::format("{}/rootFiles/pid{}_E65536_eta_{}_{}_validation.root",
166 m_param.GetInputFolder(),
167 m_pid,
168 m_etaMin,
169 m_etaMin + 5);
170 ATH_MSG_DEBUG("Opening file " << rootFileName);
171 std::unique_ptr<TFile> file (TFile::Open(rootFileName.c_str(), "read"));
172 if (!file || file->IsZombie()) {
173 throw std::runtime_error(std::format("Failed to open or initialize ROOT file: {}", rootFileName));
174 }
175 for (int layer : m_param.GetRelevantLayers()) {
176 TH1D *h = new TH1D("h", "h", 100, 0.01, 1);
177 TTree *tree = (TTree *)file->Get("rootTree");
178 std::string command = std::format("extrapWeight_{}>>h", layer);
179 tree->Draw(command.c_str());
180 m_extrapolatorWeights[layer] = h->GetMean();
181 ATH_MSG_DEBUG("Extrapolation: layer " << layer << " mean "
182 << m_extrapolatorWeights[layer]);
183 }
184}
185
188 const TFCSExtrapolationState *extrapol,
189 TFCSSimulationState &simulstate) const {
190 double randUniformZ = 0.;
191 NetworkInputs inputs;
192
193 int maxExp = 0, minExp = 0;
194 if (m_pid == 22 || std::abs(m_pid) == 11) {
195 if (truth->P() >
196 4096) { // This is the momentum, not the energy, because the split is
197 // based on the samples which are produced with the momentum
198 maxExp = 22;
199 minExp = 12;
200 } else {
201 maxExp = 12;
202 minExp = 6;
203 }
204 } else if (std::abs(m_pid) == 211) {
205 maxExp = 22;
206 minExp = 8;
207 } else if (std::abs(m_pid) == 2212) {
208 maxExp = 22;
209 minExp = 10;
210 }
211
212 int p_min = std::pow(2, minExp);
213 int p_max = std::pow(2, maxExp);
214 // Keep min and max without mass offset as we do not train on antiparticles
215 double Ekin_min =
216 std::sqrt(std::pow(p_min, 2) + std::pow(truth->M(), 2)) - truth->M();
217 double Ekin_max =
218 std::sqrt(std::pow(p_max, 2) + std::pow(truth->M(), 2)) - truth->M();
219 for (int i = 0; i < m_param.GetLatentSpaceSize(); i++) {
220 randUniformZ = CLHEP::RandGauss::shoot(simulstate.randomEngine(), 0.5, 0.5);
221 inputs["Noise"].insert(std::pair<std::string, double>(
222 "variable_" + std::to_string(i), randUniformZ));
223 }
224
225 // double e = log(truth->Ekin()/Ekin_min)/log(Ekin_max/Ekin_min) ;
226 // Could be uncommented , but would need the line above too
227 // ATH_MSG_DEBUG( "Check label: " << e <<" Ekin:" << truth->Ekin() <<" p:" <<
228 // truth->P() <<" mass:" << truth->M() <<" Ekin_off:" <<
229 // truth->Ekin_off() << " Ekin_min:"<<Ekin_min<<"
230 // Ekin_max:"<<Ekin_max);
231 // inputs["mycond"].insert ( std::pair<std::string,double>("variable_0",
232 // truth->Ekin()/(std::pow(2,maxExp))) ); //Old conditioning using linear
233 // interpolation, now use logaritminc interpolation
234 inputs["mycond"].insert(std::pair<std::string, double>(
235 "variable_0", log(truth->Ekin() / Ekin_min) / log(Ekin_max / Ekin_min)));
236
237 if (m_param.GetGANVersion() >= 2) {
238 if (false) { // conditioning on eta, should only be needed in transition
239 // regions and added only to the GANs that use it, for now all
240 // GANs have 3 conditioning inputs so filling zeros
241 inputs["mycond"].insert(std::pair<std::string, double>(
242 "variable_1", std::abs(extrapol->IDCaloBoundary_eta())));
243 } else {
244 inputs["mycond"].insert(std::pair<std::string, double>("variable_1", 0));
245 }
246 }
247
249 if (m_param.GetGANVersion() == 1 || m_pid == 211 || m_pid == 2212) {
250 outputs = GetNetAll()->compute(inputs);
251 } else {
252 if (truth->P() >
253 4096) { // This is the momentum, not the energy, because the split is
254 // based on the samples which are produced with the momentum
255 ATH_MSG_DEBUG("Computing outputs given inputs for high");
256 outputs = GetNetHigh()->compute(inputs);
257 } else {
258 outputs = GetNetLow()->compute(inputs);
259 }
260 }
261 ATH_MSG_DEBUG("Start Network inputs ~~~~~~~~");
263 ATH_MSG_DEBUG("End Network inputs ~~~~~~~~");
264 ATH_MSG_DEBUG("Start Network outputs ~~~~~~~~");
266 ATH_MSG_DEBUG("End Network outputs ~~~~~~~~");
267 return outputs;
268}
269
271 ATH_MSG_INFO("LWTNN Handler parameters");
272 ATH_MSG_INFO(" pid: " << m_pid);
273 ATH_MSG_INFO(" etaMin:" << m_etaMin);
274 ATH_MSG_INFO(" etaMax: " << m_etaMax);
275 m_param.Print();
276}
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
std::pair< std::vector< unsigned int >, bool > res
#define x
Header file for AthHistogramAlgorithm.
VNetworkBase * GetNetAll() const
TFCSGANLWTNNHandler * m_gan_all
TFCSGANLWTNNHandler * m_gan_low
NetworkOutputs GetNetworkOutputs(const TFCSTruthState *truth, const TFCSExtrapolationState *extrapol, TFCSSimulationState &simulstate) const
TFCSGANLWTNNHandler * m_gan_high
std::unique_ptr< VNetworkBase > m_net_high
std::unique_ptr< VNetworkBase > m_net_all
void CalculateMeanPointFromDistributionOfR()
virtual ~TFCSGANEtaSlice()
std::unique_ptr< VNetworkBase > m_net_low
ExtrapolatorWeights m_extrapolatorWeights
bool IsGanCorrectlyLoaded() const
VNetworkBase * GetNetLow() const
std::map< std::string, std::map< std::string, double > > NetworkInputs
VNetworkBase * GetNetHigh() const
void ExtractExtrapolatorMeansFromInputs()
FitResultsPerLayer m_allFitResults
TFCSGANXMLParameters m_param
std::map< int, TH2D > Binning
static std::unique_ptr< VNetworkBase > create(std::string input)
Given a string, make a network.
CLHEP::HepRandomEngine * randomEngine()
double Ekin() const
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
virtual NetworkOutputs compute(NetworkInputs const &inputs) const =0
Function to pass values to the network.
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
TChain * tree
TFile * file