21 #ifndef THUNDEREGG_SCHUR_PATCHIFACESCATTER_H
22 #define THUNDEREGG_SCHUR_PATCHIFACESCATTER_H
48 std::vector<std::vector<double>> send_buffers;
49 std::vector<std::vector<double>> recv_buffers;
50 std::vector<MPI_Request> send_requests;
51 std::vector<MPI_Request> recv_requests;
52 bool communicating =
true;
56 const Vector<D - 1>* curr_global_vector =
nullptr;
60 const Vector<D - 1>* curr_local_vector =
nullptr;
62 StatePrivate(
size_t send_buffers_size,
size_t recv_buffers_size)
63 : send_buffers(send_buffers_size)
64 , recv_buffers(recv_buffers_size)
65 , send_requests(send_buffers_size)
66 , recv_requests(recv_buffers_size)
71 MPI_Waitall(recv_requests.size(), recv_requests.data(), MPI_STATUSES_IGNORE);
72 MPI_Waitall(send_requests.size(), send_requests.data(), MPI_STATUSES_IGNORE);
80 std::shared_ptr<StatePrivate> ptr;
81 State(
size_t num_send,
size_t num_recv)
82 : ptr(
new StatePrivate(num_send, num_recv))
133 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
135 std::map<int, std::set<std::pair<int, int>>> incoming_ranks_to_id_local_index_pairs;
139 if (piinfo->pinfo.hasNbr(s)) {
140 auto iface_info = piinfo->getIfaceInfo(s);
141 if (iface_info->rank != rank) {
142 incoming_ranks_to_id_local_index_pairs[iface_info->rank].emplace(
143 iface_info->id, iface_info->patch_local_index);
148 int local_vector_size = 0;
149 for (
auto rank_to_id_local_index_pairs : incoming_ranks_to_id_local_index_pairs) {
150 local_vector_size += rank_to_id_local_index_pairs.second.size();
155 num_recvs = incoming_ranks_to_id_local_index_pairs.size();
160 for (
const auto& rank_to_id_local_index_pairs : incoming_ranks_to_id_local_index_pairs) {
161 recv_ranks[recv_index] = rank_to_id_local_index_pairs.first;
164 local_index_vector.reserve(rank_to_id_local_index_pairs.second.size());
166 for (
auto id_local_index_pair : rank_to_id_local_index_pairs.second) {
167 local_index_vector.push_back(id_local_index_pair.second);
180 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
182 std::map<int, std::set<std::pair<int, int>>> outgoing_ranks_to_id_local_index_pairs;
184 for (
auto patch : iface->patches) {
185 if ((patch.type.isNormal() || patch.type.isFineToFine() || patch.type.isCoarseToCoarse()) &&
186 patch.piinfo->pinfo.rank != rank) {
187 outgoing_ranks_to_id_local_index_pairs[patch.piinfo->pinfo.rank].emplace(
188 iface->id, iface->local_index);
193 num_sends = outgoing_ranks_to_id_local_index_pairs.size();
198 for (
const auto& rank_to_id_local_index_pairs : outgoing_ranks_to_id_local_index_pairs) {
199 send_ranks[send_index] = rank_to_id_local_index_pairs.first;
202 local_index_vector.reserve(rank_to_id_local_index_pairs.second.size());
204 for (
auto id_local_index_pair : rank_to_id_local_index_pairs.second) {
205 local_index_vector.push_back(id_local_index_pair.second);
217 for (
int send_index = 0; send_index <
num_sends; send_index++) {
222 for (
int recv_index = 0; recv_index <
num_recvs; recv_index++) {
238 std::array<int, D> ns = iface_domain.
getDomain().getNs();
239 for (
int i = 1; i < D; i++) {
240 if (ns[0] != ns[i]) {
242 "Cannot form Schur compliment vector for Domain with non-square patches");
258 return std::make_shared<
Vector<D - 1>>(
280 local_data.getStart(), local_data.getEnd(), [&](
const std::array<int, D - 1>& coord) {
281 local_data[coord] = global_data[coord];
287 for (
int recv_index = 0; recv_index <
num_recvs; recv_index++) {
288 MPI_Irecv(state.ptr->recv_buffers[recv_index].data(),
289 state.ptr->recv_buffers[recv_index].size(),
294 &state.ptr->recv_requests[recv_index]);
297 for (
int send_index = 0; send_index <
num_sends; send_index++) {
298 std::vector<double>& buffer = state.ptr->send_buffers[send_index];
300 int buffer_index = 0;
304 buffer[buffer_index] = local_data[coord];
309 MPI_Isend(buffer.data(),
315 &state.ptr->send_requests[send_index]);
318 for (
int local_iface = 0; local_iface < global_vector.
getNumLocalPatches(); local_iface++) {
320 auto local_data = local_patch_iface_vector.
getComponentView(0, local_iface);
322 local_data[coord] = global_data[coord];
326 state.ptr->curr_global_vector = &global_vector;
327 state.ptr->curr_local_vector = &local_patch_iface_vector;
344 if (&global_vector != state.ptr->curr_global_vector ||
345 &local_patch_iface_vector != state.ptr->curr_local_vector) {
347 "Different vectors were passed ot scatterFinish than were passed to scatterStart");
353 MPI_Waitany(
num_recvs, state.ptr->recv_requests.data(), &recv_index, &status);
355 std::vector<double>& buffer = state.ptr->recv_buffers[recv_index];
357 int buffer_index = 0;
359 auto local_data = local_patch_iface_vector.
getComponentView(0, local_index);
361 local_data[coord] = buffer[buffer_index];
367 MPI_Waitall(
num_sends, state.ptr->send_requests.data(), MPI_STATUSES_IGNORE);
369 state.ptr->communicating =
false;
372 extern template class PatchIfaceScatter<2>;
373 extern template class PatchIfaceScatter<3>;