From ec23c1f96a1fe4a9ff2c27e115f7d705ccb55799 Mon Sep 17 00:00:00 2001 From: colcarroll Date: Thu, 21 Dec 2023 08:10:49 -0800 Subject: [PATCH] Update numpy rewrite for linear_operator_circulant. PiperOrigin-RevId: 592864112 --- .../numpy/gen/linear_operator_circulant.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py index 6b7c6f0196..a8dc5fbacd 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py @@ -1042,6 +1042,32 @@ def _linop_inverse(self) -> "LinearOperatorCirculant": is_square=True, input_output_dtype=self.dtype) + def _linop_matmul( + self, + left_operator: "LinearOperatorCirculant", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if not isinstance( + right_operator, LinearOperatorCirculant + ) or not isinstance(left_operator, type(right_operator)): + return super()._linop_matmul(left_operator, right_operator) + + return LinearOperatorCirculant( + spectrum=left_operator.spectrum * right_operator.spectrum, + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator + ), + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator + ), + is_positive_definite=( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator + ) + ), + is_square=True, + ) + def _linop_solve( self, left_operator: "LinearOperatorCirculant", @@ -1271,6 +1297,32 @@ def _linop_inverse(self) -> "LinearOperatorCirculant2D": is_square=True, input_output_dtype=self.dtype) + def _linop_matmul( + self, + left_operator: "LinearOperatorCirculant2D", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if not isinstance( + right_operator, LinearOperatorCirculant2D + ) or not isinstance(left_operator, type(right_operator)): + return super()._linop_matmul(left_operator, right_operator) + + return LinearOperatorCirculant2D( + spectrum=left_operator.spectrum * right_operator.spectrum, + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator + ), + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator + ), + is_positive_definite=( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator + ) + ), + is_square=True, + ) + def _linop_solve( self, left_operator: "LinearOperatorCirculant2D", @@ -1473,6 +1525,32 @@ def _linop_inverse(self) -> "LinearOperatorCirculant3D": is_square=True, input_output_dtype=self.dtype) + def _linop_matmul( + self, + left_operator: "LinearOperatorCirculant3D", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if not isinstance( + right_operator, LinearOperatorCirculant3D + ) or not isinstance(left_operator, type(right_operator)): + return super()._linop_matmul(left_operator, right_operator) + + return LinearOperatorCirculant3D( + spectrum=left_operator.spectrum * right_operator.spectrum, + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator + ), + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator + ), + is_positive_definite=( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator + ) + ), + is_square=True, + ) + def _linop_solve( self, left_operator: "LinearOperatorCirculant3D",