23 const std::vector<Tensor>& index_tensors)
44 const std::vector<Tensor>& index_tensors);
49 const Tensor& tensor,
const std::vector<Tensor>& index_tensors);
53 static std::pair<std::vector<Tensor>,
SizeVector>
86 const std::vector<Tensor>& index_tensors);
122 const std::vector<Tensor>& index_tensors,
127 if (indexed_shape.
size() != indexed_strides.
size()) {
129 "Internal error: indexed_shape's ndim {} does not equal to "
130 "indexed_strides' ndim {}",
131 indexed_shape.
size(), indexed_strides.
size());
136 std::vector<Tensor> inputs;
137 inputs.push_back(src);
138 for (
const Tensor& index_tensor : index_tensors) {
139 if (index_tensor.NumDims() != 0) {
140 inputs.push_back(index_tensor);
148 "Internal error: indexed_shape's ndim {} does not equal to "
149 "indexd_strides' ndim {}",
160 "src's dtype {} is not the same as dst's dtype {}.",
184 int64_t index = *(
reinterpret_cast<int64_t*
>(
#define OPEN3D_HOST_DEVICE
Definition: CUDAUtils.h:44
#define LogError(...)
Definition: Logging.h:51
#define OPEN3D_ASSERT(...)
Definition: Macro.h:51
This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp.
Definition: AdvancedIndexing.h:20
void RunPreprocess()
Preprocess tensor and index tensors.
Definition: AdvancedIndexing.cpp:110
static Tensor RestrideTensor(const Tensor &tensor, int64_t dims_before, int64_t dims_indexed, SizeVector replacement_shape)
Definition: AdvancedIndexing.cpp:85
static bool IsIndexSplittedBySlice(const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:17
std::vector< Tensor > GetIndexTensors() const
Definition: AdvancedIndexing.h:30
static std::pair< Tensor, std::vector< Tensor > > ShuffleIndexedDimsToFront(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:41
SizeVector output_shape_
Output shape.
Definition: AdvancedIndexing.h:96
SizeVector GetIndexedStrides() const
Definition: AdvancedIndexing.h:38
std::vector< Tensor > index_tensors_
The processed index tensors.
Definition: AdvancedIndexing.h:93
Tensor tensor_
Definition: AdvancedIndexing.h:90
static std::vector< Tensor > ExpandBoolTensors(const std::vector< Tensor > &index_tensors)
Expand boolean tensor to integer index.
Definition: AdvancedIndexing.cpp:230
static std::pair< std::vector< Tensor >, SizeVector > ExpandToCommonShapeExceptZeroDim(const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:63
AdvancedIndexPreprocessor(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.h:22
static Tensor RestrideIndexTensor(const Tensor &index_tensor, int64_t dims_before, int64_t dims_after)
Definition: AdvancedIndexing.cpp:100
SizeVector indexed_shape_
Definition: AdvancedIndexing.h:100
SizeVector GetIndexedShape() const
Definition: AdvancedIndexing.h:36
SizeVector GetOutputShape() const
Definition: AdvancedIndexing.h:34
Tensor GetTensor() const
Definition: AdvancedIndexing.h:28
SizeVector indexed_strides_
Definition: AdvancedIndexing.h:104
Definition: AdvancedIndexing.h:116
Indexer indexer_
Definition: AdvancedIndexing.h:197
AdvancedIndexer(const Tensor &src, const Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides, AdvancedIndexerMode mode)
Definition: AdvancedIndexing.h:120
int64_t element_byte_size_
Definition: AdvancedIndexing.h:200
int64_t NumWorkloads() const
Definition: AdvancedIndexing.h:194
int64_t num_indices_
Definition: AdvancedIndexing.h:199
OPEN3D_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
Definition: AdvancedIndexing.h:173
int64_t indexed_strides_[MAX_DIMS]
Definition: AdvancedIndexing.h:202
AdvancedIndexerMode mode_
Definition: AdvancedIndexing.h:198
OPEN3D_HOST_DEVICE char * GetInputPtr(int64_t workload_idx) const
Definition: AdvancedIndexing.h:166
int64_t indexed_shape_[MAX_DIMS]
Definition: AdvancedIndexing.h:201
OPEN3D_HOST_DEVICE int64_t GetIndexedOffset(int64_t workload_idx) const
Definition: AdvancedIndexing.h:181
AdvancedIndexerMode
Definition: AdvancedIndexing.h:118
std::string ToString() const
Definition: Dtype.h:64
int64_t ByteSize() const
Definition: Dtype.h:58
Definition: Indexer.h:261
OPEN3D_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
Definition: Indexer.h:437
int64_t NumWorkloads() const
Definition: Indexer.cpp:406
OPEN3D_HOST_DEVICE char * GetInputPtr(int64_t input_idx, int64_t workload_idx) const
Definition: Indexer.h:405
Definition: SizeVector.h:69
size_t size() const
Definition: SmallVector.h:119
Dtype GetDtype() const
Definition: Tensor.h:1163
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t timeout_in_ms capture_handle capture_handle capture_handle image_handle temperature_c k4a_image_t image_handle uint8_t image_handle image_handle image_handle image_handle image_handle timestamp_usec white_balance image_handle k4a_device_configuration_t config device_handle char size_t serial_number_size bool int32_t int32_t int32_t int32_t k4a_color_control_mode_t default_mode mode
Definition: K4aPlugin.cpp:678
Definition: PinholeCameraIntrinsic.cpp:16