ATLAS Offline Software
Loading...
Searching...
No Matches
NodeFeatureFactory.cxx
Go to the documentation of this file.
1/*
2Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
7
8#include <GaudiKernel/SystemOfUnits.h>
9#include <set>
10
11namespace MuonML {
12
16
17 bool operator<(const std::string& a, const Feature_t & b) {
18 return a < b->name();
19 }
20 bool operator<( const Feature_t & a, const std::string& b) {
21 return a->name() < b;
22 }
23 bool operator<(const Feature_t& a, const Feature_t & b) {
24 return a->name() < b->name();
25 }
26
27 bool operator<(const std::string& a, const Connector_t & b) {
28 return a < b->name();
29 }
30 bool operator<( const Connector_t & a, const std::string& b) {
31 return a->name() < b;
32 }
33 bool operator<(const Connector_t& a, const Connector_t & b) {
34 return a->name() < b->name();
35 }
36
37 namespace Factory {
38 Feature_t makeFeature(const std::string& featName, MsgStream& log) {
39 using CovIdx = MuonR4::SpacePoint::CovIdx;
40 static const std::set<Feature_t, std::less<>> featurePool{
41 std::make_unique<NodeFeature>("localX",
42 [](const Bucket_t& bucket, size_t index) {
43 return bucket[index]->localPosition().x();
44 }),
45 std::make_unique<NodeFeature>("localY",
46 [](const Bucket_t& bucket, size_t index) {
47 return bucket[index]->localPosition().y();
48 bucket[index]->localPosition().y();
49 }),
50 std::make_unique<NodeFeature>("localZ",
51 [](const Bucket_t& bucket, size_t index) {
52 return bucket[index]->localPosition().z();
53 }),
54 std::make_unique<NodeFeature>("stationIndex",
55 [](const Bucket_t& bucket, size_t index) {
56 return bucket[index]->msSector()->idHelperSvc()->stationName(bucket[index]->identify());
57 }),
58 std::make_unique<NodeFeature>("stationPhi",
59 [](const Bucket_t& bucket, size_t index) {
60 return bucket[index]->msSector()->idHelperSvc()->stationPhi(bucket[index]->identify());
61 }),
62 std::make_unique<NodeFeature>("stationEta",
63 [](const Bucket_t& bucket, size_t index) {
64 return bucket[index]->msSector()->idHelperSvc()->stationEta(bucket[index]->identify());
65 }),
66 std::make_unique<NodeFeature>("driftR",
67 [](const Bucket_t& bucket, size_t index) {
68 return bucket[index]->driftRadius();
69 }),
70 std::make_unique<NodeFeature>("relative_layer",
71 [](const Bucket_t& bucket, size_t index) {
72 const double relLayNum = (1 + 1.*bucket.layerNum(index)) / (bucket.nStripLayers() + bucket.nMdtLayers());
73 return relLayNum;
74 }),
75 std::make_unique<NodeFeature>("neighbors",
76 [](const Bucket_t& bucket, size_t index) {
77 constexpr double radCut2 = (50.*Gaudi::Units::cm * 50.*Gaudi::Units::cm);
78 unsigned int n =0;
79 for (size_t other =0 ; other < bucket.size(); ++ other){
80 n+= index != other && (bucket[index]->localPosition() - bucket[other]->localPosition()).perp2() < radCut2;
81 }
82 return n;
83 }),
84 std::make_unique<NodeFeature>("bucket_density",
85 [](const Bucket_t& bucket, size_t /*index*/) {
86
87 return 1.*bucket.size() / std::max(bucket.coveredMax() - bucket.coveredMin(), 1. * Gaudi::Units::cm);
88 }),
89 std::make_unique<NodeFeature>("isolation",
90 [](const Bucket_t& bucket, size_t index) {
91 unsigned int neighbors = 0;
92 constexpr double radCut2 = (50.*Gaudi::Units::cm * 50.*Gaudi::Units::cm);
93 for (size_t other =0 ; other < bucket.size(); ++ other){
94 neighbors+= index != other && (bucket[index]->localPosition() - bucket[other]->localPosition()).perp2() < radCut2;
95 }
96
97 float bucket_density = 1.f*bucket.size() / std::max(bucket.coveredMax() - bucket.coveredMin(), 1. * Gaudi::Units::cm);
98 return neighbors / bucket_density;
99 }),
100
101 std::make_unique<NodeFeature>("covX",
102 [](const Bucket_t& bucket, size_t index) {
103 return bucket[index]->covariance()[Acts::toUnderlying(CovIdx::phiCov)];
104 }),
105 std::make_unique<NodeFeature>("covY",
106 [](const Bucket_t& bucket, size_t index) {
107 return bucket[index]->covariance()[Acts::toUnderlying(CovIdx::etaCov)];
108 }),
109 };
110 const auto feat_itr = featurePool.find(featName);
111 if(feat_itr != featurePool.end()){
112 if (log.level() <= MSG::DEBUG) {
113 log<<MSG::DEBUG<<"Found graph feature "<<featName<<"."<<endmsg;
114 }
115 return *feat_itr;
116 }
117 std::stringstream available{};
118 for (const Feature_t& known : featurePool) {
119 available<<known->name()<<", ";
120 }
121 log<<MSG::ERROR<<"The feature "<<featName<<" is unknown to the feature factory. "
122 <<" Please check for typos w.r.t "<<available.str()<<". Otherwise augment "
123 <<__FILE__<<" with your desired feature "<<endmsg;
124 return nullptr;
125 }
126 NodeFeatureList::Connector_t makeConnector(const std::string& connName, MsgStream& log) {
127
128 static const std::set<Connector_t, std::less<>> connectorPool{
129 std::make_unique<NodeConnector>("fullyConnected",
130 [](const Bucket_t& , size_t , size_t ) {
131 return true;
132 }),
133 };
134 const auto feat_itr = connectorPool.find(connName);
135 if(feat_itr != connectorPool.end()){
136 if (log.level() <= MSG::DEBUG) {
137 log<<MSG::DEBUG<<"Found graph connector "<<connName<<"."<<endmsg;
138 }
139 return *feat_itr;
140 }
141 std::stringstream available{};
142 for (const Connector_t& known : connectorPool) {
143 available<<known->name()<<", ";
144 }
145 log<<MSG::ERROR<<"The graph connector "<<connName<<" is unknown to the factory. "
146 <<" Please check for typos w.r.t "<<available.str()<<". Otherwise augment "
147 <<__FILE__<<" with your desired connection function. "<<endmsg;
148 return nullptr;
149 }
150 }
151
152}
#define endmsg
static Double_t a
uint8_t nMdtLayers() const
Returns how many Mdt layers are inside the bucket.
Definition LayerBucket.h:19
uint8_t layerNum(const size_t i) const
Returns the associated layer number of the i-the space point inside the bucket.
Definition LayerBucket.h:27
double coveredMax() const
Returns the max covered position of the bucket.
Definition LayerBucket.h:31
double coveredMin() const
Returns the min covered position of the bucket.
Definition LayerBucket.h:35
uint8_t nStripLayers() const
Returns how many Strip layers are inside the bucket.
Definition LayerBucket.h:23
std::shared_ptr< const NodeFeature > Feature_t
std::shared_ptr< const NodeConnector > Connector_t
LayerSpBucket Bucket_t
Abreviation of the Space point bucket type.
Definition NodeFeature.h:19
NodeFeatureList::Connector_t makeConnector(const std::string &connName, MsgStream &log)
Factory function that builds a connector relation between two edges in the bucket.
NodeFeatureList::Feature_t makeFeature(const std::string &featName, MsgStream &log)
Factory function that builds a NodeFeature from a predefined list of features.
NodeFeature::Bucket_t Bucket_t
bool operator<(const std::string &a, const Feature_t &b)
NodeFeatureList::Feature_t Feature_t
NodeFeatureList::Connector_t Connector_t
Definition index.py:1