ThunderEgg  1.0.0
FFTWPatchSolver.h
Go to the documentation of this file.
1 /***************************************************************************
2  * ThunderEgg, a library for solvers on adaptively refined block-structured
3  * Cartesian grids.
4  *
5  * Copyright (c) 2020-2021 Scott Aiton
6  *
7  * This program is free software: you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation, either version 3 of the License, or
10  * (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program. If not, see <https://www.gnu.org/licenses/>.
19  ***************************************************************************/
20 
21 #ifndef THUNDEREGG_POISSON_SCHUR_FFTWPATCHSOLVER_H
22 #define THUNDEREGG_POISSON_SCHUR_FFTWPATCHSOLVER_H
23 
28 #include <ThunderEgg/PatchArray.h>
30 #include <ThunderEgg/PatchSolver.h>
31 #include <ThunderEgg/Vector.h>
32 #include <bitset>
33 #include <fftw3.h>
34 #include <map>
35 
36 namespace ThunderEgg::Poisson {
42 template<int D>
43 class FFTWPatchSolver : public PatchSolver<D>
44 {
45 private:
50  using CompareFunction = std::function<bool(const PatchInfo<D>&, const PatchInfo<D>&)>;
51 
55  std::shared_ptr<const PatchOperator<D>> op;
59  std::map<const PatchInfo<D>, std::shared_ptr<fftw_plan>, CompareFunction> plan1;
63  std::map<const PatchInfo<D>, std::shared_ptr<fftw_plan>, CompareFunction> plan2;
67  std::map<const PatchInfo<D>, PatchArray<D>, CompareFunction> eigen_vals;
71  std::bitset<Side<D>::number_of> neumann;
72 
81  bool patchIsNeumannOnSide(const PatchInfo<D>& pinfo, Side<D> s)
82  {
83  return !pinfo.hasNbr(s) && neumann[s.getIndex()];
84  }
92  std::array<fftw_r2r_kind, D> getTransformsForPatch(const PatchInfo<D>& pinfo)
93  {
94  // get transform types for each axis
95  std::array<fftw_r2r_kind, D> transforms;
96  for (size_t axis = 0; axis < D; axis++) {
97  if (patchIsNeumannOnSide(pinfo, LowerSideOnAxis<D>(axis)) &&
98  patchIsNeumannOnSide(pinfo, HigherSideOnAxis<D>(axis))) {
99  transforms[D - 1 - axis] = FFTW_REDFT10;
100  } else if (patchIsNeumannOnSide(pinfo, LowerSideOnAxis<D>(axis))) {
101  transforms[D - 1 - axis] = FFTW_REDFT11;
102  } else if (patchIsNeumannOnSide(pinfo, HigherSideOnAxis<D>(axis))) {
103  transforms[D - 1 - axis] = FFTW_RODFT11;
104  } else {
105  transforms[D - 1 - axis] = FFTW_RODFT10;
106  }
107  }
108  return transforms;
109  }
117  std::array<fftw_r2r_kind, D> getInverseTransformsForPatch(const PatchInfo<D>& pinfo)
118  {
119  // get transform types for each axis
120  std::array<fftw_r2r_kind, D> transforms_inv;
121  for (size_t axis = 0; axis < D; axis++) {
122  if (patchIsNeumannOnSide(pinfo, LowerSideOnAxis<D>(axis)) &&
123  patchIsNeumannOnSide(pinfo, HigherSideOnAxis<D>(axis))) {
124  transforms_inv[D - 1 - axis] = FFTW_REDFT01;
125  } else if (patchIsNeumannOnSide(pinfo, LowerSideOnAxis<D>(axis))) {
126  transforms_inv[D - 1 - axis] = FFTW_REDFT11;
127  } else if (patchIsNeumannOnSide(pinfo, HigherSideOnAxis<D>(axis))) {
128  transforms_inv[D - 1 - axis] = FFTW_RODFT11;
129  } else {
130  transforms_inv[D - 1 - axis] = FFTW_RODFT01;
131  }
132  }
133  return transforms_inv;
134  }
141  PatchArray<D> getEigenValues(const PatchInfo<D>& pinfo)
142  {
143  PatchArray<D> retval(this->getDomain().getNs(), 1, 0);
144 
145  std::valarray<size_t> all_strides(D);
146  size_t curr_stride = 1;
147  for (size_t i = 0; i < D; i++) {
148  all_strides[i] = curr_stride;
149  curr_stride *= pinfo.ns[i];
150  }
151 
152  for (size_t axis = 0; axis < D; axis++) {
153  int n = pinfo.ns[axis];
154  double h = pinfo.spacings[axis];
155 
156  if (patchIsNeumannOnSide(pinfo, LowerSideOnAxis<D>(axis)) &&
157  patchIsNeumannOnSide(pinfo, HigherSideOnAxis<D>(axis))) {
158  for (int xi = 0; xi < n; xi++) {
159  double val = 4 / (h * h) * pow(sin(xi * M_PI / (2 * n)), 2);
160  View<double, D> slice = retval.getSliceOn(Side<D>(2 * axis), { xi });
161  Loop::OverInteriorIndexes<D>(
162  slice, [&](const std::array<int, D>& coord) { slice[coord] -= val; });
163  }
164  } else if (patchIsNeumannOnSide(pinfo, LowerSideOnAxis<D>(axis)) ||
165  patchIsNeumannOnSide(pinfo, HigherSideOnAxis<D>(axis))) {
166  for (int xi = 0; xi < n; xi++) {
167  double val = 4 / (h * h) * pow(sin((xi + 0.5) * M_PI / (2 * n)), 2);
168  View<double, D> slice = retval.getSliceOn(Side<D>(2 * axis), { xi });
169  Loop::OverInteriorIndexes<D>(
170  slice, [&](const std::array<int, D>& coord) { slice[coord] -= val; });
171  }
172  } else {
173  for (int xi = 0; xi < n; xi++) {
174  double val = 4 / (h * h) * pow(sin((xi + 1) * M_PI / (2 * n)), 2);
175  View<double, D> slice = retval.getSliceOn(Side<D>(2 * axis), { xi });
176  Loop::OverInteriorIndexes<D>(
177  slice, [&](const std::array<int, D>& coord) { slice[coord] -= val; });
178  }
179  }
180  }
181  return retval;
182  }
183 
184 public:
191  FFTWPatchSolver(const PatchOperator<D>& op, std::bitset<Side<D>::number_of> neumann)
192  : PatchSolver<D>(op.getDomain(), op.getGhostFiller())
193  , op(op.clone())
194  , neumann(neumann)
195  {
196  CompareFunction compare = [&](const PatchInfo<D>& a, const PatchInfo<D>& b) {
197  std::bitset<Side<D>::number_of> a_neumann;
198  std::bitset<Side<D>::number_of> b_neumann;
199  for (Side<D> s : Side<D>::getValues()) {
200  a_neumann[s.getIndex()] = patchIsNeumannOnSide(a, s);
201  b_neumann[s.getIndex()] = patchIsNeumannOnSide(b, s);
202  }
203  return std::forward_as_tuple(a_neumann.to_ulong(), a.spacings[0]) <
204  std::forward_as_tuple(b_neumann.to_ulong(), b.spacings[0]);
205  };
206 
207  plan1 = std::map<const PatchInfo<D>, std::shared_ptr<fftw_plan>, CompareFunction>(compare);
208  plan2 = std::map<const PatchInfo<D>, std::shared_ptr<fftw_plan>, CompareFunction>(compare);
209  eigen_vals = std::map<const PatchInfo<D>, PatchArray<D>, CompareFunction>(compare);
210 
211  // process patches
212  for (auto pinfo : this->getDomain().getPatchInfoVector()) {
213  addPatch(pinfo);
214  }
215  }
221  FFTWPatchSolver<D>* clone() const override { return new FFTWPatchSolver<D>(*this); }
222  void solveSinglePatch(const PatchInfo<D>& pinfo,
223  const PatchView<const double, D>& f_view,
224  const PatchView<double, D>& u_view) const override
225  {
226  PatchArray<D> f_copy(pinfo.ns, 1, 0);
227  PatchArray<D> tmp(pinfo.ns, 1, 0);
228  PatchArray<D> sol(pinfo.ns, 1, 0);
229 
230  Loop::OverInteriorIndexes<D + 1>(
231  f_copy, [&](std::array<int, D + 1> coord) { f_copy[coord] = f_view[coord]; });
232 
233  op->modifyRHSForInternalBoundaryConditions(pinfo, u_view, f_copy.getView());
234 
235  fftw_execute_r2r(*plan1.at(pinfo), &f_copy[f_copy.getStart()], &tmp[tmp.getStart()]);
236 
237  const PatchArray<D>& eigen_vals_view = eigen_vals.at(pinfo);
238  Loop::OverInteriorIndexes<D + 1>(
239  tmp, [&](std::array<int, D + 1> coord) { tmp[coord] /= eigen_vals_view[coord]; });
240 
241  if (neumann.all() && !pinfo.hasNbr()) {
242  tmp[tmp.getStart()] = 0;
243  }
244 
245  fftw_execute_r2r(*plan2.at(pinfo), &tmp[tmp.getStart()], &sol[sol.getStart()]);
246 
247  double scale = 1;
248  for (size_t axis = 0; axis < D; axis++) {
249  scale *= 2.0 * this->getDomain().getNs()[axis];
250  }
251  Loop::OverInteriorIndexes<D + 1>(
252  u_view, [&](std::array<int, D + 1> coord) { u_view[coord] = sol[coord] / scale; });
253  }
261  void addPatch(const PatchInfo<D>& pinfo)
262  {
263  if (plan1.count(pinfo) == 0) {
264  // revers ns because FFTW is row major
265  std::array<int, D> ns_reversed;
266  for (size_t i = 0; i < D; i++) {
267  ns_reversed[D - 1 - i] = pinfo.ns[i];
268  }
269  std::array<fftw_r2r_kind, D> transforms = getTransformsForPatch(pinfo);
270  std::array<fftw_r2r_kind, D> transforms_inv = getInverseTransformsForPatch(pinfo);
271 
272  PatchArray<D> f_copy(pinfo.ns, 1, 0);
273  PatchArray<D> tmp(pinfo.ns, 1, 0);
274  PatchArray<D> sol(pinfo.ns, 1, 0);
275 
276  fftw_plan* fftw_plan1 = new fftw_plan();
277 
278  *fftw_plan1 = fftw_plan_r2r(D,
279  ns_reversed.data(),
280  &f_copy[f_copy.getStart()],
281  &tmp[tmp.getStart()],
282  transforms.data(),
283  FFTW_MEASURE | FFTW_DESTROY_INPUT | FFTW_UNALIGNED);
284 
285  plan1[pinfo] = std::shared_ptr<fftw_plan>(fftw_plan1, [](fftw_plan* plan) {
286  fftw_destroy_plan(*plan);
287  delete plan;
288  });
289 
290  fftw_plan* fftw_plan2 = new fftw_plan();
291 
292  *fftw_plan2 = fftw_plan_r2r(D,
293  ns_reversed.data(),
294  &tmp[tmp.getStart()],
295  &sol[sol.getStart()],
296  transforms_inv.data(),
297  FFTW_MEASURE | FFTW_DESTROY_INPUT | FFTW_UNALIGNED);
298 
299  plan2[pinfo] = std::shared_ptr<fftw_plan>(fftw_plan2, [](fftw_plan* plan) {
300  fftw_destroy_plan(*plan);
301  delete plan;
302  });
303 
304  eigen_vals.emplace(pinfo, getEigenValues(pinfo));
305  }
306  }
312  std::bitset<Side<D>::number_of> getNeumann() const { return neumann; }
313 };
314 extern template class FFTWPatchSolver<2>;
315 extern template class FFTWPatchSolver<3>;
316 } // namespace ThunderEgg::Poisson
317 #endif
ThunderEgg::Poisson::FFTWPatchSolver::getNeumann
std::bitset< Side< D >::number_of > getNeumann() const
Get the neumann boundary conditions for this operator.
Definition: FFTWPatchSolver.h:312
ThunderEgg::PatchSolver::getDomain
const Domain< D > & getDomain() const
Get the Domain object.
Definition: PatchSolver.h:84
ThunderEgg::View< double, D >
ThunderEgg::Poisson::FFTWPatchSolver::FFTWPatchSolver
FFTWPatchSolver(const PatchOperator< D > &op, std::bitset< Side< D >::number_of > neumann)
Construct a new FftwPatchSolver object.
Definition: FFTWPatchSolver.h:191
Vector.h
Vector class.
ThunderEgg::PatchSolver
Solves the problem on the patches using a specified interface value.
Definition: PatchSolver.h:42
ThunderEgg::PatchInfo::hasNbr
bool hasNbr(Face< D, M > s) const
Return whether the patch has a neighbor.
Definition: PatchInfo.h:284
ThunderEgg::Poisson::FFTWPatchSolver::solveSinglePatch
void solveSinglePatch(const PatchInfo< D > &pinfo, const PatchView< const double, D > &f_view, const PatchView< double, D > &u_view) const override
Perform a single solve over a patch.
Definition: FFTWPatchSolver.h:222
ThunderEgg::PatchInfo::ns
std::array< int, D > ns
The number of cells in each direction.
Definition: PatchInfo.h:117
ThunderEgg::PatchArray::getSliceOn
View< double, M+1 > getSliceOn(Face< D, M > f, const std::array< int, D - M > &offset)
Get the slice on a given face.
Definition: PatchArray.h:116
ThunderEgg::PatchArray::getView
const PatchView< double, D > & getView()
Get the View for the array.
Definition: PatchArray.h:196
PatchOperator.h
PatchOperator class.
ThunderEgg::Poisson
Classes specific to the Poisson equation.
Definition: DFTPatchSolver.h:40
ThunderEgg::PatchView
View for accessing data of a patch. It supports variable striding.
Definition: PatchView.h:37
PatchArray.h
PatchArray class.
ThunderEgg::Poisson::FFTWPatchSolver::clone
FFTWPatchSolver< D > * clone() const override
Clone this patch solver.
Definition: FFTWPatchSolver.h:221
ThunderEgg::Face::getIndex
size_t getIndex() const
Get the index for this Face.
Definition: Face.h:452
ThunderEgg::PatchArray
Array for acessing data of a patch. It supports variable striding.
Definition: PatchArray.h:36
ThunderEgg::PatchInfo
Contains metadata for a patch.
Definition: PatchInfo.h:51
ThunderEgg::PatchSolver::getGhostFiller
const GhostFiller< D > & getGhostFiller() const
Get the GhostFiller object.
Definition: PatchSolver.h:90
ThunderEgg::PatchOperator
This is an Operator where derived classes only have to implement the two virtual functions that opera...
Definition: PatchOperator.h:40
ThunderEgg::PatchInfo::spacings
std::array< double, D > spacings
The cell spacings in each direction.
Definition: PatchInfo.h:125
ThunderEgg::Poisson::FFTWPatchSolver
Use FFT transforms to solve for the Poisson equation.
Definition: FFTWPatchSolver.h:43
ThunderEgg::Face
Enum-style class for the faces of an n-dimensional cube.
Definition: Face.h:41
PatchSolver.h
PatchSolver class.
ThunderEgg::Poisson::FFTWPatchSolver::addPatch
void addPatch(const PatchInfo< D > &pinfo)
add a patch to the solver
Definition: FFTWPatchSolver.h:261
ThunderEgg::PatchArray::getStart
const std::array< int, D+1 > & getStart() const
Get the coordinate of the first element.
Definition: PatchArray.h:178