28at::Tensor
GridSample3D_CPU(
const at::Tensor &input,
const at::Tensor &grid,
const std::string &addressModeX,
29 const std::string &addressModeY,
const std::string &addressModeZ,
30 c10::optional<at::Tensor> out);
37 const std::string &addressModeY,
const std::string &addressModeZ,
38 c10::optional<at::Tensor> out);
48 static_assert(texture_t::DIMENSIONALITY == 3);
50 using IntType =
typename texture_t::IntType;
62 const std::string &addressModeY,
const std::string &addressModeZ,
63 at::DeviceType device, c10::optional<at::Tensor> out) {
65 TORCH_CHECK(input.sizes().size() == 3)
66 TORCH_CHECK(input.dtype() == at::kFloat)
67 TORCH_INTERNAL_ASSERT(input.device().type() == device)
69 TORCH_CHECK(grid.sizes().back() == 3);
70 TORCH_CHECK(grid.dtype() == at::kFloat);
71 TORCH_INTERNAL_ASSERT(grid.device().type() == device)
74 TORCH_CHECK(out.value().sizes() == grid.sizes().slice(0, grid.sizes().size() - 1))
75 TORCH_CHECK(out.value().dtype() == at::kFloat)
76 TORCH_INTERNAL_ASSERT(out.value().device().type() == device)
85 ret.inputTexture = texture_t::FromTensor(
input,
inputSpacing, VectorType::Full(0.0), std::move(addressModes));
87 ?
out.value().view({-1})
88 : torch::zeros(at::IntArrayRef({
grid.numel() / 3}),
input.contiguous().options());
General tools and structs.
#define __host__
Definition Global.h:17
at::Tensor GridSample3D_CPU(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX, const std::string &addressModeY, const std::string &addressModeZ, c10::optional< at::Tensor > out)
Sample the given 3D input tensor at the positions given in grid according to the given address mode u...
Definition GridSample3DCPU.cpp:10
__host__ at::Tensor GridSample3D_CUDA(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX, const std::string &addressModeY, const std::string &addressModeZ, c10::optional< at::Tensor > out)
An implementation of reg23::GridSample3D_CPU that uses CUDA parallelisation.
Vec< TextureAddressMode, DIMENSIONALITY > StringsToAddressModes(const std::array< std::string_view, DIMENSIONALITY > &strings)
Definition Texture.h:44
Definition GridSample3DCPU.cpp:6
Definition GridSample3D.h:56
at::Tensor flatOutput
Definition GridSample3D.h:58
texture_t inputTexture
Definition GridSample3D.h:57
Definition GridSample3D.h:46
typename texture_t::AddressModeType AddressModeType
Definition GridSample3D.h:54
typename texture_t::FloatType FloatType
Definition GridSample3D.h:51
typename texture_t::IntType IntType
Definition GridSample3D.h:50
typename texture_t::VectorType VectorType
Definition GridSample3D.h:53
static __host__ CommonData Common(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX, const std::string &addressModeY, const std::string &addressModeZ, at::DeviceType device, c10::optional< at::Tensor > out)
Definition GridSample3D.h:61
typename texture_t::SizeType SizeType
Definition GridSample3D.h:52