Loading [MathJax]/extensions/TeX/AMSsymbols.js
Open3D (C++ API)  0.14.1
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
RaggedTensor.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2018-2021 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #include <torch/custom_class.h>
28 #include <torch/script.h>
29 
30 #include <vector>
31 
33 
38 struct RaggedTensor : torch::CustomClassHolder {
39 public:
41 
43  RaggedTensor(torch::Tensor values, torch::Tensor row_splits)
44  : _values(values), _row_splits(row_splits) {}
45 
53  c10::intrusive_ptr<RaggedTensor> FromRowSplits(torch::Tensor values,
54  torch::Tensor row_splits,
55  bool validate = true) const;
56 
58  torch::Tensor GetValues() const;
59 
61  torch::Tensor GetRowSplits() const;
62 
64  std::string ToString() const;
65 
71  torch::Tensor GetItem(int key) const;
72 
76  int64_t Len() const;
77 
79  c10::intrusive_ptr<RaggedTensor> Clone() const;
80 
81  c10::intrusive_ptr<RaggedTensor> Concat(
82  c10::intrusive_ptr<RaggedTensor> r_tensor, int64_t axis) const;
83 
84  template <typename T>
85  c10::intrusive_ptr<RaggedTensor> Add(T value) const {
86  return FromRowSplits(_values + value, _row_splits, false);
87  }
88 
89  template <typename T>
90  c10::intrusive_ptr<RaggedTensor> Add_(T value) {
91  _values += value;
92  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
93  }
94 
95  template <typename T>
96  c10::intrusive_ptr<RaggedTensor> Sub(T value) const {
97  return FromRowSplits(_values - value, _row_splits, false);
98  }
99 
100  template <typename T>
101  c10::intrusive_ptr<RaggedTensor> Sub_(T value) {
102  _values -= value;
103  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
104  }
105 
106  template <typename T>
107  c10::intrusive_ptr<RaggedTensor> Mul(T value) const {
108  return FromRowSplits(_values * value, _row_splits, false);
109  }
110 
111  template <typename T>
112  c10::intrusive_ptr<RaggedTensor> Mul_(T value) {
113  _values *= value;
114  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
115  }
116 
117  template <typename T>
118  c10::intrusive_ptr<RaggedTensor> Div(T value) const {
119  return FromRowSplits(_values / value, _row_splits, false);
120  }
121 
122  template <typename T>
123  c10::intrusive_ptr<RaggedTensor> Div_(T value) {
124  _values /= value;
125  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
126  }
127 
128  template <typename T>
129  c10::intrusive_ptr<RaggedTensor> FloorDiv(T value) const {
130  return FromRowSplits(_values.floor_divide(value), _row_splits, false);
131  }
132 
133  template <typename T>
134  c10::intrusive_ptr<RaggedTensor> FloorDiv_(T value) {
135  _values.floor_divide_(value);
136  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
137  }
138 
139 private:
140  torch::Tensor _values, _row_splits;
141 };
142 
143 static auto registry =
144  torch::class_<RaggedTensor>("my_classes", "RaggedTensor")
145  .def(torch::init<>())
146  .def("from_row_splits", &RaggedTensor::FromRowSplits)
147  .def("get_values", &RaggedTensor::GetValues)
148  .def("get_row_splits", &RaggedTensor::GetRowSplits)
149  .def("__repr__",
150  [](const c10::intrusive_ptr<RaggedTensor>& self) {
151  return self->ToString();
152  })
153  .def("__str__",
154  [](const c10::intrusive_ptr<RaggedTensor>& self) {
155  return self->ToString();
156  })
157  .def("__getitem__",
158  [](const c10::intrusive_ptr<RaggedTensor>& self,
159  int64_t key) { return self->GetItem(key); })
160  .def("__len__", &RaggedTensor::Len)
161  .def("clone", &RaggedTensor::Clone)
162  .def("concat", &RaggedTensor::Concat)
163 
164  .def("add",
165  [](const c10::intrusive_ptr<RaggedTensor>& self,
166  torch::Tensor value) { return self->Add(value); })
167  .def("add_",
168  [](const c10::intrusive_ptr<RaggedTensor>& self,
169  torch::Tensor value) { return self->Add_(value); })
170  .def("__add__",
171  [](const c10::intrusive_ptr<RaggedTensor>& self,
172  torch::Tensor value) { return self->Add(value); })
173  .def("__iadd__",
174  [](const c10::intrusive_ptr<RaggedTensor>& self,
175  torch::Tensor value) { return self->Add_(value); })
176 
177  .def("sub",
178  [](const c10::intrusive_ptr<RaggedTensor>& self,
179  torch::Tensor value) { return self->Sub(value); })
180  .def("sub_",
181  [](const c10::intrusive_ptr<RaggedTensor>& self,
182  torch::Tensor value) { return self->Sub_(value); })
183  .def("__sub__",
184  [](const c10::intrusive_ptr<RaggedTensor>& self,
185  torch::Tensor value) { return self->Sub(value); })
186  .def("__isub__",
187  [](const c10::intrusive_ptr<RaggedTensor>& self,
188  torch::Tensor value) { return self->Sub_(value); })
189 
190  .def("mul",
191  [](const c10::intrusive_ptr<RaggedTensor>& self,
192  torch::Tensor value) { return self->Mul(value); })
193  .def("mul_",
194  [](const c10::intrusive_ptr<RaggedTensor>& self,
195  torch::Tensor value) { return self->Mul_(value); })
196  .def("__mul__",
197  [](const c10::intrusive_ptr<RaggedTensor>& self,
198  torch::Tensor value) { return self->Mul(value); })
199  .def("__imul__",
200  [](const c10::intrusive_ptr<RaggedTensor>& self,
201  torch::Tensor value) { return self->Mul_(value); })
202 
203  .def("div",
204  [](const c10::intrusive_ptr<RaggedTensor>& self,
205  torch::Tensor value) { return self->Div(value); })
206  .def("div_",
207  [](const c10::intrusive_ptr<RaggedTensor>& self,
208  torch::Tensor value) { return self->Div_(value); })
209  .def("__truediv__",
210  [](const c10::intrusive_ptr<RaggedTensor>& self,
211  torch::Tensor value) { return self->Div(value); })
212  .def("__itruediv__",
213  [](const c10::intrusive_ptr<RaggedTensor>& self,
214  torch::Tensor value) { return self->Div_(value); })
215  .def("__floordiv__",
216  [](const c10::intrusive_ptr<RaggedTensor>& self,
217  torch::Tensor value) { return self->FloorDiv(value); })
218  .def("__ifloordiv__",
219  [](const c10::intrusive_ptr<RaggedTensor>& self,
220  torch::Tensor value) {
221  return self->FloorDiv_(value);
222  });
c10::intrusive_ptr< RaggedTensor > Clone() const
Copy Tensor to the same device.
Definition: RaggedTensor.cpp:75
std::string ToString() const
Returns string representation.
Definition: RaggedTensor.cpp:61
c10::intrusive_ptr< RaggedTensor > Sub(T value) const
Definition: RaggedTensor.h:96
c10::intrusive_ptr< RaggedTensor > Mul_(T value)
Definition: RaggedTensor.h:112
c10::intrusive_ptr< RaggedTensor > Add(T value) const
Definition: RaggedTensor.h:85
c10::intrusive_ptr< RaggedTensor > Div(T value) const
Definition: RaggedTensor.h:118
c10::intrusive_ptr< RaggedTensor > Add_(T value)
Definition: RaggedTensor.h:90
int64_t Len() const
Definition: RaggedTensor.cpp:73
c10::intrusive_ptr< RaggedTensor > Concat(c10::intrusive_ptr< RaggedTensor > r_tensor, int64_t axis) const
Definition: RaggedTensor.cpp:79
Definition: RaggedTensor.h:38
c10::intrusive_ptr< RaggedTensor > FloorDiv_(T value)
Definition: RaggedTensor.h:134
RaggedTensor(torch::Tensor values, torch::Tensor row_splits)
Constructor for creating RaggedTensor with values and row_splits.
Definition: RaggedTensor.h:43
c10::intrusive_ptr< RaggedTensor > Div_(T value)
Definition: RaggedTensor.h:123
c10::intrusive_ptr< RaggedTensor > Sub_(T value)
Definition: RaggedTensor.h:101
torch::Tensor GetItem(int key) const
Definition: RaggedTensor.cpp:68
torch::Tensor GetValues() const
Returns _values tensor.
Definition: RaggedTensor.cpp:58
c10::intrusive_ptr< RaggedTensor > FloorDiv(T value) const
Definition: RaggedTensor.h:129
c10::intrusive_ptr< RaggedTensor > FromRowSplits(torch::Tensor values, torch::Tensor row_splits, bool validate=true) const
Definition: RaggedTensor.cpp:31
c10::intrusive_ptr< RaggedTensor > Mul(T value) const
Definition: RaggedTensor.h:107
RaggedTensor()
Definition: RaggedTensor.h:40
torch::Tensor GetRowSplits() const
Returns _row_splits tensor.
Definition: RaggedTensor.cpp:59