ATLAS Offline Software
Loading...
Searching...
No Matches
MPIClusterSvc.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4#include "MPIClusterSvc.h"
5
6#include "CxxUtils/XXH.h"
7#include "GaudiKernel/FileIncident.h"
8
9#include <boost/serialization/variant.hpp>
10
12 ATH_MSG_DEBUG("Initializing MPI");
13 m_env = std::make_unique<mpi3::environment>(mpi3::thread_level::single);
14 ATH_MSG_DEBUG("Created MPI environment");
15 m_world = m_env->world();
16 m_datacom =
17 m_world.duplicate(); // make a duplicate communicator for event data
18 ATH_MSG_DEBUG("Got MPI_COMM_WORLD");
19 m_rank = m_world.rank();
20 ATH_MSG_INFO("On MPI rank " << m_rank);
21
22 ATH_CHECK(m_mpiLog.retrieve());
23 m_mpiLog->createStatement("PRAGMA foreign_keys = ON").run();
24
26 ->createStatement(
27 "CREATE TABLE ranks (rank INTEGER PRIMARY KEY, "
28 "node TEXT, start_time FLOAT, end_time FLOAT)")
29 .run();
31 ->createStatement(
32 "INSERT INTO ranks (rank, node, start_time) "
33 "VALUES(?1, ?2, julianday('now'))")
34 .run(m_rank, m_env->processor_name());
35 m_mpiLog->createStatement(
36 "CREATE TABLE files (fileId INTEGER PRIMARY KEY, fileName TEXT)")
37 .run();
39 ->createStatement(
40 "CREATE TABLE event_log (rank INTEGER, id INTEGER UNIQUE,"
41 "inputFileId INTEGER,"
42 "runNumber INTEGER, eventNumber INTEGER, complete INTEGER,"
43 "status INTEGER, request_time_ns INTEGER, start_time FLOAT,"
44 "end_time FLOAT, PRIMARY KEY (runNumber, eventNumber), "
45 "FOREIGN KEY (rank) REFERENCES ranks(rank),"
46 "FOREIGN KEY (inputFileId) REFERENCES files(fileId))")
47 .run();
48 m_mpiLog_addEvent = m_mpiLog->createStatement(
49 "INSERT INTO event_log(id, rank, inputFileId, runNumber, eventNumber, complete, "
50 "start_time, request_time_ns) "
51 "VALUES(?1, ?4, ?6, ?2, ?3, 0, julianday('now'), ?5)");
52 m_mpiLog_completeEvent = m_mpiLog->createStatement(
53 "UPDATE event_log SET complete = 1, status = ?3, end_time = "
54 "julianday('now') WHERE runNumber = ?1 "
55 "AND "
56 "eventNumber = ?2");
57 m_mpiLog_addFile = m_mpiLog->createStatement(
58 "INSERT INTO files (fileId, fileName) VALUES(?1, ?2)");
59
60 // Set up incident listener
61 ServiceHandle<IIncidentSvc> incsvc("IncidentSvc", this->name());
62 if (!incsvc.retrieve().isSuccess()) {
63 ATH_MSG_FATAL("Cannot get IncidentSvc.");
64 return(StatusCode::FAILURE);
65 }
66 incsvc->addListener(this, IncidentType::BeginInputFile, 100);
67 incsvc->addListener(this, IncidentType::BeginProcessing, 100);
68
69 return StatusCode::SUCCESS;
70}
71
74 ->createStatement(
75 "UPDATE ranks SET end_time = julianday('now') WHERE rank = ?1")
76 .run(m_rank);
77 m_env.reset(nullptr);
78 return StatusCode::SUCCESS;
79}
80
82void MPIClusterSvc::handle(const Incident& inc) {
83 // Fill in slot map at start of every event
84 if (inc.type() == IncidentType::BeginProcessing) {
85 const std::size_t slot = Gaudi::Hive::currentContext().slot();
87 }
88
89 // Cache new input filename on start of every file
90 if (inc.type() == IncidentType::BeginInputFile) {
91 const FileIncident* fileInc = dynamic_cast<const FileIncident*>(&inc);
92 if (fileInc == nullptr) {
93 ATH_MSG_ERROR("BeginInputFile does not have a file name attached");
94 return;
95 }
96
97 const std::string fileName = fileInc->fileName();
98 // Convert the hash into a signed int64. Just a hash so this doesn't matter.
99 m_lastInputFileHash = static_cast<std::int64_t>(xxh3::hash64(fileName));
100 m_mpiLog_addFile.run(m_lastInputFileHash, std::move(fileName));
101 }
102 return;
103}
104
105
107 return m_world.size();
108}
109
111 return m_rank;
112}
113
115 ATH_MSG_DEBUG("Barrier on rank " << rank() << " of " << numRanks());
116 m_world.barrier();
117}
118
120 m_world.abort();
121}
122
124 ClusterComm communicator) {
125 ATH_MSG_DEBUG("Sending message from rank " << rank() << " to " << destRank);
126 // Don't send event request message if we're not the master *and* we have a
127 // message waiting.
128 // Probably an emergency stop message
129 if (m_rank != 0 && message.messageType == ClusterMessageType::RequestEvent &&
130 m_world.iprobe().has_value()) {
131 return;
132 }
133
134 // Select correct communicator
135 mpi3::communicator& comm =
136 (communicator == ClusterComm::EventData) ? m_datacom : m_world;
137 if (message.messageType == ClusterMessageType::Data &&
138 communicator != ClusterComm::EventData) {
140 "Event data should be sent with EventData communicator. "
141 "Dropping message");
142 return;
143 }
144
145 message.source = m_rank;
146 const auto& [header, body] = message.wire_msg();
147 comm.send_n(header.begin(), header.size(), destRank, 0);
148 if (body.has_value()) {
149 comm.send_n(body->begin(), body->size(), destRank, header[2]);
150 if (message.messageType == ClusterMessageType::Data) {
151 const ClusterMessage::WireMsgBody& bdy = *body;
152 // Decode the body to figure out what to send
153 char* ptr = reinterpret_cast<char*>((std::uint64_t(bdy[0]) << 32) +
154 std::uint64_t(bdy[1]));
155 std::size_t len = (std::uint64_t(bdy[2]) << 32) + std::uint64_t(bdy[3]);
156
157 // Offset the tag by 16384 to minimize chance of conflict
158 // (max tag in MPI spec is 32767)
159 constexpr int tag_offset = 16384;
160 comm.send_n(ptr, len, destRank, header[2] + tag_offset);
161 }
162 }
163}
164
166 // Same offset as line 114
167 constexpr int tag_offset = 16384;
168 constexpr std::uint64_t thirtytwo_ones = 0xFFFFFFFF;
169
170 // Select correct communicator
171 mpi3::communicator& comm =
172 (communicator == ClusterComm::EventData) ? m_datacom : m_world;
174 auto&& [head, body] = msg;
175 comm.receive_n(head.begin(), head.size());
176 // Only time we need to figure out ourselves whether there's a body
177 if (head[0] == int(ClusterMessageType::FinalWorkerStatus) ||
178 head[0] == int(ClusterMessageType::WorkerError) ||
179 head[0] == int(ClusterMessageType::Data)) {
181 comm.receive_n(body->begin(), body->size(), head[1], head[2]);
182 if (head[0] == int(ClusterMessageType::Data)) {
183 ClusterMessage::WireMsgBody& bdy = *body;
184 // Decode the body to figure out what to recieve
185 std::size_t len = (std::uint64_t(bdy[2]) << 32) + std::uint64_t(bdy[3]);
186 std::size_t align = (std::uint64_t(bdy[4]) << 32) + std::uint64_t(bdy[5]);
187
188 char* ptr = static_cast<char*>(std::aligned_alloc(align, len));
189 comm.receive_n(ptr, len, head[1], head[2] + tag_offset);
190
191 // update the pointer in the WireMsgBody
192 bdy[0] = int(std::uint64_t(ptr) >> 32);
193 bdy[1] = int(std::uint64_t(ptr) & thirtytwo_ones);
194 }
195 }
196 ClusterMessage message(msg);
197 ATH_MSG_DEBUG("Rank " << rank() << " received message from "
198 << message.source);
199 return message;
200}
201
202void MPIClusterSvc::log_addEvent(int eventIdx, std::int64_t run_number,
203 std::int64_t event_number,
204 std::int64_t request_time_ns,
205 std::size_t slot) {
206 m_mpiLog_addEvent.run(eventIdx, run_number, event_number, m_rank,
207 request_time_ns,
208 m_inputFileSlotMap[slot]);
209}
210
211void MPIClusterSvc::log_completeEvent(std::int64_t run_number,
212 std::int64_t event_number,
213 std::int64_t status) {
214 m_mpiLog_completeEvent.run(run_number, event_number, status);
215}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
ClusterComm
C++ native wrapper for the C xxhash API.
std::int64_t m_lastInputFileHash
virtual ClusterMessage waitReceiveMessage(ClusterComm communicator=ClusterComm::Default) override final
Block until we receive an MPI message.
virtual void barrier() override final
Insert a barrier No rank will continue until all ranks reach this point.
virtual void log_addEvent(int eventIdx, std::int64_t run_number, std::int64_t event_number, std::int64_t request_time_ns, std::size_t slot) override final
Add (begin) an event in the log.
SQLite::Statement m_mpiLog_addFile
std::unique_ptr< mpi3::environment > m_env
SQLite::Statement m_mpiLog_addEvent
virtual void handle(const Incident &inc) override
IIncidentListener handle.
virtual void abort() override final
Abort the MPI run.
virtual int rank() const override final
Return our rank.
virtual void log_completeEvent(std::int64_t run_number, std::int64_t event_number, std::int64_t status) override final
Complete an event in the log.
mpi3::communicator m_world
virtual StatusCode initialize() override final
Initialize.
mpi3::communicator m_datacom
virtual StatusCode finalize() override final
Finalize.
SQLite::Statement m_mpiLog_completeEvent
virtual void sendMessage(int destRank, ClusterMessage message, ClusterComm communicator=ClusterComm::Default) override final
Send an MPI message.
ServiceHandle< ISQLiteDBSvc > m_mpiLog
virtual int numRanks() const override final
Return number of ranks.
std::map< std::size_t, std::int64_t > m_inputFileSlotMap
std::string head(std::string s, const std::string &pattern)
head of a string
std::uint64_t hash64(const void *data, std::size_t size)
Passthrough to XXH3_64bits.
Definition XXH.cxx:9
A class describing a message sent between nodes in a cluster.
std::array< int, 10 > WireMsgBody
std::tuple< WireMsgHdr, std::optional< WireMsgBody > > WireMsg
MsgStream & msg
Definition testRead.cxx:32