Open3D (C++ API)  0.19.0
TensorFlowHelper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 #include <tensorflow/core/framework/op_kernel.h>
10 #include <tensorflow/core/framework/shape_inference.h>
11 #include <tensorflow/core/framework/tensor.h>
12 #include <tensorflow/core/lib/core/errors.h>
13 
14 #include "absl/status/status.h"
16 
17 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
18  ::tensorflow::shape_inference::InferenceContext* c,
19  ::tensorflow::shape_inference::ShapeHandle shape_handle) {
20  using namespace open3d::ml::op_util;
21  if (!c->RankKnown(shape_handle)) {
22  return std::vector<DimValue>();
23  }
24 
25  std::vector<DimValue> shape;
26  const int rank = c->Rank(shape_handle);
27  for (int i = 0; i < rank; ++i) {
28  auto d = c->DimKnownRank(shape_handle, i);
29  if (c->ValueKnown(d)) {
30  shape.push_back(c->Value(d));
31  } else {
32  shape.push_back(DimValue());
33  }
34  }
35  return shape;
36 }
37 
39  class TDimX,
40  class... TArgs>
41 std::tuple<bool, std::string> CheckShape(
42  ::tensorflow::shape_inference::InferenceContext* c,
43  ::tensorflow::shape_inference::ShapeHandle shape_handle,
44  TDimX&& dimex,
45  TArgs&&... args) {
46  if (!c->RankKnown(shape_handle)) {
47  // without rank we cannot check
48  return std::make_tuple(true, std::string());
49  }
50  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(c, shape_handle),
51  std::forward<TDimX>(dimex),
52  std::forward<TArgs>(args)...);
53 }
54 
55 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
56  const tensorflow::Tensor& tensor) {
57  using namespace open3d::ml::op_util;
58 
59  std::vector<DimValue> shape;
60  for (int i = 0; i < tensor.dims(); ++i) {
61  shape.push_back(tensor.dim_size(i));
62  }
63  return shape;
64 }
65 
67  class TDimX,
68  class... TArgs>
69 std::tuple<bool, std::string> CheckShape(const tensorflow::Tensor& tensor,
70  TDimX&& dimex,
71  TArgs&&... args) {
72  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
73  std::forward<TDimX>(dimex),
74  std::forward<TArgs>(args)...);
75 }
76 
77 //
78 // Helper function for creating a ShapeHandle from dim expressions.
79 // Dim expressions which are not constant will translate to unknown dims in
80 // the returned shape handle.
81 //
82 // Usage:
83 // // ctx is of type tensorflow::shape_inference::InferenceContext*
84 // {
85 // using namespace open3d::ml::op_util;
86 // Dim w("w");
87 // Dim h("h");
88 // CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
89 // // 10 and assigns w and h
90 // // based on the shape of
91 // // handle1
92 //
93 // CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
94 // // last dim of handle2 matches the
95 // // last dim of handle1. The first
96 // // two dims must match 10, 20.
97 //
98 // ShapeHandle out_shape = MakeShapeHandle(ctx, Dim(), h, w);
99 // ctx->set_output(0, out_shape);
100 // }
101 //
102 //
103 // See "../ShapeChecking.h" for more info and limitations.
104 //
105 template <class TDimX, class... TArgs>
106 ::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(
107  ::tensorflow::shape_inference::InferenceContext* ctx,
108  TDimX&& dimex,
109  TArgs&&... args) {
110  using namespace tensorflow::shape_inference;
111  using namespace open3d::ml::op_util;
112  std::vector<int64_t> shape = CreateDimVector(
113  int64_t(InferenceContext::kUnknownDim), dimex, args...);
114  std::vector<DimensionHandle> dims;
115  for (int64_t d : shape) {
116  dims.push_back(ctx->MakeDim(d));
117  }
118  return ctx->MakeShape(dims);
119 }
120 
121 //
122 // Macros for checking the shape of ShapeHandle during shape inference.
123 //
124 // Usage:
125 // // ctx is of type tensorflow::shape_inference::InferenceContext*
126 // {
127 // using namespace open3d::ml::op_util;
128 // Dim w("w");
129 // Dim h("h");
130 // CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
131 // // 10 and assigns w and h
132 // // based on the shape of
133 // // handle1
134 //
135 // CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
136 // // last dim of handle2 matches the
137 // // last dim of handle1. The first
138 // // two dims must match 10, 20.
139 // }
140 //
141 //
142 // See "../ShapeChecking.h" for more info and limitations.
143 //
144 #define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \
145  do { \
146  bool cs_success_; \
147  std::string cs_errstr_; \
148  std::tie(cs_success_, cs_errstr_) = \
149  CheckShape(ctx, shape_handle, __VA_ARGS__); \
150  if (TF_PREDICT_FALSE(!cs_success_)) { \
151  return absl::InvalidArgumentError( \
152  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
153  } \
154  } while (0)
155 
156 #define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \
157  do { \
158  bool cs_success_; \
159  std::string cs_errstr_; \
160  std::tie(cs_success_, cs_errstr_) = \
161  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \
162  __VA_ARGS__); \
163  if (TF_PREDICT_FALSE(!cs_success_)) { \
164  return absl::InvalidArgumentError( \
165  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
166  } \
167  } while (0)
168 
169 #define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \
170  do { \
171  bool cs_success_; \
172  std::string cs_errstr_; \
173  std::tie(cs_success_, cs_errstr_) = \
174  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \
175  __VA_ARGS__); \
176  if (TF_PREDICT_FALSE(!cs_success_)) { \
177  return absl::InvalidArgumentError( \
178  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
179  } \
180  } while (0)
181 
182 #define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \
183  do { \
184  bool cs_success_; \
185  std::string cs_errstr_; \
186  std::tie(cs_success_, cs_errstr_) = \
187  CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \
188  __VA_ARGS__); \
189  if (TF_PREDICT_FALSE(!cs_success_)) { \
190  return absl::InvalidArgumentError( \
191  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
192  } \
193  } while (0)
194 
195 #define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \
196  do { \
197  bool cs_success_; \
198  std::string cs_errstr_; \
199  std::tie(cs_success_, cs_errstr_) = \
200  CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \
201  __VA_ARGS__); \
202  if (TF_PREDICT_FALSE(!cs_success_)) { \
203  return absl::InvalidArgumentError( \
204  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
205  } \
206  } while (0)
207 
208 //
209 // Macros for checking the shape of Tensors.
210 // Usage:
211 // // ctx is of type tensorflow::OpKernelContext*
212 // {
213 // using namespace open3d::ml::op_util;
214 // Dim w("w");
215 // Dim h("h");
216 // CHECK_SHAPE(ctx, tensor1, 10, w, h); // checks if the first dim is 10
217 // // and assigns w and h based on
218 // // the shape of tensor1
219 //
220 // CHECK_SHAPE(ctx, tensor2, 10, 20, h); // this checks if the the last dim
221 // // of tensor2 matches the last dim
222 // // of tensor1. The first two dims
223 // // must match 10, 20.
224 // }
225 //
226 //
227 // See "../ShapeChecking.h" for more info and limitations.
228 //
229 #define CHECK_SHAPE(ctx, tensor, ...) \
230  do { \
231  bool cs_success_; \
232  std::string cs_errstr_; \
233  std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
234  OP_REQUIRES( \
235  ctx, cs_success_, \
236  absl::InvalidArgumentError( \
237  "invalid shape for '" #tensor "', " + cs_errstr_)); \
238  } while (0)
239 
240 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \
241  do { \
242  bool cs_success_; \
243  std::string cs_errstr_; \
244  std::tie(cs_success_, cs_errstr_) = \
245  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
246  OP_REQUIRES( \
247  ctx, cs_success_, \
248  absl::InvalidArgumentError( \
249  "invalid shape for '" #tensor "', " + cs_errstr_)); \
250  } while (0)
251 
252 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \
253  do { \
254  bool cs_success_; \
255  std::string cs_errstr_; \
256  std::tie(cs_success_, cs_errstr_) = \
257  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
258  OP_REQUIRES( \
259  ctx, cs_success_, \
260  absl::InvalidArgumentError( \
261  "invalid shape for '" #tensor "', " + cs_errstr_)); \
262  } while (0)
263 
264 #define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \
265  do { \
266  bool cs_success_; \
267  std::string cs_errstr_; \
268  std::tie(cs_success_, cs_errstr_) = \
269  CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
270  OP_REQUIRES( \
271  ctx, cs_success_, \
272  absl::InvalidArgumentError( \
273  "invalid shape for '" #tensor "', " + cs_errstr_)); \
274  } while (0)
275 
276 #define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \
277  do { \
278  bool cs_success_; \
279  std::string cs_errstr_; \
280  std::tie(cs_success_, cs_errstr_) = \
281  CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
282  OP_REQUIRES( \
283  ctx, cs_success_, \
284  absl::InvalidArgumentError( \
285  "invalid shape for '" #tensor "', " + cs_errstr_)); \
286  } while (0)
std::tuple< bool, std::string > CheckShape(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:41
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:106
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:17
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:19
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:358