DiFfRG
Loading...
Searching...
No Matches
linear_interpolation_2D.hh
Go to the documentation of this file.
1#pragma once
2
3// DiFfRG
7
8// standard library
9#include <memory>
10
11// external libraries
12#include <autodiff/forward/real.hpp>
13
14namespace DiFfRG
15{
22 template <typename NT, typename Coordinates> class LinearInterpolator2D
23 {
24 static_assert(Coordinates::dim == 2, "LinearInterpolator2D requires 2D coordinates");
25
26 public:
33 LinearInterpolator2D(const std::vector<NT> &data, const Coordinates &coordinates);
41 LinearInterpolator2D(const NT *data, const Coordinates &coordinates);
48 LinearInterpolator2D(const Coordinates &coordinates);
55
56 template <typename NT2> void update(const NT2 *data)
57 {
58 if (!owner) throw std::runtime_error("Cannot update data of non-owner interpolator");
59
60 for (uint i = 0; i < size; ++i)
61 m_data[i] = NT(data[i]);
62
63 update();
64 }
65 void update();
66
67 NT *data() const;
68
75 __device__ __host__ NT operator()(const typename Coordinates::ctype x, const typename Coordinates::ctype y) const
76 {
77#ifndef __CUDA_ARCH__
78 using std::ceil;
79 using std::floor;
80 using std::max;
81 using std::min;
82#endif
83
84 auto [idx_x, idx_y] = coordinates.backward(x, y);
85 idx_x = max(static_cast<decltype(idx_x)>(0), min(idx_x, static_cast<decltype(idx_x)>(shape[0] - 1)));
86 idx_y = max(static_cast<decltype(idx_y)>(0), min(idx_y, static_cast<decltype(idx_y)>(shape[1] - 1)));
87
88#ifndef __CUDA_ARCH__
89 const auto *d_ptr = m_data.get();
90#else
91 const auto *d_ptr = device_data_ptr;
92#endif
93
94 uint x1 = min(ceil(idx_x + static_cast<decltype(idx_x)>(1e-16)), static_cast<decltype(idx_x)>(shape[0] - 1));
95 const auto x0 = x1 - 1;
96 uint y1 = min(ceil(idx_y + static_cast<decltype(idx_y)>(1e-16)), static_cast<decltype(idx_y)>(shape[1] - 1));
97 const auto y0 = y1 - 1;
98
99 const auto corner00 = d_ptr[x0 * shape[1] + y0];
100 const auto corner01 = d_ptr[x0 * shape[1] + y1];
101 const auto corner10 = d_ptr[x1 * shape[1] + y0];
102 const auto corner11 = d_ptr[x1 * shape[1] + y1];
103
104 return corner00 * (x1 - idx_x) * (y1 - idx_y) + corner01 * (x1 - idx_x) * (idx_y - y0) +
105 corner10 * (idx_x - x0) * (y1 - idx_y) + corner11 * (idx_x - x0) * (idx_y - y0);
106 }
107
108 NT &operator[](const uint i);
109 const NT &operator[](const uint i) const;
110
116 const Coordinates &get_coordinates() const { return coordinates; }
117
118 private:
119 const uint size;
120 const std::array<uint, 2> shape;
121 const Coordinates coordinates;
122
123 std::shared_ptr<NT[]> m_data;
124 std::shared_ptr<thrust::device_vector<NT>> device_data;
126
127 const bool owner;
128 };
129} // namespace DiFfRG
A linear interpolator for 2D data, both on GPU and CPU.
Definition linear_interpolation_2D.hh:23
const uint size
Definition linear_interpolation_2D.hh:119
void update(const NT2 *data)
Definition linear_interpolation_2D.hh:56
LinearInterpolator2D(const std::vector< NT > &data, const Coordinates &coordinates)
Construct a LinearInterpolator2D object from a vector of data and a coordinate system.
__device__ __host__ NT operator()(const typename Coordinates::ctype x, const typename Coordinates::ctype y) const
Interpolate the data at a given point.
Definition linear_interpolation_2D.hh:75
std::shared_ptr< thrust::device_vector< NT > > device_data
Definition linear_interpolation_2D.hh:124
const NT & operator[](const uint i) const
const Coordinates & get_coordinates() const
Get the coordinate system of the data.
Definition linear_interpolation_2D.hh:116
const std::array< uint, 2 > shape
Definition linear_interpolation_2D.hh:120
LinearInterpolator2D(const LinearInterpolator2D &other)
Construct a copy of a LinearInterpolator2D object.
const NT * device_data_ptr
Definition linear_interpolation_2D.hh:125
LinearInterpolator2D(const NT *data, const Coordinates &coordinates)
Construct a LinearInterpolator2D object from a pointer to data and a coordinate system.
const Coordinates coordinates
Definition linear_interpolation_2D.hh:121
NT & operator[](const uint i)
LinearInterpolator2D(const Coordinates &coordinates)
Construct a LinearInterpolator2D with internal, zeroed data and a coordinate system.
const bool owner
Definition linear_interpolation_2D.hh:127
std::shared_ptr< NT[]> m_data
Definition linear_interpolation_2D.hh:123
Definition complex_math.hh:14
unsigned int uint
Definition utils.hh:22