ATLAS Offline Software
JetJvtNN_testCfg.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 #include <iostream>
5 #include <fstream>
6 #include <string>
7 #include <vector>
8 #include <regex>
9 
10 #include "nlohmann/json.hpp"
11 #include "lwtnn/generic/FastGraph.hh"
12 #include "lwtnn/parse_json.hh"
13 #include "lwtnn/Stack.hh"
15 
16 int main() {
17 
18  // Use the Path Resolver to find the jvt file and retrieve the likelihood histogram
19  std::string NNConfigDir = "JetPileupTag/NNJvt/2022-03-22/";
20  std::string NNParamFileName = "NNJVT.Network.graph.Offline.Nonprompt_All_MaxWeight.json";
21  std::string NNCutFileName = "NNJVT.Cuts.FixedEffPt.Offline.Nonprompt_All_MaxW.json";
22 
23  std::string configPath = PathResolverFindCalibFile(NNConfigDir+"/"+NNParamFileName);
24  std::cout << "Reading JVT NN file from:\n " << NNConfigDir << "/" << NNParamFileName << std::endl;
25  std::cout << " resolved in :\n " << configPath << std::endl;
26 
27  std::ifstream fconfig( configPath.c_str() );
28 
29  if ( !fconfig.is_open() ) {
30  std::cerr << "Error opening config file: " << NNConfigDir << "/" << NNParamFileName << std::endl;
31  std::cerr << "Are you sure that the file exists at this path?" << std::endl;
32  return 1;
33  }
34 
35  std::string cutsPath = PathResolverFindCalibFile(NNConfigDir+"/"+NNCutFileName);
36  std::cout << " Reading JVT NN cut file from:\n " << NNConfigDir << "/" << NNCutFileName << std::endl;
37  std::cout << " resolved in :\n " << cutsPath << std::endl;
38  std::ifstream fcuts( cutsPath.c_str() );
39  if ( !fconfig.is_open() ) {
40  std::cerr << "Error opening cuts file: " << NNConfigDir << "/" << NNCutFileName << std::endl;
41  std::cerr << "Are you sure that the file exists at this path?" << std::endl;
42  return 2;
43  }
44 
45  nlohmann::json cut_j;
46  fcuts >> cut_j;
47  std::vector<float> ptbin_edges = cut_j["ptbin_edges"].get<std::vector<float> >();
48  std::vector<float> etabin_edges = cut_j["etabin_edges"].get<std::vector<float> >();
49  std::map<std::string,float> cut_map_raw = cut_j["cuts"].get<std::map<std::string,float> >();
50  // Initialise 2D vector with cuts per bin
51  // Edge vectors have size Nbins+1
52  std::vector<std::vector<float> > cut_map(ptbin_edges.size()-1);
53  for(std::vector<float>& cuts_vs_eta : cut_map) {
54  cuts_vs_eta.resize(etabin_edges.size()-1,0.);
55  }
56  std::regex binre("\\((\\d+),\\s*(\\d+)\\)");
57  for(const std::pair<const std::string,float>& bins_to_cut_str : cut_map_raw) {
58  std::cout << bins_to_cut_str.first << " --> " << bins_to_cut_str.second << std::endl;
59  std::smatch sm;
60  if(std::regex_match(bins_to_cut_str.first,sm,binre) && sm.size()==3) {
61  // First entry is full match, followed by sub-matches
62  size_t ptbin = std::stoi(sm[1]);
63  size_t etabin = std::stoi(sm[2]);
64  cut_map[ptbin][etabin] = bins_to_cut_str.second;
65  } else {
66  std::cerr << "Regex match of pt/eta bins failed! Received string " << bins_to_cut_str.first << std::endl;
67  std::cerr << "Match size " << sm.size() << std::endl;
68  return 3;
69  }
70  }
71 
72  lwt::GraphConfig cfg = lwt::parse_json_graph( fconfig );
73  lwt::InputOrder order;
74  lwt::order_t node_order;
75  std::vector<std::string> inputs = {"Rpt","JVFCorr","ptbin","etabin"};
76  // Single input block
77  node_order.emplace_back(cfg.inputs[0].name,inputs);
78  order.scalar = std::move(node_order);
79 
80  std::cout << "Reading JVT likelihood histogram from: " << configPath << std::endl;
81  std::cout << "Network NLayers: " << cfg.layers.size() << std::endl;
82  lwt::generic::FastGraph<double> lwnn(cfg, order);
83 
84  std::cout << "Computation for test values:" << std::endl;
85  // A few jet test cases
86  for(size_t ptbin=0; ptbin<5; ++ptbin) {
87  for(size_t etabin=0; etabin<5; ++etabin) {
88  std::cout << " pt bin[" << ptbin_edges[ptbin] << "," << ptbin_edges[ptbin+1] << "]: " << ptbin << ", eta bin [" << etabin_edges[etabin] << "," << etabin_edges[etabin+1] << "]: " << etabin <<std::endl;
89  lwt::VectorX<double> inputvals_HS = lwt::build_vector({1.0,1.0,static_cast<double>(ptbin),static_cast<double>(etabin)});
90  std::vector<lwt::VectorX<double> > scalars_HS{std::move(inputvals_HS)};
91  lwt::VectorX<double> output_HS = lwnn.compute(scalars_HS);
92  std::cout << " HS jet --> " << output_HS(0) << std::endl;
93 
94  lwt::VectorX<double> inputvals_AMB = lwt::build_vector({0.2,0.5,static_cast<double>(ptbin),static_cast<double>(etabin)});
95  std::vector<lwt::VectorX<double> > scalars_AMB{std::move(inputvals_AMB)};
96  lwt::VectorX<double> output_AMB = lwnn.compute(scalars_AMB);
97  std::cout << " Ambiguous jet --> " << output_AMB(0) << std::endl;
98 
99  lwt::VectorX<double> inputvals_PU = lwt::build_vector({0.0,-1.0,static_cast<double>(ptbin),static_cast<double>(etabin)});
100  std::vector<lwt::VectorX<double> > scalars_PU{std::move(inputvals_PU)};
101  lwt::VectorX<double> output_PU = lwnn.compute(scalars_PU);
102  std::cout << " PU jet --> " << output_PU(0) << std::endl;
103 
104  std::cout << " Cut for this bin: " << cut_map[ptbin][etabin] << std::endl;
105  }
106  }
107 
108  return 0;
109 }
json
nlohmann::json json
Definition: HistogramDef.cxx:9
FlavorTagInference::SaltModelGraphConfig::parse_json_graph
GraphConfig parse_json_graph(const nlohmann::json &metadata)
Definition: SaltModelGraphConfig.cxx:40
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
PrepareReferenceFile.regex
regex
Definition: PrepareReferenceFile.py:43
mc.order
order
Configure Herwig7.
Definition: mc.Herwig7_Dijet.py:12
lwt::atlas::order_t
std::vector< std::pair< std::string, std::vector< std::string > > > order_t
Definition: InputOrder.h:25
lwtDev::build_vector
VectorXd build_vector(const std::vector< double > &bias)
Definition: Stack.cxx:760
main
int main()
Definition: JetJvtNN_testCfg.cxx:16
PathResolver.h
WriteCaloSwCorrections.cfg
cfg
Definition: WriteCaloSwCorrections.py:23
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:321