21 #ifndef THUNDEREGG_MPIGHOSTFILLER_H
22 #define THUNDEREGG_MPIGHOSTFILLER_H
54 class RemoteCallPrototype
86 RemoteCallPrototype(
int id,
95 , local_index(local_index)
103 bool operator<(
const RemoteCallPrototype& other)
const
105 return std::forward_as_tuple(
id, face.
opposite(), nbr_type, orthant, local_index) <
106 std::forward_as_tuple(
107 other.id, other.face.opposite(), other.nbr_type, other.orthant, other.local_index);
117 class IncomingGhostPrototype
139 IncomingGhostPrototype(
int id,
Face<D, M> face,
int local_index)
142 , local_index(local_index)
150 bool operator<(
const IncomingGhostPrototype& other)
const
152 return std::forward_as_tuple(
id, face, local_index) <
153 std::forward_as_tuple(other.id, other.face, other.local_index);
192 RemoteCall(
const RemoteCallPrototype<M>& prototype,
size_t offset)
193 : face(prototype.face)
194 , nbr_type(prototype.nbr_type)
195 , orthant(prototype.orthant)
196 , local_index(prototype.local_index)
227 IncomingGhost(
const IncomingGhostPrototype<M>& prototype,
size_t offset)
228 : face(prototype.face)
229 , local_index(prototype.local_index)
275 size_t nbr_local_index)
279 , local_index(local_index)
280 , nbr_local_index(nbr_local_index)
296 size_t send_buffer_length = 0;
303 using RemoteCallDeque = std::deque<RemoteCall<M>>;
311 size_t recv_buffer_length = 0;
318 using IncomingGhostDeque = std::deque<IncomingGhost<M>>;
329 explicit RemoteCallSet(
int rank)
337 std::vector<RemoteCallSet> remote_call_sets;
345 using LocalCallDeque = std::deque<LocalCall<M>>;
383 std::array<size_t, Face<D, M>::number_of> sizes;
389 GhostViewInfo() =
default;
395 explicit GhostViewInfo(
const Domain<D>& domain)
397 std::array<int, D + 1> ns;
398 for (
int i = 0; i < D; i++) {
399 ns[i] = domain.
getNs()[i];
404 std::array<int, D + 1> ghost_ns = ns;
406 std::array<int, D + 1> face_ghost_start;
407 face_ghost_start.fill(0);
409 std::array<int, D + 1> face_start;
412 std::array<int, D + 1> face_end;
414 for (
int& v : face_end) {
418 std::array<int, D + 1> face_ghost_end;
420 for (
int& v : face_ghost_end) {
425 ghost_ns[side.getAxisIndex()] = num_ghost_cells;
426 if (side.isLowerOnAxis()) {
427 face_ghost_start[side.getAxisIndex()] = -num_ghost_cells;
428 face_ghost_end[side.getAxisIndex()] = -1;
429 face_end[side.getAxisIndex()] = -1;
431 face_ghost_start[side.getAxisIndex()] = ns[side.getAxisIndex()];
432 face_start[side.getAxisIndex()] = ns[side.getAxisIndex()];
433 face_ghost_end[side.getAxisIndex()] = ns[side.getAxisIndex()] + num_ghost_cells - 1;
436 strides[face.getIndex()][0] = 1;
437 for (
size_t i = 1; i < D + 1; i++) {
438 strides[face.getIndex()][i] = ghost_ns[i - 1] * strides[face.getIndex()][i - 1];
440 sizes[face.getIndex()] = ghost_ns[D - 1] * strides[face.getIndex()][D - 1];
443 if (side.isHigherOnAxis()) {
444 offset += -(ns[side.getAxisIndex()] + num_ghost_cells) *
445 strides[face.getIndex()][side.getAxisIndex()];
448 ghost_start[face.getIndex()] = face_ghost_start;
449 start[face.getIndex()] = face_start;
450 end[face.getIndex()] = face_end;
451 ghost_end[face.getIndex()] = face_ghost_end;
465 int num_components)
const
467 std::array<int, D + 1> my_end = end[face.
getIndex()];
468 std::array<int, D + 1> my_ghost_end = ghost_end[face.
getIndex()];
469 my_end[D] = num_components - 1;
470 my_ghost_end[D] = num_components - 1;
505 int num_components)
const
507 const GhostViewInfo<M>& gld_info = ghost_local_data_infos.template get<M>();
508 return gld_info.getPatchView(buffer_ptr, face, num_components);
516 std::vector<MPI_Request> postRecvs(std::vector<std::vector<double>>& buffers)
const
518 std::vector<MPI_Request> recv_requests(buffers.size());
519 for (
size_t i = 0; i < recv_requests.size(); i++) {
520 MPI_Irecv(buffers[i].data(),
523 remote_call_sets[i].rank,
528 return recv_requests;
540 void addRecvBufferToGhost(
const RemoteCallSet& remote_call_set,
541 std::vector<double>& buffer,
544 for (
const IncomingGhost<M>& incoming_ghost :
545 remote_call_set.incoming_ghosts.template get<M>()) {
546 int local_index = incoming_ghost.local_index;
548 size_t buffer_offset = incoming_ghost.offset;
551 double* buffer_ptr = buffer.data() + buffer_offset * u.getNumComponents();
553 getPatchViewForBuffer(buffer_ptr, face, u.getNumComponents());
554 std::array<size_t, D - M> start;
556 std::array<size_t, D - M> end;
558 Loop::Nested<D - M>(start, end, [&](
const std::array<size_t, D - M>& offset) {
561 Loop::OverInteriorIndexes<M + 1>(local_slice, [&](
const std::array<int, M + 1>& coord) {
562 local_slice[coord] += buffer_slice[coord];
574 void processRecvs(std::vector<MPI_Request>& requests,
575 std::vector<std::vector<double>>& buffers,
578 size_t num_requests = requests.size();
579 for (
size_t i = 0; i < num_requests; i++) {
581 MPI_Waitany(requests.size(), requests.data(), &finished_index, MPI_STATUS_IGNORE);
584 if constexpr (D >= 2) {
585 addRecvBufferToGhost<0>(remote_call_sets[finished_index], buffers[finished_index], u);
588 if constexpr (D == 3) {
589 addRecvBufferToGhost<1>(remote_call_sets[finished_index], buffers[finished_index], u);
592 addRecvBufferToGhost<D - 1>(remote_call_sets[finished_index], buffers[finished_index], u);
609 void fillSendBuffer(
const RemoteCallSet& remote_call_set,
610 std::vector<double>& buffer,
613 for (
const RemoteCall<M>& call : remote_call_set.remote_calls.template get<M>()) {
617 double* buffer_ptr = buffer.data() + call.offset * u.getNumComponents();
621 getPatchViewForBuffer(buffer_ptr, face, u.getNumComponents());
624 fillGhostCellsForNbrPatchPriv(
625 pinfo, local_view, buffer_view, call.face, call.nbr_type, call.orthant);
636 std::vector<MPI_Request> postSends(std::vector<std::vector<double>>& buffers,
639 std::vector<MPI_Request> send_requests(remote_call_sets.size());
640 for (
size_t i = 0; i < remote_call_sets.size(); i++) {
643 if constexpr (D >= 2) {
644 fillSendBuffer<0>(remote_call_sets[i], buffers[i], u);
647 if constexpr (D == 3) {
648 fillSendBuffer<1>(remote_call_sets[i], buffers[i], u);
651 fillSendBuffer<D - 1>(remote_call_sets[i], buffers[i], u);
656 MPI_Isend(buffers[i].data(),
659 remote_call_sets[i].rank,
664 return send_requests;
680 void fillGhostCellsForNbrPatchPriv(
const PatchInfo<D>& pinfo,
687 if constexpr (M == D - 1) {
689 }
else if constexpr (M > 0) {
705 void processLocalFills(
const Vector<D>& u)
const
707 for (
const LocalCall<M>& call : local_calls.template get<M>()) {
711 fillGhostCellsForNbrPatchPriv(
712 pinfo, local_view, nbr_view, call.face, call.nbr_type, call.orthant);
714 if constexpr (M > 0) {
715 processLocalFills<M - 1>(u);
730 static void addNormalNbrCalls(
733 std::deque<LocalCall<M>>& my_local_calls,
734 std::map<
int, std::set<RemoteCallPrototype<M>>>& rank_to_remote_call_prototypes,
735 std::map<
int, std::set<IncomingGhostPrototype<M>>>& rank_to_incoming_ghost_prototypes)
738 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
740 if (nbrinfo.rank == rank) {
741 my_local_calls.emplace_back(
744 rank_to_remote_call_prototypes[nbrinfo.rank].emplace(
746 rank_to_incoming_ghost_prototypes[nbrinfo.rank].emplace(pinfo.
id, f, pinfo.
local_index);
761 static void addFineNbrCalls(
764 std::deque<LocalCall<M>>& my_local_calls,
765 std::map<
int, std::set<RemoteCallPrototype<M>>>& rank_to_remote_call_prototypes,
766 std::map<
int, std::set<IncomingGhostPrototype<M>>>& rank_to_incoming_ghost_prototypes)
769 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
771 for (
size_t i = 0; i < Orthant<M>::num_orthants; i++) {
772 if (nbrinfo.ranks[i] == rank) {
773 my_local_calls.emplace_back(
776 rank_to_remote_call_prototypes[nbrinfo.ranks[i]].emplace(
778 rank_to_incoming_ghost_prototypes[nbrinfo.ranks[i]].emplace(pinfo.
id, f, pinfo.
local_index);
794 static void addCoarseNbrCalls(
797 std::deque<LocalCall<M>>& my_local_calls,
798 std::map<
int, std::set<RemoteCallPrototype<M>>>& rank_to_remote_call_prototypes,
799 std::map<
int, std::set<IncomingGhostPrototype<M>>>& rank_to_incoming_ghost_prototypes)
802 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
804 if (nbrinfo.rank == rank) {
805 my_local_calls.emplace_back(
808 rank_to_remote_call_prototypes[nbrinfo.rank].emplace(
810 rank_to_incoming_ghost_prototypes[nbrinfo.rank].emplace(pinfo.
id, f, pinfo.
local_index);
822 void enumerateCalls(std::deque<LocalCall<M>>& my_local_calls,
823 std::map<int, RemoteCallSet>& rank_to_remote_call_sets)
const
826 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
827 std::map<int, std::set<RemoteCallPrototype<M>>> rank_to_remote_call_prototypes;
828 std::map<int, std::set<IncomingGhostPrototype<M>>> rank_to_incoming_ghost_prototypes;
835 addNormalNbrCalls(pinfo,
838 rank_to_remote_call_prototypes,
839 rank_to_incoming_ghost_prototypes);
842 addFineNbrCalls(pinfo,
845 rank_to_remote_call_prototypes,
846 rank_to_incoming_ghost_prototypes);
849 addCoarseNbrCalls(pinfo,
852 rank_to_remote_call_prototypes,
853 rank_to_incoming_ghost_prototypes);
861 const GhostViewInfo<M>& ghost_local_data_info = ghost_local_data_infos.template get<M>();
862 for (
const auto& pair : rank_to_remote_call_prototypes) {
863 int rank = pair.first;
864 rank_to_remote_call_sets.emplace(rank, rank);
865 RemoteCallSet& remote_call_set = rank_to_remote_call_sets.at(rank);
866 std::tuple<int, Face<D, M>> prev_id_side;
867 for (
const RemoteCallPrototype<M>& call : pair.second) {
868 size_t offset = remote_call_set.send_buffer_length;
871 size_t length = ghost_local_data_info.getSize(call.face);
874 if (std::make_tuple(call.id, call.face) == prev_id_side) {
877 remote_call_set.send_buffer_length += length;
879 remote_call_set.remote_calls.template get<M>().emplace_back(call, offset);
880 prev_id_side = std::make_tuple(call.id, call.face);
883 for (
const auto& pair : rank_to_incoming_ghost_prototypes) {
884 RemoteCallSet& remote_call_set = rank_to_remote_call_sets.at(pair.first);
885 for (
const IncomingGhostPrototype<M>& prototype : pair.second) {
887 size_t length = ghost_local_data_info.getSize(prototype.face);
888 size_t offset = remote_call_set.recv_buffer_length;
889 remote_call_set.recv_buffer_length += length;
892 remote_call_set.incoming_ghosts.template get<M>().emplace_back(prototype, offset);
909 std::array<size_t, D - M> start;
911 std::array<size_t, D - M> end;
913 Loop::Nested<D - M>(start, end, [&](
const std::array<size_t, D - M>& offset) {
915 Loop::OverInteriorIndexes<M + 1>(
916 this_ghost, [&](
const std::array<int, M + 1>& coord) { this_ghost[coord] = 0; });
927 void zeroGhostCells(
const Vector<D>& u)
const
933 if constexpr (D >= 2) {
934 zeroGhostCellsOnAllFaces<0>(pinfo, this_patch);
937 if constexpr (D == 3) {
938 zeroGhostCellsOnAllFaces<1>(pinfo, this_patch);
941 zeroGhostCellsOnAllFaces<D - 1>(pinfo, this_patch);
958 , fill_type(fill_type)
960 std::map<int, RemoteCallSet> rank_to_remote_call_sets;
964 if constexpr (D >= 2) {
965 ghost_local_data_infos.template get<0>() = GhostViewInfo<0>(domain);
966 enumerateCalls<0>(local_calls.template get<0>(), rank_to_remote_call_sets);
969 if constexpr (D == 3) {
970 ghost_local_data_infos.template get<1>() = GhostViewInfo<1>(domain);
971 enumerateCalls<1>(local_calls.template get<1>(), rank_to_remote_call_sets);
974 ghost_local_data_infos.template get<D - 1>() = GhostViewInfo<D - 1>(domain);
975 enumerateCalls<D - 1>(local_calls.template get<D - 1>(), rank_to_remote_call_sets);
981 remote_call_sets.reserve(rank_to_remote_call_sets.size());
982 for (
const auto& pair : rank_to_remote_call_sets) {
983 remote_call_sets.push_back(pair.second);
1057 if constexpr (ENABLE_DEBUG) {
1059 throw RuntimeError(
"u vector is incorrect length. Expected Lenght of " +
1068 std::vector<std::vector<double>> recv_buffers(remote_call_sets.size());
1069 for (
size_t i = 0; i < remote_call_sets.size(); i++) {
1070 recv_buffers[i].resize(remote_call_sets[i].recv_buffer_length * u.getNumComponents());
1072 std::vector<MPI_Request> recv_requests = postRecvs(recv_buffers);
1075 std::vector<std::vector<double>> out_buffers(remote_call_sets.size());
1076 for (
size_t i = 0; i < remote_call_sets.size(); i++) {
1077 out_buffers[i].resize(remote_call_sets[i].send_buffer_length * u.getNumComponents());
1079 std::vector<MPI_Request> send_requests = postSends(out_buffers, u);
1086 processLocalFills<D - 1>(u);
1088 processRecvs(recv_requests, recv_buffers, u);
1091 MPI_Waitall(send_requests.size(), send_requests.data(), MPI_STATUS_IGNORE);
1108 extern template class MPIGhostFiller<1>;
1109 extern template class MPIGhostFiller<2>;
1110 extern template class MPIGhostFiller<3>;