34 #include "CaloDetDescr/CaloDetDescrElement.h"
43 #include "lwtnn/parse_json.hh"
47 #include "CLHEP/Random/RandGaussZiggurat.h"
48 #include "CLHEP/Random/RandFlat.h"
49 #include "CLHEP/Units/SystemOfUnits.h"
58 m_theContainer(nullptr),
59 m_rndGenSvc(
"AtRndmGenSvc",
name),
60 m_randomEngine(nullptr),
61 m_caloDetDescrManager(nullptr),
64 declareProperty(
"ParamsInputFilename" ,
m_paramsFilename=
"DNNCaloSim/DNNCaloSim_GAN_nn_v0.json",
" lwtnn json output describing the trained network");
65 declareProperty(
"ParamsInputArchitecture" ,
m_paramsInputArchitecture=
"GANv0",
"tag describing additional parameters necessary for the network");
82 ATH_MSG_INFO(m_screenOutputPrefix <<
"Initializing ...");
85 m_randomEngine = m_rndGenSvc->GetEngine( m_randomEngineName);
88 ATH_MSG_ERROR(
"Could not get random number engine from RandomNumberService. Abort.");
89 return StatusCode::FAILURE;
93 if (!m_caloDetDescrManager) {
98 if (!m_caloDetDescrManager) {
100 return StatusCode::FAILURE;
106 const CaloIdManager* caloId_mgr = m_caloDetDescrManager->getCalo_Mgr();
109 m_caloGeo = std::make_unique<CaloGeometryFromCaloDDM>();
110 m_caloGeo->LoadGeometryFromCaloDDM(m_caloDetDescrManager);
111 if(!m_caloGeo->LoadFCalChannelMapFromFCalDDM(fcalManager) )
ATH_MSG_FATAL(
"Found inconsistency between FCal_Channel map and GEO file. Please, check if they are configured properly.");
115 if (initializeNetwork().isFailure())
118 return StatusCode::FAILURE;
124 if(m_FastCaloSimCaloExtrapolation.retrieve().isFailure())
127 return StatusCode::FAILURE;
130 m_windowCells.reserve(m_numberOfCellsForDNN);
134 return StatusCode::SUCCESS;
146 ATH_MSG_ERROR(
"Could not find json file " << m_paramsFilename );
147 return StatusCode::FAILURE;
156 if (m_graph==
nullptr){
157 ATH_MSG_ERROR(
"Could not create LightWeightGraph from " << m_paramsFilename );
158 return StatusCode::FAILURE;
166 ATH_MSG_INFO(
"Using ParamsInputArchitecture: " << m_paramsInputArchitecture );
167 if (m_paramsInputArchitecture==
"GANv0")
169 m_GANLatentSize = 300;
170 m_logTrueEnergyMean = 9.70406053;
171 m_logTrueEnergyScale = 1.76099569;
172 m_riImpactEtaMean = 3.47603256e-05;
173 m_riImpactEtaScale = 0.00722316;
174 m_riImpactPhiMean = -5.42153684e-05;
175 m_riImpactPhiScale = 0.00708241;
178 if (m_GANLatentSize==0){
179 ATH_MSG_ERROR(
"m_GANLatentSize uninitialized!.ParamsInputArchitecture= " << m_paramsInputArchitecture );
180 return StatusCode::FAILURE;
183 return StatusCode::SUCCESS;
190 return StatusCode::SUCCESS;
195 const EventContext& ctx = Gaudi::Hive::currentContext();
196 ATH_MSG_INFO(m_screenOutputPrefix <<
"setupEvent NEW EVENT! ");
200 StatusCode sc = evtStore()->record(m_theContainer, m_caloCellsOutputName);
203 ATH_MSG_FATAL( m_screenOutputPrefix <<
"cannot record CaloCellContainer " << m_caloCellsOutputName );
204 return StatusCode::FAILURE;
208 CHECK( m_caloCellMakerToolsSetup.retrieve() );
209 ATH_MSG_DEBUG(
"Successfully retrieve CaloCellMakerTools: " << m_caloCellMakerToolsSetup );
212 for (; itrTool != endTool; ++itrTool)
214 std::string chronoName=this->
name()+
"_"+ itrTool->name();
215 if (m_chrono) m_chrono->chronoStart(chronoName);
216 StatusCode sc = (*itrTool)->process(m_theContainer, ctx);
218 m_chrono->chronoStop(chronoName);
224 ATH_MSG_ERROR( m_screenOutputPrefix <<
"Error executing tool " << itrTool->name() );
225 return StatusCode::FAILURE;
237 const int ntrial=100;
238 ATH_MSG_INFO (
"Trial window building on " << ntrial <<
" dummy eta phi " );
239 for (
int i=0 ;
i< ntrial ;
i++){
240 const double eta = CLHEP::RandFlat::shoot(testsimulstate.
randomEngine(), 0.2, 0.25);
244 if (fillWindowCells(eta,phi,testImpactCellDDE).isFailure()){
245 ATH_MSG_WARNING(
"Could not build trial window cells vector with eta " << eta <<
" phi " << phi);
248 ATH_MSG_INFO (
"End of trial window building on " << ntrial <<
" dummy eta phi " );
255 return StatusCode::SUCCESS;
260 const EventContext& ctx = Gaudi::Hive::currentContext();
263 CHECK( m_caloCellMakerToolsRelease.retrieve() );
268 for (; itrTool != endTool; ++itrTool)
270 ATH_MSG_VERBOSE( m_screenOutputPrefix <<
"Calling tool " << itrTool->name() );
272 StatusCode sc = (*itrTool)->process(m_theContainer, ctx);
276 ATH_MSG_ERROR( m_screenOutputPrefix <<
"Error executing tool " << itrTool->name() );
280 return StatusCode::SUCCESS;
288 float aEtaRaw =
a->caloDDE()->eta_raw();
289 float bEtaRaw =
b->caloDDE()->eta_raw();
291 float aPhiRaw =
a->caloDDE()->phi_raw();
292 float bPhiRaw =
b->caloDDE()->phi_raw();
294 if ((aSampling) < (bSampling))
296 else if ((aSampling) > (bSampling))
299 if ((aEtaRaw) < (bEtaRaw))
301 else if ((aEtaRaw) > (bEtaRaw))
312 float aEtaRaw =
a->caloDDE()->eta_raw();
313 float bEtaRaw =
b->caloDDE()->eta_raw();
315 float aPhiRaw =
a->caloDDE()->phi_raw();
316 float bPhiRaw =
b->caloDDE()->phi_raw();
318 if ((aSampling) < (bSampling))
320 else if ((aSampling) > (bSampling))
323 if ((aEtaRaw) < (bEtaRaw))
325 else if ((aEtaRaw) > (bEtaRaw))
335 ATH_MSG_VERBOSE(
"NEW PARTICLE! DNNCaloSimSvc called with ISFParticle: " << isfp);
339 if(isfp.
ekin() < 10) {
340 ATH_MSG_VERBOSE(
"Skipping particle with Ekin: " << isfp.
ekin() <<
" MeV. Below the 10 MeV threshold.");
341 return StatusCode::SUCCESS;
348 if (fillNetworkInputs(isfp,
inputs,trueEnergy).isFailure()) {
351 return StatusCode::SUCCESS;
363 for (
auto & windowCell : m_windowCells) {
368 <<
" phi_raw " << windowCell->caloDDE()->phi_raw()
369 <<
" sampling " << windowCell->caloDDE()->getSampling()
370 <<
" energy " << windowCell->energy());
377 return StatusCode::SUCCESS;
392 trueEnergy = isfp.
ekin();
398 m_FastCaloSimCaloExtrapolation->extrapolate(
extrapol,&truth);
407 for (
int isubpos=0; isubpos< 3 ; isubpos++){
410 " isubpos=" << isubpos <<
411 " OK=" <<
extrapol.OK(isam,isubpos) <<
412 " eta=" <<
extrapol.eta(isam,isubpos) <<
413 " phi=" <<
extrapol.phi(isam,isubpos) <<
414 " r=" <<
extrapol.r(isam,isubpos) );
423 double etaExtrap=-999.;
424 double phiExtrap=-999.;
426 etaExtrap=
extrapol.eta(isam,isubpos);
427 phiExtrap=
extrapol.phi(isam,isubpos);
431 " isubpos=" << isubpos <<
432 " eta=" << etaExtrap <<
433 " phi=" << phiExtrap );
439 if (fillWindowCells(etaExtrap,phiExtrap,impactCellDDE).isFailure()){
441 return StatusCode::FAILURE;
450 double randGaussz = 0.;
453 int impactEtaIndex = m_emID->eta(impactCellDDE->
identify());
454 int impactPhiIndex = m_emID->phi(impactCellDDE->
identify());
455 double etaRawImpactCell=impactCellDDE->
eta_raw();
456 double phiRawImpactCell=impactCellDDE->
phi_raw();
459 <<
" phi_index " << impactPhiIndex
460 <<
" sampling " << m_emID->sampling(impactCellDDE->
identify()));
462 int pconf = impactPhiIndex % 4 ;
463 int econf = (impactEtaIndex + 1) % 2 ;
465 double riImpactEta = (etaExtrap - etaRawImpactCell);
472 riImpactEta = (riImpactEta - m_riImpactEtaMean)/m_riImpactEtaScale;
473 double riImpactPhi = (
CaloPhiRange::diff(phiExtrap, phiRawImpactCell) - m_riImpactPhiMean)/m_riImpactPhiScale;
480 for (
int i = 0;
i< m_GANLatentSize;
i ++)
482 randGaussz = CLHEP::RandGaussZiggurat::shoot(simulstate.
randomEngine(), 0., 1.);
488 inputs[
"E_true"].insert ( std::pair<std::string,double>(
"0", (
std::log(trueEnergy) - m_logTrueEnergyMean)/m_logTrueEnergyScale) );
490 for (
int i = 0;
i< 4;
i ++)
499 for (
int i = 0;
i< 2;
i ++){
508 inputs[
"ripos"].insert ( std::pair<std::string,double>(
"0", riImpactEta) );
509 inputs[
"ripos"].insert ( std::pair<std::string,double>(
"1", riImpactPhi ) );
511 return StatusCode::SUCCESS;
520 impactCellDDE=m_caloDetDescrManager->get_element(
CaloCell_ID::EMB2,etaExtrap,phiExtrap);
521 if (impactCellDDE==
nullptr){
522 ATH_MSG_WARNING(
"No cell found for this eta " << etaExtrap <<
" phi " << phiExtrap);
523 return StatusCode::FAILURE;
529 const int caloHashImpactCell=impactCellDDE->
calo_hash();
530 const double etaImpactCell=impactCellDDE->
eta();
531 const double phiImpactCell=impactCellDDE->
phi();
532 const double etaRawImpactCell=impactCellDDE->
eta_raw();
533 const double phiRawImpactCell=impactCellDDE->
phi_raw();
537 " eta=" << etaImpactCell <<
538 " phi=" << phiImpactCell <<
539 " eta raw=" << etaRawImpactCell <<
540 " phi raw=" << phiRawImpactCell );
550 m_windowCells.clear();
554 for(
CaloCell* theCell : * m_theContainer) {
555 sampling = theCell->caloDDE()->getSampling();
556 eta_raw = theCell->caloDDE()->eta_raw();
557 phi_raw = theCell->caloDDE()->phi_raw();
558 if (( eta_raw < etaRawImpactCell + m_etaRawBackCut) && (eta_raw > etaRawImpactCell - m_etaRawBackCut)) {
570 if ((sampling == 0) || (sampling == 1) ){
571 if ((eta_raw < etaRawImpactCell + m_etaRawMiddleCut) && (eta_raw > etaRawImpactCell - m_etaRawMiddleCut)) {
574 m_windowCells.push_back(theCell);
578 else if(sampling == 2) {
580 if ((eta_raw < etaRawImpactCell + m_etaRawMiddleCut) && (eta_raw > etaRawImpactCell - m_etaRawMiddleCut)) {
582 m_windowCells.push_back(theCell);
587 else if(sampling == 3){
590 m_windowCells.push_back(theCell);
596 if (nSqCuts != m_numberOfCellsForDNN){
597 ATH_MSG_WARNING(
"Total cells passing DNN selection is " << nSqCuts <<
" but should be " << m_numberOfCellsForDNN );
599 return StatusCode::FAILURE;
604 if (etaRawImpactCell < 0){
611 return StatusCode::SUCCESS;