Skip to content

Commit

Permalink
FEAT: Fock state transition modelling (#21)
Browse files Browse the repository at this point in the history
* FEAT: Porting only to JAX

* FEAT: Finished porting only to JAX

* FEAT: Calculate subunitary working

* FEAT: Extracting fock state transition probability amplitudes

* FEAT: Get rid of autoapi

* FIX: Docs
  • Loading branch information
daquintero authored Jul 24, 2023
1 parent c5476b9 commit d17e7f6
Show file tree
Hide file tree
Showing 36 changed files with 1,035 additions and 467 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/code_coverage.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
rev: ad3ff374e97e29ca87c94b5dc7eccdd29adc6296
hooks:
- id: codespell
args: ["-L TE,TE/TM,te,ba,FPR,fpr_spacing,ro,nd,donot,schem,Synopsys"]
args: ["-L TE,TE/TM,te,ba,FPR,fpr_spacing,ro,nd,donot,schem,Synopsys,ket"]
additional_dependencies:
- tomli

Expand Down
10 changes: 1 addition & 9 deletions docs/autoapi/piel/config/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,11 @@

We create a set of parameters that can be used throughout the project for optimisation.

The numerical solver is normally delegated for as `numpy` but there are cases where a much faster solver is desired, and where different functioanlity is required. For example, `sax` uses `JAX` for its numerical solver. In this case, we will create a global numerical solver that we can use throughout the project, and that can be extended and solved accordingly for the particular project requirements.
The numerical solver is jax and is imported throughout the module.



Module Contents
---------------

.. py:data:: numerical_solver
.. py:data:: nso
.. py:data:: piel_path_types
245 changes: 107 additions & 138 deletions docs/autoapi/piel/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ Functions
piel.construct_hdl21_module
piel.convert_connections_to_tuples
piel.gdsfactory_netlist_with_hdl21_generators
piel.sax_to_s_parameters_standard_matrix
piel.unitary_permanent
piel.sax_circuit_permanent
piel.unitary_permanent
piel.sax_to_ideal_qutip_unitary
piel.standard_s_parameters_to_ideal_qutip_unitary
piel.fock_transition_probability_amplitude
piel.single_parameter_sweep
piel.multi_parameter_sweep
piel.check_cocotb_testbench_exists
Expand Down Expand Up @@ -115,6 +114,11 @@ Functions
piel.convert_numeric_to_prefix
piel.get_sdense_ports_index
piel.sax_to_s_parameters_standard_matrix
piel.fock_state_nonzero_indexes
piel.fock_state_to_photon_number_factorial
piel.verify_matrix_is_unitary
piel.subunitary_selection_on_range
piel.subunitary_selection_on_index



Expand All @@ -123,27 +127,18 @@ Attributes

.. autoapisummary::

piel.numerical_solver
piel.nso
piel.piel_path_types
piel.test_spm_open_lane_configuration
piel.example_open_lane_configuration
piel.delete_simulation_output_files
piel.get_simulation_output_files
piel.snet
piel.standard_s_parameters_to_qutip_qobj
piel.__author__
piel.__email__
piel.__version__


.. py:data:: numerical_solver
.. py:data:: nso
.. py:data:: piel_path_types
Expand Down Expand Up @@ -417,106 +412,30 @@ Attributes
:returns: The ``GDSFactory`` netlist with the ``hdl21`` models dictionary.


.. py:function:: sax_to_s_parameters_standard_matrix(sax_input: sax.SType, input_ports_order: tuple | None = None) -> tuple
A ``sax`` S-parameter SDict is provided as a dictionary of tuples with (port0, port1) as the key. This
determines the direction of the scattering relationship. It means that the number of terms in an S-parameter
matrix is the number of ports squared.

In order to generalise, this function returns both the S-parameter matrices and the indexing ports based on the
amount provided. In terms of computational speed, we definitely would like this function to be algorithmically
very fast. For now, I will write a simple python implementation and optimise in the future.

It is possible to see the `sax` SDense notation equivalence here:
https://flaport.github.io/sax/nbs/08_backends.html

.. code-block:: python
import jax.numpy as jnp
from sax.core import SDense
# Directional coupler SDense representation
dc_sdense: SDense = (
jnp.array([[0, 0, τ, κ], [0, 0, κ, τ], [τ, κ, 0, 0], [κ, τ, 0, 0]]),
{"in0": 0, "in1": 1, "out0": 2, "out1": 3},
)
# Directional coupler SDict representation
# Taken from https://flaport.github.io/sax/nbs/05_models.html
def coupler(*, coupling: float = 0.5) -> SDict:
kappa = coupling**0.5
tau = (1 - coupling) ** 0.5
sdict = reciprocal(
{
("in0", "out0"): tau,
("in0", "out1"): 1j * kappa,
("in1", "out0"): 1j * kappa,
("in1", "out1"): tau,
}
)
return sdict
If we were to relate the mapping accordingly based on the ports indexes, a S-Parameter matrix in the form of
:math:`S_{(output,i),(input,i)}` would be:

.. math::
S = \begin{bmatrix}
S_{00} & S_{10} \\
S_{01} & S_{11} \\
\end{bmatrix} =
\begin{bmatrix}
\tau & j \kappa \\
j \kappa & \tau \\
\end{bmatrix}
Note that the standard S-parameter and hence unitary representation is in the form of:

.. math::
S = \begin{bmatrix}
S_{00} & S_{01} \\
S_{10} & S_{11} \\
\end{bmatrix}
.. py:function:: sax_circuit_permanent(sax_input: sax.SType) -> tuple
.. math::
The permanent of a unitary is used to determine the state probability of combinatorial Gaussian boson samping systems.

\begin{bmatrix}
b_{1} \\
\vdots \\
b_{n}
\end{bmatrix}
=
\begin{bmatrix}
S_{11} & \dots & S_{1n} \\
\vdots & \ddots & \vdots \\
S_{n1} & \dots & S_{nn}
\end{bmatrix}
\begin{bmatrix}
a_{1} \\
\vdots \\
a_{n}
\end{bmatrix}
``thewalrus`` Ryser's algorithm permananet implementation is described here: https://the-walrus.readthedocs.io/en/latest/gallery/permanent_tutorial.html

TODO check with Floris, does this mean we need to transpose the matrix?
# TODO maybe implement subroutine if computation is taking forever.

:param sax_input: The sax S-parameter dictionary.
:type sax_input: sax.SType
:param input_ports_order: The ports order tuple containing the names and order of the input ports.
:type input_ports_order: tuple

:returns: The S-parameter matrix and the input ports index tuple in the standard S-parameter notation.
:returns: The circuit permanent and the time it took to compute it.
:rtype: tuple


.. py:function:: unitary_permanent(unitary_matrix: numpy.ndarray) -> tuple
.. py:function:: unitary_permanent(unitary_matrix: jax.numpy.ndarray) -> tuple
The permanent of a unitary is used to determine the state probability of combinatorial Gaussian boson samping systems.

``thewalrus`` Ryser's algorithm permananet implementation is described here: https://the-walrus.readthedocs.io/en/latest/gallery/permanent_tutorial.html

Note that this function needs to be as optimised as possible, so we need to minimise our computational complexity of our operation.

# TODO implement validation
# TODO maybe implement subroutine if computation is taking forever.
# TODO why two outputs? Understand this properly later.

Expand All @@ -527,21 +446,6 @@ Attributes
:rtype: tuple


.. py:function:: sax_circuit_permanent(sax_input: sax.SType) -> tuple
The permanent of a unitary is used to determine the state probability of combinatorial Gaussian boson samping systems.

``thewalrus`` Ryser's algorithm permananet implementation is described here: https://the-walrus.readthedocs.io/en/latest/gallery/permanent_tutorial.html

# TODO maybe implement subroutine if computation is taking forever.

:param sax_input: The sax S-parameter dictionary.
:type sax_input: sax.SType

:returns: The circuit permanent and the time it took to compute it.
:rtype: tuple


.. py:function:: sax_to_ideal_qutip_unitary(sax_input: sax.SType)
This function converts the calculated S-parameters into a standard Unitary matrix topology so that the shape and
Expand Down Expand Up @@ -580,41 +484,39 @@ Attributes
:rtype: qobj_unitary (qutip.Qobj)


.. py:function:: standard_s_parameters_to_ideal_qutip_unitary(s_parameters_standard_matrix: piel.config.nso.ndarray)
.. py:function:: fock_transition_probability_amplitude(initial_fock_state: qutip.Qobj, final_fock_state: qutip.Qobj, unitary_matrix: jax.numpy.ndarray)
This function converts the calculated S-parameters into a standard Unitary matrix topology so that the shape and
dimensions of the matrix can be observed.
This function returns the transition probability amplitude between two Fock states when propagating in between
the unitary_matrix which represents a quantum state circuit.
I think this means we need to transpose the output of the filtered sax SDense matrix to map it to a QuTip matrix.
Note that the documentation and formatting of the standard `sax` mapping to a S-parameter standard notation is
already in described in piel/piel/sax/utils.py.
Note that based on (TODO cite Jeremy), the initial Fock state corresponds to the columns of the unitary and the
final Fock states corresponds to the rows of the unitary.
From this stage we can implement a ``QObj`` matrix accordingly and perform simulations accordingly. https://qutip.org/docs/latest/guide/qip/qip-basics.html#unitaries

For example, a ``qutip`` representation of an s-gate gate would be:
.. math ::
..code-block::
ewcommand{\ket}[1]{\left|{#1}
ight
angle}

import numpy as np
import qutip
# S-Gate
s_gate_matrix = np.array([[1., 0], [0., 1.j]])
s_gate = qutip.Qobj(mat, dims=[[2], [2]])
The subunitary :math:`U_{f_1}^{f_2}` is composed from the larger unitary by selecting the rows from the output state
Fock state occupation of :math:`\ket{f_2}`, and columns from the input :math:`\ket{f_1}`. In our case, we need to select the
columns indexes :math:`(0,3)` and rows indexes :math:`(1,2)`.

In mathematical notation, this S-gate would be written as:
If we consider a photon number of more than one for the transition Fock states, then the Permanent needs to be
normalised. The probability amplitude for the transition is described as:

..math::
.. math ::
a(\ket{f_1} o \ket{f_2}) =
rac{ ext{per}(U_{f_1}^{f_2})}{\sqrt{(j_1! j_2! ... j_N!)(j_1^{'}! j_2^{'}! ... j_N^{'}!)}}

S = \begin{bmatrix}
1 & 0 \\
0 & i \\
\end{bmatrix}
Args:
initial_fock_state (qutip.Qobj): A QuTip QObj representation of the initial Fock state.
final_fock_state (qutip.Qobj): A QuTip QObj representation of the final Fock state.
unitary_matrix (jnp.ndarray): A JAX NumPy array representation of the unitary matrix.

:param s_parameters_standard_matrix: A dictionary of S-parameters in the form of a SDict from `sax`.
:type s_parameters_standard_matrix: nso.ndarray
Returns:
float: The transition probability amplitude between the initial and final Fock states.

:returns: A QuTip QObj representation of the S-parameters in a unitary matrix.
:rtype: qobj_unitary (qutip.Qobj)


.. py:function:: single_parameter_sweep(base_design_configuration: dict, parameter_name: str, parameter_sweep_values: list)
Expand Down Expand Up @@ -1381,6 +1283,73 @@ Attributes
.. py:function:: fock_state_nonzero_indexes(fock_state: qutip.Qobj)
This function returns the indexes of the nonzero elements of a Fock state.

:param fock_state: A QuTip QObj representation of the Fock state.
:type fock_state: qutip.Qobj

:returns: The indexes of the nonzero elements of the Fock state.
:rtype: tuple


.. py:function:: fock_state_to_photon_number_factorial(fock_state: qutip.Qobj)
This function converts a Fock state defined as:
.. math::
ewcommand{\ket}[1]{\left|{#1}
ight
angle}
\ket{f_1} = \ket{j_1, j_2, ... j_N}$

and returns:

.. math::
j_1^{'}! j_2^{'}! ... j_N^{'}!
Args:
fock_state (qutip.Qobj): A QuTip QObj representation of the Fock state.

Returns:
float: The photon number factorial of the Fock state.



.. py:data:: standard_s_parameters_to_qutip_qobj
.. py:function:: verify_matrix_is_unitary(matrix: jax.numpy.ndarray) -> bool
Verify that the matrix is unitary.

:param matrix: The matrix to verify.
:type matrix: jnp.ndarray

:returns: True if the matrix is unitary, False otherwise.
:rtype: bool


.. py:function:: subunitary_selection_on_range(unitary_matrix: jax.numpy.ndarray, stop_index: tuple, start_index: Optional[tuple] = (0, 0))
This function returns a unitary between the indexes selected, and verifies the indexes are valid by checking that
the output matrix is also a unitary.

TODO implement validation of a 2D matrix.


.. py:function:: subunitary_selection_on_index(unitary_matrix: jax.numpy.ndarray, rows_index: jax.numpy.ndarray | tuple, columns_index: jax.numpy.ndarray | tuple)
This function returns a unitary between the indexes selected, and verifies the indexes are valid by checking that
the output matrix is also a unitary.

TODO implement validation of a 2D matrix.


.. py:data:: __author__
:value: 'Dario Quintero'

Expand Down
Loading

0 comments on commit d17e7f6

Please sign in to comment.