27 #include <torch/custom_class.h> 28 #include <torch/script.h> 44 : _values(values), _row_splits(row_splits) {}
53 c10::intrusive_ptr<RaggedTensor>
FromRowSplits(torch::Tensor values,
54 torch::Tensor row_splits,
55 bool validate =
true)
const;
71 torch::Tensor
GetItem(
int key)
const;
79 c10::intrusive_ptr<RaggedTensor>
Clone()
const;
81 c10::intrusive_ptr<RaggedTensor>
Concat(
82 c10::intrusive_ptr<RaggedTensor> r_tensor, int64_t axis)
const;
85 c10::intrusive_ptr<RaggedTensor>
Add(T value)
const {
90 c10::intrusive_ptr<RaggedTensor>
Add_(T value) {
92 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
96 c10::intrusive_ptr<RaggedTensor>
Sub(T value)
const {
100 template <
typename T>
101 c10::intrusive_ptr<RaggedTensor>
Sub_(T value) {
103 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
106 template <
typename T>
107 c10::intrusive_ptr<RaggedTensor>
Mul(T value)
const {
111 template <
typename T>
112 c10::intrusive_ptr<RaggedTensor>
Mul_(T value) {
114 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
117 template <
typename T>
118 c10::intrusive_ptr<RaggedTensor>
Div(T value)
const {
122 template <
typename T>
123 c10::intrusive_ptr<RaggedTensor>
Div_(T value) {
125 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
128 template <
typename T>
129 c10::intrusive_ptr<RaggedTensor>
FloorDiv(T value)
const {
130 return FromRowSplits(_values.floor_divide(value), _row_splits,
false);
133 template <
typename T>
135 _values.floor_divide_(value);
136 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
140 torch::Tensor _values, _row_splits;
143 static auto registry =
144 torch::class_<RaggedTensor>(
"my_classes",
"RaggedTensor")
145 .def(torch::init<>())
150 [](
const c10::intrusive_ptr<RaggedTensor>&
self) {
151 return self->ToString();
154 [](
const c10::intrusive_ptr<RaggedTensor>&
self) {
155 return self->ToString();
158 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
159 int64_t key) {
return self->GetItem(key); })
165 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
166 torch::Tensor value) {
return self->Add(value); })
168 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
169 torch::Tensor value) {
return self->Add_(value); })
171 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
172 torch::Tensor value) {
return self->Add(value); })
174 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
175 torch::Tensor value) {
return self->Add_(value); })
178 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
179 torch::Tensor value) {
return self->Sub(value); })
181 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
182 torch::Tensor value) {
return self->Sub_(value); })
184 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
185 torch::Tensor value) {
return self->Sub(value); })
187 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
188 torch::Tensor value) {
return self->Sub_(value); })
191 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
192 torch::Tensor value) {
return self->Mul(value); })
194 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
195 torch::Tensor value) {
return self->Mul_(value); })
197 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
198 torch::Tensor value) {
return self->Mul(value); })
200 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
201 torch::Tensor value) {
return self->Mul_(value); })
204 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
205 torch::Tensor value) {
return self->Div(value); })
207 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
208 torch::Tensor value) {
return self->Div_(value); })
210 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
211 torch::Tensor value) {
return self->Div(value); })
213 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
214 torch::Tensor value) {
return self->Div_(value); })
216 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
217 torch::Tensor value) {
return self->FloorDiv(value); })
218 .def(
"__ifloordiv__",
219 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
220 torch::Tensor value) {
221 return self->FloorDiv_(value);
c10::intrusive_ptr< RaggedTensor > Clone() const
Copy Tensor to the same device.
Definition: RaggedTensor.cpp:75
std::string ToString() const
Returns string representation.
Definition: RaggedTensor.cpp:61
c10::intrusive_ptr< RaggedTensor > Sub(T value) const
Definition: RaggedTensor.h:96
c10::intrusive_ptr< RaggedTensor > Mul_(T value)
Definition: RaggedTensor.h:112
c10::intrusive_ptr< RaggedTensor > Add(T value) const
Definition: RaggedTensor.h:85
c10::intrusive_ptr< RaggedTensor > Div(T value) const
Definition: RaggedTensor.h:118
c10::intrusive_ptr< RaggedTensor > Add_(T value)
Definition: RaggedTensor.h:90
int64_t Len() const
Definition: RaggedTensor.cpp:73
c10::intrusive_ptr< RaggedTensor > Concat(c10::intrusive_ptr< RaggedTensor > r_tensor, int64_t axis) const
Definition: RaggedTensor.cpp:79
Definition: RaggedTensor.h:38
c10::intrusive_ptr< RaggedTensor > FloorDiv_(T value)
Definition: RaggedTensor.h:134
RaggedTensor(torch::Tensor values, torch::Tensor row_splits)
Constructor for creating RaggedTensor with values and row_splits.
Definition: RaggedTensor.h:43
c10::intrusive_ptr< RaggedTensor > Div_(T value)
Definition: RaggedTensor.h:123
c10::intrusive_ptr< RaggedTensor > Sub_(T value)
Definition: RaggedTensor.h:101
torch::Tensor GetItem(int key) const
Definition: RaggedTensor.cpp:68
torch::Tensor GetValues() const
Returns _values tensor.
Definition: RaggedTensor.cpp:58
c10::intrusive_ptr< RaggedTensor > FloorDiv(T value) const
Definition: RaggedTensor.h:129
c10::intrusive_ptr< RaggedTensor > FromRowSplits(torch::Tensor values, torch::Tensor row_splits, bool validate=true) const
Definition: RaggedTensor.cpp:31
c10::intrusive_ptr< RaggedTensor > Mul(T value) const
Definition: RaggedTensor.h:107
RaggedTensor()
Definition: RaggedTensor.h:40
torch::Tensor GetRowSplits() const
Returns _row_splits tensor.
Definition: RaggedTensor.cpp:59