Grangeat-based 2D/3D image registration
Loading...
Searching...
No Matches
ResampleSinogram3D.h
Go to the documentation of this file.
1#pragma once
2
3#include "Common.h"
4
5namespace reg23 {
6
33at::Tensor ResampleSinogram3D_CPU(const at::Tensor &sinogram3d, const std::string &sinogramType, double rSpacing,
34 const at::Tensor &projectionMatrix, const at::Tensor &phiValues,
35 const at::Tensor &rValues, c10::optional<at::Tensor> out);
36
41__host__ at::Tensor ResampleSinogram3D_CUDA(const at::Tensor &sinogram3d, const std::string &sinogramType,
42 double rSpacing, const at::Tensor &projectionMatrix,
43 const at::Tensor &phiValues, const at::Tensor &rValues,
44 c10::optional<at::Tensor> out);
45
53 const std::string &sinogramType, double rSpacing,
54 const at::Tensor &projectionMatrix, const at::Tensor &phiValues,
55 const at::Tensor &rValues, c10::optional<at::Tensor> out);
56
57
63
64 enum class SinogramType { CLASSIC, HEALPIX };
65
71
77
78 __host__ static CommonData Common(const std::string &sinogramType, const at::Tensor &projectionMatrix,
79 const at::Tensor &phiValues, const at::Tensor &rValues, at::DeviceType device,
80 c10::optional<at::Tensor> out) {
81 // projectionMatrix should be of size (4, 4), contain floats and be on the chosen device
82 TORCH_CHECK(projectionMatrix.sizes() == at::IntArrayRef({4, 4}))
83 TORCH_CHECK(projectionMatrix.dtype() == at::kFloat)
85 // phiValues and rValues should be of the same size, contain floats and be on the chosen device
86 TORCH_CHECK(phiValues.sizes() == rValues.sizes())
87 TORCH_CHECK(phiValues.dtype() == at::kFloat)
88 TORCH_CHECK(rValues.dtype() == at::kFloat)
89 TORCH_INTERNAL_ASSERT(phiValues.device().type() == device)
90 TORCH_INTERNAL_ASSERT(rValues.device().type() == device);
91 if (out) {
92 TORCH_CHECK(out.value().sizes() == phiValues.sizes())
93 TORCH_CHECK(out.value().dtype() == at::kFloat)
94 TORCH_INTERNAL_ASSERT(out.value().device().type() == device)
95 }
96
98 if (sinogramType == "healpix") {
100 } else if (sinogramType != "classic") {
101 TORCH_WARN("Invalid sinogram type string given. Valid values are: 'classic', 'healpix'. Assuming default "
102 "value: 'classic'.")
103 }
104
106 torch::tensor(
107 {{0.f, 0.f, 0.f, 1.f}},
108 projectionMatrix.options()).t());
109
110 CommonData ret{};
112 ret.geometry.originProjection = Vec<float, 2>{originProjectionHomogeneous[0].item().toFloat(),
113 originProjectionHomogeneous[1].item().toFloat()} /
114 originProjectionHomogeneous[3].item().toFloat();
115 ret.geometry.squareRadius = .25f * ret.geometry.originProjection.Apply<float>(&Square<float>).Sum();
116 ret.geometry.projectionMatrixTranspose = Vec<Vec<float, 4>, 4>::FromTensor2D(projectionMatrix.t());
117 ret.flatOutput = out
118 ? out.value().view({-1})
119 : torch::zeros(at::IntArrayRef({phiValues.numel()}), at::TensorOptions{device});
120 return ret;
121 }
122
123 template <typename sinogram_t> __host__ __device__ static float ResamplePlane(
124 const sinogram_t &sinogram, const ConstantGeometry &geometry, float phi, float r) {
125 const float cp = std::cos(phi);
126 const float sp = std::sin(phi);
127 const Vec<float, 4> intermediate = MatMul(geometry.projectionMatrixTranspose, Vec<float, 4>{cp, sp, 0.f, -r});
130 &Square<float>).Sum();
131
133 rThetaPhi[2] = std::atan2(posCartesian.Y(), posCartesian.X());
134 const float magXY = posCartesian.X() * posCartesian.X() + posCartesian.Y() * posCartesian.Y();
135 rThetaPhi[1] = std::atan2(posCartesian.Z(), std::sqrt(magXY));
136 rThetaPhi[0] = std::sqrt(magXY + posCartesian.Z() * posCartesian.Z());
138
139 float ret = sinogram.Sample(rThetaPhi);
140
141 if ((r * Vec<float, 2>{cp, sp} - .5f * geometry.originProjection).Apply<float>(&Square<float>).Sum() < geometry.
142 squareRadius) {
143 ret *= -1.f;
144 }
145 return ret;
146 }
147};
148
149} // namespace reg23
General tools and structs.
#define __host__
Definition Global.h:17
#define __device__
Definition Global.h:22
A simple vector class derived from std::array<T, N>, providing overrides for all useful operators.
Definition Vec.h:21
__host__ __device__ constexpr Vec< newT, N > Apply(const std::function< newT(T)> &f) const
Map all elements with a common std::function mapping function.
Definition Vec.h:285
__host__ __device__ constexpr const T & W() const
Get a constant reference to the fourth element.
Definition Vec.h:459
__host__ __device__ constexpr Vec< T, 3 > XYZ() const
Construct a Vec from the first three elements.
Definition Vec.h:480
__host__ __device__ Vec< T, 3 > UnflipSphericalCoordinate(const Vec< T, 3 > &coordSph)
'Unflips' the given spherical coordinates so that theta and phi both lie between -pi/2 and pi/2
Definition Common.h:108
__host__ __device__ constexpr Vec< T, R > MatMul(const Vec< Vec< T, R >, C > &lhs, const Vec< T, C > &rhs)
Matrix-vector multiplication of the Vec struct.
Definition Vec.h:894
at::Tensor ResampleSinogram3D_CPU(const at::Tensor &sinogram3d, const std::string &sinogramType, const double rSpacing, const at::Tensor &projectionMatrix, const at::Tensor &phiValues, const at::Tensor &rValues, c10::optional< at::Tensor > out)
Resample the given 3D sinogram at locations corresponding to the given 2D sinogram grid (phiValues,...
Definition ResampleSinogram3DCPU.cpp:12
__host__ at::Tensor ResampleSinogram3DCUDATexture(int64_t sinogram3dTextureHandle, int64_t sinogramWidth, int64_t sinogramHeight, int64_t sinogramDepth, const std::string &sinogramType, double rSpacing, const at::Tensor &projectionMatrix, const at::Tensor &phiValues, const at::Tensor &rValues, c10::optional< at::Tensor > out)
An implementation of reg23::ResampleSinogram3D_CUDA that takes a handle to a pre-allocated CUDA textu...
__host__ at::Tensor ResampleSinogram3D_CUDA(const at::Tensor &sinogram3d, const std::string &sinogramType, double rSpacing, const at::Tensor &projectionMatrix, const at::Tensor &phiValues, const at::Tensor &rValues, c10::optional< at::Tensor > out)
An implementation of reg23::ResampleSinogram3D_CPU that uses CUDA parallelisation.
Vec< TextureAddressMode, DIMENSIONALITY > StringsToAddressModes(const std::array< std::string_view, DIMENSIONALITY > &strings)
Definition Texture.h:44
Definition GridSample3DCPU.cpp:6
Definition ResampleSinogram3D.h:72
ConstantGeometry geometry
Definition ResampleSinogram3D.h:74
SinogramType sinogramType
Definition ResampleSinogram3D.h:73
at::Tensor flatOutput
Definition ResampleSinogram3D.h:75
Definition ResampleSinogram3D.h:66
float squareRadius
Definition ResampleSinogram3D.h:68
Vec< Vec< float, 4 >, 4 > projectionMatrixTranspose
Definition ResampleSinogram3D.h:69
Vec< float, 2 > originProjection
Definition ResampleSinogram3D.h:67
Definition ResampleSinogram3D.h:62
SinogramType
Definition ResampleSinogram3D.h:64
static __host__ CommonData Common(const std::string &sinogramType, const at::Tensor &projectionMatrix, const at::Tensor &phiValues, const at::Tensor &rValues, at::DeviceType device, c10::optional< at::Tensor > out)
Definition ResampleSinogram3D.h:78
__host__ static __device__ float ResamplePlane(const sinogram_t &sinogram, const ConstantGeometry &geometry, float phi, float r)
Definition ResampleSinogram3D.h:123