Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
L1TopoRatesCalculator_submatrix_plotter.py
Go to the documentation of this file.
1 #!/usr/bin/env athena.py
2 # Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 import ROOT
4 import numpy as np
5 import matplotlib.pyplot as plt
6 from matplotlib.colors import Normalize
7 import argparse
8 
9 def plot_correlation_matrix(file_name, matrix_name, selected_vars, output_name):
10  # Open the ROOT file
11  root_file = ROOT.TFile.Open(file_name)
12 
13  # Retrieve the correlation matrix
14  rates_matrix = root_file.Get(matrix_name)
15 
16  # Extract the list of variables
17  n = rates_matrix.GetNbinsX() # Assuming the number of bins is the same in both X and Y axes
18  variables = [rates_matrix.GetXaxis().GetBinLabel(i) for i in range(1, n+1)]
19 
20  # Filter the selected variables
21  indices = [variables.index(var) for var in selected_vars]
22 
23  # Create a numpy array to store the selected values
24  selected_matrix = np.zeros((len(indices), len(indices)))
25  errors_matrix = np.zeros((len(indices), len(indices)))
26 
27  for i, ix in enumerate(indices):
28  for j, iy in enumerate(indices):
29  selected_matrix[i, j] = rates_matrix.GetBinContent(ix+1, iy+1)
30  errors_matrix[i, j] = rates_matrix.GetBinError(ix+1, iy+1)
31  # Find the minimum and maximum values in the selected matrix
32  min_value = np.min(selected_matrix)
33  max_value = np.max(selected_matrix)
34 
35  # Create the plot
36  fig, ax = plt.subplots(figsize=(24, 22))
37  cax = ax.matshow(selected_matrix, cmap='coolwarm', norm=Normalize(vmin=min_value, vmax=max_value))
38 
39  # Add text inside the cells
40  for (i, j), val in np.ndenumerate(selected_matrix):
41  error = errors_matrix[i, j]
42  ax.text(j, i, f'{val:.2f}\n±{error:.2f}', ha='center', va='center', color='black', fontsize=25)
43 
44  # Set the axis labels
45  ax.set_xticks(np.arange(len(indices)))
46  ax.set_yticks([])
47  ax.set_xticklabels([selected_vars[i] for i in range(len(indices))], rotation=90, fontsize=25)
48  ax.set_yticklabels([])
49  plt.xticks(rotation=45)
50 
51  # Adjust labels to align at the start of each bin
52  for tick in ax.get_xticklabels():
53  tick.set_horizontalalignment('left') # Align left
54  # Add a color bar
55  colorbar = fig.colorbar(cax)
56  colorbar.ax.tick_params(labelsize=25)
57  colorbar.set_label('Hz', fontsize=25)
58 
59  # Save the image as a PNG file
60  plt.savefig(output_name, bbox_inches='tight')
61  plt.close()
62 
63 if __name__ == '__main__':
64  # Set up argument parser
65  parser = argparse.ArgumentParser(description='Plot a correlation matrix from a ROOT file.')
66  parser.add_argument('file_name', type=str, help='Path to the ROOT file.')
67  parser.add_argument('matrix_name', type=str, help='Name of the matrix in the ROOT file.')
68  parser.add_argument('selected_vars', type=str, nargs='+', help='List of selected variables.')
69  parser.add_argument('output_name', type=str, help='Name of the output PNG file.')
70 
71  # Parse the arguments
72  args = parser.parse_args()
73 
74  # Call the function with the parsed arguments
75  plot_correlation_matrix(args.file_name, args.matrix_name, args.selected_vars, args.output_name)
76 
77 # Example usage
78 # python L1TopoRatesCalculator_submatrix_plotter.py RatesHistograms.root rates_matrix L1_eTAU12 L1_jTAU20 L1_gXEJWOJ70 L1_JPSI-1M5-eEM15 L1_jJ40p30ETA49 correlation_matrix.png
79 #file_name = 'RatesHistograms.root' # Path to the ROOT file
80 #matrix_name = 'rates_matrix' # Name of the matrix in the ROOT file
81 #selected_vars = ["L1_eTAU12","L1_jTAU20","L1_gXEJWOJ70","L1_JPSI-1M5-eEM15","L1_jJ40p30ETA49"]
82 #output_name = 'correlation_matrix.png' # Name of the output file
83 
84 #plot_correlation_matrix(file_name, matrix_name, selected_vars, output_name)
85 
L1TopoRatesCalculator_submatrix_plotter.plot_correlation_matrix
def plot_correlation_matrix(file_name, matrix_name, selected_vars, output_name)
Definition: L1TopoRatesCalculator_submatrix_plotter.py:9
plotBeamSpotVxVal.range
range
Definition: plotBeamSpotVxVal.py:195