ATLAS Offline Software
Loading...
Searching...
No Matches
MVATrackVertexAssociationTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
5// Includes from this package
7
8// FrameWork includes
9#include "AthLinks/ElementLink.h"
15
16// EDM includes
19#include "xAODTracking/Vertex.h"
22
23// lwtnn includes
24#include "lwtnn/NNLayerConfig.hh"
25#include "lwtnn/parse_json.hh"
26
27// STL includes
28#include <fstream>
29#include <iterator>
30#include <memory>
31
32#include <stdexcept>
33
34namespace CP {
35
38
40
41 // Init EventInfo and hardscatter vertex link deco
42 ATH_CHECK(m_eventInfo.initialize());
44 ATH_CHECK(m_hardScatterDecoKey.initialize());
45
46 // Init network
47 StatusCode initNetworkStatus = initializeNetwork();
48 if (initNetworkStatus != StatusCode::SUCCESS) {
49 return initNetworkStatus;
50 }
51
52 // Map our working point to a cut on the MVA output discriminant
53 if (m_wp == "Tight") {
54 m_cut = 0.85;
55 }
56 else if (m_wp == "Custom") {
57 // Nothing to do here
58 }
59 else {
60 ATH_MSG_ERROR("Invalid TVA working point \"" << m_wp << "\" - for a custom configuration, please provide \"Custom\" for the \"WorkingPoint\" property.");
61 return StatusCode::FAILURE;
62 }
63
64 // Some extra printout for Custom
65 if (m_wp == "Custom") {
66 ATH_MSG_INFO("TVA working point \"Custom\" provided - tool properties are initialized to default values unless explicitly set by the user.");
67 }
68 else {
69 ATH_MSG_INFO("TVA working point \"" << m_wp << "\" provided - tool properties have been configured accordingly.");
70 }
71
72 ATH_MSG_DEBUG("Cut on MVA output discriminant: " << m_cut);
73
74 return StatusCode::SUCCESS;
75}
76
78 float mvaOutput = -1.;
79 return isMatch(trk, vx, mvaOutput);
80}
81
83 const EventContext& ctx = Gaudi::Hive::currentContext();
85 if (!evt.isValid()) {
86 throw std::runtime_error("ERROR in CP::MVATrackVertexAssociationTool::isCompatible : could not retrieve xAOD::EventInfo!");
87 }
89 const ElementLink<xAOD::VertexContainer>& vtxLink = hardScatterDeco(*evt);
90 if (!vtxLink.isValid()) {
91 throw std::runtime_error("ERROR in CP::MVATrackVertexAssociationTool::isCompatible : hardscatter vertex link is not valid!");
92 }
93 float mvaOutput = -1.;
94 return isMatch(trk, **vtxLink, mvaOutput, evt.get());
95}
96
97xAOD::TrackVertexAssociationMap MVATrackVertexAssociationTool::getMatchMap(std::vector<const xAOD::TrackParticle*>& trk_list, std::vector<const xAOD::Vertex*>& vx_list) const {
98 return getMatchMapInternal(trk_list, vx_list);
99}
100
104
105const xAOD::Vertex* MVATrackVertexAssociationTool::getUniqueMatchVertex(const xAOD::TrackParticle& trk, std::vector<const xAOD::Vertex*>& vx_list) const {
106 return getUniqueMatchVertexInternal(trk, vx_list);
107}
108
111 const xAOD::Vertex* vx_tmp = getUniqueMatchVertexInternal(trk, vxCont);
112 if (vx_tmp) {
113 vx_link_tmp.toContainedElement(vxCont, vx_tmp);
114 }
115 return vx_link_tmp;
116}
117
118xAOD::TrackVertexAssociationMap MVATrackVertexAssociationTool::getUniqueMatchMap(std::vector<const xAOD::TrackParticle*>& trk_list, std::vector<const xAOD::Vertex*>& vx_list) const {
119 return getUniqueMatchMapInternal(trk_list, vx_list);
120}
121
125
126// --------------- //
127// Private methods //
128// --------------- //
129
130bool MVATrackVertexAssociationTool::isMatch(const xAOD::TrackParticle& trk, const xAOD::Vertex& vx, float& mvaOutput, const xAOD::EventInfo* evtInfo) const {
131
132 const EventContext& ctx = Gaudi::Hive::currentContext();
133
134 // Fake vertex, return false
135 if (vx.vertexType() == xAOD::VxType::NoVtx) {
136 return false;
137 }
138
139 // Retrieve our EventInfo
140 const xAOD::EventInfo* evt = nullptr;
141 if (!evtInfo) {
143 if (!evttmp.isValid()) {
144 throw std::runtime_error("ERROR in CP::MVATrackVertexAssociationTool::isMatch : could not retrieve xAOD::EventInfo!");
145 }
146 evt = evttmp.get();
147 }
148 else {
149 evt = evtInfo;
150 }
151
152 // Evaluate our network and compare against our TVA cut (">= cut" := associated)
153 mvaOutput = this->evaluateNetwork(trk, vx, *evt);
154 return (mvaOutput >= m_cut);
155}
156
157template <typename T, typename V>
159
161
162 for (const auto *vertex : vx_list) {
164 trktovxlist.clear();
165 trktovxlist.reserve(trk_list.size());
166 for (const auto *track : trk_list) {
167 if (isCompatible(*track, *vertex)) {
168 trktovxlist.push_back(track);
169 }
170 }
171 trktovxmap[vertex] = trktovxlist;
172 }
173
174 return trktovxmap;
175}
176
177template <typename T>
179
180 bool match;
181 float mvaOutput;
182 float maxValue = -1.0; // MVA output ranges between 0 and 1
183 const xAOD::Vertex* bestMatchVertex = nullptr;
184
185 for (const auto *vertex : vx_list) {
186 match = isMatch(trk, *vertex, mvaOutput);
187 if (match && (maxValue < mvaOutput)) {
188 maxValue = mvaOutput;
189 bestMatchVertex = vertex;
190 }
191 }
192
193 // check if get the matched Vertex, for the tracks not used in vertex fit
194 if (!bestMatchVertex) {
195 ATH_MSG_DEBUG("Could not find any matched vertex for this track.");
196 }
197
198 return bestMatchVertex;
199}
200
201template <typename T, typename V>
203
205
206 // Initialize map
207 for (const auto *vertex : vx_list) {
209 trktovxlist.clear();
210 trktovxlist.reserve(trk_list.size());
211 trktovxmap[vertex] = trktovxlist;
212 }
213
214 // Perform matching
215 for (const auto *track : trk_list) {
216 const xAOD::Vertex* vx_match = getUniqueMatchVertexInternal(*track, vx_list);
217 if (vx_match) {
218 // Found matched vertex
219 trktovxmap[vx_match].push_back(track);
220 }
221 }
222
223 return trktovxmap;
224}
225
227
228 // Load our input evaluator
229 if (m_inputNames.size() != m_inputTypes.size()) {
230 ATH_MSG_ERROR("Size of input variable names (" + std::to_string(m_inputNames.size()) + ") does not equal size of input variable types (" + std::to_string(m_inputTypes.size()) + ").");
231 return StatusCode::FAILURE;
232 }
233 m_inputMap.clear();
234 for (std::size_t i = 0; i < m_inputNames.size(); i++) {
236 }
238
239 // Load our input file
240 std::string fileName;
241 if (m_usePathResolver) {
243 if (fileName.empty()) {
244 ATH_MSG_ERROR("Could not find input network file: " + m_fileName);
245 return StatusCode::FAILURE;
246 }
247 }
248 else {
249 fileName = m_fileName;
250 }
251 std::ifstream netFile(fileName);
252 if (!netFile) {
253 ATH_MSG_ERROR("Could not properly open input network file: " + fileName);
254 return StatusCode::FAILURE;
255 }
256
257 // For sequential:
258 if (m_isSequential) {
259 lwt::JSONConfig netDef = lwt::parse_json(netFile);
260 m_network = std::make_unique<lwt::LightweightNeuralNetwork>(netDef.inputs, netDef.layers, netDef.outputs);
261 }
262 // For functional:
263 else {
264 lwt::GraphConfig netDef = lwt::parse_json_graph(netFile);
265 if (netDef.inputs.size() != 1) {
266 ATH_MSG_ERROR("Network in file \"" + fileName + "\" has more than 1 input node: # of input nodes = " + std::to_string(netDef.inputs.size()));
267 return StatusCode::FAILURE;
268 }
269 m_inputNodeName = netDef.inputs[0].name;
270 m_graph = std::make_unique<lwt::LightweightGraph>(netDef);
271 }
272
273 return StatusCode::SUCCESS;
274}
275
277
278 // Evaluate our inputs
279 std::map<std::string, double> input;
280 m_inputEval.eval(trk, vx, evt, input);
281
282 // Evaluate our network
283 std::map<std::string, double> output;
284 // For sequential:
285 if (m_isSequential) {
286 output = m_network->compute(input);
287 }
288 // For functional:
289 else {
290 std::map<std::string, std::map<std::string, double>> wrappedInput;
291 wrappedInput[m_inputNodeName] = input;
292 output = m_graph->compute(wrappedInput);
293 }
294
295 return output[m_outputName];
296}
297
298} // namespace CP
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
Handle class for reading a decoration on an object.
Handle class for reading from StoreGate.
#define maxValue(current, test)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
virtual bool isCompatible(const xAOD::TrackParticle &trk, const xAOD::Vertex &vx) const override
This function just return the decision of whether the track is matched to the Vertex Not sure whether...
const xAOD::Vertex * getUniqueMatchVertexInternal(const xAOD::TrackParticle &trk, const T &vx_list) const
Gaudi::Property< std::string > m_wp
TVA working point.
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
EventInfo key.
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
std::unique_ptr< lwt::LightweightGraph > m_graph
Gaudi::Property< std::vector< int > > m_inputTypes
Vector of input variable types.
Gaudi::Property< bool > m_usePathResolver
Use the PathResolver to find our input file.
float evaluateNetwork(const xAOD::TrackParticle &trk, const xAOD::Vertex &vx, const xAOD::EventInfo &evt) const
Gaudi::Property< std::string > m_outputName
Name of the output node to cut on.
std::string m_inputNodeName
Name of the input node (for functional modes)
virtual xAOD::TrackVertexAssociationMap getMatchMap(std::vector< const xAOD::TrackParticle * > &trk_list, std::vector< const xAOD::Vertex * > &vx_list) const override
virtual const xAOD::Vertex * getUniqueMatchVertex(const xAOD::TrackParticle &trk, std::vector< const xAOD::Vertex * > &vx_list) const override
Gaudi::Property< std::string > m_hardScatterDeco
The decoration name of the ElementLink to the hardscatter vertex (found on xAOD::EventInfo)
xAOD::TrackVertexAssociationMap getMatchMapInternal(const T &trk_list, const V &vx_list) const
Gaudi::Property< std::vector< std::string > > m_inputNames
Vector of input variable names.
Gaudi::Property< std::string > m_fileName
Input lwtnn network file.
virtual xAOD::TrackVertexAssociationMap getUniqueMatchMap(std::vector< const xAOD::TrackParticle * > &trk_list, std::vector< const xAOD::Vertex * > &vx_list) const override
This functions related to the previous functions, will return a 2D vector to store the best matched t...
Gaudi::Property< bool > m_isSequential
Is the network sequential or functional.
MVAInputEvaluator m_inputEval
Input variable evaluator.
std::unique_ptr< lwt::LightweightNeuralNetwork > m_network
Network as implemented using lwtnn.
virtual ElementLink< xAOD::VertexContainer > getUniqueMatchVertexLink(const xAOD::TrackParticle &trk, const xAOD::VertexContainer &vx_cont) const override
This functions will return the best matched vertex.
Gaudi::Property< float > m_cut
TVA cut value on the output discriminant.
MVAInputEvaluator::InputSelectionMap m_inputMap
Input variable name/type map.
SG::ReadDecorHandleKey< xAOD::EventInfo > m_hardScatterDecoKey
Hardscatter vertex link key.
xAOD::TrackVertexAssociationMap getUniqueMatchMapInternal(const T &trk_list, const V &vx_list) const
bool isMatch(const xAOD::TrackParticle &trk, const xAOD::Vertex &vx, float &mvaOutput, const xAOD::EventInfo *evtInfo=nullptr) const
Handle class for reading a decoration on an object.
virtual bool isValid() override final
Can the handle be successfully dereferenced?
const_pointer_type get() const
Dereference the pointer, but don't cache anything.
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
VxType::VertexType vertexType() const
The type of the vertex.
bool match(std::string s1, std::string s2)
match the individual directories of two strings
Definition hcg.cxx:357
Select isolated Photons, Electrons and Muons.
@ NoVtx
Dummy vertex. TrackParticle was not used in vertex fit.
std::vector< const xAOD::TrackParticle * > TrackVertexAssociationList
EventInfo_v1 EventInfo
Definition of the latest event info version.
std::map< const xAOD::Vertex *, xAOD::TrackVertexAssociationList > TrackVertexAssociationMap
TrackParticle_v1 TrackParticle
Reference the current persistent version:
VertexContainer_v1 VertexContainer
Definition of the current "Vertex container version".
Vertex_v1 Vertex
Define the latest version of the vertex class.
TrackParticleContainer_v1 TrackParticleContainer
Definition of the current "TrackParticle container version".