13 #include "GaudiKernel/SystemOfUnits.h"
14 #include "Gaudi/Property.h"
15 #include "EventInfo/EventInfo.h"
23 #include <onnxruntime_cxx_api.h>
36 ATH_CHECK(m_filterParams.initialize(
false));
37 ATH_CHECK( m_eventInfoContainerName.initialize() );
38 ATH_CHECK( m_truthParticleName.initialize( (m_pileupSwitch ==
"HardScatter" or m_pileupSwitch ==
"All") and not m_truthParticleName.key().empty() ) );
39 ATH_CHECK(m_truthSelectionTool.retrieve(EnableTool {not m_truthParticleName.key().empty()} ));
40 ATH_CHECK( m_truthEventName.initialize( (m_pileupSwitch ==
"HardScatter" or m_pileupSwitch ==
"All") and not m_truthEventName.key().empty() ) );
41 ATH_CHECK( m_truthPileUpEventName.initialize( (m_pileupSwitch ==
"PileUp" or m_pileupSwitch ==
"All") and not m_truthPileUpEventName.key().empty() ) );
43 std::string this_file = __FILE__;
45 Ort::SessionOptions session_options;
47 m_session = std::make_unique<Ort::Session>(m_svc->env(), model_path.c_str(), session_options);
48 m_inputInfo = TrackOverlayDecisionAlg::GetInputNodeInfo(m_session);
49 m_outputInfo = TrackOverlayDecisionAlg::GetOutputNodeInfo(m_session);
50 return StatusCode::SUCCESS;
56 ATH_MSG_VERBOSE(
"-----------------------------------------------------------------");
57 ATH_MSG_VERBOSE(
"m_filterParams.summary()" << m_filterParams.summary());
58 ATH_MSG_VERBOSE(
"-----------------------------------------------------------------");
60 ATH_MSG_VERBOSE(
" =====================================================================");
62 return StatusCode::SUCCESS;
66 std::vector<const xAOD::TruthParticle*> tempVec {};
67 if (m_pileupSwitch ==
"All") {
68 if (m_truthParticleName.key().empty()) {
72 if (not truthParticleContainer.
isValid()) {
75 tempVec.insert(tempVec.begin(), truthParticleContainer->
begin(), truthParticleContainer->
end());
77 if (m_pileupSwitch ==
"HardScatter") {
78 if (not m_truthEventName.key().empty()) {
85 const auto&
links =
event->truthParticleLinks();
86 tempVec.reserve(
event->nTruthParticles());
87 for (
const auto& link :
links) {
89 tempVec.push_back(*link);
93 }
else if (m_pileupSwitch ==
"PileUp") {
94 if (not m_truthPileUpEventName.key().empty()) {
98 if (truthPileupEventContainer.
isValid()) {
99 const unsigned int nPileup = truthPileupEventContainer->
size();
100 tempVec.reserve(nPileup * 200);
101 for (
unsigned int i(0);
i != nPileup; ++
i) {
102 const auto *eventPileup = truthPileupEventContainer->
at(
i);
104 int ntruth = eventPileup->nTruthParticles();
105 ATH_MSG_VERBOSE(
"Adding " << ntruth <<
" truth particles from TruthPileupEvents container");
106 const auto&
links = eventPileup->truthParticleLinks();
107 for (
const auto& link :
links) {
109 tempVec.push_back(*link);
131 float eventPxSum = 0.0;
132 float eventPySum = 0.0;
134 float puEvents = 0.0;
136 std::vector<float> pxValues, pyValues, pzValues, eValues, etaValues, phiValues, ptValues;
137 float truthMultiplicity = 0.0;
138 const int truthParticles = truthParticlesVec.size();
139 for (
int itruth = 0; itruth < truthParticles; itruth++) {
143 pxValues.push_back((thisTruth->
px()*0.001-1.46988000e+03)*
px_diff);
144 pyValues.push_back((thisTruth->
py()*0.001-1.35142000e+03)*
py_diff);
145 pzValues.push_back((thisTruth->
pz()*0.001-1.50464000e+03)*
pz_diff);
146 ptValues.push_back((thisTruth->
pt()*0.001-5.00006000e-01)*
pt_diff);
148 etaValues.push_back(thisTruth->
eta());
149 phiValues.push_back(thisTruth->
phi());
150 eValues.push_back((thisTruth->
e()*0.001-5.08307000e-01)*
e_diff);
152 eventPxSum += thisTruth->
px();
153 eventPySum += thisTruth->
py();
159 if (!m_truthPileUpEventName.key().empty()) {
163 eventPt = std::sqrt(eventPxSum*eventPxSum + eventPySum*eventPySum)*0.001;
165 std::vector<float> puEventsVec(pxValues.size(), (puEvents-1.55000000e+01)*
pu_diff);
166 std::vector<float> truthMultiplicityVec(pxValues.size(), (truthMultiplicity-1.80000000e+01)*
multi_diff);
167 std::vector<float> eventPtVec(pxValues.size(), (eventPt-3.42359395e-01)*
eventPt_diff);
168 std::vector<float> predictions;
171 Eigen::VectorXf ptEigen = Eigen::VectorXf::Map(ptValues.data(), ptValues.size());
172 Eigen::VectorXf phiEigen = Eigen::VectorXf::Map(phiValues.data(), phiValues.size());
173 Eigen::VectorXf etaEigen = Eigen::VectorXf::Map(etaValues.data(), etaValues.size());
174 for (std::size_t
i = 0;
i < truthMultiplicity; ++
i) {
175 float multiplicity_0p05 = 0.0, multiplicity_0p2 = 0.0;
176 float sum_0p05 = 0.0, sum_0p2 = 0.0;
177 float pt_0p05 = 0.0, pt_0p2 = 0.0;
178 float deltaEtaI = etaEigen[
i];
179 float phiI = phiEigen[
i];
180 for (std::size_t j = 0; j < truthMultiplicity; ++j) {
181 if (
i == j)
continue;
182 float deltaEta = deltaEtaI - etaEigen[j];
183 float deltaPhi = phiI - phiEigen[j];
188 if (distances < 0.05){
190 sum_0p05 += distances;
191 pt_0p05 += ptEigen[j];
193 if (distances < 0.2){
195 sum_0p2 += distances;
196 pt_0p2 += ptEigen[j];
200 std::vector<float> featData;
201 featData.push_back(pxValues[
i]);
202 featData.push_back(pyValues[
i]);
203 featData.push_back(pzValues[
i]);
204 featData.push_back(eValues[
i]);
205 featData.push_back(ptValues[
i]);
213 featData.push_back(puEventsVec[
i]);
214 featData.push_back(truthMultiplicityVec[
i]);
215 featData.push_back(eventPtVec[
i]);
217 std::vector<int64_t> input_node_dims;
218 std::vector<char*> input_node_names;
219 input_node_dims = std::get<0>(m_inputInfo);
220 input_node_names = std::get<1>(m_inputInfo);
222 std::vector<int64_t> output_node_dims;
223 std::vector<char*> output_node_names;
224 output_node_dims = std::get<0>(m_outputInfo);
225 output_node_names = std::get<1>(m_outputInfo);
227 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
228 input_node_dims[0]=1;
229 Ort::Value input_data = Ort::Value::CreateTensor(memoryInfo, featData.data(), featData.size(), input_node_dims.data(), input_node_dims.size());
230 Ort::RunOptions run_options(
nullptr);
233 auto output_values = mysession.Run(run_options, input_node_names.data(), &input_data, input_node_names.size(), output_node_names.data(), output_node_names.size());
234 float* predictionData = output_values[0].GetTensorMutableData<
float>();
235 float prediction = predictionData[0];
237 predictions.push_back(prediction);
242 for (
float prediction : predictions) {
247 float rouletteScore =
static_cast<float>(badTracks) /
static_cast<float>(truthMultiplicity);
251 int decision = rouletteScore == 0;
259 if (m_invertfilter) {
263 ATH_MSG_ALWAYS(
"End TrackOverlayDecisionAlg, difference in filters: "<<(pass ?
"found" :
"not found")<<
"="<<pass<<
", invert="<<m_invertfilter);
264 return StatusCode::SUCCESS;