ATLAS Offline Software
JetJvtNN_testCfg.cxx
Go to the documentation of this file.
1 #include <iostream>
2 #include <fstream>
3 #include <string>
4 #include <vector>
5 #include <regex>
6 
7 #include "nlohmann/json.hpp"
8 #include "lwtnn/generic/FastGraph.hh"
9 #include "lwtnn/parse_json.hh"
10 #include "lwtnn/Stack.hh"
12 
13 int main() {
14 
15  // Use the Path Resolver to find the jvt file and retrieve the likelihood histogram
16  std::string NNConfigDir = "JetPileupTag/NNJvt/2022-03-22/";
17  std::string NNParamFileName = "NNJVT.Network.graph.Offline.Nonprompt_All_MaxWeight.json";
18  std::string NNCutFileName = "NNJVT.Cuts.FixedEffPt.Offline.Nonprompt_All_MaxW.json";
19 
20  std::string configPath = PathResolverFindCalibFile(NNConfigDir+"/"+NNParamFileName);
21  std::cout << "Reading JVT NN file from:\n " << NNConfigDir << "/" << NNParamFileName << std::endl;
22  std::cout << " resolved in :\n " << configPath << std::endl;
23 
24  std::ifstream fconfig( configPath.c_str() );
25 
26  if ( !fconfig.is_open() ) {
27  std::cerr << "Error opening config file: " << NNConfigDir << "/" << NNParamFileName << std::endl;
28  std::cerr << "Are you sure that the file exists at this path?" << std::endl;
29  return 1;
30  }
31 
32  std::string cutsPath = PathResolverFindCalibFile(NNConfigDir+"/"+NNCutFileName);
33  std::cout << " Reading JVT NN cut file from:\n " << NNConfigDir << "/" << NNCutFileName << std::endl;
34  std::cout << " resolved in :\n " << cutsPath << std::endl;
35  std::ifstream fcuts( cutsPath.c_str() );
36  if ( !fconfig.is_open() ) {
37  std::cerr << "Error opening cuts file: " << NNConfigDir << "/" << NNCutFileName << std::endl;
38  std::cerr << "Are you sure that the file exists at this path?" << std::endl;
39  return 2;
40  }
41 
42  nlohmann::json cut_j;
43  fcuts >> cut_j;
44  std::vector<float> ptbin_edges = cut_j["ptbin_edges"].get<std::vector<float> >();
45  std::vector<float> etabin_edges = cut_j["etabin_edges"].get<std::vector<float> >();
46  std::map<std::string,float> cut_map_raw = cut_j["cuts"].get<std::map<std::string,float> >();
47  // Initialise 2D vector with cuts per bin
48  // Edge vectors have size Nbins+1
49  std::vector<std::vector<float> > cut_map(ptbin_edges.size()-1);
50  for(std::vector<float>& cuts_vs_eta : cut_map) {
51  cuts_vs_eta.resize(etabin_edges.size()-1,0.);
52  }
53  std::regex binre("\\((\\d+),\\s*(\\d+)\\)");
54  for(const std::pair<const std::string,float>& bins_to_cut_str : cut_map_raw) {
55  std::cout << bins_to_cut_str.first << " --> " << bins_to_cut_str.second << std::endl;
56  std::smatch sm;
57  if(std::regex_match(bins_to_cut_str.first,sm,binre) && sm.size()==3) {
58  // First entry is full match, followed by sub-matches
59  size_t ptbin = std::stoi(sm[1]);
60  size_t etabin = std::stoi(sm[2]);
61  cut_map[ptbin][etabin] = bins_to_cut_str.second;
62  } else {
63  std::cerr << "Regex match of pt/eta bins failed! Received string " << bins_to_cut_str.first << std::endl;
64  std::cerr << "Match size " << sm.size() << std::endl;
65  return 3;
66  }
67  }
68 
69  lwt::GraphConfig cfg = lwt::parse_json_graph( fconfig );
70  lwt::InputOrder order;
71  lwt::order_t node_order;
72  std::vector<std::string> inputs = {"Rpt","JVFCorr","ptbin","etabin"};
73  // Single input block
74  node_order.emplace_back(cfg.inputs[0].name,inputs);
75  order.scalar = node_order;
76 
77  std::cout << "Reading JVT likelihood histogram from: " << configPath << std::endl;
78  std::cout << "Network NLayers: " << cfg.layers.size() << std::endl;
79  lwt::generic::FastGraph<double> lwnn(cfg, order);
80 
81  std::cout << "Computation for test values:" << std::endl;
82  // A few jet test cases
83  for(size_t ptbin=0; ptbin<5; ++ptbin) {
84  for(size_t etabin=0; etabin<5; ++etabin) {
85  std::cout << " pt bin[" << ptbin_edges[ptbin] << "," << ptbin_edges[ptbin+1] << "]: " << ptbin << ", eta bin [" << etabin_edges[etabin] << "," << etabin_edges[etabin+1] << "]: " << etabin <<std::endl;
86  lwt::VectorX<double> inputvals_HS = lwt::build_vector({1.0,1.0,static_cast<double>(ptbin),static_cast<double>(etabin)});
87  std::vector<lwt::VectorX<double> > scalars_HS{inputvals_HS};
88  lwt::VectorX<double> output_HS = lwnn.compute(scalars_HS);
89  std::cout << " HS jet --> " << output_HS(0) << std::endl;
90 
91  lwt::VectorX<double> inputvals_AMB = lwt::build_vector({0.2,0.5,static_cast<double>(ptbin),static_cast<double>(etabin)});
92  std::vector<lwt::VectorX<double> > scalars_AMB{inputvals_AMB};
93  lwt::VectorX<double> output_AMB = lwnn.compute(scalars_AMB);
94  std::cout << " Ambiguous jet --> " << output_AMB(0) << std::endl;
95 
96  lwt::VectorX<double> inputvals_PU = lwt::build_vector({0.0,-1.0,static_cast<double>(ptbin),static_cast<double>(etabin)});
97  std::vector<lwt::VectorX<double> > scalars_PU{inputvals_PU};
98  lwt::VectorX<double> output_PU = lwnn.compute(scalars_PU);
99  std::cout << " PU jet --> " << output_PU(0) << std::endl;
100 
101  std::cout << " Cut for this bin: " << cut_map[ptbin][etabin] << std::endl;
102  }
103  }
104 
105  return 0;
106 }
json
nlohmann::json json
Definition: HistogramDef.cxx:9
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
PrepareReferenceFile.regex
regex
Definition: PrepareReferenceFile.py:43
mc.order
order
Configure Herwig7.
Definition: mc.Herwig7_Dijet.py:12
lwt::atlas::order_t
std::vector< std::pair< std::string, std::vector< std::string > > > order_t
Definition: InputOrder.h:25
lwtDev::build_vector
VectorXd build_vector(const std::vector< double > &bias)
Definition: Stack.cxx:760
main
int main()
Definition: JetJvtNN_testCfg.cxx:13
PathResolver.h
WriteCaloSwCorrections.cfg
cfg
Definition: WriteCaloSwCorrections.py:23
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71