23 #include "TMVA/Reader.h"
35 float largestWeight = 0;
37 for (
const auto *vtx : *vertices) {
40 const size_t nTrackLinks=trkLinks.size();
41 for (
unsigned i=0;
i<nTrackLinks;++
i) {
43 if( vtx->trackWeights()[
i] > largestWeight ){
44 vtxWithLargestWeight = vtx;
51 return vtxWithLargestWeight;
90 float PhotonVertexSelectionTool::getScore(
int nVars,
const std::vector<std::vector<float>>& input_data,
const std::shared_ptr<Ort::Session> sessionHandle, std::vector<int64_t> input_node_dims, std::vector<const char*> input_node_names, std::vector<const char*> output_node_names)
const{
94 std::vector<std::vector<float>> input_tensor_values_ = input_data;
97 size_t input_tensor_size = nVars;
98 std::vector<float> input_tensor_values(nVars);
99 input_tensor_values = input_tensor_values_[0];
102 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
104 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), input_node_dims.size());
107 assert(input_tensor.IsTensor());
110 auto output_tensors = sessionHandle->Run(Ort::RunOptions{
nullptr}, input_node_names.data(), &input_tensor, input_node_names.size(), output_node_names.data(), output_node_names.size());
113 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
117 float* floatarr = output_tensors[0].GetTensorMutableData<
float>();
119 int arrSize =
sizeof(*floatarr)/
sizeof(floatarr[0]);
128 std::vector<int64_t> input_node_dims;
129 size_t num_input_nodes = sessionHandle->GetInputCount();
130 std::vector<const char*> input_node_names(num_input_nodes);
133 for( std::size_t
i = 0;
i < num_input_nodes;
i++ ) {
135 char* input_name = sessionHandle->GetInputNameAllocated(
i, allocator).release();
137 input_node_names[
i] = input_name;
140 Ort::TypeInfo type_info = sessionHandle->GetInputTypeInfo(
i);
141 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
142 ONNXTensorElementDataType
type = tensor_info.GetElementType();
146 input_node_dims = tensor_info.GetShape();
147 ATH_MSG_DEBUG(
"Input "<<
i<<
" : num_dims= "<<input_node_dims.size());
148 for (std::size_t j = 0; j < input_node_dims.size(); j++){
149 if(input_node_dims[j]<0){input_node_dims[j] =1;}
150 ATH_MSG_DEBUG(
"Input"<<
i<<
" : dim "<<j<<
"= "<<input_node_dims[j]);
153 return std::make_tuple(input_node_dims, input_node_names);
159 std::vector<int64_t> output_node_dims;
160 size_t num_output_nodes = sessionHandle->GetOutputCount();
161 std::vector<const char*> output_node_names(num_output_nodes);
164 for( std::size_t
i = 0;
i < num_output_nodes;
i++ ) {
166 char* output_name = sessionHandle->GetOutputNameAllocated(
i, allocator).release();
168 output_node_names[
i] = output_name;
170 Ort::TypeInfo type_info = sessionHandle->GetOutputTypeInfo(
i);
171 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
172 ONNXTensorElementDataType
type = tensor_info.GetElementType();
176 output_node_dims = tensor_info.GetShape();
177 ATH_MSG_DEBUG(
"Output "<<
i<<
" : num_dims= "<<output_node_dims.size());
178 for (std::size_t j = 0; j < output_node_dims.size(); j++){
179 if(output_node_dims[j]<0){output_node_dims[j] =1;}
180 ATH_MSG_DEBUG(
"Output"<<
i<<
" : dim "<<j<<
"= "<<output_node_dims[j]);
183 return std::make_tuple(output_node_dims, output_node_names);
193 Ort::SessionOptions sessionOptions;
194 sessionOptions.SetIntraOpNumThreads( 1 );
195 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
197 Ort::AllocatorWithDefaultOptions allocator;
199 std::shared_ptr<Ort::Session> sessionHandle = std::make_shared<Ort::Session>(
env, modelFileName.c_str(), sessionOptions );
201 ATH_MSG_INFO(
"Created the ONNX Runtime session for model file = " << modelFileName);
202 return std::make_tuple(sessionHandle, allocator);
208 ATH_MSG_INFO(
"Initializing PhotonVertexSelectionTool...");
215 std::vector<std::string> var_names = {
216 "deltaZ := TMath::Min(abs(PrimaryVerticesAuxDyn.z-zCommon)/zCommonError,20)",
217 "deltaPhi := abs(deltaPhi(PrimaryVerticesAuxDyn.phi,egamma_phi))" ,
218 "logSumpt := log10(PrimaryVerticesAuxDyn.sumPt)" ,
219 "logSumpt2 := log10(PrimaryVerticesAuxDyn.sumPt2)"
221 auto mva1 =
new TMVA::Reader(var_names,
"!Silent:Color");
223 m_mva1 = std::unique_ptr<TMVA::Reader>( std::move(mva1) );
225 auto mva2 = std::make_unique<TMVA::Reader>(var_names,
"!Silent:Color");
227 m_mva2 = std::unique_ptr<TMVA::Reader>( std::move(mva2) );
255 #ifndef XAOD_STANDALONE
260 return StatusCode::SUCCESS;
265 auto fail = FailType::NoFail;
267 const EventContext& ctx = Gaudi::Hive::currentContext();
305 deltaPhi(*vertex) = (
fail != FailType::FailEgamVect) ? std::abs(vmom.DeltaPhi(vegamma)) : -999.;
306 deltaZ(*vertex) = std::abs((zCommon.first - vertex->z())/zCommon.second);
311 if(failType!=
nullptr)
313 return StatusCode::SUCCESS;
317 std::vector<std::pair<const xAOD::Vertex*, float> >
321 std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
323 FailType failType = FailType::NoFail;
324 if (
getVertexImp( egammas, vertex, ignoreConv, noDecorate, vertexMLP, vtxCase, failType ).isSuccess()) {
325 std::sort(vertexMLP.begin(), vertexMLP.end(),
sortMLP);
327 if(vtxCasePtr!=
nullptr)
328 *vtxCasePtr = vtxCase;
329 if(failTypePtr!=
nullptr)
330 *failTypePtr = failType;
338 bool ignoreConv)
const
340 std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
342 FailType failType = FailType::NoFail;
343 return getVertexImp( egammas, prime_vertex, ignoreConv,
false, vertexMLP, vtxcase, failType );
350 std::vector<std::pair<const xAOD::Vertex*, float> >& vertexMLP,
yyVtxType& vtxCase,
FailType&
fail)
const
360 return StatusCode::FAILURE;
364 if (!ignoreConv && photons) {
366 if (prime_vertex !=
nullptr) {
367 vtxCase = yyVtxType::ConvTrack;
368 fail = FailType::MatchedTrack;
369 vertexMLP.emplace_back(prime_vertex, 0.);
370 return StatusCode::SUCCESS;
374 if (
fail != FailType::NoFail){
378 vertexMLP.emplace_back(prime_vertex, 10.);
379 return StatusCode::SUCCESS;
388 bool isConverted =
false;
391 vtxCase = yyVtxType::NoSiTracks;
392 if (!ignoreConv && photons) {
393 for (
const auto *
photon: *photons) {
397 return StatusCode::FAILURE;
403 vtxCase = yyVtxType::SiConvTrack;
413 tmva_reader =
m_mva1.get();
417 tmva_reader =
m_mva2.get();
431 std::vector<float> ONNXInputVector;
432 std::vector<std::vector<float>> onnx_input_tensor_values;
433 std::vector<float> TMVAInputVector;
435 float mlp = 0.0, mlp_max = -99999.0;
436 float doSkipByZSigmaScore = -9999.0;
438 float thresGoodVtxScore;
440 else{thresGoodVtxScore = mlp_max;}
446 onnx_input_tensor_values.clear();
450 float log10_sumPt, log10_sumPt2;
452 sumPt = (sumPtA)(*vertex);
453 sumPt2 = (sumPt2A)(*vertex);
455 deltaZ = (deltaZA)(*vertex);
457 " sumPt2: " << sumPt2 <<
464 TMVAMethod =
"MLP method";
465 log10_sumPt =
static_cast<float>(log10(sumPt));
466 log10_sumPt2 =
static_cast<float>(log10(sumPt2));
474 for (
long unsigned int i = 0;
i < ONNXInputVector.size();
i++) {
479 if (ONNXInputVector[
i] != 0 && std::isinf(ONNXInputVector[
i]) !=
true && std::isnan(ONNXInputVector[
i]) !=
true){
480 ONNXInputVector[
i] =
log(std::abs(ONNXInputVector[
i]));
483 ONNXInputVector[
i] =
log(std::abs(0.00000001));
486 onnx_input_tensor_values.push_back(ONNXInputVector);
491 mlp = tmva_reader->EvaluateMVA(TMVAInputVector, TMVAMethod);
506 " log(abs(sumPt2)): " << sumPt2 <<
508 " log(abs(deltaZ)): " <<
deltaZ);
509 ATH_MSG_VERBOSE(
"ONNX output, isConverted = " << isConverted <<
", mlp=" << mlp);
515 if ((isConverted &&
deltaZ > 15) || (!isConverted &&
deltaZ > 10)) {
516 mlp = doSkipByZSigmaScore;
521 vertexMLP.emplace_back(vertex, mlp);
526 prime_vertex = vertex;
533 if (mlp_max <= thresGoodVtxScore) {
534 ATH_MSG_DEBUG(
"No good vertex candidates from pointing, returning hardest vertex.");
536 fail = FailType::NoGdCandidate;
542 return StatusCode::SUCCESS;
547 const std::pair<const xAOD::Vertex*, float> &
b)
548 {
return a.second >
b.second; }
553 if (photons ==
nullptr) {
554 ATH_MSG_WARNING(
"Passed nullptr photon container, returning nullptr vertex from getPrimaryVertexFromConv");
558 std::vector<const xAOD::Vertex*> vertices;
561 size_t NumberOfTracks = 0;
567 for (
const auto *
photon: *photons) {
568 conversionVertex =
photon->vertex();
569 if (conversionVertex ==
nullptr)
continue;
572 for (
size_t i = 0;
i < NumberOfTracks; ++
i) {
575 if (gsfTp ==
nullptr)
continue;
580 if (
tp ==
nullptr)
continue;
583 if (
primary ==
nullptr)
continue;
587 if (
std::find(vertices.begin(), vertices.end(),
primary) == vertices.end()) {
595 if (!vertices.empty()) {
596 if (vertices.size() > 1)
597 ATH_MSG_WARNING(
"Photons associated to different vertices! Returning lead photon association.");
607 TLorentzVector
v, v1;
612 failType = FailType::FailEgamVect;
615 cluster =
egamma->caloCluster();
616 if (cluster ==
nullptr) {
617 ATH_MSG_WARNING(
"No cluster associated to egamma, not adding to 4-vector.");
621 v1.SetPtEtaPhiM(
egamma->
e()/cosh(cluster->etaBE(2)),