Skip to content

Commit

Permalink
FEAT: Finished porting only to JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
daquintero committed Jul 19, 2023
1 parent 5a4d655 commit dad4325
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 46 deletions.
4 changes: 2 additions & 2 deletions docs/examples/05_quantum_integration_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@

# The way this works is straightforward:

piel.unitary_permanent(s_parameters_standard_matrix)

# We might want to calculate the permanent of subsections of the larger unitary to calculate certain operations probability:

s_parameters_standard_matrix.shape
Expand All @@ -94,3 +92,5 @@
jax_array

jax_array.at[jnp.array([0, 1])].get()

jnp.ndarray
20 changes: 1 addition & 19 deletions piel/config.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,13 @@
"""
We create a set of parameters that can be used throughout the project for optimisation.
The numerical solver is normally delegated for as `numpy` but there are cases where a much faster solver is desired, and where different functioanlity is required. For example, `sax` uses `JAX` for its numerical solver. In this case, we will create a global numerical solver that we can use throughout the project, and that can be extended and solved accordingly for the particular project requirements.
The numerical solver is jax and is imported throughout the module.
"""
import pathlib
import sys
import types

__all__ = [
"numerical_solver",
"nso",
"piel_path_types",
]


if "jax" in sys.modules:
import jax.numpy as jnp

numerical_solver = jnp
elif "numpy" in sys.modules:
import numpy

numerical_solver = numpy
else:
import numpy

numerical_solver = numpy

nso = numerical_solver
piel_path_types = str | pathlib.Path | types.ModuleType
3 changes: 2 additions & 1 deletion piel/file_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pandas as pd
from pyDigitalWaveTools.vcd.parser import VcdParser
from .file_system import return_path
from .config import piel_path_types

Expand All @@ -19,6 +18,8 @@ def read_csv_to_pandas(file_path: piel_path_types):


def read_vcd_to_json(file_path: piel_path_types):
from pyDigitalWaveTools.vcd.parser import VcdParser

file_path = return_path(file_path)
with open(str(file_path.resolve())) as vcd_file:
vcd = VcdParser()
Expand Down
8 changes: 4 additions & 4 deletions piel/integration/sax_qutip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import qutip # NOQA : F401
import sax
from ..config import nso
import jax.numpy as jnp
from piel.tools.sax.utils import sax_to_s_parameters_standard_matrix

__all__ = [
Expand All @@ -11,7 +11,7 @@


def matrix_to_qutip_qobj(
s_parameters_standard_matrix: nso.ndarray,
s_parameters_standard_matrix: jnp.ndarray,
):
"""
This function converts the calculated S-parameters into a standard Unitary matrix topology so that the shape and
Expand Down Expand Up @@ -99,12 +99,12 @@ def sax_to_ideal_qutip_unitary(sax_input: sax.SType):
return qobj_unitary


def verify_matrix_is_unitary(matrix: nso.ndarray) -> bool:
def verify_matrix_is_unitary(matrix: jnp.ndarray) -> bool:
"""
Verify that the matrix is unitary.
Args:
matrix (nso.ndarray): The matrix to verify.
matrix (jnp.ndarray): The matrix to verify.
Returns:
bool: True if the matrix is unitary, False otherwise.
Expand Down
6 changes: 3 additions & 3 deletions piel/integration/sax_thewalrus.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sax
import time
import thewalrus
from ..config import nso
import jax.numpy as jnp
from ..tools.sax import sax_to_s_parameters_standard_matrix
from typing import Optional

Expand Down Expand Up @@ -33,7 +33,7 @@ def sax_circuit_permanent(


def subunitary_selection(
unitary_matrix: nso.ndarray,
unitary_matrix: jnp.ndarray,
stop_index: tuple,
start_index: Optional[tuple] = (0, 0),
):
Expand All @@ -47,7 +47,7 @@ def subunitary_selection(


def unitary_permanent(
unitary_matrix: nso.ndarray,
unitary_matrix: jnp.ndarray,
) -> tuple:
"""
The permanent of a unitary is used to determine the state probability of combinatorial Gaussian boson samping systems.
Expand Down
6 changes: 3 additions & 3 deletions piel/models/frequency/photonic/directional_coupler_length.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Translated from https://github.com/flaport/sax or https://github.com/flaport/photontorch/tree/master
"""
import jax.numpy as jnp
import sax
from ....config import nso

__all__ = ["directional_coupler_with_length"]

Expand All @@ -13,8 +13,8 @@ def directional_coupler_with_length(
kappa = coupling**0.5
tau = (1 - coupling) ** 0.5
loss = 10 ** (-loss * length / 20) # factor 20 bc amplitudes, not intensities.
cos_phase = nso.cos(phase)
sin_phase = nso.sin(phase)
cos_phase = jnp.cos(phase)
sin_phase = jnp.sin(phase)
sdict = sax.reciprocal(
{
("port0", "port1"): tau * loss * cos_phase,
Expand Down
6 changes: 3 additions & 3 deletions piel/models/frequency/photonic/grating_coupler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""
Translated from https://github.com/flaport/sax or https://github.com/flaport/photontorch/tree/master
"""
from ....config import nso
import jax.numpy as jnp

__all__ = ["grating_coupler_simple"]


def grating_coupler_simple(R=0.0, R_in=0.0, Tmax=1.0, bandwidth=0.06e-6, wl0=1.55e-6):
# Constants
fwhm2sigma = 1.0 / (2 * nso.sqrt(2 * nso.log(2)))
fwhm2sigma = 1.0 / (2 * jnp.sqrt(2 * jnp.log(2)))

# Compute sigma
sigma = fwhm2sigma * bandwidth
Expand All @@ -17,7 +17,7 @@ def grating_coupler_simple(R=0.0, R_in=0.0, Tmax=1.0, bandwidth=0.06e-6, wl0=1.5
wls = wl0

# Compute loss
loss = nso.sqrt(Tmax * nso.exp(-((wl0 - wls) ** 2) / (2 * sigma**2)))
loss = jnp.sqrt(Tmax * jnp.exp(-((wl0 - wls) ** 2) / (2 * sigma**2)))

# Create scattering dictionary
sdict = {
Expand Down
14 changes: 7 additions & 7 deletions piel/models/frequency/photonic/straight_waveguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Translated from https://github.com/flaport/sax or https://github.com/flaport/photontorch/tree/master
"""
import sax
from ....config import nso
import jax.numpy as jnp

__all__ = ["ideal_active_waveguide", "waveguide", "simple_straight"]

Expand All @@ -11,9 +11,9 @@ def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):
dwl = wl - wl0
dneff_dwl = (ng - neff) / wl0
neff = neff - dwl * dneff_dwl
phase = 2 * nso.pi * neff * length / wl
amplitude = nso.asarray(10 ** (-loss * length / 20), dtype=complex)
transmission = amplitude * nso.exp(1j * phase)
phase = 2 * jnp.pi * neff * length / wl
amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
transmission = amplitude * jnp.exp(1j * phase)
sdict = sax.reciprocal({("o1", "o2"): transmission})
return sdict

Expand All @@ -24,9 +24,9 @@ def ideal_active_waveguide(
dwl = wl - wl0
dneff_dwl = (ng - neff) / wl0
neff = neff - dwl * dneff_dwl
phase = (2 * nso.pi * neff * length / wl) + active_phase_rad
amplitude = nso.asarray(10 ** (-loss * length / 20), dtype=complex)
transmission = amplitude * nso.exp(1j * phase)
phase = (2 * jnp.pi * neff * length / wl) + active_phase_rad
amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
transmission = amplitude * jnp.exp(1j * phase)
sdict = sax.reciprocal({("o1", "o2"): transmission})
return sdict

Expand Down
4 changes: 2 additions & 2 deletions piel/models/physical/geometry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ...config import nso
import jax.numpy as jnp

__all__ = ["calculate_cross_sectional_area_m2"]

Expand All @@ -15,4 +15,4 @@ def calculate_cross_sectional_area_m2(
Returns:
float: Cross sectional area in meters squared.
"""
return nso.pi * (diameter_m**2) / 4
return jnp.pi * (diameter_m**2) / 4
4 changes: 2 additions & 2 deletions piel/models/physical/thermal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ...config import nso
import jax.numpy as jnp

__all__ = [
"heat_transfer_1d_W",
Expand All @@ -8,7 +8,7 @@
def heat_transfer_1d_W(
thermal_conductivity_fit, temperature_range_K, cross_sectional_area_m2, length_m
) -> float:
thermal_conductivity_integral_area = nso.trapz(
thermal_conductivity_integral_area = jnp.trapz(
thermal_conductivity_fit, temperature_range_K
)
return cross_sectional_area_m2 * thermal_conductivity_integral_area / length_m
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ hdl21
ipyevents
ipytree
ipywidgets>=7.6.0,<9
jax
jupyter_bokeh
jupyter_packaging~=0.7.9
jupyterlab~=3.0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"Click>=7.0",
"cocotb",
"hdl21",
"jax",
"gdsfactory",
"networkx",
"openlane",
Expand Down

0 comments on commit dad4325

Please sign in to comment.