ATLAS Offline Software
MVATrackVertexAssociationTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 // Includes from this package
7 
8 // FrameWork includes
9 #include "AthLinks/ElementLink.h"
15 
16 // EDM includes
19 #include "xAODTracking/Vertex.h"
22 
23 // lwtnn includes
24 #include "lwtnn/NNLayerConfig.hh"
25 #include "lwtnn/parse_json.hh"
26 
27 // STL includes
28 #include <fstream>
29 #include <iterator>
30 #include <memory>
31 
32 #include <stdexcept>
33 
34 namespace CP {
35 
37  AsgTool(name) {}
38 
40 
41  // Init EventInfo and hardscatter vertex link deco
44  ATH_CHECK(m_hardScatterDecoKey.initialize());
45 
46  // Init network
47  StatusCode initNetworkStatus = initializeNetwork();
48  if (initNetworkStatus != StatusCode::SUCCESS) {
49  return initNetworkStatus;
50  }
51 
52  // Map our working point to a cut on the MVA output discriminant
53  if (m_wp == "Tight") {
54  m_cut = 0.85;
55  }
56  else if (m_wp == "Custom") {
57  // Nothing to do here
58  }
59  else {
60  ATH_MSG_ERROR("Invalid TVA working point \"" << m_wp << "\" - for a custom configuration, please provide \"Custom\" for the \"WorkingPoint\" property.");
61  return StatusCode::FAILURE;
62  }
63 
64  // Some extra printout for Custom
65  if (m_wp == "Custom") {
66  ATH_MSG_INFO("TVA working point \"Custom\" provided - tool properties are initialized to default values unless explicitly set by the user.");
67  }
68  else {
69  ATH_MSG_INFO("TVA working point \"" << m_wp << "\" provided - tool properties have been configured accordingly.");
70  }
71 
72  ATH_MSG_DEBUG("Cut on MVA output discriminant: " << m_cut);
73 
74  return StatusCode::SUCCESS;
75 }
76 
78  float mvaOutput = -1.;
79  return isMatch(trk, vx, mvaOutput);
80 }
81 
83  const EventContext& ctx = Gaudi::Hive::currentContext();
85  if (!evt.isValid()) {
86  throw std::runtime_error("ERROR in CP::MVATrackVertexAssociationTool::isCompatible : could not retrieve xAOD::EventInfo!");
87  }
89  const ElementLink<xAOD::VertexContainer>& vtxLink = hardScatterDeco(*evt);
90  if (!vtxLink.isValid()) {
91  throw std::runtime_error("ERROR in CP::MVATrackVertexAssociationTool::isCompatible : hardscatter vertex link is not valid!");
92  }
93  float mvaOutput = -1.;
94  return isMatch(trk, **vtxLink, mvaOutput, evt.get());
95 }
96 
97 xAOD::TrackVertexAssociationMap MVATrackVertexAssociationTool::getMatchMap(std::vector<const xAOD::TrackParticle*>& trk_list, std::vector<const xAOD::Vertex*>& vx_list) const {
98  return getMatchMapInternal(trk_list, vx_list);
99 }
100 
102  return getMatchMapInternal(trkCont, vxCont);
103 }
104 
105 const xAOD::Vertex* MVATrackVertexAssociationTool::getUniqueMatchVertex(const xAOD::TrackParticle& trk, std::vector<const xAOD::Vertex*>& vx_list) const {
106  return getUniqueMatchVertexInternal(trk, vx_list);
107 }
108 
111  const xAOD::Vertex* vx_tmp = getUniqueMatchVertexInternal(trk, vxCont);
112  if (vx_tmp) {
113  vx_link_tmp.toContainedElement(vxCont, vx_tmp);
114  }
115  return vx_link_tmp;
116 }
117 
118 xAOD::TrackVertexAssociationMap MVATrackVertexAssociationTool::getUniqueMatchMap(std::vector<const xAOD::TrackParticle*>& trk_list, std::vector<const xAOD::Vertex*>& vx_list) const {
119  return getUniqueMatchMapInternal(trk_list, vx_list);
120 }
121 
123  return getUniqueMatchMapInternal(trkCont, vxCont);
124 }
125 
126 // --------------- //
127 // Private methods //
128 // --------------- //
129 
130 bool MVATrackVertexAssociationTool::isMatch(const xAOD::TrackParticle& trk, const xAOD::Vertex& vx, float& mvaOutput, const xAOD::EventInfo* evtInfo) const {
131 
132  const EventContext& ctx = Gaudi::Hive::currentContext();
133 
134  // Fake vertex, return false
135  if (vx.vertexType() == xAOD::VxType::NoVtx) {
136  return false;
137  }
138 
139  // Retrieve our EventInfo
140  const xAOD::EventInfo* evt = nullptr;
141  if (!evtInfo) {
143  if (!evttmp.isValid()) {
144  throw std::runtime_error("ERROR in CP::MVATrackVertexAssociationTool::isMatch : could not retrieve xAOD::EventInfo!");
145  }
146  evt = evttmp.get();
147  }
148  else {
149  evt = evtInfo;
150  }
151 
152  // Evaluate our network and compare against our TVA cut (">= cut" := associated)
153  mvaOutput = this->evaluateNetwork(trk, vx, *evt);
154  return (mvaOutput >= m_cut);
155 }
156 
157 template <typename T, typename V>
159 
161 
162  for (const auto *vertex : vx_list) {
164  trktovxlist.clear();
165  trktovxlist.reserve(trk_list.size());
166  for (const auto *track : trk_list) {
167  if (isCompatible(*track, *vertex)) {
168  trktovxlist.push_back(track);
169  }
170  }
171  trktovxmap[vertex] = trktovxlist;
172  }
173 
174  return trktovxmap;
175 }
176 
177 template <typename T>
179 
180  bool match;
181  float mvaOutput;
182  float maxValue = -1.0; // MVA output ranges between 0 and 1
183  const xAOD::Vertex* bestMatchVertex = nullptr;
184 
185  for (const auto *vertex : vx_list) {
186  match = isMatch(trk, *vertex, mvaOutput);
187  if (match && (maxValue < mvaOutput)) {
188  maxValue = mvaOutput;
189  bestMatchVertex = vertex;
190  }
191  }
192 
193  // check if get the matched Vertex, for the tracks not used in vertex fit
194  if (!bestMatchVertex) {
195  ATH_MSG_DEBUG("Could not find any matched vertex for this track.");
196  }
197 
198  return bestMatchVertex;
199 }
200 
201 template <typename T, typename V>
203 
205 
206  // Initialize map
207  for (const auto *vertex : vx_list) {
209  trktovxlist.clear();
210  trktovxlist.reserve(trk_list.size());
211  trktovxmap[vertex] = trktovxlist;
212  }
213 
214  // Perform matching
215  for (const auto *track : trk_list) {
216  const xAOD::Vertex* vx_match = getUniqueMatchVertexInternal(*track, vx_list);
217  if (vx_match) {
218  // Found matched vertex
219  trktovxmap[vx_match].push_back(track);
220  }
221  }
222 
223  return trktovxmap;
224 }
225 
227 
228  // Load our input evaluator
229  if (m_inputNames.size() != m_inputTypes.size()) {
230  ATH_MSG_ERROR("Size of input variable names (" + std::to_string(m_inputNames.size()) + ") does not equal size of input variable types (" + std::to_string(m_inputTypes.size()) + ").");
231  return StatusCode::FAILURE;
232  }
233  m_inputMap.clear();
234  for (std::size_t i = 0; i < m_inputNames.size(); i++) {
236  }
238 
239  // Load our input file
240  std::string fileName;
241  if (m_usePathResolver) {
243  if (fileName.empty()) {
244  ATH_MSG_ERROR("Could not find input network file: " + m_fileName);
245  return StatusCode::FAILURE;
246  }
247  }
248  else {
250  }
251  std::ifstream netFile(fileName);
252  if (!netFile) {
253  ATH_MSG_ERROR("Could not properly open input network file: " + fileName);
254  return StatusCode::FAILURE;
255  }
256 
257  // For sequential:
258  if (m_isSequential) {
259  lwt::JSONConfig netDef = lwt::parse_json(netFile);
260  m_network = std::make_unique<lwt::LightweightNeuralNetwork>(netDef.inputs, netDef.layers, netDef.outputs);
261  }
262  // For functional:
263  else {
264  lwt::GraphConfig netDef = lwt::parse_json_graph(netFile);
265  if (netDef.inputs.size() != 1) {
266  ATH_MSG_ERROR("Network in file \"" + fileName + "\" has more than 1 input node: # of input nodes = " + std::to_string(netDef.inputs.size()));
267  return StatusCode::FAILURE;
268  }
269  m_inputNodeName = netDef.inputs[0].name;
270  m_graph = std::make_unique<lwt::LightweightGraph>(netDef);
271  }
272 
273  return StatusCode::SUCCESS;
274 }
275 
277 
278  // Evaluate our inputs
279  std::map<std::string, double> input;
280  m_inputEval.eval(trk, vx, evt, input);
281 
282  // Evaluate our network
283  std::map<std::string, double> output;
284  // For sequential:
285  if (m_isSequential) {
286  output = m_network->compute(input);
287  }
288  // For functional:
289  else {
290  std::map<std::string, std::map<std::string, double>> wrappedInput;
291  wrappedInput[m_inputNodeName] = input;
292  output = m_graph->compute(wrappedInput);
293  }
294 
295  return output[m_outputName];
296 }
297 
298 } // namespace CP
CP::MVATrackVertexAssociationTool::m_eventInfo
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
EventInfo key.
Definition: MVATrackVertexAssociationTool.h:136
CP::MVATrackVertexAssociationTool::m_inputMap
MVAInputEvaluator::InputSelectionMap m_inputMap
Input variable name/type map.
Definition: MVATrackVertexAssociationTool.h:130
PropertyWrapper.h
maxValue
#define maxValue(current, test)
Definition: CompoundLayerMaterialCreator.h:22
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
CP::MVATrackVertexAssociationTool::m_isSequential
Gaudi::Property< bool > m_isSequential
Is the network sequential or functional.
Definition: MVATrackVertexAssociationTool.h:104
CurrentContext.h
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:70
CP::MVATrackVertexAssociationTool::initializeNetwork
StatusCode initializeNetwork()
Definition: MVATrackVertexAssociationTool.cxx:226
MVATrackVertexAssociationTool.h
LArG4FSStartPointFilter.evt
evt
Definition: LArG4FSStartPointFilter.py:42
xAOD::TrackVertexAssociationMap
std::map< const xAOD::Vertex *, xAOD::TrackVertexAssociationList > TrackVertexAssociationMap
Definition: TrackVertexAssociationMap.h:19
CP::MVAInputEvaluator::eval
void eval(const xAOD::TrackParticle &trk, const xAOD::Vertex &vx, const xAOD::EventInfo &evt, std::map< std::string, double > &input) const
Definition: MVAInputEvaluator.cxx:435
SG::VarHandleKey::key
const std::string & key() const
Return the StoreGate ID for the referenced object.
Definition: AthToolSupport/AsgDataHandles/Root/VarHandleKey.cxx:141
CP::MVATrackVertexAssociationTool::m_inputNames
Gaudi::Property< std::vector< std::string > > m_inputNames
Vector of input variable names.
Definition: MVATrackVertexAssociationTool.h:92
CP
Select isolated Photons, Electrons and Muons.
Definition: Control/xAODRootAccess/xAODRootAccess/TEvent.h:48
xAOD::Vertex_v1::vertexType
VxType::VertexType vertexType() const
The type of the vertex.
xAOD::VxType::NoVtx
@ NoVtx
Dummy vertex. TrackParticle was not used in vertex fit.
Definition: TrackingPrimitives.h:570
lwtDev::parse_json
JSONConfig parse_json(std::istream &json)
Definition: parse_json.cxx:42
CP::MVATrackVertexAssociationTool::m_outputName
Gaudi::Property< std::string > m_outputName
Name of the output node to cut on.
Definition: MVATrackVertexAssociationTool.h:100
CP::MVATrackVertexAssociationTool::getMatchMapInternal
xAOD::TrackVertexAssociationMap getMatchMapInternal(const T &trk_list, const V &vx_list) const
Definition: MVATrackVertexAssociationTool.cxx:158
CP::MVATrackVertexAssociationTool::initialize
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
Definition: MVATrackVertexAssociationTool.cxx:39
CP::MVATrackVertexAssociationTool::getUniqueMatchMap
virtual xAOD::TrackVertexAssociationMap getUniqueMatchMap(std::vector< const xAOD::TrackParticle * > &trk_list, std::vector< const xAOD::Vertex * > &vx_list) const override
This functions related to the previous functions, will return a 2D vector to store the best matched t...
Definition: MVATrackVertexAssociationTool.cxx:118
CP::MVATrackVertexAssociationTool::getUniqueMatchVertex
virtual const xAOD::Vertex * getUniqueMatchVertex(const xAOD::TrackParticle &trk, std::vector< const xAOD::Vertex * > &vx_list) const override
Definition: MVATrackVertexAssociationTool.cxx:105
FortranAlgorithmOptions.fileName
fileName
Definition: FortranAlgorithmOptions.py:13
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
SG::ReadDecorHandle
Handle class for reading a decoration on an object.
Definition: StoreGate/StoreGate/ReadDecorHandle.h:94
CP::MVATrackVertexAssociationTool::m_cut
Gaudi::Property< float > m_cut
TVA cut value on the output discriminant.
Definition: MVATrackVertexAssociationTool.h:111
lumiFormat.i
int i
Definition: lumiFormat.py:92
CP::MVATrackVertexAssociationTool::m_network
std::unique_ptr< lwt::LightweightNeuralNetwork > m_network
Network as implemented using lwtnn.
Definition: MVATrackVertexAssociationTool.h:146
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
SG::ReadHandle::get
const_pointer_type get() const
Dereference the pointer, but don't cache anything.
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
CP::MVAInputEvaluator::load
void load(const MVAInputEvaluator::InputSelectionMap &selection)
Definition: MVAInputEvaluator.cxx:428
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
CP::MVATrackVertexAssociationTool::getUniqueMatchMapInternal
xAOD::TrackVertexAssociationMap getUniqueMatchMapInternal(const T &trk_list, const V &vx_list) const
Definition: MVATrackVertexAssociationTool.cxx:202
SG::VarHandleKey::initialize
StatusCode initialize(bool used=true)
If this object is used as a property, then this should be called during the initialize phase.
Definition: AthToolSupport/AsgDataHandles/Root/VarHandleKey.cxx:103
DataVector< xAOD::TrackParticle_v1 >
Vertex.h
CP::MVATrackVertexAssociationTool::m_wp
Gaudi::Property< std::string > m_wp
TVA working point.
Definition: MVATrackVertexAssociationTool.h:108
SG::ReadHandle::isValid
virtual bool isValid() override final
Can the handle be successfully dereferenced?
CP::MVATrackVertexAssociationTool::m_inputTypes
Gaudi::Property< std::vector< int > > m_inputTypes
Vector of input variable types.
Definition: MVATrackVertexAssociationTool.h:96
CP::MVATrackVertexAssociationTool::m_fileName
Gaudi::Property< std::string > m_fileName
Input lwtnn network file.
Definition: MVATrackVertexAssociationTool.h:89
merge.output
output
Definition: merge.py:17
PathResolver.h
CP::MVATrackVertexAssociationTool::m_inputNodeName
std::string m_inputNodeName
Name of the input node (for functional modes)
Definition: MVATrackVertexAssociationTool.h:143
xAOD::TrackVertexAssociationList
std::vector< const xAOD::TrackParticle * > TrackVertexAssociationList
Definition: TrackVertexAssociationMap.h:18
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
ReadHandle.h
Handle class for reading from StoreGate.
CP::MVATrackVertexAssociationTool::m_hardScatterDecoKey
SG::ReadDecorHandleKey< xAOD::EventInfo > m_hardScatterDecoKey
Hardscatter vertex link key.
Definition: MVATrackVertexAssociationTool.h:139
CP::MVATrackVertexAssociationTool::getMatchMap
virtual xAOD::TrackVertexAssociationMap getMatchMap(std::vector< const xAOD::TrackParticle * > &trk_list, std::vector< const xAOD::Vertex * > &vx_list) const override
Definition: MVATrackVertexAssociationTool.cxx:97
xAOD::EventInfo_v1
Class describing the basic event information.
Definition: EventInfo_v1.h:43
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
CP::MVATrackVertexAssociationTool::m_inputEval
MVAInputEvaluator m_inputEval
Input variable evaluator.
Definition: MVATrackVertexAssociationTool.h:133
TrackParticle.h
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
Trk::vertex
@ vertex
Definition: MeasurementType.h:21
VertexContainer.h
xAOD::Vertex_v1
Class describing a Vertex.
Definition: Vertex_v1.h:42
CP::MVATrackVertexAssociationTool::MVATrackVertexAssociationTool
MVATrackVertexAssociationTool(const std::string &name)
Definition: MVATrackVertexAssociationTool.cxx:36
CP::MVATrackVertexAssociationTool::isMatch
bool isMatch(const xAOD::TrackParticle &trk, const xAOD::Vertex &vx, float &mvaOutput, const xAOD::EventInfo *evtInfo=nullptr) const
Definition: MVATrackVertexAssociationTool.cxx:130
CP::MVATrackVertexAssociationTool::evaluateNetwork
float evaluateNetwork(const xAOD::TrackParticle &trk, const xAOD::Vertex &vx, const xAOD::EventInfo &evt) const
Definition: MVATrackVertexAssociationTool.cxx:276
CP::MVATrackVertexAssociationTool::isCompatible
virtual bool isCompatible(const xAOD::TrackParticle &trk, const xAOD::Vertex &vx) const override
This function just return the decision of whether the track is matched to the Vertex Not sure whether...
Definition: MVATrackVertexAssociationTool.cxx:77
ReadDecorHandle.h
Handle class for reading a decoration on an object.
CP::MVAInputEvaluator::Input
Input
Definition: MVAInputEvaluator.h:33
CP::MVATrackVertexAssociationTool::getUniqueMatchVertexInternal
const xAOD::Vertex * getUniqueMatchVertexInternal(const xAOD::TrackParticle &trk, const T &vx_list) const
Definition: MVATrackVertexAssociationTool.cxx:178
xAOD::track
@ track
Definition: TrackingPrimitives.h:512
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
CP::MVATrackVertexAssociationTool::getUniqueMatchVertexLink
virtual ElementLink< xAOD::VertexContainer > getUniqueMatchVertexLink(const xAOD::TrackParticle &trk, const xAOD::VertexContainer &vx_cont) const override
This functions will return the best matched vertex.
Definition: MVATrackVertexAssociationTool.cxx:109
CP::MVATrackVertexAssociationTool::m_usePathResolver
Gaudi::Property< bool > m_usePathResolver
Use the PathResolver to find our input file.
Definition: MVATrackVertexAssociationTool.h:115
TrackingPrimitives.h
CP::MVATrackVertexAssociationTool::m_graph
std::unique_ptr< lwt::LightweightGraph > m_graph
Definition: MVATrackVertexAssociationTool.h:147
TrackParticleContainer.h
match
bool match(std::string s1, std::string s2)
match the individual directories of two strings
Definition: hcg.cxx:356
CP::MVATrackVertexAssociationTool::m_hardScatterDeco
Gaudi::Property< std::string > m_hardScatterDeco
The decoration name of the ElementLink to the hardscatter vertex (found on xAOD::EventInfo)
Definition: MVATrackVertexAssociationTool.h:119