ATLAS Offline Software
AnomDetVAE.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 /*********************************
5  * AnomDetVAE.cxx
6  * Created by Sagar Addepalli on 25/11/2024.
7  *
8  * @brief algorithm uses a variational auto-encoder based anomaly detection network
9  * Uses 6 jets, 4 muons, 4 taus, and MET to calculate the anomaly score
10  * based on the KL-divergence in a lower dimension latent space
11  *
12  * @param ScaleSqr1
13  * @param ScaleSqr2
14  * @param ScaleSqr3
15  * @param AnomalyScoreThresh
16 **********************************/
17 
18 #include <cmath>
19 
21 #include "L1TopoCommon/Exception.h"
24 #include <VAENetwork.h>
25 
26 REGISTER_ALG_TCS(ADVAE_2A)
27 
28 
30 {
31  defineParameter("InputWidth1", 6);
32  defineParameter("InputWidth2", 6);
33  defineParameter("InputWidth3", 6);
34  defineParameter("InputWidth4", 1);
35  defineParameter("MaxTob1", 6);
36  defineParameter("MaxTob2", 4);
37  defineParameter("MaxTob3", 4);
38  defineParameter("MaxTob4", 1);
39  defineParameter("NumResultBits", 2);
40 
41  // Version parameter, used for L1TopoFW book-keeping, no practical application
42  defineParameter("ADVAEVersion", 1);
43  // minEt cuts, one per input list (TOBs failing these are set to ET = eta = phi = 0)
44  defineParameter("MinET1",0);
45  defineParameter("MinET2",0);
46  defineParameter("MinET3",0);
47  defineParameter("MinET4",0);
48  // Configurable scale for the mu parameters in the latent space and AD score threshold
49  // The value used is sum of squares of the NN result vector elements, bit-shifted by 8 to get all decimal bits
50  defineParameter("ScaleSqr1",128,0);
51  defineParameter("ScaleSqr2",128,0);
52  defineParameter("ScaleSqr3",128,0);
53  defineParameter("AnomalyScoreThresh", 1000000, 0);
54  defineParameter("ScaleSqr1",128,1);
55  defineParameter("ScaleSqr2",128,1);
56  defineParameter("ScaleSqr3",128,1);
57  defineParameter("AnomalyScoreThresh", 1000000, 1);
58 
59  setNumberOutputBits(2);
60 }
61 
63 
64 
67  p_NumberLeading1 = parameter("InputWidth1").value();
68  p_NumberLeading2 = parameter("InputWidth2").value();
69  p_NumberLeading3 = parameter("InputWidth3").value();
70  p_NumberLeading4 = parameter("InputWidth4").value();
71 
72  if(parameter("MaxTob1").value() > 0) p_NumberLeading1 = parameter("MaxTob1").value();
73  if(parameter("MaxTob2").value() > 0) p_NumberLeading2 = parameter("MaxTob2").value();
74  if(parameter("MaxTob3").value() > 0) p_NumberLeading3 = parameter("MaxTob3").value();
75  if(parameter("MaxTob4").value() > 0) p_NumberLeading4 = parameter("MaxTob4").value();
76 
77  p_minEt1 = parameter("MinET1").value();
78  p_minEt2 = parameter("MinET2").value();
79  p_minEt3 = parameter("MinET3").value();
80  p_minEt4 = parameter("MinET4").value();
81 
82  for(unsigned int i=0; i<numberOutputBits(); ++i) {
83  p_ScaleSqr1[i] = parameter("ScaleSqr1", i).value();
84  p_ScaleSqr2[i] = parameter("ScaleSqr2", i).value();
85  p_ScaleSqr3[i] = parameter("ScaleSqr3", i).value();
86  p_AnomalyScoreThresh[i] = parameter("AnomalyScoreThresh", i).value();
87  }
88 
89  TRG_MSG_INFO("number output : " << numberOutputBits());
90 
91  // book histograms
92  for(unsigned int i=0; i<numberOutputBits(); ++i) {
93  std::string hname_accept = "hAnomalyScore_accept_bit"+std::to_string((int)i);
94  std::string hname_reject = "hAnomalyScore_reject_bit"+std::to_string((int)i);
95  // score
96  bookHist(m_histAccept, hname_accept, "ADScore", 2000, 0, 2000000);
97  bookHist(m_histReject, hname_reject, "ADScore", 2000, 0, 2000000);
98  }
99 
100  return StatusCode::SUCCESS;
101 }
102 
103 
105 TCS::ADVAE_2A::processBitCorrect( const std::vector<TCS::TOBArray const *> & input,
106  const std::vector<TCS::TOBArray *> & output,
107  Decision & decision )
108 {
109 
110 
111  if( input.size() == 4) {
112 
113  TCS::TOBArray const* jets = input[0];
114  TCS::TOBArray const* taus = input[1];
115  TCS::TOBArray const* mus = input[2];
116  TCS::TOBArray const* met = input[3];
117  TRG_MSG_DEBUG("Number of jets are " << (*jets).size());
118  TRG_MSG_DEBUG("Number of taus are " << (*taus).size());
119  TRG_MSG_DEBUG("Number of mus are " << (*mus).size());
120  TRG_MSG_DEBUG("Number of met are " << (*met).size());
121 
122  //check for ambiguous sorting and set corresponding flag if an ambiguity is found
123  bool hasAmbiguousInputs = TSU::isAmbiguousTruncation(jets, p_NumberLeading1, p_minEt1)
124  || TSU::isAmbiguousTruncation(taus, p_NumberLeading2, p_minEt2)
125  || TSU::isAmbiguousTruncation(mus, p_NumberLeading3, p_minEt3)
126  || TSU::isAmbiguousTruncation(met, p_NumberLeading4, p_minEt4)
127  || TSU::isAmbiguousAnywhere(jets, p_NumberLeading1, p_minEt1)
128  || TSU::isAmbiguousAnywhere(taus, p_NumberLeading2, p_minEt2)
129  || TSU::isAmbiguousAnywhere(mus, p_NumberLeading3, p_minEt3)
130  || TSU::isAmbiguousAnywhere(met, p_NumberLeading4, p_minEt4);
131 
132  std::vector<u_int> jet_pt(6,0), tau_pt(4,0), mu_pt(4,0), met_pt(1,0);
133  std::vector<int> jet_eta(6,0), tau_eta(4,0), mu_eta(4,0); //no met_eta
134  std::vector<int> jet_phi(6,0), tau_phi(4,0), mu_phi(4,0), met_phi(1,0);
135 
136  for (u_int i = 0; i<(*jets).size() && i<6; ++i) {
137  if ( parType_t( (*jets)[i].Et() ) <= p_minEt1 ) continue; //ET cut, leave NN inputs at default values (0)
138  jet_pt[i] = (*jets)[i].Et();
139  jet_eta[i] = (*jets)[i].eta();
140  jet_phi[i] = (*jets)[i].phi();
141  }
142  for (u_int i = 0; i < (*taus).size() && i<4; ++i) {
143  if ( parType_t( (*taus)[i].Et() ) <= p_minEt2 ) continue; //ET cut, leave NN inputs at default values (0)
144  tau_pt[i] = (*taus)[i].Et();
145  tau_eta[i] = (*taus)[i].eta();
146  tau_phi[i] = (*taus)[i].phi();
147  }
148  for (u_int i = 0; i < (*mus).size() && i<4; ++i) {
149  if ( parType_t( (*mus)[i].Et() ) <= p_minEt3 ) continue; //ET cut, leave NN inputs at default values (0)
150  mu_pt[i] = (*mus)[i].Et();
151  mu_eta[i] = (*mus)[i].eta();
152  mu_phi[i] = (*mus)[i].phi();
153  }
154  for (u_int i = 0; i < (*met).size() && i<1; ++i) {
155  if ( parType_t( (*met)[i].Et() ) <= p_minEt4 ) continue; //ET cut, leave NN inputs at default values (0)
156  met_pt[i] = (*met)[i].Et();
157  met_phi[i] = (*met)[i].phi();
158  }
159 
160 
161  TRG_MSG_DEBUG("Jet0: " << jet_pt[0] << ", " << jet_eta[0] << ", " << jet_phi[0] );
162  TRG_MSG_DEBUG("Jet1: " << jet_pt[1] << ", " << jet_eta[1] << ", " << jet_phi[1] );
163  TRG_MSG_DEBUG("Jet2: " << jet_pt[2] << ", " << jet_eta[2] << ", " << jet_phi[2] );
164  TRG_MSG_DEBUG("Jet3: " << jet_pt[3] << ", " << jet_eta[3] << ", " << jet_phi[3] );
165  TRG_MSG_DEBUG("Jet4: " << jet_pt[4] << ", " << jet_eta[4] << ", " << jet_phi[4] );
166  TRG_MSG_DEBUG("Jet5: " << jet_pt[5] << ", " << jet_eta[5] << ", " << jet_phi[5] );
167 
168  TRG_MSG_DEBUG("Tau0: " << tau_pt[0] << ", " << tau_eta[0] << ", " << tau_phi[0] );
169  TRG_MSG_DEBUG("Tau1: " << tau_pt[1] << ", " << tau_eta[1] << ", " << tau_phi[1] );
170  TRG_MSG_DEBUG("Tau2: " << tau_pt[2] << ", " << tau_eta[2] << ", " << tau_phi[2] );
171  TRG_MSG_DEBUG("Tau3: " << tau_pt[3] << ", " << tau_eta[3] << ", " << tau_phi[3] );
172 
173  TRG_MSG_DEBUG("Mu0: " << mu_pt[0] << ", " << mu_eta[0] << ", " << mu_phi[0] );
174  TRG_MSG_DEBUG("Mu1: " << mu_pt[1] << ", " << mu_eta[1] << ", " << mu_phi[1] );
175  TRG_MSG_DEBUG("Mu2: " << mu_pt[2] << ", " << mu_eta[2] << ", " << mu_phi[2] );
176  TRG_MSG_DEBUG("Mu3: " << mu_pt[3] << ", " << mu_eta[3] << ", " << mu_phi[3] );
177 
178  TRG_MSG_DEBUG("MET: " << met_pt[0] << ", " << met_phi[0] << std::endl);
179 
180  ADVAE2A::VAENetwork AD_Network( jet_pt[0], jet_eta[0], jet_phi[0],
181  jet_pt[1], jet_eta[1], jet_phi[1],
182  jet_pt[2], jet_eta[2], jet_phi[2],
183  jet_pt[3], jet_eta[3], jet_phi[3],
184  jet_pt[4], jet_eta[4], jet_phi[4],
185  jet_pt[5], jet_eta[5], jet_phi[5],
186  tau_pt[0], tau_eta[0], tau_phi[0],
187  tau_pt[1], tau_eta[1], tau_phi[1],
188  tau_pt[2], tau_eta[2], tau_phi[2],
189  tau_pt[3], tau_eta[3], tau_phi[3],
190  mu_pt [0], mu_eta [0], mu_phi [0],
191  mu_pt [1], mu_eta [1], mu_phi [1],
192  mu_pt [2], mu_eta [2], mu_phi [2],
193  mu_pt [3], mu_eta [3], mu_phi [3],
194  met_pt[0], met_phi[0] );
195  std::vector<int64_t> anomScoreInt64Vec = AD_Network.getAnomalyScoreInt64Vec();
196 
197  for(u_int i=0; i<numberOutputBits(); ++i) {
198  bool accept = false;
199  // Retrieve threshold
200  int32_t threshold = int32_t ( p_AnomalyScoreThresh[i] );
201  // Calculate event score
202  int64_t anomScoreInt64 = 0;
203  anomScoreInt64 = (p_ScaleSqr1[i] * anomScoreInt64Vec.at(0)*anomScoreInt64Vec.at(0) >> p_ScaleSqr_DropBits) +
204  (p_ScaleSqr2[i] * anomScoreInt64Vec.at(1)*anomScoreInt64Vec.at(1) >> p_ScaleSqr_DropBits) +
205  (p_ScaleSqr3[i] * anomScoreInt64Vec.at(2)*anomScoreInt64Vec.at(2) >> p_ScaleSqr_DropBits);
206  if ( anomScoreInt64 > threshold ) {
207  accept = true;
208  decision.setBit(i, true);
209  for ( u_int j = 0; j<6 && j<(*jets).size(); ++j ) output[i]->push_back((*jets)[j]);
210  for ( u_int j = 0; j<4 && j<(*taus).size(); ++j ) output[i]->push_back((*taus)[j]);
211  for ( u_int j = 0; j<4 && j<(*mus).size() ; ++j ) output[i]->push_back((*mus) [j]);
212  output[i]->push_back((*met)[0]);
213  }
214  output[i]->setAmbiguityFlag(hasAmbiguousInputs);
215 
216  if(fillHistos() and accept) {
217  fillHist1D(m_histAccept[i],anomScoreInt64);
218  } else if(fillHistos() && !accept) {
219  fillHist1D(m_histReject[i],anomScoreInt64);
220  }
221 
222  TRG_MSG_DEBUG("Decision for bit" << i << ": " << (accept?"pass":"fail") << " anomaly score = " << anomScoreInt64 << std::endl);
223  }
224  } else {
225  TCS_EXCEPTION("ADVAE_2A alg must have 4 inputs, but got " << input.size());
226  }
227 
229 }
230 
232 TCS::ADVAE_2A::process( const std::vector<TCS::TOBArray const *> & input,
233  const std::vector<TCS::TOBArray *> & output,
234  Decision & decision )
235 {
236 
237 
238  if( input.size() == 4) {
239 
240  TCS::TOBArray const* jets = input[0];
241  TCS::TOBArray const* taus = input[1];
242  TCS::TOBArray const* mus = input[2];
243  TCS::TOBArray const* met = input[3];
244  TRG_MSG_DEBUG("Number of jets are " << (*jets).size());
245  TRG_MSG_DEBUG("Number of taus are " << (*taus).size());
246  TRG_MSG_DEBUG("Number of mus are " << (*mus).size());
247  TRG_MSG_DEBUG("Number of met are " << (*met).size());
248 
249  std::vector<u_int> jet_pt(6,0), tau_pt(4,0), mu_pt(4,0), met_pt(1,0);
250  std::vector<int> jet_eta(6,0), tau_eta(4,0), mu_eta(4,0); //no met_eta
251  std::vector<int> jet_phi(6,0), tau_phi(4,0), mu_phi(4,0), met_phi(1,0);
252 
253  for (u_int i = 0; i<(*jets).size() && i<6; ++i) {
254  if ( parType_t( (*jets)[i].Et() ) <= p_minEt1 ) continue; //ET cut, leave NN inputs at default values (0)
255  jet_pt[i] = (*jets)[i].Et();
256  jet_eta[i] = (*jets)[i].eta();
257  jet_phi[i] = (*jets)[i].phi();
258  }
259  for (u_int i = 0; i < (*taus).size() && i<4; ++i) {
260  if ( parType_t( (*taus)[i].Et() ) <= p_minEt2 ) continue; //ET cut, leave NN inputs at default values (0)
261  tau_pt[i] = (*taus)[i].Et();
262  tau_eta[i] = (*taus)[i].eta();
263  tau_phi[i] = (*taus)[i].phi();
264  }
265  for (u_int i = 0; i < (*mus).size() && i<4; ++i) {
266  if ( parType_t( (*mus)[i].Et() ) <= p_minEt3 ) continue; //ET cut, leave NN inputs at default values (0)
267  mu_pt[i] = (*mus)[i].Et();
268  mu_eta[i] = (*mus)[i].eta();
269  mu_phi[i] = (*mus)[i].phi();
270  }
271  for (u_int i = 0; i < (*met).size() && i<1; ++i) {
272  if ( parType_t( (*met)[i].Et() ) <= p_minEt4 ) continue; //ET cut, leave NN inputs at default values (0)
273  met_pt[i] = (*met)[i].Et();
274  met_phi[i] = (*met)[i].phi();
275  }
276 
277  ADVAE2A::VAENetwork AD_Network( jet_pt[0], jet_eta[0], jet_phi[0],
278  jet_pt[1], jet_eta[1], jet_phi[1],
279  jet_pt[2], jet_eta[2], jet_phi[2],
280  jet_pt[3], jet_eta[3], jet_phi[3],
281  jet_pt[4], jet_eta[4], jet_phi[4],
282  jet_pt[5], jet_eta[5], jet_phi[5],
283  tau_pt[0], tau_eta[0], tau_phi[0],
284  tau_pt[1], tau_eta[1], tau_phi[1],
285  tau_pt[2], tau_eta[2], tau_phi[2],
286  tau_pt[3], tau_eta[3], tau_phi[3],
287  mu_pt [0], mu_eta [0], mu_phi [0],
288  mu_pt [1], mu_eta [1], mu_phi [1],
289  mu_pt [2], mu_eta [2], mu_phi [2],
290  mu_pt [3], mu_eta [3], mu_phi [3],
291  met_pt[0], met_phi[0] );
292  std::vector<int64_t> anomScoreInt64Vec = AD_Network.getAnomalyScoreInt64Vec();
293 
294  for(u_int i=0; i<numberOutputBits(); ++i) {
295  bool accept = false;
296  // Retrieve threshold
297  int32_t threshold = int32_t ( p_AnomalyScoreThresh[i] );
298  // Calculate event score
299  int64_t anomScoreInt64 = 0;
300  anomScoreInt64 = p_ScaleSqr1[i]/std::pow(2,p_ScaleSqr_DropBits) * anomScoreInt64Vec.at(0)*anomScoreInt64Vec.at(0) +
301  p_ScaleSqr2[i]/std::pow(2,p_ScaleSqr_DropBits) * anomScoreInt64Vec.at(1)*anomScoreInt64Vec.at(1) +
302  p_ScaleSqr3[i]/std::pow(2,p_ScaleSqr_DropBits) * anomScoreInt64Vec.at(2)*anomScoreInt64Vec.at(2);
303  if ( anomScoreInt64 > threshold ) {
304  accept = true;
305  decision.setBit(i, true);
306  for ( u_int j = 0; j<6 && j<(*jets).size(); ++j ) output[i]->push_back((*jets)[j]);
307  for ( u_int j = 0; j<4 && j<(*taus).size(); ++j ) output[i]->push_back((*taus)[j]);
308  for ( u_int j = 0; j<4 && j<(*mus).size() ; ++j ) output[i]->push_back((*mus) [j]);
309  output[i]->push_back((*met)[0]);
310  }
311 
312  if(fillHistos() and accept) {
313  fillHist1D(m_histAccept[i],anomScoreInt64);
314  } else if(fillHistos() && !accept) {
315  fillHist1D(m_histReject[i],anomScoreInt64);
316  }
317 
318  TRG_MSG_DEBUG("Decision for bit" << i << ": " << (accept?"pass":"fail") << " anomaly score = " << anomScoreInt64 << std::endl);
319  }
320  } else {
321  TCS_EXCEPTION("ADVAE_2A alg must have 4 inputs, but got " << input.size());
322  }
323 
325 }
TCS::ADVAE_2A::initialize
virtual StatusCode initialize()
Definition: AnomDetVAE.cxx:66
TCS::StatusCode::SUCCESS
@ SUCCESS
Definition: Trigger/TrigT1/L1Topo/L1TopoCommon/L1TopoCommon/StatusCode.h:17
TCS::parType_t
uint32_t parType_t
Definition: Parameter.h:22
CutsMETMaker::accept
StatusCode accept(const xAOD::Muon *mu)
Definition: CutsMETMaker.cxx:18
TSU::isAmbiguousTruncation
bool isAmbiguousTruncation(TCS::TOBArray const *tobs, size_t pos, unsigned minEt=0)
Definition: Trigger/TrigT1/L1Topo/L1TopoSimulationUtils/Root/Helpers.cxx:23
defineDB.jets
jets
Definition: JetTagCalibration/share/defineDB.py:24
athena.value
value
Definition: athena.py:124
TCS::ADVAE_2A
Definition: AnomDetVAE.h:16
Decision.h
const
bool const RAWDATA *ch2 const
Definition: LArRodBlockPhysicsV0.cxx:560
TCS::DecisionAlg
Definition: DecisionAlg.h:22
TCS::Decision::setBit
void setBit(unsigned int index, bool value)
Definition: L1Topo/L1TopoInterfaces/Root/Decision.cxx:12
met
Definition: IMETSignificance.h:24
lumiFormat.i
int i
Definition: lumiFormat.py:85
TCS_EXCEPTION
#define TCS_EXCEPTION(MSG)
Definition: Trigger/TrigT1/L1Topo/L1TopoCommon/L1TopoCommon/Exception.h:14
TCS::TOBArray
Definition: TOBArray.h:24
plotIsoValidation.mu_phi
mu_phi
Definition: plotIsoValidation.py:152
TCS::Decision
Definition: L1Topo/L1TopoInterfaces/L1TopoInterfaces/Decision.h:19
TRG_MSG_INFO
#define TRG_MSG_INFO(x)
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStreamMacros.h:27
TCS::ADVAE_2A::processBitCorrect
virtual StatusCode processBitCorrect(const std::vector< TCS::TOBArray const * > &input, const std::vector< TCS::TOBArray * > &output, Decision &decison)
Definition: AnomDetVAE.cxx:105
TCS::ADVAE_2A::process
virtual StatusCode process(const std::vector< TCS::TOBArray const * > &input, const std::vector< TCS::TOBArray * > &output, Decision &decison)
Definition: AnomDetVAE.cxx:232
REGISTER_ALG_TCS
#define REGISTER_ALG_TCS(CLASS)
Definition: AlgFactory.h:62
AnomDetVAE.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
threshold
Definition: chainparser.cxx:74
TCS
Definition: Global/GlobalSimulation/src/IO/Decision.h:18
TSU::isAmbiguousAnywhere
bool isAmbiguousAnywhere(TCS::TOBArray const *tobs, size_t pos, unsigned minEt=0)
Definition: Trigger/TrigT1/L1Topo/L1TopoSimulationUtils/Root/Helpers.cxx:37
plotIsoValidation.mu_eta
mu_eta
Definition: plotIsoValidation.py:151
Trk::jet_phi
@ jet_phi
Definition: JetVtxParamDefs.h:28
Helpers.h
pow
constexpr int pow(int base, int exp) noexcept
Definition: ap_fixedTest.cxx:15
Exception.h
TCS::ADVAE_2A::~ADVAE_2A
virtual ~ADVAE_2A()
Definition: AnomDetVAE.cxx:62
TRG_MSG_DEBUG
#define TRG_MSG_DEBUG(x)
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStreamMacros.h:25
TCS::StatusCode
Definition: Trigger/TrigT1/L1Topo/L1TopoCommon/L1TopoCommon/StatusCode.h:15
plotIsoValidation.mu_pt
mu_pt
Definition: plotIsoValidation.py:150