Loading [MathJax]/extensions/TeX/AMSsymbols.js
Open3D (C++ API)  0.14.1
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
TorchHelper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2018-2021 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
28 #include <torch/script.h>
29 
30 #include <sstream>
31 #include <type_traits>
32 
34 
35 // Macros for checking tensor properties
36 #define CHECK_CUDA(x) \
37  do { \
38  TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
39  } while (0)
40 
41 #define CHECK_CONTIGUOUS(x) \
42  do { \
43  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
44  } while (0)
45 
46 #define CHECK_TYPE(x, type) \
47  do { \
48  TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
49  } while (0)
50 
51 #define CHECK_SAME_DEVICE_TYPE(...) \
52  do { \
53  if (!SameDeviceType({__VA_ARGS__})) { \
54  TORCH_CHECK( \
55  false, \
56  #__VA_ARGS__ \
57  " must all have the same device type but got " + \
58  TensorInfoStr({__VA_ARGS__})) \
59  } \
60  } while (0)
61 
62 #define CHECK_SAME_DTYPE(...) \
63  do { \
64  if (!SameDtype({__VA_ARGS__})) { \
65  TORCH_CHECK(false, \
66  #__VA_ARGS__ \
67  " must all have the same dtype but got " + \
68  TensorInfoStr({__VA_ARGS__})) \
69  } \
70  } while (0)
71 
72 // Conversion from standard types to torch types
74 template <class T>
76  TORCH_CHECK(false, "Unsupported type");
77 }
78 template <>
80  return torch::kUInt8;
81 }
82 template <>
84  return torch::kInt8;
85 }
86 template <>
88  return torch::kInt16;
89 }
90 template <>
92  return torch::kInt32;
93 }
94 template <>
96  return torch::kInt64;
97 }
98 template <>
100  return torch::kFloat32;
101 }
102 template <>
104  return torch::kFloat64;
105 }
106 
107 // convenience function for comparing standard types with torch types
108 template <class T, class TDtype>
109 inline bool CompareTorchDtype(const TDtype& t) {
110  return ToTorchDtype<T>() == t;
111 }
112 
113 // convenience function to check if all tensors have the same device type
114 inline bool SameDeviceType(std::initializer_list<torch::Tensor> tensors) {
115  if (tensors.size()) {
116  auto device_type = tensors.begin()->device().type();
117  for (auto t : tensors) {
118  if (device_type != t.device().type()) {
119  return false;
120  }
121  }
122  }
123  return true;
124 }
125 
126 // convenience function to check if all tensors have the same dtype
127 inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
128  if (tensors.size()) {
129  auto dtype = tensors.begin()->dtype();
130  for (auto t : tensors) {
131  if (dtype != t.dtype()) {
132  return false;
133  }
134  }
135  }
136  return true;
137 }
138 
139 inline std::string TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
140  std::stringstream sstr;
141  size_t count = 0;
142  for (const auto t : tensors) {
143  sstr << t.sizes() << " " << t.toString() << " " << t.device();
144  ++count;
145  if (count < tensors.size()) sstr << ", ";
146  }
147  return sstr.str();
148 }
149 
150 // convenience function for creating a tensor for temp memory
151 inline torch::Tensor CreateTempTensor(const int64_t size,
152  const torch::Device& device,
153  void** ptr = nullptr) {
154  torch::Tensor tensor = torch::empty(
155  {size}, torch::dtype(ToTorchDtype<uint8_t>()).device(device));
156  if (ptr) {
157  *ptr = tensor.data_ptr<uint8_t>();
158  }
159  return tensor;
160 }
161 
162 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
163  torch::Tensor tensor) {
164  using namespace open3d::ml::op_util;
165 
166  std::vector<DimValue> shape;
167  const int rank = tensor.dim();
168  for (int i = 0; i < rank; ++i) {
169  shape.push_back(tensor.size(i));
170  }
171  return shape;
172 }
173 
175  class TDimX,
176  class... TArgs>
177 std::tuple<bool, std::string> CheckShape(torch::Tensor tensor,
178  TDimX&& dimex,
179  TArgs&&... args) {
180  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
181  std::forward<TDimX>(dimex),
182  std::forward<TArgs>(args)...);
183 }
184 
185 //
186 // Macros for checking the shape of Tensors.
187 // Usage:
188 // {
189 // using namespace open3d::ml::op_util;
190 // Dim w("w");
191 // Dim h("h");
192 // CHECK_SHAPE(tensor1, 10, w, h); // checks if the first dim is 10
193 // // and assigns w and h based on
194 // // the shape of tensor1
195 //
196 // CHECK_SHAPE(tensor2, 10, 20, h); // this checks if the the last dim
197 // // of tensor2 matches the last dim
198 // // of tensor1. The first two dims
199 // // must match 10, 20.
200 // }
201 //
202 //
203 // See "../ShapeChecking.h" for more info and limitations.
204 //
205 #define CHECK_SHAPE(tensor, ...) \
206  do { \
207  bool cs_success_; \
208  std::string cs_errstr_; \
209  std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
210  TORCH_CHECK(cs_success_, \
211  "invalid shape for '" #tensor "', " + cs_errstr_) \
212  } while (0)
213 
214 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \
215  do { \
216  bool cs_success_; \
217  std::string cs_errstr_; \
218  std::tie(cs_success_, cs_errstr_) = \
219  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
220  TORCH_CHECK(cs_success_, \
221  "invalid shape for '" #tensor "', " + cs_errstr_) \
222  } while (0)
223 
224 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \
225  do { \
226  bool cs_success_; \
227  std::string cs_errstr_; \
228  std::tie(cs_success_, cs_errstr_) = \
229  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
230  TORCH_CHECK(cs_success_, \
231  "invalid shape for '" #tensor "', " + cs_errstr_) \
232  } while (0)
233 
234 #define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \
235  do { \
236  bool cs_success_; \
237  std::string cs_errstr_; \
238  std::tie(cs_success_, cs_errstr_) = \
239  CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
240  TORCH_CHECK(cs_success_, \
241  "invalid shape for '" #tensor "', " + cs_errstr_) \
242  } while (0)
243 
244 #define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \
245  do { \
246  bool cs_success_; \
247  std::string cs_errstr_; \
248  std::tie(cs_success_, cs_errstr_) = \
249  CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
250  TORCH_CHECK(cs_success_, \
251  "invalid shape for '" #tensor "', " + cs_errstr_) \
252  } while (0)
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:162
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:87
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:79
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:99
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:95
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:151
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:114
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:83
int count
Definition: FilePCD.cpp:61
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:139
char type
Definition: FilePCD.cpp:60
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:73
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:109
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:103
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:593
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:75
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:91
int size
Definition: FilePCD.cpp:59
Definition: ShapeChecking.h:35
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:127