11 #pragma GCC diagnostic ignored "-Warray-bounds"
12 #pragma GCC diagnostic ignored "-Wstringop-overflow"
13 #include <torch/script.h>
16 #include <type_traits>
21 #define CHECK_CUDA(x) \
23 TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
26 #define CHECK_CONTIGUOUS(x) \
28 TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
31 #define CHECK_TYPE(x, type) \
33 TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
36 #define CHECK_SAME_DEVICE_TYPE(...) \
38 if (!SameDeviceType({__VA_ARGS__})) { \
42 " must all have the same device type but got " + \
43 TensorInfoStr({__VA_ARGS__})) \
47 #define CHECK_SAME_DTYPE(...) \
49 if (!SameDtype({__VA_ARGS__})) { \
52 " must all have the same dtype but got " + \
53 TensorInfoStr({__VA_ARGS__})) \
61 TORCH_CHECK(
false,
"Unsupported type");
85 return torch::kFloat32;
89 return torch::kFloat64;
93 template <
class T,
class TDtype>
95 return ToTorchDtype<T>() ==
t;
100 if (tensors.size()) {
101 auto device_type = tensors.begin()->device().type();
102 for (
const auto&
t : tensors) {
103 if (device_type !=
t.device().type()) {
112 inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
113 if (tensors.size()) {
114 auto dtype = tensors.begin()->dtype();
115 for (
const auto&
t : tensors) {
116 if (dtype !=
t.dtype()) {
124 inline std::string
TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
125 std::stringstream sstr;
127 for (
const auto&
t : tensors) {
128 sstr <<
t.sizes() <<
" " <<
t.toString() <<
" " <<
t.device();
130 if (
count < tensors.size()) sstr <<
", ";
137 const torch::Device& device,
138 void** ptr =
nullptr) {
139 torch::Tensor tensor = torch::empty(
142 *ptr = tensor.data_ptr<uint8_t>();
148 torch::Tensor tensor) {
151 std::vector<DimValue> shape;
152 const int rank = tensor.dim();
153 for (
int i = 0; i < rank; ++i) {
154 shape.push_back(tensor.size(i));
162 std::tuple<bool, std::string>
CheckShape(torch::Tensor tensor,
165 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(tensor),
166 std::forward<TDimX>(dimex),
167 std::forward<TArgs>(args)...);
190 #define CHECK_SHAPE(tensor, ...) \
193 std::string cs_errstr_; \
194 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
195 TORCH_CHECK(cs_success_, \
196 "invalid shape for '" #tensor "', " + cs_errstr_) \
199 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \
202 std::string cs_errstr_; \
203 std::tie(cs_success_, cs_errstr_) = \
204 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
205 TORCH_CHECK(cs_success_, \
206 "invalid shape for '" #tensor "', " + cs_errstr_) \
209 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \
212 std::string cs_errstr_; \
213 std::tie(cs_success_, cs_errstr_) = \
214 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
215 TORCH_CHECK(cs_success_, \
216 "invalid shape for '" #tensor "', " + cs_errstr_) \
219 #define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \
222 std::string cs_errstr_; \
223 std::tie(cs_success_, cs_errstr_) = \
224 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
225 TORCH_CHECK(cs_success_, \
226 "invalid shape for '" #tensor "', " + cs_errstr_) \
229 #define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \
232 std::string cs_errstr_; \
233 std::tie(cs_success_, cs_errstr_) = \
234 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
235 TORCH_CHECK(cs_success_, \
236 "invalid shape for '" #tensor "', " + cs_errstr_) \
double t
Definition: SurfaceReconstructionPoisson.cpp:172
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition: TorchHelper.h:162
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:80
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:64
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:124
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:147
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:72
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:68
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:88
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:112
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:99
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:58
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:60
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:136
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:76
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:94
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:84
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405