ATLAS Offline Software
Loading...
Searching...
No Matches
JetJvtNN_testCfg.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4#include <iostream>
5#include <fstream>
6#include <string>
7#include <vector>
8#include <regex>
9
10#include "nlohmann/json.hpp"
11#include "lwtnn/generic/FastGraph.hh"
12#include "lwtnn/parse_json.hh"
13#include "lwtnn/Stack.hh"
15
16int main() {
17
18 // Use the Path Resolver to find the jvt file and retrieve the likelihood histogram
19 std::string NNConfigDir = "JetPileupTag/NNJvt/2022-03-22/";
20 std::string NNParamFileName = "NNJVT.Network.graph.Offline.Nonprompt_All_MaxWeight.json";
21 std::string NNCutFileName = "NNJVT.Cuts.FixedEffPt.Offline.Nonprompt_All_MaxW.json";
22
23 std::string configPath = PathResolverFindCalibFile(NNConfigDir+"/"+NNParamFileName);
24 std::cout << "Reading JVT NN file from:\n " << NNConfigDir << "/" << NNParamFileName << std::endl;
25 std::cout << " resolved in :\n " << configPath << std::endl;
26
27 std::ifstream fconfig( configPath.c_str() );
28
29 if ( !fconfig.is_open() ) {
30 std::cerr << "Error opening config file: " << NNConfigDir << "/" << NNParamFileName << std::endl;
31 std::cerr << "Are you sure that the file exists at this path?" << std::endl;
32 return 1;
33 }
34
35 std::string cutsPath = PathResolverFindCalibFile(NNConfigDir+"/"+NNCutFileName);
36 std::cout << " Reading JVT NN cut file from:\n " << NNConfigDir << "/" << NNCutFileName << std::endl;
37 std::cout << " resolved in :\n " << cutsPath << std::endl;
38 std::ifstream fcuts( cutsPath.c_str() );
39 if ( !fconfig.is_open() ) {
40 std::cerr << "Error opening cuts file: " << NNConfigDir << "/" << NNCutFileName << std::endl;
41 std::cerr << "Are you sure that the file exists at this path?" << std::endl;
42 return 2;
43 }
44
45 nlohmann::json cut_j;
46 fcuts >> cut_j;
47 std::vector<float> ptbin_edges = cut_j["ptbin_edges"].get<std::vector<float> >();
48 std::vector<float> etabin_edges = cut_j["etabin_edges"].get<std::vector<float> >();
49 std::map<std::string,float> cut_map_raw = cut_j["cuts"].get<std::map<std::string,float> >();
50 // Initialise 2D vector with cuts per bin
51 // Edge vectors have size Nbins+1
52 std::vector<std::vector<float> > cut_map(ptbin_edges.size()-1);
53 for(std::vector<float>& cuts_vs_eta : cut_map) {
54 cuts_vs_eta.resize(etabin_edges.size()-1,0.);
55 }
56 std::regex binre("\\((\\d+),\\s*(\\d+)\\)");
57 for(const std::pair<const std::string,float>& bins_to_cut_str : cut_map_raw) {
58 std::cout << bins_to_cut_str.first << " --> " << bins_to_cut_str.second << std::endl;
59 std::smatch sm;
60 if(std::regex_match(bins_to_cut_str.first,sm,binre) && sm.size()==3) {
61 // First entry is full match, followed by sub-matches
62 size_t ptbin = std::stoi(sm[1]);
63 size_t etabin = std::stoi(sm[2]);
64 cut_map[ptbin][etabin] = bins_to_cut_str.second;
65 } else {
66 std::cerr << "Regex match of pt/eta bins failed! Received string " << bins_to_cut_str.first << std::endl;
67 std::cerr << "Match size " << sm.size() << std::endl;
68 return 3;
69 }
70 }
71
72 lwt::GraphConfig cfg = lwt::parse_json_graph( fconfig );
73 lwt::InputOrder order;
74 lwt::order_t node_order;
75 std::vector<std::string> inputs = {"Rpt","JVFCorr","ptbin","etabin"};
76 // Single input block
77 node_order.emplace_back(cfg.inputs[0].name,inputs);
78 order.scalar = std::move(node_order);
79
80 std::cout << "Reading JVT likelihood histogram from: " << configPath << std::endl;
81 std::cout << "Network NLayers: " << cfg.layers.size() << std::endl;
82 lwt::generic::FastGraph<double> lwnn(cfg, order);
83
84 std::cout << "Computation for test values:" << std::endl;
85 // A few jet test cases
86 for(size_t ptbin=0; ptbin<5; ++ptbin) {
87 for(size_t etabin=0; etabin<5; ++etabin) {
88 std::cout << " pt bin[" << ptbin_edges[ptbin] << "," << ptbin_edges[ptbin+1] << "]: " << ptbin << ", eta bin [" << etabin_edges[etabin] << "," << etabin_edges[etabin+1] << "]: " << etabin <<std::endl;
89 lwt::VectorX<double> inputvals_HS = lwt::build_vector({1.0,1.0,static_cast<double>(ptbin),static_cast<double>(etabin)});
90 std::vector<lwt::VectorX<double> > scalars_HS{std::move(inputvals_HS)};
91 lwt::VectorX<double> output_HS = lwnn.compute(scalars_HS);
92 std::cout << " HS jet --> " << output_HS(0) << std::endl;
93
94 lwt::VectorX<double> inputvals_AMB = lwt::build_vector({0.2,0.5,static_cast<double>(ptbin),static_cast<double>(etabin)});
95 std::vector<lwt::VectorX<double> > scalars_AMB{std::move(inputvals_AMB)};
96 lwt::VectorX<double> output_AMB = lwnn.compute(scalars_AMB);
97 std::cout << " Ambiguous jet --> " << output_AMB(0) << std::endl;
98
99 lwt::VectorX<double> inputvals_PU = lwt::build_vector({0.0,-1.0,static_cast<double>(ptbin),static_cast<double>(etabin)});
100 std::vector<lwt::VectorX<double> > scalars_PU{std::move(inputvals_PU)};
101 lwt::VectorX<double> output_PU = lwnn.compute(scalars_PU);
102 std::cout << " PU jet --> " << output_PU(0) << std::endl;
103
104 std::cout << " Cut for this bin: " << cut_map[ptbin][etabin] << std::endl;
105 }
106 }
107
108 return 0;
109}
int main()
std::string PathResolverFindCalibFile(const std::string &logical_file_name)