28 #include <tensorflow/core/framework/op_kernel.h> 29 #include <tensorflow/core/framework/shape_inference.h> 30 #include <tensorflow/core/framework/tensor.h> 31 #include <tensorflow/core/lib/core/errors.h> 36 ::tensorflow::shape_inference::InferenceContext* c,
37 ::tensorflow::shape_inference::ShapeHandle shape_handle) {
39 if (!c->RankKnown(shape_handle)) {
40 return std::vector<DimValue>();
43 std::vector<DimValue> shape;
44 const int rank = c->Rank(shape_handle);
45 for (
int i = 0; i < rank; ++i) {
46 auto d = c->DimKnownRank(shape_handle, i);
47 if (c->ValueKnown(d)) {
48 shape.push_back(c->Value(d));
60 ::tensorflow::shape_inference::InferenceContext* c,
61 ::tensorflow::shape_inference::ShapeHandle shape_handle,
64 if (!c->RankKnown(shape_handle)) {
66 return std::make_tuple(
true, std::string());
68 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(c, shape_handle),
69 std::forward<TDimX>(dimex),
70 std::forward<TArgs>(args)...);
74 const tensorflow::Tensor& tensor) {
77 std::vector<DimValue> shape;
78 for (
int i = 0; i < tensor.dims(); ++i) {
79 shape.push_back(tensor.dim_size(i));
87 std::tuple<bool, std::string>
CheckShape(
const tensorflow::Tensor& tensor,
90 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(tensor),
91 std::forward<TDimX>(dimex),
92 std::forward<TArgs>(args)...);
123 template <
class TDimX,
class... TArgs>
125 ::tensorflow::shape_inference::InferenceContext* ctx,
128 using namespace tensorflow::shape_inference;
131 int64_t(InferenceContext::kUnknownDim), dimex, args...);
132 std::vector<DimensionHandle> dims;
133 for (int64_t d : shape) {
134 dims.push_back(ctx->MakeDim(d));
136 return ctx->MakeShape(dims);
162 #define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \ 165 std::string cs_errstr_; \ 166 std::tie(cs_success_, cs_errstr_) = \ 167 CheckShape(ctx, shape_handle, __VA_ARGS__); \ 168 if (TF_PREDICT_FALSE(!cs_success_)) { \ 169 return tensorflow::errors::InvalidArgument( \ 170 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 174 #define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \ 177 std::string cs_errstr_; \ 178 std::tie(cs_success_, cs_errstr_) = \ 179 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \ 181 if (TF_PREDICT_FALSE(!cs_success_)) { \ 182 return tensorflow::errors::InvalidArgument( \ 183 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 187 #define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \ 190 std::string cs_errstr_; \ 191 std::tie(cs_success_, cs_errstr_) = \ 192 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \ 194 if (TF_PREDICT_FALSE(!cs_success_)) { \ 195 return tensorflow::errors::InvalidArgument( \ 196 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 200 #define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \ 203 std::string cs_errstr_; \ 204 std::tie(cs_success_, cs_errstr_) = \ 205 CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \ 207 if (TF_PREDICT_FALSE(!cs_success_)) { \ 208 return tensorflow::errors::InvalidArgument( \ 209 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 213 #define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \ 216 std::string cs_errstr_; \ 217 std::tie(cs_success_, cs_errstr_) = \ 218 CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \ 220 if (TF_PREDICT_FALSE(!cs_success_)) { \ 221 return tensorflow::errors::InvalidArgument( \ 222 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 247 #define CHECK_SHAPE(ctx, tensor, ...) \ 250 std::string cs_errstr_; \ 251 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \ 254 tensorflow::errors::InvalidArgument( \ 255 "invalid shape for '" #tensor "', " + cs_errstr_)); \ 258 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \ 261 std::string cs_errstr_; \ 262 std::tie(cs_success_, cs_errstr_) = \ 263 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \ 266 tensorflow::errors::InvalidArgument( \ 267 "invalid shape for '" #tensor "', " + cs_errstr_)); \ 270 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \ 273 std::string cs_errstr_; \ 274 std::tie(cs_success_, cs_errstr_) = \ 275 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \ 278 tensorflow::errors::InvalidArgument( \ 279 "invalid shape for '" #tensor "', " + cs_errstr_)); \ 282 #define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \ 285 std::string cs_errstr_; \ 286 std::tie(cs_success_, cs_errstr_) = \ 287 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \ 290 tensorflow::errors::InvalidArgument( \ 291 "invalid shape for '" #tensor "', " + cs_errstr_)); \ 294 #define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \ 297 std::string cs_errstr_; \ 298 std::tie(cs_success_, cs_errstr_) = \ 299 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \ 302 tensorflow::errors::InvalidArgument( \ 303 "invalid shape for '" #tensor "', " + cs_errstr_)); \ Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:38
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:377
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:35
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:124
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:593
Definition: ShapeChecking.h:35