22 #include "TMVA/Reader.h"
34 float largestWeight = 0;
36 for (
const auto *vtx : *vertices) {
39 const size_t nTrackLinks=trkLinks.size();
40 for (
unsigned i=0;
i<nTrackLinks;++
i) {
42 if( vtx->trackWeights()[
i] > largestWeight ){
43 vtxWithLargestWeight = vtx;
50 return vtxWithLargestWeight;
90 int nVars,
const std::vector<std::vector<float>>& input_data,
91 const std::shared_ptr<Ort::Session>& sessionHandle,
92 std::vector<int64_t> input_node_dims,
93 std::vector<const char*> input_node_names,
94 std::vector<const char*> output_node_names)
const {
98 std::vector<std::vector<float>> input_tensor_values_ = input_data;
101 size_t input_tensor_size = nVars;
102 std::vector<float> input_tensor_values(nVars);
103 input_tensor_values = input_tensor_values_[0];
107 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
109 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
110 memory_info, input_tensor_values.data(), input_tensor_size,
111 input_node_dims.data(), input_node_dims.size());
114 assert(input_tensor.IsTensor());
117 auto output_tensors =
118 sessionHandle->Run(Ort::RunOptions{
nullptr}, input_node_names.data(),
119 &input_tensor, input_node_names.size(),
120 output_node_names.data(), output_node_names.size());
123 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
127 float* floatarr = output_tensors[0].GetTensorMutableData<
float>();
129 int arrSize =
sizeof(*floatarr) /
sizeof(floatarr[0]);
136 std::tuple<std::vector<int64_t>, std::vector<const char*>>
138 const std::shared_ptr<Ort::Session>& sessionHandle,
139 Ort::AllocatorWithDefaultOptions& allocator) {
141 std::vector<int64_t> input_node_dims;
142 size_t num_input_nodes = sessionHandle->GetInputCount();
143 std::vector<const char*> input_node_names(num_input_nodes);
146 for( std::size_t
i = 0;
i < num_input_nodes;
i++ ) {
148 char* input_name = sessionHandle->GetInputNameAllocated(
i, allocator).release();
150 input_node_names[
i] = input_name;
153 Ort::TypeInfo type_info = sessionHandle->GetInputTypeInfo(
i);
154 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
155 ONNXTensorElementDataType
type = tensor_info.GetElementType();
159 input_node_dims = tensor_info.GetShape();
160 ATH_MSG_DEBUG(
"Input "<<
i<<
" : num_dims= "<<input_node_dims.size());
161 for (std::size_t j = 0; j < input_node_dims.size(); j++){
162 if(input_node_dims[j]<0){input_node_dims[j] =1;}
163 ATH_MSG_DEBUG(
"Input"<<
i<<
" : dim "<<j<<
"= "<<input_node_dims[j]);
166 return std::make_tuple(input_node_dims, input_node_names);
170 std::tuple<std::vector<int64_t>, std::vector<const char*>>
172 const std::shared_ptr<Ort::Session>& sessionHandle,
173 Ort::AllocatorWithDefaultOptions& allocator) {
175 std::vector<int64_t> output_node_dims;
176 size_t num_output_nodes = sessionHandle->GetOutputCount();
177 std::vector<const char*> output_node_names(num_output_nodes);
180 for( std::size_t
i = 0;
i < num_output_nodes;
i++ ) {
182 char* output_name = sessionHandle->GetOutputNameAllocated(
i, allocator).release();
184 output_node_names[
i] = output_name;
186 Ort::TypeInfo type_info = sessionHandle->GetOutputTypeInfo(
i);
187 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
188 ONNXTensorElementDataType
type = tensor_info.GetElementType();
192 output_node_dims = tensor_info.GetShape();
193 ATH_MSG_DEBUG(
"Output "<<
i<<
" : num_dims= "<<output_node_dims.size());
194 for (std::size_t j = 0; j < output_node_dims.size(); j++){
195 if(output_node_dims[j]<0){output_node_dims[j] =1;}
196 ATH_MSG_DEBUG(
"Output"<<
i<<
" : dim "<<j<<
"= "<<output_node_dims[j]);
199 return std::make_tuple(output_node_dims, output_node_names);
203 std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions>
205 const std::string& modelFilePath) {
211 Ort::SessionOptions sessionOptions;
212 sessionOptions.SetIntraOpNumThreads( 1 );
213 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
215 Ort::AllocatorWithDefaultOptions allocator;
217 std::shared_ptr<Ort::Session> sessionHandle = std::make_shared<Ort::Session>(
env, modelFileName.c_str(), sessionOptions );
219 ATH_MSG_INFO(
"Created the ONNX Runtime session for model file = " << modelFileName);
220 return std::make_tuple(sessionHandle, allocator);
226 ATH_MSG_INFO(
"Initializing PhotonVertexSelectionTool...");
233 std::vector<std::string> var_names = {
234 "deltaZ := TMath::Min(abs(PrimaryVerticesAuxDyn.z-zCommon)/zCommonError,20)",
235 "deltaPhi := abs(deltaPhi(PrimaryVerticesAuxDyn.phi,egamma_phi))" ,
236 "logSumpt := log10(PrimaryVerticesAuxDyn.sumPt)" ,
237 "logSumpt2 := log10(PrimaryVerticesAuxDyn.sumPt2)"
239 auto *mva1 =
new TMVA::Reader(var_names,
"!Silent:Color");
241 m_mva1 = std::unique_ptr<TMVA::Reader>( mva1 );
243 auto mva2 = std::make_unique<TMVA::Reader>(var_names,
"!Silent:Color");
245 m_mva2 = std::unique_ptr<TMVA::Reader>( std::move(mva2) );
273 #ifndef XAOD_STANDALONE
278 return StatusCode::SUCCESS;
283 auto fail = FailType::NoFail;
285 const EventContext& ctx = Gaudi::Hive::currentContext();
323 deltaPhi(*
vertex) = (
fail != FailType::FailEgamVect) ? std::abs(vmom.DeltaPhi(vegamma)) : -999.;
329 if(failType!=
nullptr)
331 return StatusCode::SUCCESS;
335 std::vector<std::pair<const xAOD::Vertex*, float>>
342 std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
344 FailType failType = FailType::NoFail;
345 if (
getVertexImp( egammas,
vertex, ignoreConv, noDecorate, vertexMLP, vtxCase, failType ).isSuccess()) {
346 std::sort(vertexMLP.begin(), vertexMLP.end(),
sortMLP);
348 if(vtxCasePtr!=
nullptr)
349 *vtxCasePtr = vtxCase;
350 if(failTypePtr!=
nullptr)
351 *failTypePtr = failType;
359 bool ignoreConv)
const
361 std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
363 FailType failType = FailType::NoFail;
364 return getVertexImp( egammas, prime_vertex, ignoreConv,
false, vertexMLP, vtxcase, failType );
372 std::vector<std::pair<const xAOD::Vertex*, float>>& vertexMLP,
383 return StatusCode::FAILURE;
387 if (!ignoreConv && photons) {
389 if (prime_vertex !=
nullptr) {
390 vtxCase = yyVtxType::ConvTrack;
391 fail = FailType::MatchedTrack;
392 vertexMLP.emplace_back(prime_vertex, 0.);
393 return StatusCode::SUCCESS;
397 if (
fail != FailType::NoFail){
401 vertexMLP.emplace_back(prime_vertex, 10.);
402 return StatusCode::SUCCESS;
411 bool isConverted =
false;
414 vtxCase = yyVtxType::NoSiTracks;
415 if (!ignoreConv && photons) {
416 for (
const auto *
photon: *photons) {
420 return StatusCode::FAILURE;
426 vtxCase = yyVtxType::SiConvTrack;
436 tmva_reader =
m_mva1.get();
440 tmva_reader =
m_mva2.get();
454 std::vector<float> ONNXInputVector;
455 std::vector<std::vector<float>> onnx_input_tensor_values;
456 std::vector<float> TMVAInputVector;
458 float mlp = 0.0, mlp_max = -99999.0;
459 float doSkipByZSigmaScore = -9999.0;
461 float thresGoodVtxScore;
463 else{thresGoodVtxScore = mlp_max;}
469 onnx_input_tensor_values.clear();
473 float log10_sumPt, log10_sumPt2;
475 sumPt = (sumPtA)(*
vertex);
476 sumPt2 = (sumPt2A)(*
vertex);
480 " sumPt2: " << sumPt2 <<
487 TMVAMethod =
"MLP method";
488 log10_sumPt =
static_cast<float>(log10(sumPt));
489 log10_sumPt2 =
static_cast<float>(log10(sumPt2));
497 for (
long unsigned int i = 0;
i < ONNXInputVector.size();
i++) {
502 if (ONNXInputVector[
i] != 0 && std::isinf(ONNXInputVector[
i]) !=
true && std::isnan(ONNXInputVector[
i]) !=
true){
503 ONNXInputVector[
i] =
log(std::abs(ONNXInputVector[
i]));
506 ONNXInputVector[
i] =
log(std::abs(0.00000001));
509 onnx_input_tensor_values.push_back(ONNXInputVector);
514 mlp = tmva_reader->EvaluateMVA(TMVAInputVector, TMVAMethod);
529 " log(abs(sumPt2)): " << sumPt2 <<
531 " log(abs(deltaZ)): " <<
deltaZ);
532 ATH_MSG_VERBOSE(
"ONNX output, isConverted = " << isConverted <<
", mlp=" << mlp);
538 if ((isConverted &&
deltaZ > 15) || (!isConverted &&
deltaZ > 10)) {
539 mlp = doSkipByZSigmaScore;
544 vertexMLP.emplace_back(
vertex, mlp);
556 if (mlp_max <= thresGoodVtxScore) {
557 ATH_MSG_DEBUG(
"No good vertex candidates from pointing, returning hardest vertex.");
559 fail = FailType::NoGdCandidate;
565 return StatusCode::SUCCESS;
570 const std::pair<const xAOD::Vertex*, float> &
b)
571 {
return a.second >
b.second; }
576 if (photons ==
nullptr) {
577 ATH_MSG_WARNING(
"Passed nullptr photon container, returning nullptr vertex from getPrimaryVertexFromConv");
581 std::vector<const xAOD::Vertex*> vertices;
584 size_t NumberOfTracks = 0;
590 for (
const auto *
photon: *photons) {
591 conversionVertex =
photon->vertex();
592 if (conversionVertex ==
nullptr)
continue;
595 for (
size_t i = 0;
i < NumberOfTracks; ++
i) {
598 if (gsfTp ==
nullptr)
continue;
603 if (
tp ==
nullptr)
continue;
606 if (
primary ==
nullptr)
continue;
610 if (
std::find(vertices.begin(), vertices.end(),
primary) == vertices.end()) {
618 if (!vertices.empty()) {
619 if (vertices.size() > 1)
620 ATH_MSG_WARNING(
"Photons associated to different vertices! Returning lead photon association.");
630 TLorentzVector
v, v1;
635 failType = FailType::FailEgamVect;
638 cluster =
egamma->caloCluster();
639 if (cluster ==
nullptr) {
640 ATH_MSG_WARNING(
"No cluster associated to egamma, not adding to 4-vector.");