Loading [MathJax]/extensions/TeX/AMSsymbols.js
Open3D (C++ API)  0.14.1
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
TensorList.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2018-2021 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
28 
29 #include <cstddef>
30 #include <memory>
31 #include <string>
32 
33 #include "open3d/core/Blob.h"
34 #include "open3d/core/Device.h"
35 #include "open3d/core/Dtype.h"
36 #include "open3d/core/ShapeUtil.h"
37 #include "open3d/core/SizeVector.h"
38 #include "open3d/core/Tensor.h"
39 #include "open3d/core/TensorKey.h"
40 
41 namespace open3d {
42 namespace core {
43 
58 class TensorList {
59 public:
62 
69  TensorList(const SizeVector& element_shape,
70  Dtype dtype,
71  const Device& device = Device("CPU:0"))
72  : element_shape_(element_shape),
73  size_(0),
76  dtype,
77  device) {}
78 
84  TensorList(const std::vector<Tensor>& tensors)
85  : TensorList(tensors.begin(), tensors.end()) {}
86 
94  TensorList(int64_t size,
95  const SizeVector& element_shape,
96  Dtype dtype,
97  const Device& device = Device("CPU:0"))
98  : element_shape_(element_shape),
99  size_(size),
102  dtype,
103  device) {}
104 
110  TensorList(const std::initializer_list<Tensor>& tensors)
111  : TensorList(tensors.begin(), tensors.end()) {}
112 
118  template <class InputIterator>
119  TensorList(InputIterator begin, InputIterator end) {
120  int64_t size = std::distance(begin, end);
121  if (size == 0) {
123  "Empty input tensors cannot initialize a tensorlist.");
124  }
125 
126  // Set size_ and reserved_size_.
127  size_ = size;
129 
130  // Check shape consistency and set element_shape_.
131  element_shape_ = begin->GetShape();
132  std::for_each(begin, end, [&](const Tensor& tensor) -> void {
133  if (tensor.GetShape() != element_shape_) {
135  "Tensors must have the same shape {}, but got {}.",
136  element_shape_, tensor.GetShape());
137  }
138  });
139 
140  // Check dtype consistency.
141  Dtype dtype = begin->GetDtype();
142  std::for_each(begin, end, [&](const Tensor& tensor) -> void {
143  if (tensor.GetDtype() != dtype) {
145  "Tensors must have the same dtype {}, but got {}.",
146  dtype.ToString(), tensor.GetDtype().ToString());
147  }
148  });
149 
150  // Check device consistency.
151  Device device = begin->GetDevice();
152  std::for_each(begin, end, [&](const Tensor& tensor) -> void {
153  if (tensor.GetDevice() != device) {
155  "Tensors must have the same device {}, but got {}.",
156  device.ToString(), tensor.GetDevice().ToString());
157  }
158  });
159 
160  // Construct internal tensor.
163  dtype, device);
164  size_t i = 0;
165  for (auto iter = begin; iter != end; ++iter, ++i) {
166  internal_tensor_[i] = *iter;
167  }
168  }
169 
183  static TensorList FromTensor(const Tensor& tensor, bool inplace = false);
184 
187  TensorList(const TensorList& other) = default;
188 
191  TensorList(TensorList&& other) = default;
192 
195  TensorList& operator=(const TensorList& other) & = default;
196 
199  TensorList& operator=(TensorList&& other) & = default;
200 
204  void CopyFrom(const TensorList& other);
205 
208  TensorList Clone() const;
209 
211  Tensor AsTensor() const;
212 
217  void Resize(int64_t new_size);
218 
225  void PushBack(const Tensor& tensor);
226 
231  void Extend(const TensorList& other);
232 
236  static TensorList Concatenate(const TensorList& a, const TensorList& b);
237 
239  TensorList operator+(const TensorList& other) const {
240  return Concatenate(*this, other);
241  }
242 
246  Extend(other);
247  return *this;
248  }
249 
252  Tensor operator[](int64_t index) const;
253 
256  void Clear();
257 
258  std::string ToString() const;
259 
261 
262  void AssertElementShape(const SizeVector& expected_element_shape) const {
263  if (expected_element_shape != element_shape_) {
265  "TensorList has element shape {}, but is expected to have "
266  "element shape {}.",
267  element_shape_, expected_element_shape);
268  }
269  }
270 
271  void AssertDevice(const Device& expected_device) const {
272  if (GetDevice() != expected_device) {
274  "TensorList has device {}, but is expected to be {}.",
275  GetDevice().ToString(), expected_device.ToString());
276  }
277  }
278 
280 
281  Dtype GetDtype() const { return internal_tensor_.GetDtype(); }
282 
283  int64_t GetSize() const { return size_; }
284 
285  int64_t GetReservedSize() const { return reserved_size_; }
286 
287  const Tensor& GetInternalTensor() const { return internal_tensor_; }
288 
289  bool IsResizable() const { return is_resizable_; }
290 
291 protected:
293  TensorList(const SizeVector element_shape,
294  int64_t size,
295  int64_t reserved_size,
296  const Tensor& internal_tensor,
297  bool is_resizable)
298  : element_shape_(element_shape),
299  size_(size),
300  reserved_size_(reserved_size),
301  internal_tensor_(internal_tensor),
302  is_resizable_(is_resizable) {}
303 
312  void ResizeWithExpand(int64_t new_size);
313 
316  static int64_t ComputeReserveSize(int64_t size);
317 
318 protected:
321 
325  int64_t size_ = 0;
326 
336  int64_t reserved_size_ = 0;
337 
340 
344  bool is_resizable_ = true;
345 };
346 } // namespace core
347 } // namespace open3d
static TensorList FromTensor(const Tensor &tensor, bool inplace=false)
Definition: TensorList.cpp:47
SizeVector GetElementShape() const
Definition: TensorList.h:260
TensorList(int64_t size, const SizeVector &element_shape, Dtype dtype, const Device &device=Device("CPU:0"))
Definition: TensorList.h:94
bool is_resizable_
Definition: TensorList.h:344
TensorList & operator=(const TensorList &other) &=default
int64_t reserved_size_
Definition: TensorList.h:336
int64_t GetSize() const
Definition: TensorList.h:283
const Tensor & GetInternalTensor() const
Definition: TensorList.h:287
Definition: Dtype.h:39
TensorList Clone() const
Definition: TensorList.cpp:75
TensorList()
Useful to support operator[] in a map.
Definition: TensorList.h:61
Tensor AsTensor() const
Return the reference of the contained valid tensors with shared memory.
Definition: TensorList.cpp:89
void CopyFrom(const TensorList &other)
Definition: TensorList.cpp:81
Dtype GetDtype() const
Definition: TensorList.h:281
const Dtype Float32
Definition: Dtype.cpp:61
TensorList(InputIterator begin, InputIterator end)
Definition: TensorList.h:119
std::string ToString() const
Definition: Dtype.h:83
Definition: TensorList.h:58
Device GetDevice() const
Definition: Tensor.cpp:1365
Definition: SizeVector.h:79
SizeVector Concat(const SizeVector &l_shape, const SizeVector &r_shape)
Concatenate two shapes.
Definition: ShapeUtil.cpp:218
TensorList(const SizeVector element_shape, int64_t size, int64_t reserved_size, const Tensor &internal_tensor, bool is_resizable)
Fully specified constructor.
Definition: TensorList.h:293
Dtype GetDtype() const
Definition: Tensor.h:1094
TensorList(const SizeVector &element_shape, Dtype dtype, const Device &device=Device("CPU:0"))
Definition: TensorList.h:69
static TensorList Concatenate(const TensorList &a, const TensorList &b)
Definition: TensorList.cpp:140
static int64_t ComputeReserveSize(int64_t size)
Definition: TensorList.cpp:175
Device GetDevice() const
Definition: TensorList.h:279
TensorList(const std::vector< Tensor > &tensors)
Definition: TensorList.h:84
TensorList operator+(const TensorList &other) const
Concatenate two tensorlists.
Definition: TensorList.h:239
Definition: Device.h:39
Tensor internal_tensor_
The internal tensor for data storage.
Definition: TensorList.h:339
TensorList(const std::initializer_list< Tensor > &tensors)
Definition: TensorList.h:110
bool IsResizable() const
Definition: TensorList.h:289
TensorList & operator+=(const TensorList &other)
Definition: TensorList.h:245
void AssertDevice(const Device &expected_device) const
Definition: TensorList.h:271
void AssertElementShape(const SizeVector &expected_element_shape) const
Definition: TensorList.h:262
SizeVector GetShape() const
Definition: Tensor.h:1057
void ResizeWithExpand(int64_t new_size)
Definition: TensorList.cpp:159
Definition: PinholeCameraIntrinsic.cpp:35
Definition: Tensor.h:50
void PushBack(const Tensor &tensor)
Definition: TensorList.cpp:102
SizeVector element_shape_
The shape for each element tensor in the tensorlist.
Definition: TensorList.h:320
int64_t size_
Definition: TensorList.h:325
void Extend(const TensorList &other)
Definition: TensorList.cpp:113
int64_t GetReservedSize() const
Definition: TensorList.h:285
void Clear()
Definition: TensorList.cpp:153
std::string ToString() const
Definition: Device.h:75
int size
Definition: FilePCD.cpp:59
std::string ToString() const
Definition: TensorList.cpp:204
Tensor operator[](int64_t index) const
Definition: TensorList.cpp:147
#define LogError(...)
Definition: Logging.h:72
void Resize(int64_t new_size)
Definition: TensorList.cpp:93