From e65c624b4d29ceeb8a997c3584160b4523dc8545 Mon Sep 17 00:00:00 2001 From: Dario Quintero Date: Wed, 19 Jul 2023 13:56:26 +0100 Subject: [PATCH] FEAT: Calculate subunitary working --- .../examples/05_quantum_integration_basics.py | 41 ++++++++++++------- piel/integration/sax_qutip.py | 10 +++-- piel/integration/sax_thewalrus.py | 16 +++++++- piel/tools/sax/utils.py | 21 +++++----- 4 files changed, 57 insertions(+), 31 deletions(-) diff --git a/docs/examples/05_quantum_integration_basics.py b/docs/examples/05_quantum_integration_basics.py index 09558d22..23102dc6 100644 --- a/docs/examples/05_quantum_integration_basics.py +++ b/docs/examples/05_quantum_integration_basics.py @@ -35,16 +35,20 @@ ) = piel.sax_to_s_parameters_standard_matrix(default_state_s_parameters) s_parameters_standard_matrix +# ```python +# Array([[ 0.40105772+0.49846345j, -0.45904815-0.197149j , +# 0.00180554+0.17483076j, 0.4000432 +0.38792986j], +# [-0.4590482 -0.197149j , -0.8361797 +0.13278401j, +# -0.03938162-0.03818914j, -0.17480364+0.00356933j], +# [ 0.00180554+0.17483076j, -0.03938162-0.03818914j, +# -0.8536251 +0.11586684j, 0.11507235-0.45943272j], +# [ 0.40004322+0.3879298j , -0.17480363+0.00356933j, +# 0.11507231-0.45943272j, -0.5810837 -0.31133226j]], dtype=complex64) # ``` -# array([[ 0.40117208+0.49838198j, -0.45906927-0.19706765j, -# 0.00184268+0.17482864j, 0.40013665+0.38783803j], -# [-0.45906927-0.19706765j, -0.83617032+0.13289595j, -# -0.03938958-0.0381789j , -0.17480098+0.0036143j ], -# [ 0.00184268+0.17482864j, -0.03938958-0.0381789j , -# -0.85361747+0.11598505j, 0.11497926-0.45944187j], -# [ 0.40013665+0.38783803j, -0.17480098+0.0036143j , -# 0.11497926-0.45944187j, -0.58117378-0.31118139j]]) -# ``` + +import numpy as np + +np.asarray(s_parameters_standard_matrix) # We can explore some properties of this matrix: @@ -84,13 +88,20 @@ # For, example, we might want to just calculate it for the first two input modes. This would be indexed when starting from the first row and column as `start_index` = (0,0) and `stop_index` = (`unitary_size`, `unitary_size`). Note that an error will be raised if a non-unitary matrix is inputted. Some examples are: -s_parameters_standard_matrix +our_subunitary = piel.subunitary_selection( + s_parameters_standard_matrix, start_index=(0, 0), stop_index=(1, 1) +) +our_subunitary -import jax.numpy as jnp +# ```python +# Array([[ 0.40105772+0.49846345j, -0.45904815-0.197149j ], +# [-0.4590482 -0.197149j , -0.8361797 +0.13278401j]], dtype=complex64) +# ``` -jax_array = jnp.array(s_parameters_standard_matrix) -jax_array +# We can now calculate the permanent of this submatrix: -jax_array.at[jnp.array([0, 1])].get() +piel.unitary_permanent(our_subunitary) -jnp.ndarray +# ```python +# ((-0.2296868-0.18254918j), 0.0) +# ``` diff --git a/piel/integration/sax_qutip.py b/piel/integration/sax_qutip.py index c9003f0d..5149fd79 100644 --- a/piel/integration/sax_qutip.py +++ b/piel/integration/sax_qutip.py @@ -1,5 +1,6 @@ import qutip # NOQA : F401 import sax +import numpy as np import jax.numpy as jnp from piel.tools.sax.utils import sax_to_s_parameters_standard_matrix @@ -49,7 +50,8 @@ def matrix_to_qutip_qobj( qobj_unitary (qutip.Qobj): A QuTip QObj representation of the S-parameters in a unitary matrix. """ - qobj_unitary = qutip.Qobj(s_parameters_standard_matrix) + s_parameter_standard_matrix_numpy = np.asarray(s_parameters_standard_matrix) + qobj_unitary = qutip.Qobj(s_parameter_standard_matrix_numpy) return qobj_unitary @@ -95,7 +97,8 @@ def sax_to_ideal_qutip_unitary(sax_input: sax.SType): s_parameters_standard_matrix, input_ports_index_tuple_order, ) = sax_to_s_parameters_standard_matrix(sax_input) - qobj_unitary = matrix_to_qutip_qobj(s_parameters_standard_matrix) + s_parameter_standard_matrix_numpy = np.asarray(s_parameters_standard_matrix) + qobj_unitary = matrix_to_qutip_qobj(s_parameter_standard_matrix_numpy) return qobj_unitary @@ -109,7 +112,8 @@ def verify_matrix_is_unitary(matrix: jnp.ndarray) -> bool: Returns: bool: True if the matrix is unitary, False otherwise. """ - qobj = matrix_to_qutip_qobj(matrix) + matrix_numpy = np.asarray(matrix) + qobj = matrix_to_qutip_qobj(matrix_numpy) return qobj.check_isunitary() diff --git a/piel/integration/sax_thewalrus.py b/piel/integration/sax_thewalrus.py index 4fa069b6..c10be524 100644 --- a/piel/integration/sax_thewalrus.py +++ b/piel/integration/sax_thewalrus.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from ..tools.sax import sax_to_s_parameters_standard_matrix from typing import Optional +import numpy as np __all__ = ["sax_circuit_permanent", "subunitary_selection", "unitary_permanent"] @@ -43,7 +44,17 @@ def subunitary_selection( TODO implement validation of a 2D matrix. """ - pass + start_row = start_index[0] + end_row = stop_index[0] + start_column = start_index[1] + end_column = stop_index[1] + column_range = jnp.arange(start_column, end_column + 1) + row_range = jnp.arange(start_row, end_row + 1) + unitary_matrix_row_selection = unitary_matrix.at[row_range, :].get() + unitary_matrix_row_column_selection = unitary_matrix_row_selection.at[ + :, column_range + ].get() + return unitary_matrix_row_column_selection def unitary_permanent( @@ -68,7 +79,8 @@ def unitary_permanent( """ start_time = time.time() - circuit_permanent = thewalrus.perm(unitary_matrix) + unitary_matrix_numpy = np.asarray(unitary_matrix) + circuit_permanent = thewalrus.perm(unitary_matrix_numpy) end_time = time.time() computed_time = end_time - start_time return circuit_permanent, computed_time diff --git a/piel/tools/sax/utils.py b/piel/tools/sax/utils.py index 7171e127..a6d4cded 100644 --- a/piel/tools/sax/utils.py +++ b/piel/tools/sax/utils.py @@ -1,7 +1,7 @@ """ This file provides a set of utilities that allow much easier integration between `sax` and the relevant tools that we use. """ -import numpy as np +import jax.numpy as jnp import sax from ..gdsfactory.netlist import get_matched_ports_tuple_index from typing import Optional # NOQA : F401 @@ -180,18 +180,17 @@ def coupler(*, coupling: float = 0.5) -> SDict: ports_index=dense_s_parameter_index, prefix="out" ) + output_ports_index_tuple_order_jax = jnp.asarray(output_ports_index_tuple_order) + input_ports_index_tuple_order_jax = jnp.asarray(input_ports_index_tuple_order) # We now select the SDense columns that we care about. - dense_s_parameter_matrix = np.asarray( - dense_s_parameter_matrix - ) # TODO port to JAX multiindexing - s_parameters_standard_matrix = dense_s_parameter_matrix[ - [output_ports_index_tuple_order] - ][0] + s_parameters_standard_matrix = dense_s_parameter_matrix.at[ + output_ports_index_tuple_order_jax + ].get() + s_parameters_standard_matrix = s_parameters_standard_matrix.at[ + :, input_ports_index_tuple_order_jax + ].get() # Now we select the SDense rows that we care about after transposing the matrix. - s_parameters_standard_matrix = s_parameters_standard_matrix[ - :, [input_ports_index_tuple_order][0] - ] - # TODO verify matrix transpose for unitary match. + # TODO verify matrix transpose for unitary match. I think it is right. return s_parameters_standard_matrix.T, input_matched_ports_name_tuple_order