Grangeat-based 2D/3D image registration
Loading...
Searching...
No Matches
GridSample3D.h
Go to the documentation of this file.
1
6#pragma once
7
8#include "Common.h"
9#include "Texture.h"
10
11namespace reg23 {
12
28at::Tensor GridSample3D_CPU(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX,
29 const std::string &addressModeY, const std::string &addressModeZ,
30 c10::optional<at::Tensor> out);
31
36__host__ at::Tensor GridSample3D_CUDA(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX,
37 const std::string &addressModeY, const std::string &addressModeZ,
38 c10::optional<at::Tensor> out);
39
46template <typename texture_t> struct GridSample3D {
47
48 static_assert(texture_t::DIMENSIONALITY == 3);
49
50 using IntType = typename texture_t::IntType;
51 using FloatType = typename texture_t::FloatType;
52 using SizeType = typename texture_t::SizeType;
53 using VectorType = typename texture_t::VectorType;
54 using AddressModeType = typename texture_t::AddressModeType;
55
56 struct CommonData {
57 texture_t inputTexture{};
58 at::Tensor flatOutput{};
59 };
60
61 __host__ static CommonData Common(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX,
62 const std::string &addressModeY, const std::string &addressModeZ,
63 at::DeviceType device, c10::optional<at::Tensor> out) {
64 // input should be a 3D tensor of floats on the chosen device
65 TORCH_CHECK(input.sizes().size() == 3)
66 TORCH_CHECK(input.dtype() == at::kFloat)
67 TORCH_INTERNAL_ASSERT(input.device().type() == device)
68 // grid should be a tensor of floats with a final dimension of 3 on the chosen device
69 TORCH_CHECK(grid.sizes().back() == 3);
70 TORCH_CHECK(grid.dtype() == at::kFloat);
71 TORCH_INTERNAL_ASSERT(grid.device().type() == device)
72 if (out) {
73 // out should be a tensor of floats matching all but the last dimension of grid in size, on the chosen device
74 TORCH_CHECK(out.value().sizes() == grid.sizes().slice(0, grid.sizes().size() - 1))
75 TORCH_CHECK(out.value().dtype() == at::kFloat)
76 TORCH_INTERNAL_ASSERT(out.value().device().type() == device)
77 }
78
79 // All addressMode<dim>s should be one of the valid values:
81
83 const SizeType inputSize = SizeType::FromIntArrayRef(input.sizes()).Flipped();
85 ret.inputTexture = texture_t::FromTensor(input, inputSpacing, VectorType::Full(0.0), std::move(addressModes));
86 ret.flatOutput = out
87 ? out.value().view({-1})
88 : torch::zeros(at::IntArrayRef({grid.numel() / 3}), input.contiguous().options());
89 return ret;
90 }
91};
92
93} // namespace reg23
General tools and structs.
#define __host__
Definition Global.h:17
at::Tensor GridSample3D_CPU(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX, const std::string &addressModeY, const std::string &addressModeZ, c10::optional< at::Tensor > out)
Sample the given 3D input tensor at the positions given in grid according to the given address mode u...
Definition GridSample3DCPU.cpp:10
__host__ at::Tensor GridSample3D_CUDA(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX, const std::string &addressModeY, const std::string &addressModeZ, c10::optional< at::Tensor > out)
An implementation of reg23::GridSample3D_CPU that uses CUDA parallelisation.
Vec< TextureAddressMode, DIMENSIONALITY > StringsToAddressModes(const std::array< std::string_view, DIMENSIONALITY > &strings)
Definition Texture.h:44
Definition GridSample3DCPU.cpp:6
Definition GridSample3D.h:56
at::Tensor flatOutput
Definition GridSample3D.h:58
texture_t inputTexture
Definition GridSample3D.h:57
Definition GridSample3D.h:46
typename texture_t::AddressModeType AddressModeType
Definition GridSample3D.h:54
typename texture_t::FloatType FloatType
Definition GridSample3D.h:51
typename texture_t::IntType IntType
Definition GridSample3D.h:50
typename texture_t::VectorType VectorType
Definition GridSample3D.h:53
static __host__ CommonData Common(const at::Tensor &input, const at::Tensor &grid, const std::string &addressModeX, const std::string &addressModeY, const std::string &addressModeZ, at::DeviceType device, c10::optional< at::Tensor > out)
Definition GridSample3D.h:61
typename texture_t::SizeType SizeType
Definition GridSample3D.h:52