10 #include "nlohmann/json.hpp"
11 #include "lwtnn/generic/FastGraph.hh"
12 #include "lwtnn/parse_json.hh"
13 #include "lwtnn/Stack.hh"
19 std::string NNConfigDir =
"JetPileupTag/NNJvt/2022-03-22/";
20 std::string NNParamFileName =
"NNJVT.Network.graph.Offline.Nonprompt_All_MaxWeight.json";
21 std::string NNCutFileName =
"NNJVT.Cuts.FixedEffPt.Offline.Nonprompt_All_MaxW.json";
24 std::cout <<
"Reading JVT NN file from:\n " << NNConfigDir <<
"/" << NNParamFileName << std::endl;
25 std::cout <<
" resolved in :\n " << configPath << std::endl;
27 std::ifstream fconfig( configPath.c_str() );
29 if ( !fconfig.is_open() ) {
30 std::cerr <<
"Error opening config file: " << NNConfigDir <<
"/" << NNParamFileName << std::endl;
31 std::cerr <<
"Are you sure that the file exists at this path?" << std::endl;
36 std::cout <<
" Reading JVT NN cut file from:\n " << NNConfigDir <<
"/" << NNCutFileName << std::endl;
37 std::cout <<
" resolved in :\n " << cutsPath << std::endl;
38 std::ifstream fcuts( cutsPath.c_str() );
39 if ( !fconfig.is_open() ) {
40 std::cerr <<
"Error opening cuts file: " << NNConfigDir <<
"/" << NNCutFileName << std::endl;
41 std::cerr <<
"Are you sure that the file exists at this path?" << std::endl;
47 std::vector<float> ptbin_edges = cut_j[
"ptbin_edges"].get<std::vector<float> >();
48 std::vector<float> etabin_edges = cut_j[
"etabin_edges"].get<std::vector<float> >();
49 std::map<std::string,float> cut_map_raw = cut_j[
"cuts"].get<std::map<std::string,float> >();
52 std::vector<std::vector<float> > cut_map(ptbin_edges.size()-1);
53 for(std::vector<float>& cuts_vs_eta : cut_map) {
54 cuts_vs_eta.resize(etabin_edges.size()-1,0.);
57 for(
const std::pair<const std::string,float>& bins_to_cut_str : cut_map_raw) {
58 std::cout << bins_to_cut_str.first <<
" --> " << bins_to_cut_str.second << std::endl;
60 if(std::regex_match(bins_to_cut_str.first,sm,binre) && sm.size()==3) {
62 size_t ptbin = std::stoi(sm[1]);
63 size_t etabin = std::stoi(sm[2]);
64 cut_map[ptbin][etabin] = bins_to_cut_str.second;
66 std::cerr <<
"Regex match of pt/eta bins failed! Received string " << bins_to_cut_str.first << std::endl;
67 std::cerr <<
"Match size " << sm.size() << std::endl;
73 lwt::InputOrder
order;
75 std::vector<std::string>
inputs = {
"Rpt",
"JVFCorr",
"ptbin",
"etabin"};
77 node_order.emplace_back(
cfg.inputs[0].name,
inputs);
78 order.scalar = std::move(node_order);
80 std::cout <<
"Reading JVT likelihood histogram from: " << configPath << std::endl;
81 std::cout <<
"Network NLayers: " <<
cfg.layers.size() << std::endl;
82 lwt::generic::FastGraph<double> lwnn(
cfg,
order);
84 std::cout <<
"Computation for test values:" << std::endl;
86 for(
size_t ptbin=0; ptbin<5; ++ptbin) {
87 for(
size_t etabin=0; etabin<5; ++etabin) {
88 std::cout <<
" pt bin[" << ptbin_edges[ptbin] <<
"," << ptbin_edges[ptbin+1] <<
"]: " << ptbin <<
", eta bin [" << etabin_edges[etabin] <<
"," << etabin_edges[etabin+1] <<
"]: " << etabin <<std::endl;
89 lwt::VectorX<double> inputvals_HS =
lwt::build_vector({1.0,1.0,
static_cast<double>(ptbin),
static_cast<double>(etabin)});
90 std::vector<lwt::VectorX<double> > scalars_HS{std::move(inputvals_HS)};
91 lwt::VectorX<double> output_HS = lwnn.compute(scalars_HS);
92 std::cout <<
" HS jet --> " << output_HS(0) << std::endl;
94 lwt::VectorX<double> inputvals_AMB =
lwt::build_vector({0.2,0.5,
static_cast<double>(ptbin),
static_cast<double>(etabin)});
95 std::vector<lwt::VectorX<double> > scalars_AMB{std::move(inputvals_AMB)};
96 lwt::VectorX<double> output_AMB = lwnn.compute(scalars_AMB);
97 std::cout <<
" Ambiguous jet --> " << output_AMB(0) << std::endl;
99 lwt::VectorX<double> inputvals_PU =
lwt::build_vector({0.0,-1.0,
static_cast<double>(ptbin),
static_cast<double>(etabin)});
100 std::vector<lwt::VectorX<double> > scalars_PU{std::move(inputvals_PU)};
101 lwt::VectorX<double> output_PU = lwnn.compute(scalars_PU);
102 std::cout <<
" PU jet --> " << output_PU(0) << std::endl;
104 std::cout <<
" Cut for this bin: " << cut_map[ptbin][etabin] << std::endl;