ATLAS Offline Software
histsampling.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
2 
3 """
4 Tools for histogram sampling, in particular inverse transform sampling which is
5 missing from ROOT's TH2 classes.
6 """
7 
8 __author__ = "Andy Buckley <andy.buckley@cern.ch>"
9 
10 import random, copy, ROOT
11 
12 
13 def load_hist(*args):
14  """
15  Load a histogram from a filename/TFile and histo name. If a single arg is
16  provided, it has to be a histo object and will be cloned before return.
17  """
18  h = None
19  if len(args) == 1 and issubclass(type(args[0]), ROOT.TH1):
20  h = args[0].Clone()
21  elif len(args) == 2:
22  if isinstance(args[0], str) and isinstance(args[1], str) :
23  f = ROOT.TFile.Open(args[0])
24  h = copy.deepcopy(f.Get(args[1]).Clone())
25  #f.Close()
26  elif type(args[0]) is ROOT.TFile and type(args[1]) is str:
27  h = args[0].Get(args[1]).Clone()
28  if h is None:
29  raise Exception("Error in histogram loading from " + args)
30  return h
31 
32 
34  """
35  Get the following from a histogram h, since the ROOT API sucks:
36  * list of global bin IDs (not even contiguous for 2D, gee thanks ROOT)
37  * dict mapping global bin IDs to a tuple of axis bin IDs
38  * list of nbins+1 cumulative bin values, in the same order as globalbins
39  """
40  globalbin_to_axisbin = {} # for reverse axis bin lookup to get edges
41  globalbins = [] # because they aren't easily predicted, nor contiguous
42  cheights = [0] # cumulative "histogram" from which to uniformly sample
43  if issubclass(type(h), ROOT.TH1):
44  for ix in range(1, h.GetNbinsX()+1):
45  iglobal = h.GetBin(ix)
46  globalbins.append(iglobal)
47  globalbin_to_axisbin[iglobal] = (ix,)
48  cheights.append(cheights[-1] + h.GetBinContent(iglobal))
49  elif issubclass(type(h), ROOT.TH2):
50  for ix in range(1, h.GetNbinsX()+1):
51  for iy in range(1, h.GetNbinsY()+1):
52  iglobal = h.GetBin(ix, iy)
53  globalbins.append(iglobal)
54  globalbin_to_axisbin[iglobal] = (ix, iy)
55  cheights.append(cheights[-1] + h.GetBinContent(iglobal))
56  return globalbins, globalbin_to_axisbin, cheights
57 
58 
59 def get_random_bin(globalbins, cheights):
60  """
61  Choose a random bin from the cumulative distribution list of nbins+1 entries.
62 
63  TODO: Search more efficiently (lin and log guesses, then lin search or
64  binary split depending on vector size).
65  """
66  assert len(cheights) == len(globalbins)+1
67  randomheight = random.uniform(0, cheights[-1])
68  for i, iglobal in enumerate(globalbins):
69  if randomheight >= cheights[i] and randomheight < cheights[i+1]:
70  return iglobal
71  raise Exception("Sample fell outside range of cumulative distribution?!?!")
72 
73 
74 def get_random_x(h, globalbins, cheights, globalbin_to_axisbin):
75  """
76  Choose a random bin via get_random_bin, then pick a uniform random x
77  point in that bin (without any attempt at estimating the in-bin distribution).
78  """
79  irand = get_random_bin(globalbins, cheights)
80  axisids = globalbin_to_axisbin.get(irand)
81  assert axisids is not None
82  xrand = random.uniform(h.GetXaxis().GetBinLowEdge(axisids[0]), h.GetXaxis().GetBinUpEdge(axisids[0]))
83  return xrand
84 
85 
86 def get_random_xy(h2, globalbins, cheights, globalbin_to_axisbin):
87  """
88  Choose a random bin via get_random_bin, then pick a uniform random x,y
89  point in that bin (without any attempt at estimating the in-bin distribution).
90  """
91  irand = get_random_bin(globalbins, cheights)
92  axisids = globalbin_to_axisbin.get(irand)
93  assert axisids is not None
94  xrand = random.uniform(h2.GetXaxis().GetBinLowEdge(axisids[0]), h2.GetXaxis().GetBinUpEdge(axisids[0]))
95  yrand = random.uniform(h2.GetYaxis().GetBinLowEdge(axisids[1]), h2.GetYaxis().GetBinUpEdge(axisids[1]))
96  return xrand, yrand
97 
98 
99 class TH1(object):
100  "Minimal wrapper for ROOT TH1, for sampling consistency and easy loading"
101 
102  def __init__(self, *args):
103  self.th1 = load_hist(*args)
104  self.globalbins, self.globalbin_to_axisbin, self.cheights = None, None, None
105 
106  def GetRandom(self):
107  "A GetRandom that works for TH1s and uses Python random numbers"
108  if self.globalbins is None or self.globalbin_to_axisbin is None or self.cheights is None:
109  self.globalbins, self.globalbin_to_axisbin, self.cheights = get_sampling_vars(self.th1)
110  return get_random_x(self.th1, self.globalbins, self.cheights, self.globalbin_to_axisbin)
111 
112  def __getattr__(self, attr):
113  "Forward all attributes to the contained TH1"
114  return getattr(self.th1, attr)
115 
116 
117 class TH2(object):
118  "Minimal wrapper for ROOT TH2, for easy loading and to allow 2D sampling"
119 
120  def __init__(self, *args):
121  self.th2 = load_hist(*args)
122  self.globalbins, self.globalbin_to_axisbin, self.cheights = None, None, None
123 
124  def GetRandom(self):
125  "A GetRandom that works for TH2s"
126  if self.globalbins is None or self.globalbin_to_axisbin is None or self.cheights is None:
127  self.globalbins, self.globalbin_to_axisbin, self.cheights = get_sampling_vars(self.th2)
128  return get_random_xy(self.th2, self.globalbins, self.cheights, self.globalbin_to_axisbin)
129 
130  def __getattr__(self, attr):
131  "Forward other attributes to the contained TH2"
132  return getattr(self.th2, attr)
python.histsampling.TH1.cheights
cheights
Definition: histsampling.py:104
python.histsampling.TH2.__init__
def __init__(self, *args)
Definition: histsampling.py:120
python.histsampling.TH1.GetRandom
def GetRandom(self)
Definition: histsampling.py:106
python.histsampling.TH1.__init__
def __init__(self, *args)
Definition: histsampling.py:102
Get
T * Get(TFile &f, const std::string &n, const std::string &dir="", const chainmap_t *chainmap=0, std::vector< std::string > *saved=0)
get a histogram given a path, and an optional initial directory if histogram is not found,...
Definition: comparitor.cxx:178
python.histsampling.TH1.th1
th1
Definition: histsampling.py:103
python.histsampling.TH1.__getattr__
def __getattr__(self, attr)
Definition: histsampling.py:112
plotBeamSpotVxVal.range
range
Definition: plotBeamSpotVxVal.py:195
python.histsampling.TH2.GetRandom
def GetRandom(self)
Definition: histsampling.py:124
python.histsampling.TH2.cheights
cheights
Definition: histsampling.py:122
python.histsampling.get_random_x
def get_random_x(h, globalbins, cheights, globalbin_to_axisbin)
Definition: histsampling.py:74
python.histsampling.TH2.__getattr__
def __getattr__(self, attr)
Definition: histsampling.py:130
python.histsampling.get_random_bin
def get_random_bin(globalbins, cheights)
Definition: histsampling.py:59
python.histsampling.get_random_xy
def get_random_xy(h2, globalbins, cheights, globalbin_to_axisbin)
Definition: histsampling.py:86
python.histsampling.TH2
Definition: histsampling.py:117
python.histsampling.TH2.th2
th2
Definition: histsampling.py:121
python.histsampling.load_hist
def load_hist(*args)
Definition: histsampling.py:13
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
python.histsampling.get_sampling_vars
def get_sampling_vars(h)
Definition: histsampling.py:33
pickleTool.object
object
Definition: pickleTool.py:30
python.histsampling.TH1
Definition: histsampling.py:99