Loading [MathJax]/extensions/TeX/AMSsymbols.js
Open3D (C++ API)  0.14.1
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
TensorFlowHelper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2018-2021 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
28 #include <tensorflow/core/framework/op_kernel.h>
29 #include <tensorflow/core/framework/shape_inference.h>
30 #include <tensorflow/core/framework/tensor.h>
31 #include <tensorflow/core/lib/core/errors.h>
32 
34 
35 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
36  ::tensorflow::shape_inference::InferenceContext* c,
37  ::tensorflow::shape_inference::ShapeHandle shape_handle) {
38  using namespace open3d::ml::op_util;
39  if (!c->RankKnown(shape_handle)) {
40  return std::vector<DimValue>();
41  }
42 
43  std::vector<DimValue> shape;
44  const int rank = c->Rank(shape_handle);
45  for (int i = 0; i < rank; ++i) {
46  auto d = c->DimKnownRank(shape_handle, i);
47  if (c->ValueKnown(d)) {
48  shape.push_back(c->Value(d));
49  } else {
50  shape.push_back(DimValue());
51  }
52  }
53  return shape;
54 }
55 
57  class TDimX,
58  class... TArgs>
59 std::tuple<bool, std::string> CheckShape(
60  ::tensorflow::shape_inference::InferenceContext* c,
61  ::tensorflow::shape_inference::ShapeHandle shape_handle,
62  TDimX&& dimex,
63  TArgs&&... args) {
64  if (!c->RankKnown(shape_handle)) {
65  // without rank we cannot check
66  return std::make_tuple(true, std::string());
67  }
68  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(c, shape_handle),
69  std::forward<TDimX>(dimex),
70  std::forward<TArgs>(args)...);
71 }
72 
73 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
74  const tensorflow::Tensor& tensor) {
75  using namespace open3d::ml::op_util;
76 
77  std::vector<DimValue> shape;
78  for (int i = 0; i < tensor.dims(); ++i) {
79  shape.push_back(tensor.dim_size(i));
80  }
81  return shape;
82 }
83 
85  class TDimX,
86  class... TArgs>
87 std::tuple<bool, std::string> CheckShape(const tensorflow::Tensor& tensor,
88  TDimX&& dimex,
89  TArgs&&... args) {
90  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
91  std::forward<TDimX>(dimex),
92  std::forward<TArgs>(args)...);
93 }
94 
95 //
96 // Helper function for creating a ShapeHandle from dim expressions.
97 // Dim expressions which are not constant will translate to unknown dims in
98 // the returned shape handle.
99 //
100 // Usage:
101 // // ctx is of type tensorflow::shape_inference::InferenceContext*
102 // {
103 // using namespace open3d::ml::op_util;
104 // Dim w("w");
105 // Dim h("h");
106 // CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
107 // // 10 and assigns w and h
108 // // based on the shape of
109 // // handle1
110 //
111 // CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
112 // // last dim of handle2 matches the
113 // // last dim of handle1. The first
114 // // two dims must match 10, 20.
115 //
116 // ShapeHandle out_shape = MakeShapeHandle(ctx, Dim(), h, w);
117 // ctx->set_output(0, out_shape);
118 // }
119 //
120 //
121 // See "../ShapeChecking.h" for more info and limitations.
122 //
123 template <class TDimX, class... TArgs>
124 ::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(
125  ::tensorflow::shape_inference::InferenceContext* ctx,
126  TDimX&& dimex,
127  TArgs&&... args) {
128  using namespace tensorflow::shape_inference;
129  using namespace open3d::ml::op_util;
130  std::vector<int64_t> shape = CreateDimVector(
131  int64_t(InferenceContext::kUnknownDim), dimex, args...);
132  std::vector<DimensionHandle> dims;
133  for (int64_t d : shape) {
134  dims.push_back(ctx->MakeDim(d));
135  }
136  return ctx->MakeShape(dims);
137 }
138 
139 //
140 // Macros for checking the shape of ShapeHandle during shape inference.
141 //
142 // Usage:
143 // // ctx is of type tensorflow::shape_inference::InferenceContext*
144 // {
145 // using namespace open3d::ml::op_util;
146 // Dim w("w");
147 // Dim h("h");
148 // CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
149 // // 10 and assigns w and h
150 // // based on the shape of
151 // // handle1
152 //
153 // CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
154 // // last dim of handle2 matches the
155 // // last dim of handle1. The first
156 // // two dims must match 10, 20.
157 // }
158 //
159 //
160 // See "../ShapeChecking.h" for more info and limitations.
161 //
162 #define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \
163  do { \
164  bool cs_success_; \
165  std::string cs_errstr_; \
166  std::tie(cs_success_, cs_errstr_) = \
167  CheckShape(ctx, shape_handle, __VA_ARGS__); \
168  if (TF_PREDICT_FALSE(!cs_success_)) { \
169  return tensorflow::errors::InvalidArgument( \
170  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
171  } \
172  } while (0)
173 
174 #define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \
175  do { \
176  bool cs_success_; \
177  std::string cs_errstr_; \
178  std::tie(cs_success_, cs_errstr_) = \
179  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \
180  __VA_ARGS__); \
181  if (TF_PREDICT_FALSE(!cs_success_)) { \
182  return tensorflow::errors::InvalidArgument( \
183  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
184  } \
185  } while (0)
186 
187 #define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \
188  do { \
189  bool cs_success_; \
190  std::string cs_errstr_; \
191  std::tie(cs_success_, cs_errstr_) = \
192  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \
193  __VA_ARGS__); \
194  if (TF_PREDICT_FALSE(!cs_success_)) { \
195  return tensorflow::errors::InvalidArgument( \
196  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
197  } \
198  } while (0)
199 
200 #define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \
201  do { \
202  bool cs_success_; \
203  std::string cs_errstr_; \
204  std::tie(cs_success_, cs_errstr_) = \
205  CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \
206  __VA_ARGS__); \
207  if (TF_PREDICT_FALSE(!cs_success_)) { \
208  return tensorflow::errors::InvalidArgument( \
209  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
210  } \
211  } while (0)
212 
213 #define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \
214  do { \
215  bool cs_success_; \
216  std::string cs_errstr_; \
217  std::tie(cs_success_, cs_errstr_) = \
218  CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \
219  __VA_ARGS__); \
220  if (TF_PREDICT_FALSE(!cs_success_)) { \
221  return tensorflow::errors::InvalidArgument( \
222  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
223  } \
224  } while (0)
225 
226 //
227 // Macros for checking the shape of Tensors.
228 // Usage:
229 // // ctx is of type tensorflow::OpKernelContext*
230 // {
231 // using namespace open3d::ml::op_util;
232 // Dim w("w");
233 // Dim h("h");
234 // CHECK_SHAPE(ctx, tensor1, 10, w, h); // checks if the first dim is 10
235 // // and assigns w and h based on
236 // // the shape of tensor1
237 //
238 // CHECK_SHAPE(ctx, tensor2, 10, 20, h); // this checks if the the last dim
239 // // of tensor2 matches the last dim
240 // // of tensor1. The first two dims
241 // // must match 10, 20.
242 // }
243 //
244 //
245 // See "../ShapeChecking.h" for more info and limitations.
246 //
247 #define CHECK_SHAPE(ctx, tensor, ...) \
248  do { \
249  bool cs_success_; \
250  std::string cs_errstr_; \
251  std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
252  OP_REQUIRES( \
253  ctx, cs_success_, \
254  tensorflow::errors::InvalidArgument( \
255  "invalid shape for '" #tensor "', " + cs_errstr_)); \
256  } while (0)
257 
258 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \
259  do { \
260  bool cs_success_; \
261  std::string cs_errstr_; \
262  std::tie(cs_success_, cs_errstr_) = \
263  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
264  OP_REQUIRES( \
265  ctx, cs_success_, \
266  tensorflow::errors::InvalidArgument( \
267  "invalid shape for '" #tensor "', " + cs_errstr_)); \
268  } while (0)
269 
270 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \
271  do { \
272  bool cs_success_; \
273  std::string cs_errstr_; \
274  std::tie(cs_success_, cs_errstr_) = \
275  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
276  OP_REQUIRES( \
277  ctx, cs_success_, \
278  tensorflow::errors::InvalidArgument( \
279  "invalid shape for '" #tensor "', " + cs_errstr_)); \
280  } while (0)
281 
282 #define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \
283  do { \
284  bool cs_success_; \
285  std::string cs_errstr_; \
286  std::tie(cs_success_, cs_errstr_) = \
287  CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
288  OP_REQUIRES( \
289  ctx, cs_success_, \
290  tensorflow::errors::InvalidArgument( \
291  "invalid shape for '" #tensor "', " + cs_errstr_)); \
292  } while (0)
293 
294 #define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \
295  do { \
296  bool cs_success_; \
297  std::string cs_errstr_; \
298  std::tie(cs_success_, cs_errstr_) = \
299  CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
300  OP_REQUIRES( \
301  ctx, cs_success_, \
302  tensorflow::errors::InvalidArgument( \
303  "invalid shape for '" #tensor "', " + cs_errstr_)); \
304  } while (0)
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:38
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:377
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:35
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:124
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:593
Definition: ShapeChecking.h:35