forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FunctionsManual.h
1063 lines (1033 loc) · 30 KB
/
FunctionsManual.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#pragma once
// NB: Must be at the top of file to avoid including the deprecated "math.h".
// https://stackoverflow.com/questions/6563810/m-pi-works-with-math-h-but-not-with-cmath-in-visual-studio
#ifdef _MSC_VER
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <cmath>
#endif
#include <ATen/ATen.h>
#include <torch/csrc/autograd/generated/Functions.h>
namespace torch {
namespace autograd {
namespace generated {
namespace details {
extern const char* kCudnnDoubleBackwardMsg;
// A simple way to imperatively compute index ranges for slots
// that have been flattened
struct IndexRangeGenerator {
IndexRange range(size_t range_size) {
i += range_size;
return {i - range_size, i};
}
size_t size() {
return i;
}
private:
size_t i = 0;
};
Tensor toNonOptFwGrad(const c10::optional<Tensor>& t);
Tensor toNonOptPrimal(const c10::optional<Tensor>& t);
Tensor toNonOptTensor(const c10::optional<Tensor>& t);
Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction);
bool any_variable_defined(const variable_list& variables);
void copy_range(variable_list& out, IndexRange range, const at::Tensor& t);
void copy_range(
variable_list& out,
IndexRange range,
at::ArrayRef<at::Tensor> t);
at::Tensor copysign_tensor_self_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& result);
at::Tensor not_implemented(const char* name, const char* reason = "");
std::vector<Tensor> not_implemented_list(
const char* name,
const char* reason = "");
at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result);
at::Tensor maybe_multiply(const at::Tensor& t, const at::Scalar& s);
int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim);
Tensor restore_reduced_dims(
const Tensor& output,
IntArrayRef dims,
bool keepdim);
Tensor scale_grad_by_count(
const Tensor& grad,
const Tensor& mask,
IntArrayRef dims);
at::Tensor norm_backward(
const at::Tensor& grad,
const at::Tensor& self,
const optional<at::Scalar>& p_,
const at::Tensor& norm);
at::Tensor norm_backward(
at::Tensor grad,
const at::Tensor& self,
const optional<at::Scalar>& p_,
at::Tensor norm,
at::IntArrayRef dim,
bool keepdim);
Tensor norm_jvp(
const Tensor& self_p,
const Tensor& self_t,
const optional<Scalar>& p_,
Tensor norm,
IntArrayRef dim,
bool keepdim);
Tensor norm_jvp(
const Tensor& grad,
const Tensor& self,
const optional<Scalar>& p_,
Tensor norm);
Tensor _nested_from_padded_backward(
const Tensor& grad,
const Tensor& input,
const bool do_transform_0213);
std::tuple<Tensor, Tensor, Tensor> linear_double_backward(
const variable_list& grads,
const Tensor& self,
const Tensor& grad_output,
const Tensor& weight);
Tensor linalg_vector_norm_jvp(
const Tensor& self_p,
const Tensor& self_t,
const Scalar& scalar_ord,
Tensor norm,
const at::OptionalIntArrayRef& opt_dim,
bool keepdim);
at::Tensor linalg_vector_norm_backward(
at::Tensor grad,
const at::Tensor& self,
const at::Scalar& ord,
at::Tensor norm,
const at::OptionalIntArrayRef& opt_dim,
bool keepdim);
at::Tensor pow_backward(
at::Tensor grad,
const at::Tensor& self,
const at::Scalar& exponent_);
at::Tensor pow_backward_self(
at::Tensor grad,
const at::Tensor& self,
const at::Tensor& exponent);
at::Tensor pow_backward_exponent(
at::Tensor grad,
const at::Tensor& self,
const at::Tensor& exponent,
at::Tensor result);
at::Tensor pow_backward_exponent(
at::Tensor grad,
const at::Scalar& base,
const at::Tensor& exponent,
at::Tensor result);
at::Tensor angle_backward(at::Tensor grad, const at::Tensor& self);
template <typename T>
at::Tensor mul_tensor_backward(Tensor grad, T other, ScalarType self_st);
template <typename T>
at::Tensor div_tensor_self_backward(Tensor grad, T other, ScalarType self_st);
at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other);
template <typename T>
at::Tensor div_tensor_self_backward(
Tensor grad,
T other,
ScalarType self_st,
const c10::optional<c10::string_view>& rounding_mode);
at::Tensor div_tensor_other_backward(
Tensor grad,
Tensor self,
Tensor other,
const c10::optional<c10::string_view>& rounding_mode);
at::Tensor mvlgamma_backward(
at::Tensor grad,
const at::Tensor& self,
int64_t p);
at::Tensor permute_backwards(const at::Tensor& grad, at::IntArrayRef fwd_dims);
at::Tensor rad2deg_backward(const at::Tensor& grad);
at::Tensor deg2rad_backward(const at::Tensor& grad);
at::Tensor unsqueeze_multiple(
const at::Tensor& t,
at::OptionalIntArrayRef opt_dim,
size_t n_dims);
at::Tensor sum_backward(
const at::Tensor& grad,
at::SymIntArrayRef sizes,
at::OptionalIntArrayRef opt_dims,
bool keepdim);
at::Tensor sum_backward(
const at::Tensor& grad,
c10::SymIntArrayRef sizes,
c10::IntArrayRef dims,
bool keepdim);
at::Tensor nansum_backward(
const at::Tensor& grad,
const at::Tensor& self,
at::OptionalIntArrayRef dims,
bool keepdim);
std::vector<int64_t> reverse_list(const at::IntArrayRef list);
std::vector<c10::SymInt> reverse_list_symint(const c10::SymIntArrayRef list);
at::Tensor reverse_dim(const at::Tensor& t, int64_t dim);
at::Tensor prod_safe_zeros_backward(
const at::Tensor& grad,
const at::Tensor& inp,
int64_t dim);
at::Tensor prod_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& result);
at::Tensor prod_backward(
at::Tensor grad,
const at::Tensor& input,
at::Tensor result,
int64_t dim,
bool keepdim);
at::Tensor solve_jvp(
const Tensor& X,
const Tensor& A,
const Tensor& dA,
const Tensor& dB);
at::Tensor solve_backward_self(
const at::Tensor& grad,
const at::Tensor& self,
const at::Tensor& A);
at::Tensor solve_backward_A(
const at::Tensor& grad,
const at::Tensor& self,
const at::Tensor& A,
const at::Tensor& solution);
at::Tensor cumsum_backward(const at::Tensor& grad, int64_t dim);
at::Tensor logsumexp_backward(
at::Tensor grad,
const at::Tensor& self,
at::Tensor result,
at::IntArrayRef dim,
bool keepdim);
at::Tensor logsumexp_jvp(
const at::Tensor& self_p,
const at::Tensor& self_t,
IntArrayRef dim,
bool keepdim);
at::Tensor logcumsumexp_backward(
at::Tensor grad,
const at::Tensor& self,
at::Tensor result,
int64_t dim);
at::Tensor logcumsumexp_jvp(
const at::Tensor& self_p,
const at::Tensor& self_t,
int64_t dim);
at::Tensor unbind_backward(const variable_list& grads, int64_t dim);
at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
at::Tensor unsqueeze_to(
const at::Tensor& self,
int64_t dim,
c10::SymIntArrayRef sym_sizes);
at::Tensor unsqueeze_to(
const at::Tensor& self,
IntArrayRef dim,
c10::SymIntArrayRef sym_sizes);
std::vector<at::Tensor> cat_tensors_backward(
const at::Tensor& grad,
const std::vector<std::vector<c10::SymInt>>& sizes,
const std::vector<ScalarType>& dtypes,
int64_t dim);
std::vector<at::Tensor> stack_tensors_backward(
const at::Tensor& grad,
int64_t dim,
const std::vector<ScalarType>& dtypes);
std::vector<at::Tensor> block_diag_backward(
const at::Tensor& grad,
const std::vector<std::vector<int64_t>>& sizes,
const std::vector<ScalarType>& dtypes);
at::Tensor clamp_backward(
const at::Tensor& grad,
const at::Tensor& self,
const optional<at::Scalar>& min,
const optional<at::Scalar>& max);
at::Tensor clamp_backward(
const at::Tensor& grad,
const at::Tensor& self,
const at::Tensor& min,
const at::Tensor& max);
std::tuple<at::Tensor, at::Tensor> clamp_backward_min_max(
const at::Tensor& grad,
const at::Tensor& self,
const at::Tensor& min,
const at::Tensor& max,
const std::array<bool, 2>&);
at::Tensor clamp_jvp(
const Tensor& self_p,
const Tensor& self_t,
const Tensor& min_p,
const Tensor& min_t,
const Tensor& max_p,
const Tensor& max_t);
at::SymIntArrayRef strides_or_error(
const Tensor& input,
c10::string_view const& input_name);
at::Tensor mm_mat1_backward(
const Tensor& grad,
const Tensor& mat2,
at::SymIntArrayRef mat1_sizes,
at::SymIntArrayRef mat1_strides,
c10::Layout mat1_layout,
const Scalar& alpha);
at::Tensor mm_mat2_backward(
const at::Tensor& grad,
const at::Tensor& mat1,
at::SymIntArrayRef sizes,
at::SymIntArrayRef strides,
c10::Layout layout,
const at::Scalar& alpha);
at::Tensor mm_mat1_sparse_backward(
const at::Tensor& grad,
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Scalar& alpha);
at::Tensor sparse_sparse_matmul_backward(
const at::Tensor& grad,
const at::Tensor& mat1,
const at::Tensor& mat2,
int64_t grad_order);
at::Tensor renorm_backward(
const at::Tensor& grad,
const at::Tensor& self,
const at::Scalar& p,
int64_t dim,
const at::Scalar& maxnorm);
at::Tensor repeat_backward(
at::Tensor grad,
at::SymIntArrayRef repeats,
at::SymIntArrayRef input_shape);
at::Tensor _fused_dropout_backward(
at::Tensor grad,
at::Tensor mask,
double p1m);
at::Tensor infinitely_differentiable_native_dropout_backward(
const at::Tensor& grad,
const at::Tensor& mask,
double scale);
at::Tensor native_dropout_double_backward(
const at::Tensor& ggI,
const at::Tensor& grad,
const at::Tensor& mask,
double scale);
at::Tensor evenly_distribute_backward(
at::Tensor grad,
const at::Tensor& input,
const at::Tensor& value);
Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn);
Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask);
at::Tensor var_backward(
at::Tensor grad,
const at::Tensor& self,
at::OptionalIntArrayRef dim,
const c10::optional<c10::Scalar>& correction,
bool keepdim);
at::Tensor var_jvp(
const at::Tensor& self_t,
const at::Tensor& self_p,
const at::Tensor& result,
at::OptionalIntArrayRef dim_opt,
const c10::optional<c10::Scalar>& correction,
bool keepdim);
at::Tensor std_backward(
const at::Tensor& result,
const at::Tensor& grad,
const at::Tensor& self,
at::OptionalIntArrayRef dim,
const c10::optional<c10::Scalar>& correction,
bool keepdim);
Tensor mean_backward(
const Tensor& grad,
c10::SymIntArrayRef shape,
at::OptionalIntArrayRef opt_dim,
c10::SymInt numel,
bool keepdim);
Tensor var_mean_backward(
const Tensor& gvar,
const Tensor& gmean,
const Tensor& self,
at::OptionalIntArrayRef dim_opt,
const c10::optional<c10::Scalar>& correction,
bool keepdim);
Tensor std_mean_backward(
const Tensor& gstd,
const Tensor& gmean,
const Tensor& self,
const Tensor& std,
at::OptionalIntArrayRef dim_opt,
const c10::optional<c10::Scalar>& correction,
bool keepdim);
at::Tensor masked_scatter_backward(
const at::Tensor& grad,
const at::Tensor& mask,
c10::SymIntArrayRef sizes);
at::Tensor cholesky_backward(
const at::Tensor& grad,
bool upper,
const at::Tensor& L);
at::Tensor cholesky_jvp(
const at::Tensor& input_tangent,
const at::Tensor& L,
bool upper);
at::Tensor cholesky_inverse_backward(
at::Tensor grad,
at::Tensor L,
bool upper,
at::Tensor inverse);
at::Tensor cholesky_inverse_jvp(
const at::Tensor& F,
const at::Tensor& dF,
const at::Tensor& X,
bool upper);
Tensor pinv_jvp(const Tensor& A, const Tensor& pinvA, const Tensor& dA);
Tensor pinv_backward(const Tensor& grad, const Tensor& pinvA, const Tensor& A);
at::Tensor split_with_sizes_backward(
const std::vector<torch::autograd::Variable>& grads,
c10::SymIntArrayRef split_sizes,
int64_t dim,
c10::SymIntArrayRef sizes,
const at::TensorOptions& options);
at::Tensor _nested_split_with_sizes_backward(
const std::vector<torch::autograd::Variable>& grads,
c10::SymIntArrayRef split_sizes,
int64_t dim,
const Tensor& self);
at::Tensor split_backward(
const std::vector<torch::autograd::Variable>& grads,
c10::SymInt split_size,
int64_t dim,
c10::SymIntArrayRef sizes,
const at::TensorOptions& options);
at::Tensor max_pool_double_backward(
const at::Tensor& grad,
const at::Tensor& indices,
int dim);
at::Tensor error_for_max_pool2d_double_backward();
at::Tensor glu_double_backward(
const at::Tensor& grad,
const at::Tensor& grad_output,
const at::Tensor& input,
int64_t dim);
at::Tensor glu_double_backward_grad_output(
const at::Tensor& grad,
const at::Tensor& input,
int64_t dim);
at::Tensor infinitely_differentiable_silu_backward(
const at::Tensor& grad_output,
const at::Tensor& input);
at::Tensor infinitely_differentiable_mish_backward(
const at::Tensor& grad_output,
const at::Tensor& input);
Tensor infinitely_differentiable_logit_backward(
const Tensor& grad,
const Tensor& self,
c10::optional<double> eps);
Tensor binary_cross_entropy_target_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& target,
const c10::optional<Tensor>& weight,
int64_t reduction);
Tensor binary_cross_entropy_double_backward_target(
const Tensor& grad,
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
const c10::optional<Tensor>& weight,
int64_t reduction);
Tensor binary_cross_entropy_with_logits_backward(
const Tensor& grad,
const Tensor& input,
const Tensor& target,
const c10::optional<Tensor>& weight_opt,
const c10::optional<Tensor>& pos_weight_opt,
int64_t reduction);
at::Tensor binary_cross_entropy_with_logits_target_backward(
const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& target,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& pos_weight,
int64_t reduction);
at::Tensor log_sigmoid_double_backward(
const at::Tensor& grad,
const at::Tensor& input);
at::Tensor softmax_double_backward(
const at::Tensor& grad,
const at::Tensor& grad_output,
int dim,
const at::Tensor& output);
at::Tensor binary_cross_entropy_double_backward(
const at::Tensor& grad_output,
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& target,
const c10::optional<at::Tensor>& weight,
int64_t reduction);
at::Tensor binary_cross_entropy_double_backward_grad_output(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& target,
const c10::optional<at::Tensor>& weight,
int64_t reduction);
at::Tensor smooth_l1_loss_double_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& target,
int64_t reduction,
double beta);
at::Tensor huber_loss_double_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& target,
int64_t reduction,
double delta);
at::Tensor huber_loss_double_backward_grad_output(
const at::Tensor& grad,
const at::Tensor& grad_output,
const at::Tensor& input,
const at::Tensor& target,
int64_t reduction,
double delta);
at::Tensor mse_loss_double_backward(
const at::Tensor& grad,
const at::Tensor& input,
int64_t reduction);
at::Tensor soft_margin_loss_double_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& target,
int64_t reduction);
at::Tensor soft_margin_loss_double_backward_grad_output(
const at::Tensor& grad,
const at::Tensor& grad_output,
const at::Tensor& input,
const at::Tensor& target,
int64_t reduction);
at::Tensor softplus_double_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Scalar& beta,
const at::Scalar& threshold);
std::tuple<at::Tensor, at::Tensor> slogdet_jvp(
const at::Tensor& LU,
const at::Tensor& pivots,
const at::Tensor& dA,
const at::Tensor& sign,
const bool use_A_T);
at::Tensor slogdet_backward(
const at::Tensor& grad_sign,
const at::Tensor& grad_logabsdet,
const at::Tensor& A,
const at::Tensor& signdet,
const at::Tensor& LU,
const at::Tensor& pivots);
at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self);
at::Tensor sinc_backward(const at::Tensor& grad, const at::Tensor& self);
at::Tensor sparse_constructor_values_backward(
const at::Tensor& sparse_grad_out,
const at::Tensor& indices);
at::Tensor embedding_dense_double_backward_symint(
const at::Tensor& grad,
const at::Tensor& indices,
c10::SymInt padding_idx);
at::Tensor index_backward(
at::Tensor zeros_like_self,
const torch::List<c10::optional<Tensor>>& indices,
const at::Tensor& grad);
at::Tensor _cudnn_ctc_loss_backward(
const at::Tensor& grad_out,
const at::Tensor& loss,
const at::Tensor& raw_grad,
bool zero_infinity);
at::Tensor elu_double_backward(
const Tensor& grad,
const Tensor& grad_output,
const Scalar& alpha,
const Scalar& scale,
const Scalar& input_scale,
bool is_result,
const Tensor& self_or_result);
Tensor svd_backward(
const Tensor& gU,
const Tensor& gS,
const Tensor& gVh,
const Tensor& U,
const Tensor& S,
const Tensor& Vh);
std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
const Tensor& dA,
const Tensor& U,
const Tensor& S,
const Tensor& Vh,
const bool full_matrices);
Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::SymIntArrayRef& input_sizes,
int64_t dim,
c10::optional<c10::SymInt> start,
c10::optional<c10::SymInt> end,
c10::SymInt step);
std::tuple<Tensor, Tensor> linalg_eig_jvp(
const Tensor& dA,
const Tensor& L,
const Tensor& V,
const bool is_hermitian);
Tensor linalg_eig_backward(
const Tensor& gL,
const Tensor& gV,
const Tensor& L,
const Tensor& V,
const bool is_hermitian,
const bool symeig_eigenvectors = true);
Tensor linalg_lstsq_jvp(
const Tensor& A,
const Tensor& B,
const Tensor& dA,
const Tensor& dB);
std::tuple<Tensor, Tensor> triangular_solve_backward(
const Tensor& grad_x,
const Tensor& grad_m,
const Tensor& b,
const Tensor& a,
const Tensor& x,
const bool upper,
const bool transpose,
const bool unitriangular,
std::array<bool, 2> output_mask);
Tensor triangular_solve_jvp(
const Tensor& X,
const Tensor& A,
const Tensor& dA,
const Tensor& dB,
const bool upper,
const bool transpose,
const bool unitriangular);
Tensor linalg_solve_triangular_forward_AD(
const Tensor& A_t,
const Tensor& B_t,
const Tensor& A,
const Tensor& X,
const bool upper,
const bool left,
const bool unitriangular);
std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
const Tensor& grad,
const Tensor& A,
const Tensor& X,
const bool upper,
const bool left,
const bool unitriangular,
std::array<bool, 2> output_mask);
std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
const Tensor& grad_out,
const Tensor& i1,
const Tensor& i2,
const Tensor& i3,
IntArrayRef expand1,
IntArrayRef expand2,
IntArrayRef expand3,
IntArrayRef sumdim,
std::array<bool, 3> grad_mask);
std::tuple<Tensor, Tensor> linalg_qr_jvp(
const Tensor& dA,
const Tensor& Q,
const Tensor& R,
const c10::string_view mode);
Tensor linalg_qr_backward(
const Tensor& gQ,
const Tensor& gR,
const Tensor& Q,
const Tensor& R,
const c10::string_view mode);
Tensor linalg_matrix_exp_differential(
const Tensor& self,
const Tensor& grad,
bool adjoint);
std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
const Tensor& input,
const c10::optional<Tensor>& gamma,
const Tensor& ggI,
const Tensor& ggG,
const Tensor& ggB,
const Tensor& gO,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
bool training,
double eps,
const c10::optional<Tensor>& save_mean,
const c10::optional<Tensor>& save_invstd,
std::array<bool, 3> output_mask);
std::tuple<Tensor, Tensor> _euclidean_dist_backward(
const Tensor& grad,
const Tensor& x1,
const Tensor& x2,
const Tensor& res);
Tensor fft_backward(
const Tensor& self,
const Tensor& grad,
int64_t signal_ndim,
bool complex_input,
bool complex_output,
bool inverse,
IntArrayRef checked_signal_sizes,
int64_t normalization,
bool onesided,
IntArrayRef output_sizes);
Tensor fft_r2c_backward(
const Tensor& grad,
at::IntArrayRef dim,
int64_t normalization,
bool onesided,
c10::SymInt last_dim_size);
Tensor fft_c2r_backward(
const Tensor& grad,
IntArrayRef dim,
int64_t normalization);
Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad);
std::tuple<Tensor, Tensor> cholesky_solve_backward(
const Tensor& grad_x,
const Tensor& self,
const Tensor& input2,
const Tensor& result,
const bool upper);
Tensor cholesky_solve_jvp(
const Tensor& X,
const Tensor& U,
const Tensor& dU,
const Tensor& dB,
const bool upper);
std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(
const Tensor& dY,
const Tensor& dmean,
const Tensor& drstd,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const c10::optional<Tensor>& gamma,
c10::SymInt N,
c10::SymInt C,
c10::SymInt HxW,
int64_t group,
double eps,
std::array<bool, 3> grad_input_mask);
Tensor gelu_double_backward(
const Tensor& ggI,
const Tensor& gO,
const Tensor& input,
c10::string_view approximate);
Tensor as_strided_backward(
Tensor grad,
const TensorGeometry& input_geometry,
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides,
optional<c10::SymInt> storage_offset_);
Tensor as_strided_scatter_backward(
Tensor grad,
const TensorGeometry& input_geometry,
TensorGeometry src_geometry,
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides,
optional<c10::SymInt> storage_offset);
std::tuple<Tensor, Tensor> atan2_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& other,
std::array<bool, 2> output_mask);
Tensor amaxamin_jvp(
const Tensor& x,
const Tensor& dx,
const Tensor& result,
IntArrayRef dim,
bool keepdim);
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
const Tensor& input,
const c10::optional<Tensor>& gamma,
const Tensor& ggI,
const Tensor& ggG,
const Tensor& ggB,
const Tensor& gO,
const Tensor& save_mean,
const Tensor& save_invstd,
c10::SymIntArrayRef normalized_shape,
std::array<bool, 3> output_mask);
std::tuple<Tensor, Tensor> householder_product_backward(
const Tensor& grad,
const Tensor& result,
const Tensor& input,
const Tensor& tau,
const bool flip_order = false);
Tensor householder_product_jvp(
const Tensor& dV,
const Tensor& dtau,
const Tensor& prod,
const Tensor& V,
const Tensor& tau);
std::tuple<Tensor, Tensor, Tensor> ormqr_backward(
const Tensor& grad,
const Tensor& result,
const Tensor& self,
const Tensor& tau,
const Tensor& other,
bool left,
bool transpose,
std::array<bool, 3> grad_output_mask);
std::tuple<Tensor, Tensor> polar_backward(
const Tensor& grad,
const Tensor& result);
Tensor i1_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& result);
Tensor i1e_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& result);
Tensor linalg_lu_solve_LU(
const Tensor& grad,
const Tensor& LU,
const Tensor& pivots,
const Tensor& X,
const bool left,
const bool adjoint);
Tensor linalg_lu_solve_jvp(
const Tensor& X,
const Tensor& LU,
const Tensor& pivots,
const Tensor& dLU,
const Tensor& dB,
const bool left,
const bool adjoint);
std::tuple<Tensor, Tensor> linalg_solve_backward(
const Tensor& gX,
const Tensor& X,
const Tensor& A,
const Tensor& LU,
const Tensor& pivots,
const bool left,
const bool B_requires_grad);
Tensor linalg_solve_jvp(
const Tensor& dA,
const Tensor& dB,
const Tensor& X,
const Tensor& LU,
const Tensor& pivots,
const bool left,
const bool use_A_T);
Tensor lu_unpack_backward(
const Tensor& L_grad,
const Tensor& U_grad,
const c10::SymInt m,
const c10::SymInt n);
Tensor linalg_det_backward(
const Tensor& grad,
const Tensor& det,
const Tensor& A,
const Tensor& LU,
const Tensor& pivots);
Tensor linalg_det_jvp(
const Tensor& dA,
const Tensor& det,
const Tensor& LU,
const Tensor& pivots,
const bool use_A_T);
std::tuple<Tensor, Tensor> linalg_lstsq_backward(
const Tensor& grad,
const Tensor& A,
const Tensor& B_,
const std::array<bool, 2>& grad_input_mask);
Tensor linalg_lu_backward(
const Tensor& L_grad,
const Tensor& U_grad,
const Tensor& P,
const Tensor& L,
const Tensor& U,
const bool pivot);
std::tuple<Tensor, Tensor> linalg_lu_jvp(
const Tensor& dA,
const Tensor& P,
const Tensor& L,
const Tensor& U,
const bool pivot);
Tensor lu_factor_ex_backward(
const Tensor& grad,
const Tensor& LU,
const Tensor& pivs,
const bool pivot);
Tensor lu_factor_ex_jvp(
const Tensor& dX,
const Tensor& LU,
const Tensor& pivs,
const bool pivot);
Tensor batch_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& bias_p,
const Tensor& bias_t,
const c10::optional<Tensor>& running_mean,
const c10::optional<Tensor>& running_var,
const Tensor& saved_mean,
const Tensor& saved_invstd,
bool train,
double eps);
Tensor layer_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& bias_p,
const Tensor& bias_t,
const Tensor& saved_mean,
const Tensor& saved_invstd,
c10::SymIntArrayRef normalized_shape);
Tensor group_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& bias_p,
const Tensor& bias_t,
const Tensor& saved_mean,
const Tensor& saved_invstd,
int64_t groups);
Tensor group_norm_mean_jvp(
const Tensor& input_t,
const Tensor& mean_p,
int64_t groups);
Tensor group_norm_invstd_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& mean_p,
const Tensor& invstd_p,
int64_t groups);
Tensor convolution_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& bias_p,
const Tensor& bias_t,
IntArrayRef stride,
at::SymIntArrayRef padding,
IntArrayRef dilation,
bool transposed,
at::SymIntArrayRef output_padding,
int64_t groups);
Tensor _convolution_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& bias_p,
const Tensor& bias_t,
IntArrayRef stride,
at::SymIntArrayRef padding,
IntArrayRef dilation,
bool transposed,
at::SymIntArrayRef output_padding,
int64_t groups,
bool benchmark,
bool deterministic,
bool cudnn_enabled,
bool allow_tf32);
Tensor convolution_backward_jvp_grad_bias(
const Tensor& grad_out_t,
const Tensor& grad_bias);
Tensor cat_jvp(at::ITensorListRef tensors, int64_t dim);
Tensor block_diag_jvp(at::TensorList tensors);
Tensor stack_jvp(at::TensorList tensors, int64_t dim);
Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim);
Tensor gather_with_keepdimed_indices(
const Tensor& input,
int64_t dim,
const Tensor& indices,
bool keepdim);
Tensor evenly_read_jvp(
const Tensor& fw_grad,
const Tensor& input,
const Tensor& value);
Tensor warn_backwards(const Tensor& grad_output);
std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
const at::Tensor& self,
const at::Tensor& grad_output,
const at::Tensor& weight,
at::IntArrayRef padding,
at::IntArrayRef output_padding,
at::IntArrayRef stride,
at::IntArrayRef dilation,
bool transposed,
int64_t groups,
::std::array<bool, 2> output_mask);
Tensor scatter_reduce_jvp(
const Tensor& self_p,
const Tensor& self_t,
int dim,
const Tensor& index,
const Tensor& src_p,
const Tensor& src_t,
c10::string_view reduce,
bool include_self,
const Tensor& result);
std::tuple<Tensor, Tensor> scatter_reduce_backward(