Skip to content

Commit 17a35ca

Browse files
author
Niklas Gustafsson
authored
Merge pull request #926 from NiklasGustafsson/missing
Missing methods implemented
2 parents 64b6999 + 35dbb1b commit 17a35ca

39 files changed

Lines changed: 3394 additions & 2040 deletions

RELEASENOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ Adding allow_tf32<br/>
1111
Adding overloads of Module.save() and Module.load() taking a 'Stream' argument.<br/>
1212
Adding torch.softmax() and Tensor.softmax() as aliases for torch.special.softmax()<br/>
1313
Adding torch.from_file()<br/>
14+
Adding a number of missing pointwise Tensor operations.<br/>
15+
Adding select_scatter, diagonal_scatter, and slice_scatter<br/>
16+
Adding torch.set_printoptions<br/>
17+
Adding torch.cartesian_prod, combinations, and cov.<br/>
18+
Adding torch.cdist, diag_embed, rot90, triu_indices, tril_indices<br/>
1419

1520
__Fixed Bugs__:
1621

src/Native/LibTorchSharp/THSLinearAlgebra.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ Tensor THSLinalg_det(const Tensor tensor)
4747
CATCH_TENSOR(torch::linalg::det(*tensor));
4848
}
4949

50+
Tensor THSTensor_logdet(const Tensor tensor)
51+
{
52+
CATCH_TENSOR(torch::logdet(*tensor));
53+
}
54+
5055
Tensor THSLinalg_slogdet(const Tensor tensor, Tensor* logabsdet)
5156
{
5257
std::tuple<at::Tensor, at::Tensor> res;
@@ -63,6 +68,13 @@ Tensor THSLinalg_eig(const Tensor tensor, Tensor* eigenvectors)
6368
return ResultTensor(std::get<0>(res));
6469
}
6570

71+
Tensor THSTensor_geqrf(const Tensor tensor, Tensor* tau)
72+
{
73+
std::tuple<at::Tensor, at::Tensor> res;
74+
CATCH(res = torch::geqrf(*tensor);)
75+
*tau = ResultTensor(std::get<1>(res));
76+
return ResultTensor(std::get<0>(res));
77+
}
6678

6779
#if 0
6880
Tensor THSTensor_eig(const Tensor tensor, bool vectors, Tensor* eigenvectors)
@@ -98,6 +110,11 @@ Tensor THSLinalg_eigvalsh(const Tensor tensor, const char UPLO)
98110
CATCH_TENSOR(torch::linalg::eigvalsh(*tensor, _uplo));
99111
}
100112

113+
Tensor THSLinalg_householder_product(const Tensor tensor, const Tensor tau)
114+
{
115+
CATCH_TENSOR(torch::linalg::householder_product(*tensor, *tau));
116+
}
117+
101118
Tensor THSLinalg_inv(const Tensor tensor)
102119
{
103120
CATCH_TENSOR(torch::linalg::inv(*tensor));

src/Native/LibTorchSharp/THSTensor.cpp

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ Tensor THSTensor_any_along_dimension(const Tensor tensor, const int64_t dim, boo
6666
{
6767
CATCH_TENSOR(tensor->any(dim, keepdim));
6868
}
69+
70+
Tensor THSTensor_adjoint(const Tensor tensor)
71+
{
72+
CATCH_TENSOR(tensor->adjoint());
73+
}
74+
6975
Tensor THSTensor_argmax(const Tensor tensor)
7076
{
7177
CATCH_TENSOR(tensor->argmax());
@@ -86,6 +92,11 @@ Tensor THSTensor_argmin_along_dimension(const Tensor tensor, const int64_t dim,
8692
CATCH_TENSOR(tensor->argmin(dim, keepdim));
8793
}
8894

95+
Tensor THSTensor_argwhere(const Tensor tensor)
96+
{
97+
CATCH_TENSOR(tensor->argwhere());
98+
}
99+
89100
Tensor THSTensor_atleast_1d(const Tensor tensor)
90101
{
91102
CATCH_TENSOR(torch::atleast_1d(*tensor));
@@ -159,6 +170,11 @@ void THSTensor_vector_to_parameters(const Tensor vec, const Tensor* tensors, con
159170
CATCH(torch::nn::utils::vector_to_parameters(*vec, toTensors<at::Tensor>((torch::Tensor**)tensors, length)););
160171
}
161172

173+
Tensor THSTensor_cartesian_prod(const Tensor* tensors, const int length)
174+
{
175+
CATCH_TENSOR(torch::cartesian_prod(toTensors<at::Tensor>((torch::Tensor**)tensors, length)));
176+
}
177+
162178
double THSTensor_clip_grad_norm_(const Tensor* tensors, const int length, const double max_norm, const double norm_type)
163179
{
164180
double res = 0.0;
@@ -258,6 +274,11 @@ Tensor THSTensor_clone(const Tensor tensor)
258274
CATCH_TENSOR(tensor->clone());
259275
}
260276

277+
Tensor THSTensor_combinations(const Tensor tensor, const int r, const bool with_replacement)
278+
{
279+
CATCH_TENSOR(torch::combinations(*tensor, r, with_replacement));
280+
}
281+
261282
Tensor THSTensor_copy_(const Tensor input, const Tensor other, const bool non_blocking)
262283
{
263284
CATCH_TENSOR(input->copy_(*other, non_blocking));
@@ -285,6 +306,13 @@ int THSTensor_is_contiguous(const Tensor tensor)
285306
return result;
286307
}
287308

309+
int64_t THSTensor_is_nonzero(const Tensor tensor)
310+
{
311+
bool result = false;
312+
CATCH(result = tensor->is_nonzero();)
313+
return result;
314+
}
315+
288316
Tensor THSTensor_copysign(const Tensor input, const Tensor other)
289317
{
290318
CATCH_TENSOR(input->copysign(*other));
@@ -295,13 +323,6 @@ Tensor THSTensor_corrcoef(const Tensor tensor)
295323
CATCH_TENSOR(tensor->corrcoef());
296324
}
297325

298-
Tensor THSTensor_cov(const Tensor input, int64_t correction, const Tensor fweights, const Tensor aweights)
299-
{
300-
c10::optional<at::Tensor> fw = (fweights == nullptr) ? c10::optional<at::Tensor>() : *fweights;
301-
c10::optional<at::Tensor> aw = (aweights == nullptr) ? c10::optional<at::Tensor>() : *aweights;
302-
CATCH_TENSOR(input->cov(correction, fw, aw));
303-
}
304-
305326
bool THSTensor_is_cpu(const Tensor tensor)
306327
{
307328
bool result = true;
@@ -402,6 +423,11 @@ int THSTensor_device_type(const Tensor tensor)
402423
return (int)device.type();
403424
}
404425

426+
Tensor THSTensor_diag_embed(const Tensor tensor, const int64_t offset, const int64_t dim1, const int64_t dim2)
427+
{
428+
CATCH_TENSOR(tensor->diag_embed(offset, dim1, dim2));
429+
}
430+
405431
Tensor THSTensor_diff(const Tensor tensor, const int64_t n, const int64_t dim, const Tensor prepend, const Tensor append)
406432
{
407433
c10::optional<at::Tensor> prep = prepend != nullptr ? *prepend : c10::optional<at::Tensor>(c10::nullopt);
@@ -473,6 +499,11 @@ Tensor THSTensor_repeat_interleave_int64(const Tensor tensor, const int64_t repe
473499
CATCH_TENSOR(tensor->repeat_interleave(repeats, _dim, _output_size));
474500
}
475501

502+
int THSTensor_result_type(const Tensor left, const Tensor right)
503+
{
504+
CATCH_RETURN_RES(int, -1, res = (int)torch::result_type(*left, *right));
505+
}
506+
476507
Tensor THSTensor_movedim(const Tensor tensor, const int64_t* src, const int src_len, const int64_t* dst, const int dst_len)
477508
{
478509
CATCH_TENSOR(tensor->movedim(at::ArrayRef<int64_t>(src, src_len), at::ArrayRef<int64_t>(dst, dst_len)));
@@ -1070,6 +1101,11 @@ Tensor THSTensor_outer(const Tensor left, const Tensor right)
10701101
CATCH_TENSOR(left->outer(*right));
10711102
}
10721103

1104+
Tensor THSTensor_ormqr(const Tensor input, const Tensor tau, const Tensor other, bool left, bool transpose)
1105+
{
1106+
CATCH_TENSOR(torch::ormqr(*input, *tau, *other, left, transpose));
1107+
}
1108+
10731109
Tensor THSTensor_mH(const Tensor tensor)
10741110
{
10751111
CATCH_TENSOR(tensor->mH());
@@ -1161,6 +1197,11 @@ Tensor THSTensor_reshape(const Tensor tensor, const int64_t* shape, const int le
11611197
CATCH_TENSOR(tensor->reshape(at::ArrayRef<int64_t>(shape, length)));
11621198
}
11631199

1200+
Tensor THSTensor_rot90(const Tensor tensor, const int64_t k, const int64_t dim1, const int64_t dim2)
1201+
{
1202+
CATCH_TENSOR(tensor->rot90(k, { dim1, dim2 }));
1203+
}
1204+
11641205
Tensor THSTensor_roll(const Tensor tensor, const int64_t* shifts, const int shLength, const int64_t* dims, const int dimLength)
11651206
{
11661207
CATCH_TENSOR(
@@ -1194,6 +1235,36 @@ Tensor THSTensor_scatter_(
11941235
CATCH_TENSOR(tensor->scatter_(dim, *index, *source));
11951236
}
11961237

1238+
Tensor THSTensor_select_scatter(
1239+
const Tensor tensor,
1240+
const Tensor source,
1241+
const int64_t dim,
1242+
const int64_t index)
1243+
{
1244+
CATCH_TENSOR(torch::select_scatter(*tensor, *source, dim, index));
1245+
}
1246+
1247+
Tensor THSTensor_diagonal_scatter(
1248+
const Tensor tensor,
1249+
const Tensor source,
1250+
const int64_t offset,
1251+
const int64_t dim1,
1252+
const int64_t dim2)
1253+
{
1254+
CATCH_TENSOR(torch::diagonal_scatter(*tensor, *source, offset, dim1, dim2));
1255+
}
1256+
1257+
Tensor THSTensor_slice_scatter(
1258+
const Tensor tensor,
1259+
const Tensor source,
1260+
const int64_t dim,
1261+
const int64_t *start,
1262+
const int64_t *end,
1263+
const int64_t step)
1264+
{
1265+
CATCH_TENSOR(torch::slice_scatter(*tensor, *source, dim, start == nullptr ? c10::optional<int64_t>() : c10::optional<int64_t>(*start), end == nullptr ? c10::optional<int64_t>() : c10::optional<int64_t>(*end), step));
1266+
}
1267+
11971268
Tensor THSTensor_scatter_add(
11981269
const Tensor tensor,
11991270
const int64_t dim,
@@ -1762,6 +1833,23 @@ Tensor THSTensor_tril(const Tensor tensor, const int64_t diagonal)
17621833
CATCH_TENSOR(tensor->tril(diagonal));
17631834
}
17641835

1836+
Tensor THSTensor_tril_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index)
1837+
{
1838+
auto options = at::TensorOptions()
1839+
.dtype(at::ScalarType(scalar_type))
1840+
.device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index));
1841+
CATCH_TENSOR(torch::tril_indices(row, col, offset, options));
1842+
}
1843+
1844+
Tensor THSTensor_triu_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index)
1845+
{
1846+
auto options = at::TensorOptions()
1847+
.dtype(at::ScalarType(scalar_type))
1848+
.device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index));
1849+
CATCH_TENSOR(torch::triu_indices(row, col, offset, options));
1850+
}
1851+
1852+
17651853
Tensor THSTensor_transpose(const Tensor tensor, const int64_t dim1, const int64_t dim2)
17661854
{
17671855
CATCH_TENSOR(tensor->transpose(dim1, dim2));

0 commit comments

Comments
 (0)