Open3D (C++ API)  0.19.0
TorchHelper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 // https://stackoverflow.com/q/77034039 : False Alarm warnings from PyTorch
10 // headers
11 #pragma GCC diagnostic ignored "-Warray-bounds"
12 #pragma GCC diagnostic ignored "-Wstringop-overflow"
13 #include <torch/script.h>
14 
15 #include <sstream>
16 #include <type_traits>
17 
19 
20 // Macros for checking tensor properties
21 #define CHECK_CUDA(x) \
22  do { \
23  TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
24  } while (0)
25 
26 #define CHECK_CONTIGUOUS(x) \
27  do { \
28  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
29  } while (0)
30 
31 #define CHECK_TYPE(x, type) \
32  do { \
33  TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
34  } while (0)
35 
36 #define CHECK_SAME_DEVICE_TYPE(...) \
37  do { \
38  if (!SameDeviceType({__VA_ARGS__})) { \
39  TORCH_CHECK( \
40  false, \
41  #__VA_ARGS__ \
42  " must all have the same device type but got " + \
43  TensorInfoStr({__VA_ARGS__})) \
44  } \
45  } while (0)
46 
47 #define CHECK_SAME_DTYPE(...) \
48  do { \
49  if (!SameDtype({__VA_ARGS__})) { \
50  TORCH_CHECK(false, \
51  #__VA_ARGS__ \
52  " must all have the same dtype but got " + \
53  TensorInfoStr({__VA_ARGS__})) \
54  } \
55  } while (0)
56 
57 // Conversion from standard types to torch types
58 typedef std::remove_const<decltype(torch::kInt32)>::type TorchDtype_t;
59 template <class T>
61  TORCH_CHECK(false, "Unsupported type");
62 }
63 template <>
65  return torch::kUInt8;
66 }
67 template <>
69  return torch::kInt8;
70 }
71 template <>
73  return torch::kInt16;
74 }
75 template <>
77  return torch::kInt32;
78 }
79 template <>
81  return torch::kInt64;
82 }
83 template <>
85  return torch::kFloat32;
86 }
87 template <>
89  return torch::kFloat64;
90 }
91 
92 // convenience function for comparing standard types with torch types
93 template <class T, class TDtype>
94 inline bool CompareTorchDtype(const TDtype& t) {
95  return ToTorchDtype<T>() == t;
96 }
97 
98 // convenience function to check if all tensors have the same device type
99 inline bool SameDeviceType(std::initializer_list<torch::Tensor> tensors) {
100  if (tensors.size()) {
101  auto device_type = tensors.begin()->device().type();
102  for (const auto& t : tensors) {
103  if (device_type != t.device().type()) {
104  return false;
105  }
106  }
107  }
108  return true;
109 }
110 
111 // convenience function to check if all tensors have the same dtype
112 inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
113  if (tensors.size()) {
114  auto dtype = tensors.begin()->dtype();
115  for (const auto& t : tensors) {
116  if (dtype != t.dtype()) {
117  return false;
118  }
119  }
120  }
121  return true;
122 }
123 
124 inline std::string TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
125  std::stringstream sstr;
126  size_t count = 0;
127  for (const auto& t : tensors) {
128  sstr << t.sizes() << " " << t.toString() << " " << t.device();
129  ++count;
130  if (count < tensors.size()) sstr << ", ";
131  }
132  return sstr.str();
133 }
134 
135 // convenience function for creating a tensor for temp memory
136 inline torch::Tensor CreateTempTensor(const int64_t size,
137  const torch::Device& device,
138  void** ptr = nullptr) {
139  torch::Tensor tensor = torch::empty(
140  {size}, torch::dtype(ToTorchDtype<uint8_t>()).device(device));
141  if (ptr) {
142  *ptr = tensor.data_ptr<uint8_t>();
143  }
144  return tensor;
145 }
146 
147 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
148  torch::Tensor tensor) {
149  using namespace open3d::ml::op_util;
150 
151  std::vector<DimValue> shape;
152  const int rank = tensor.dim();
153  for (int i = 0; i < rank; ++i) {
154  shape.push_back(tensor.size(i));
155  }
156  return shape;
157 }
158 
160  class TDimX,
161  class... TArgs>
162 std::tuple<bool, std::string> CheckShape(torch::Tensor tensor,
163  TDimX&& dimex,
164  TArgs&&... args) {
165  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
166  std::forward<TDimX>(dimex),
167  std::forward<TArgs>(args)...);
168 }
169 
170 //
171 // Macros for checking the shape of Tensors.
172 // Usage:
173 // {
174 // using namespace open3d::ml::op_util;
175 // Dim w("w");
176 // Dim h("h");
177 // CHECK_SHAPE(tensor1, 10, w, h); // checks if the first dim is 10
178 // // and assigns w and h based on
179 // // the shape of tensor1
180 //
181 // CHECK_SHAPE(tensor2, 10, 20, h); // this checks if the the last dim
182 // // of tensor2 matches the last dim
183 // // of tensor1. The first two dims
184 // // must match 10, 20.
185 // }
186 //
187 //
188 // See "../ShapeChecking.h" for more info and limitations.
189 //
190 #define CHECK_SHAPE(tensor, ...) \
191  do { \
192  bool cs_success_; \
193  std::string cs_errstr_; \
194  std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
195  TORCH_CHECK(cs_success_, \
196  "invalid shape for '" #tensor "', " + cs_errstr_) \
197  } while (0)
198 
199 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \
200  do { \
201  bool cs_success_; \
202  std::string cs_errstr_; \
203  std::tie(cs_success_, cs_errstr_) = \
204  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
205  TORCH_CHECK(cs_success_, \
206  "invalid shape for '" #tensor "', " + cs_errstr_) \
207  } while (0)
208 
209 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \
210  do { \
211  bool cs_success_; \
212  std::string cs_errstr_; \
213  std::tie(cs_success_, cs_errstr_) = \
214  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
215  TORCH_CHECK(cs_success_, \
216  "invalid shape for '" #tensor "', " + cs_errstr_) \
217  } while (0)
218 
219 #define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \
220  do { \
221  bool cs_success_; \
222  std::string cs_errstr_; \
223  std::tie(cs_success_, cs_errstr_) = \
224  CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
225  TORCH_CHECK(cs_success_, \
226  "invalid shape for '" #tensor "', " + cs_errstr_) \
227  } while (0)
228 
229 #define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \
230  do { \
231  bool cs_success_; \
232  std::string cs_errstr_; \
233  std::tie(cs_success_, cs_errstr_) = \
234  CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
235  TORCH_CHECK(cs_success_, \
236  "invalid shape for '" #tensor "', " + cs_errstr_) \
237  } while (0)
double t
Definition: SurfaceReconstructionPoisson.cpp:172
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition: TorchHelper.h:162
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:80
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:64
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:124
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:147
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:72
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:68
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:88
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:112
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:99
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:58
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:60
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:136
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:76
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:94
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:84
int size
Definition: FilePCD.cpp:40
int count
Definition: FilePCD.cpp:42
char type
Definition: FilePCD.cpp:41
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405