Skip to content

Commit

Permalink
FEAT: Calculate subunitary working
Browse files Browse the repository at this point in the history
  • Loading branch information
daquintero committed Jul 19, 2023
1 parent dad4325 commit e65c624
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 31 deletions.
41 changes: 26 additions & 15 deletions docs/examples/05_quantum_integration_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,20 @@
) = piel.sax_to_s_parameters_standard_matrix(default_state_s_parameters)
s_parameters_standard_matrix

# ```python
# Array([[ 0.40105772+0.49846345j, -0.45904815-0.197149j ,
# 0.00180554+0.17483076j, 0.4000432 +0.38792986j],
# [-0.4590482 -0.197149j , -0.8361797 +0.13278401j,
# -0.03938162-0.03818914j, -0.17480364+0.00356933j],
# [ 0.00180554+0.17483076j, -0.03938162-0.03818914j,
# -0.8536251 +0.11586684j, 0.11507235-0.45943272j],
# [ 0.40004322+0.3879298j , -0.17480363+0.00356933j,
# 0.11507231-0.45943272j, -0.5810837 -0.31133226j]], dtype=complex64)
# ```
# array([[ 0.40117208+0.49838198j, -0.45906927-0.19706765j,
# 0.00184268+0.17482864j, 0.40013665+0.38783803j],
# [-0.45906927-0.19706765j, -0.83617032+0.13289595j,
# -0.03938958-0.0381789j , -0.17480098+0.0036143j ],
# [ 0.00184268+0.17482864j, -0.03938958-0.0381789j ,
# -0.85361747+0.11598505j, 0.11497926-0.45944187j],
# [ 0.40013665+0.38783803j, -0.17480098+0.0036143j ,
# 0.11497926-0.45944187j, -0.58117378-0.31118139j]])
# ```

import numpy as np

np.asarray(s_parameters_standard_matrix)

# We can explore some properties of this matrix:

Expand Down Expand Up @@ -84,13 +88,20 @@

# For, example, we might want to just calculate it for the first two input modes. This would be indexed when starting from the first row and column as `start_index` = (0,0) and `stop_index` = (`unitary_size`, `unitary_size`). Note that an error will be raised if a non-unitary matrix is inputted. Some examples are:

s_parameters_standard_matrix
our_subunitary = piel.subunitary_selection(
s_parameters_standard_matrix, start_index=(0, 0), stop_index=(1, 1)
)
our_subunitary

import jax.numpy as jnp
# ```python
# Array([[ 0.40105772+0.49846345j, -0.45904815-0.197149j ],
# [-0.4590482 -0.197149j , -0.8361797 +0.13278401j]], dtype=complex64)
# ```

jax_array = jnp.array(s_parameters_standard_matrix)
jax_array
# We can now calculate the permanent of this submatrix:

jax_array.at[jnp.array([0, 1])].get()
piel.unitary_permanent(our_subunitary)

jnp.ndarray
# ```python
# ((-0.2296868-0.18254918j), 0.0)
# ```
10 changes: 7 additions & 3 deletions piel/integration/sax_qutip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import qutip # NOQA : F401
import sax
import numpy as np
import jax.numpy as jnp
from piel.tools.sax.utils import sax_to_s_parameters_standard_matrix

Expand Down Expand Up @@ -49,7 +50,8 @@ def matrix_to_qutip_qobj(
qobj_unitary (qutip.Qobj): A QuTip QObj representation of the S-parameters in a unitary matrix.
"""
qobj_unitary = qutip.Qobj(s_parameters_standard_matrix)
s_parameter_standard_matrix_numpy = np.asarray(s_parameters_standard_matrix)
qobj_unitary = qutip.Qobj(s_parameter_standard_matrix_numpy)
return qobj_unitary


Expand Down Expand Up @@ -95,7 +97,8 @@ def sax_to_ideal_qutip_unitary(sax_input: sax.SType):
s_parameters_standard_matrix,
input_ports_index_tuple_order,
) = sax_to_s_parameters_standard_matrix(sax_input)
qobj_unitary = matrix_to_qutip_qobj(s_parameters_standard_matrix)
s_parameter_standard_matrix_numpy = np.asarray(s_parameters_standard_matrix)
qobj_unitary = matrix_to_qutip_qobj(s_parameter_standard_matrix_numpy)
return qobj_unitary


Expand All @@ -109,7 +112,8 @@ def verify_matrix_is_unitary(matrix: jnp.ndarray) -> bool:
Returns:
bool: True if the matrix is unitary, False otherwise.
"""
qobj = matrix_to_qutip_qobj(matrix)
matrix_numpy = np.asarray(matrix)
qobj = matrix_to_qutip_qobj(matrix_numpy)
return qobj.check_isunitary()


Expand Down
16 changes: 14 additions & 2 deletions piel/integration/sax_thewalrus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax.numpy as jnp
from ..tools.sax import sax_to_s_parameters_standard_matrix
from typing import Optional
import numpy as np

__all__ = ["sax_circuit_permanent", "subunitary_selection", "unitary_permanent"]

Expand Down Expand Up @@ -43,7 +44,17 @@ def subunitary_selection(
TODO implement validation of a 2D matrix.
"""
pass
start_row = start_index[0]
end_row = stop_index[0]
start_column = start_index[1]
end_column = stop_index[1]
column_range = jnp.arange(start_column, end_column + 1)
row_range = jnp.arange(start_row, end_row + 1)
unitary_matrix_row_selection = unitary_matrix.at[row_range, :].get()
unitary_matrix_row_column_selection = unitary_matrix_row_selection.at[
:, column_range
].get()
return unitary_matrix_row_column_selection


def unitary_permanent(
Expand All @@ -68,7 +79,8 @@ def unitary_permanent(
"""
start_time = time.time()
circuit_permanent = thewalrus.perm(unitary_matrix)
unitary_matrix_numpy = np.asarray(unitary_matrix)
circuit_permanent = thewalrus.perm(unitary_matrix_numpy)
end_time = time.time()
computed_time = end_time - start_time
return circuit_permanent, computed_time
21 changes: 10 additions & 11 deletions piel/tools/sax/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This file provides a set of utilities that allow much easier integration between `sax` and the relevant tools that we use.
"""
import numpy as np
import jax.numpy as jnp
import sax
from ..gdsfactory.netlist import get_matched_ports_tuple_index
from typing import Optional # NOQA : F401
Expand Down Expand Up @@ -180,18 +180,17 @@ def coupler(*, coupling: float = 0.5) -> SDict:
ports_index=dense_s_parameter_index, prefix="out"
)

output_ports_index_tuple_order_jax = jnp.asarray(output_ports_index_tuple_order)
input_ports_index_tuple_order_jax = jnp.asarray(input_ports_index_tuple_order)
# We now select the SDense columns that we care about.
dense_s_parameter_matrix = np.asarray(
dense_s_parameter_matrix
) # TODO port to JAX multiindexing
s_parameters_standard_matrix = dense_s_parameter_matrix[
[output_ports_index_tuple_order]
][0]
s_parameters_standard_matrix = dense_s_parameter_matrix.at[
output_ports_index_tuple_order_jax
].get()
s_parameters_standard_matrix = s_parameters_standard_matrix.at[
:, input_ports_index_tuple_order_jax
].get()
# Now we select the SDense rows that we care about after transposing the matrix.
s_parameters_standard_matrix = s_parameters_standard_matrix[
:, [input_ports_index_tuple_order][0]
]
# TODO verify matrix transpose for unitary match.
# TODO verify matrix transpose for unitary match. I think it is right.
return s_parameters_standard_matrix.T, input_matched_ports_name_tuple_order


Expand Down

0 comments on commit e65c624

Please sign in to comment.