DiFfRG
Loading...
Searching...
No Matches
linear_interpolation_3D.hh
Go to the documentation of this file.
1#pragma once
2
3// DiFfRG
7
8// standard library
9#include <memory>
10
11// external libraries
12#include <autodiff/forward/real.hpp>
13
14namespace DiFfRG
15{
22 template <typename NT, typename Coordinates> class LinearInterpolator3D
23 {
24 static_assert(Coordinates::dim == 3, "LinearInterpolator3D requires 3D coordinates");
25
26 public:
33 LinearInterpolator3D(const std::vector<NT> &data, const Coordinates &coordinates);
41 LinearInterpolator3D(const NT *data, const Coordinates &coordinates);
48 LinearInterpolator3D(const Coordinates &coordinates);
55
56 template <typename NT2> void update(const NT2 *data)
57 {
58 if (!owner) throw std::runtime_error("Cannot update data of non-owner interpolator");
59
60 for (uint i = 0; i < size; ++i)
61 m_data[i] = NT(data[i]);
62
63 update();
64 }
65 void update();
66
67 NT *data() const;
68
75 __device__ __host__ NT operator()(const typename Coordinates::ctype x, const typename Coordinates::ctype y,
76 const typename Coordinates::ctype z) const
77 {
78#ifndef __CUDA_ARCH__
79 using std::ceil;
80 using std::floor;
81 using std::max;
82 using std::min;
83#endif
84
85 auto [idx_x, idx_y, idx_z] = coordinates.backward(x, y, z);
86 idx_x = max(static_cast<decltype(idx_x)>(0), min(idx_x, static_cast<decltype(idx_x)>(shape[0] - 1)));
87 idx_y = max(static_cast<decltype(idx_y)>(0), min(idx_y, static_cast<decltype(idx_y)>(shape[1] - 1)));
88 idx_z = max(static_cast<decltype(idx_z)>(0), min(idx_z, static_cast<decltype(idx_z)>(shape[2] - 1)));
89
90#ifndef __CUDA_ARCH__
91 const auto *d_ptr = m_data.get();
92#else
93 const auto *d_ptr = device_data_ptr;
94#endif
95
96 uint x1 = min(ceil(idx_x + static_cast<decltype(idx_x)>(1e-16)), static_cast<decltype(idx_x)>(shape[0] - 1));
97 const auto x0 = x1 - 1;
98 uint y1 = min(ceil(idx_y + static_cast<decltype(idx_y)>(1e-16)), static_cast<decltype(idx_y)>(shape[1] - 1));
99 const auto y0 = y1 - 1;
100 uint z1 = min(ceil(idx_z + static_cast<decltype(idx_z)>(1e-16)), static_cast<decltype(idx_z)>(shape[2] - 1));
101 const auto z0 = z1 - 1;
102
103 const auto corner000 = d_ptr[x0 * shape[1] * shape[2] + y0 * shape[2] + z0];
104 const auto corner001 = d_ptr[x0 * shape[1] * shape[2] + y0 * shape[2] + z1];
105 const auto corner010 = d_ptr[x0 * shape[1] * shape[2] + y1 * shape[2] + z0];
106 const auto corner011 = d_ptr[x0 * shape[1] * shape[2] + y1 * shape[2] + z1];
107 const auto corner100 = d_ptr[x1 * shape[1] * shape[2] + y0 * shape[2] + z0];
108 const auto corner101 = d_ptr[x1 * shape[1] * shape[2] + y0 * shape[2] + z1];
109 const auto corner110 = d_ptr[x1 * shape[1] * shape[2] + y1 * shape[2] + z0];
110 const auto corner111 = d_ptr[x1 * shape[1] * shape[2] + y1 * shape[2] + z1];
111
112 return corner000 * (x1 - idx_x) * (y1 - idx_y) * (z1 - idx_z) +
113 corner001 * (x1 - idx_x) * (y1 - idx_y) * (idx_z - z0) +
114 corner010 * (x1 - idx_x) * (idx_y - y0) * (z1 - idx_z) +
115 corner011 * (x1 - idx_x) * (idx_y - y0) * (idx_z - z0) +
116 corner100 * (idx_x - x0) * (y1 - idx_y) * (z1 - idx_z) +
117 corner101 * (idx_x - x0) * (y1 - idx_y) * (idx_z - z0) +
118 corner110 * (idx_x - x0) * (idx_y - y0) * (z1 - idx_z) +
119 corner111 * (idx_x - x0) * (idx_y - y0) * (idx_z - z0);
120 }
121
122 NT &operator[](const uint i);
123 const NT &operator[](const uint i) const;
124
130 const Coordinates &get_coordinates() const { return coordinates; }
131
132 private:
133 const uint size;
134 const Coordinates coordinates;
135 const std::array<uint, 3> shape;
136
137 std::shared_ptr<NT[]> m_data;
138 std::shared_ptr<thrust::device_vector<NT>> device_data;
140
141 const bool owner;
142 };
143} // namespace DiFfRG
A linear interpolator for 3D data, both on GPU and CPU.
Definition linear_interpolation_3D.hh:23
LinearInterpolator3D(const LinearInterpolator3D &other)
Construct a copy of a LinearInterpolator3D object.
__device__ __host__ NT operator()(const typename Coordinates::ctype x, const typename Coordinates::ctype y, const typename Coordinates::ctype z) const
Interpolate the data at a given point.
Definition linear_interpolation_3D.hh:75
const Coordinates coordinates
Definition linear_interpolation_3D.hh:134
const bool owner
Definition linear_interpolation_3D.hh:141
void update(const NT2 *data)
Definition linear_interpolation_3D.hh:56
const uint size
Definition linear_interpolation_3D.hh:133
std::shared_ptr< NT[]> m_data
Definition linear_interpolation_3D.hh:137
LinearInterpolator3D(const NT *data, const Coordinates &coordinates)
Construct a LinearInterpolator3D object from a pointer to data and a coordinate system.
const std::array< uint, 3 > shape
Definition linear_interpolation_3D.hh:135
std::shared_ptr< thrust::device_vector< NT > > device_data
Definition linear_interpolation_3D.hh:138
const Coordinates & get_coordinates() const
Get the coordinate system of the data.
Definition linear_interpolation_3D.hh:130
LinearInterpolator3D(const Coordinates &coordinates)
Construct a LinearInterpolator3D with internal, zeroed data and a coordinate system.
NT & operator[](const uint i)
const NT & operator[](const uint i) const
LinearInterpolator3D(const std::vector< NT > &data, const Coordinates &coordinates)
Construct a LinearInterpolator3D object from a vector of data and a coordinate system.
const NT * device_data_ptr
Definition linear_interpolation_3D.hh:139
Definition complex_math.hh:14
unsigned int uint
Definition utils.hh:22