9 #include <tensorflow/core/framework/op_kernel.h>
10 #include <tensorflow/core/framework/shape_inference.h>
11 #include <tensorflow/core/framework/tensor.h>
12 #include <tensorflow/core/lib/core/errors.h>
14 #include "absl/status/status.h"
18 ::tensorflow::shape_inference::InferenceContext* c,
19 ::tensorflow::shape_inference::ShapeHandle shape_handle) {
21 if (!c->RankKnown(shape_handle)) {
22 return std::vector<DimValue>();
25 std::vector<DimValue> shape;
26 const int rank = c->Rank(shape_handle);
27 for (
int i = 0; i < rank; ++i) {
28 auto d = c->DimKnownRank(shape_handle, i);
29 if (c->ValueKnown(d)) {
30 shape.push_back(c->Value(d));
42 ::tensorflow::shape_inference::InferenceContext* c,
43 ::tensorflow::shape_inference::ShapeHandle shape_handle,
46 if (!c->RankKnown(shape_handle)) {
48 return std::make_tuple(
true, std::string());
50 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(c, shape_handle),
51 std::forward<TDimX>(dimex),
52 std::forward<TArgs>(args)...);
56 const tensorflow::Tensor& tensor) {
59 std::vector<DimValue> shape;
60 for (
int i = 0; i < tensor.dims(); ++i) {
61 shape.push_back(tensor.dim_size(i));
69 std::tuple<bool, std::string>
CheckShape(
const tensorflow::Tensor& tensor,
72 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(tensor),
73 std::forward<TDimX>(dimex),
74 std::forward<TArgs>(args)...);
105 template <
class TDimX,
class... TArgs>
107 ::tensorflow::shape_inference::InferenceContext* ctx,
110 using namespace tensorflow::shape_inference;
113 int64_t(InferenceContext::kUnknownDim), dimex, args...);
114 std::vector<DimensionHandle> dims;
115 for (int64_t d : shape) {
116 dims.push_back(ctx->MakeDim(d));
118 return ctx->MakeShape(dims);
144 #define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \
147 std::string cs_errstr_; \
148 std::tie(cs_success_, cs_errstr_) = \
149 CheckShape(ctx, shape_handle, __VA_ARGS__); \
150 if (TF_PREDICT_FALSE(!cs_success_)) { \
151 return absl::InvalidArgumentError( \
152 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
156 #define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \
159 std::string cs_errstr_; \
160 std::tie(cs_success_, cs_errstr_) = \
161 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \
163 if (TF_PREDICT_FALSE(!cs_success_)) { \
164 return absl::InvalidArgumentError( \
165 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
169 #define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \
172 std::string cs_errstr_; \
173 std::tie(cs_success_, cs_errstr_) = \
174 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \
176 if (TF_PREDICT_FALSE(!cs_success_)) { \
177 return absl::InvalidArgumentError( \
178 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
182 #define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \
185 std::string cs_errstr_; \
186 std::tie(cs_success_, cs_errstr_) = \
187 CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \
189 if (TF_PREDICT_FALSE(!cs_success_)) { \
190 return absl::InvalidArgumentError( \
191 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
195 #define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \
198 std::string cs_errstr_; \
199 std::tie(cs_success_, cs_errstr_) = \
200 CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \
202 if (TF_PREDICT_FALSE(!cs_success_)) { \
203 return absl::InvalidArgumentError( \
204 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
229 #define CHECK_SHAPE(ctx, tensor, ...) \
232 std::string cs_errstr_; \
233 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
236 absl::InvalidArgumentError( \
237 "invalid shape for '" #tensor "', " + cs_errstr_)); \
240 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \
243 std::string cs_errstr_; \
244 std::tie(cs_success_, cs_errstr_) = \
245 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
248 absl::InvalidArgumentError( \
249 "invalid shape for '" #tensor "', " + cs_errstr_)); \
252 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \
255 std::string cs_errstr_; \
256 std::tie(cs_success_, cs_errstr_) = \
257 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
260 absl::InvalidArgumentError( \
261 "invalid shape for '" #tensor "', " + cs_errstr_)); \
264 #define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \
267 std::string cs_errstr_; \
268 std::tie(cs_success_, cs_errstr_) = \
269 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
272 absl::InvalidArgumentError( \
273 "invalid shape for '" #tensor "', " + cs_errstr_)); \
276 #define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \
279 std::string cs_errstr_; \
280 std::tie(cs_success_, cs_errstr_) = \
281 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
284 absl::InvalidArgumentError( \
285 "invalid shape for '" #tensor "', " + cs_errstr_)); \
std::tuple< bool, std::string > CheckShape(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:41
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:106
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:17
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:19
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:358