diff --git a/.bumpversion.cfg b/.bumpversion.cfg new file mode 100644 index 0000000..4eb4564 --- /dev/null +++ b/.bumpversion.cfg @@ -0,0 +1,15 @@ +[bumpversion] +current_version = 0.6.0 +commit = True +tag = True +tag_name = {new_version} + +[bumpversion:file:setup.cfg] + +[bumpversion:file:settings.ini] + +[bumpversion:file:environment.yml] + +[bumpversion:file:nbs/99_init.ipynb] + +[bumpversion:file:sax/__init__.py] diff --git a/.gitignore b/.gitignore index cde6ff9..7ddd2fa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,23 @@ +# IDE settings +.vscode/ +.idea/ + +# nbdev +docs/ +.jekyll-cache/ +Gemfile.lock +*.bak +.gitattributes +.last_checked +.gitconfig +*.bak +*.log +*~ +~* +_tmp* +tmp* +tags + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -128,3 +148,7 @@ dmypy.json # Pyre type checker .pyre/ + +# Pyright +node_modules +package-lock.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100755 index 0000000..6504654 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,27 @@ +repos: + - repo: local + hooks: + - id: nbdev_clean_nbs + name: nbdev_clean_nbs + entry: nbdev_clean_nbs + language: python + pass_filenames: false + additional_dependencies: + - nbdev + - repo: local + hooks: + - id: nbstripout + name: nbstripout + entry: nbstripout + language: python + pass_filenames: true + types: [jupyter] + - repo: local + hooks: + - id: nbdev_diff_nbs + name: nbdev_diff_nbs + entry: bash -c 'if [ ! -z "$(nbdev_diff_nbs)" ]; then nbdev_build_lib; fi' + language: python + pass_filenames: false + additional_dependencies: + - nbdev diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..67747b1 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,52 @@ +# How to contribute + +## How to get started + +Before anything else, please install the git hooks that run automatic scripts during +each commit and merge to strip the notebooks of superfluous metadata (and avoid merge +conflicts). After cloning the repository, run the following commands inside the sax +conda environment: + +``` +nbdev_install_git_hooks +pre-commit install +``` + +## Did you find a bug? + +- Ensure the bug was not already reported by searching on GitHub under Issues. +- If you're unable to find an open issue addressing the problem, open a new one. Be sure + to include a title and clear description, as much relevant information as possible, and + a code sample or an executable test case demonstrating the expected behavior that is not + occurring. +- Be sure to add the complete error messages. + +#### Did you write a patch that fixes a bug? + +- Open a new GitHub pull request with the patch. +- Ensure that your PR includes a test that fails without your patch, and pass with it. +- Ensure the PR description clearly describes the problem and solution. Include the + relevant issue number if applicable. + +## PR submission guidelines + +- Keep each PR focused. While it's more convenient, do not combine several unrelated + fixes together. Create as many branches as needing to keep each PR focused. +- Do not mix style changes/fixes with "functional" changes. It's very difficult to + review such PRs and it most likely get rejected. +- Do not add/remove vertical whitespace. Preserve the original style of the file you + edit as much as you can. +- Do not turn an already submitted PR into your development playground. If after you + submitted PR, you discovered that more work is needed - close the PR, do the required + work and then submit a new PR. Otherwise each of your commits requires attention from + maintainers of the project. +- If, however, you submitted a PR and received a request for changes, you should proceed + with commits inside that PR, so that the maintainer can see the incremental fixes and + won't need to review the whole PR again. In the exception case where you realize it'll + take many many commits to complete the requests, then it's probably best to close the + PR, do the work and then submit it again. Use common sense where you'd choose one way + over another. + +## Do you want to contribute to the documentation? + +- Docs are automatically created from the notebooks in the nbs folder. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..188c5ad --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM condaforge/mambaforge:4.11.0-0 + +COPY environment.yml /environment.yml +RUN mamba env update -n base -f /environment.yml +RUN conda run -n base python -m ipykernel install --user --name base --display-name base +RUN rm -rf /environment.yml + +COPY docs/nbdev_showdoc.patch /nbdev_showdoc.patch +RUN patch -R $(python -c "from nbdev import showdoc; print(showdoc.__file__)") < /nbdev_showdoc.patch +RUN rm -rf /nbdev_showdoc.patch + +ADD . /sax +RUN pip install /sax +RUN rm -rf /sax + +RUN conda create -n sax --clone base +RUN conda run -n sax python -m ipykernel install --user --name sax --display-name sax diff --git a/LICENSE b/LICENSE index 261eeb9..168bf32 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2022 Floris Laporte Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1c94b66 --- /dev/null +++ b/Makefile @@ -0,0 +1,67 @@ +.ONESHELL: +SHELL := /bin/bash +SRC = $(wildcard nbs/*.ipynb) + +all: sax docs + +.SILENT: docker +docker: + sed -i "s/^[ ]*-[ ]\bsax\b.*//g" environment.yml + sed -i "s/\$${RPYPI_USER}/${RPYPI_USER}/g" environment.yml + sed -i "s/\$${RPYPI_TOKEN}/${RPYPI_TOKEN}/g" environment.yml + -docker build . -t sax + git checkout -- environment.yml + +sax: $(SRC) + nbdev_build_lib + +lib: + nbdev_build_lib + +sync: + nbdev_update_lib + +serve: + cd docs && bundle exec jekyll serve + +.PHONY: docs +docs: + jupyter nbconvert --execute --inplace index.ipynb + nbdev_build_docs + +run: + find . -name "*.ipynb" | grep -v ipynb_checkpoints | xargs -I {} papermill {} {} + +test: + nbdev_test_nbs + +release: pypi conda_release + nbdev_bump_version + +conda_release: + fastrelease_conda_package + +pypi: dist + twine upload --repository pypi dist/* + +dist: clean + python -m build --sdist --wheel + +clean: + nbdev_clean_nbs + find . -name "*.ipynb" | xargs nbstripout + find . -name "dist" | xargs rm -rf + find . -name "build" | xargs rm -rf + find . -name "builds" | xargs rm -rf + find . -name "__pycache__" | xargs rm -rf + find . -name "*.so" | xargs rm -rf + find . -name "*.egg-info" | xargs rm -rf + find . -name ".ipynb_checkpoints" | xargs rm -rf + find . -name ".pytest_cache" | xargs rm -rf + +reset: + rm -rf sax + rm -rf docs + git checkout -- docs + nbdev_build_lib + make clean diff --git a/README.md b/README.md index bb48b5c..dce2c6e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # SAX +> S + Autograd + XLA -![Docs](https://readthedocs.org/projects/sax/badge/?version=latest) + +![](docs/images/sax.svg) Autograd and XLA for S-parameters - a scatter parameter circuit simulator and optimizer for the frequency domain based on [JAX](https://github.com/google/jax). @@ -15,9 +17,8 @@ dive in... ## Quick Start -[Full Quick Start page](https://sax.readthedocs.io/en/latest/examples/01_quick_start.html) - -[Examples](https://sax.readthedocs.io/en/latest/examples.html) - -[Full Docs](https://sax.readthedocs.io/en/latest/index.html). +[Full Quick Start page](https://flaport.github.io/sax/quick_start) - +[Documentation](https://flaport.github.io/sax). Let's first import the SAX library, along with JAX and the JAX-version of numpy: @@ -27,105 +28,120 @@ import jax import jax.numpy as jnp ``` -Define a model -- which is just a port combination -> function dictionary -- for your -component. For example a directional coupler: +Define a model function for your component. A SAX model is just a function that returns +an 'S-dictionary'. For example a directional coupler: ```python -directional_coupler = { - ("p0", "p1"): lambda params: (1 - params["coupling"]) ** 0.5, - ("p1", "p0"): lambda params: (1 - params["coupling"]) ** 0.5, - ("p2", "p3"): lambda params: (1 - params["coupling"]) ** 0.5, - ("p3", "p2"): lambda params: (1 - params["coupling"]) ** 0.5, - ("p0", "p2"): lambda params: 1j * params["coupling"] ** 0.5, - ("p2", "p0"): lambda params: 1j * params["coupling"] ** 0.5, - ("p1", "p3"): lambda params: 1j * params["coupling"] ** 0.5, - ("p3", "p1"): lambda params: 1j * params["coupling"] ** 0.5, - "default_params": { - "coupling": 0.5 - }, -} +def coupler(coupling=0.5): + kappa = coupling**0.5 + tau = (1-coupling)**0.5 + sdict = sax.reciprocal({ + ("in0", "out0"): tau, + ("in0", "out1"): 1j*kappa, + ("in1", "out0"): 1j*kappa, + ("in1", "out1"): tau, + }) + return sdict + +coupler(coupling=0.3) ``` + + + + {('in0', 'out0'): 0.8366600265340756, + ('in0', 'out1'): 0.5477225575051661j, + ('in1', 'out0'): 0.5477225575051661j, + ('in1', 'out1'): 0.8366600265340756, + ('out0', 'in0'): 0.8366600265340756, + ('out1', 'in0'): 0.5477225575051661j, + ('out0', 'in1'): 0.5477225575051661j, + ('out1', 'in1'): 0.8366600265340756} + + + Or a waveguide: ```python -def model_waveguide_transmission(params): - neff = params["neff"] - dwl = params["wl"] - params["wl0"] - dneff_dwl = (params["ng"] - params["neff"]) / params["wl0"] +def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0): + dwl = wl - wl0 + dneff_dwl = (ng - neff) / wl0 neff = neff - dwl * dneff_dwl - phase = jnp.exp( - jnp.log(2 * jnp.pi * neff * params["length"]) - jnp.log(params["wl"]) - ) - return 10 ** (-params["loss"] * params["length"] / 20) * jnp.exp(1j * phase) - -waveguide = { - ("in", "out"): model_waveguide_transmission, - ("out", "in"): model_waveguide_transmission, - "default_params": { - "length": 25e-6, - "wl": 1.55e-6, - "wl0": 1.55e-6, - "neff": 2.34, - "ng": 3.4, - "loss": 0.0, - }, -} + phase = 2 * jnp.pi * neff * length / wl + amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex) + transmission = amplitude * jnp.exp(1j * phase) + sdict = sax.reciprocal({("in0", "out0"): transmission}) + return sdict + +waveguide(length=100.0) ``` -These component model dictionaries can be combined into a circuit model dictionary: + + + + {('in0', 'out0'): 0.97953-0.2013j, ('out0', 'in0'): 0.97953-0.2013j} + + + +These component models can then be combined into a circuit: ```python mzi = sax.circuit( - models = { - "dc1": directional_coupler, + instances = { + "lft": coupler, "top": waveguide, - "dc2": directional_coupler, - "btm": waveguide, + "rgt": coupler, }, connections={ - "dc1:p2": "top:in", - "dc1:p1": "btm:in", - "top:out": "dc2:p3", - "btm:out": "dc2:p0", + "lft,out0": "rgt,in0", + "lft,out1": "top,in0", + "top,out0": "rgt,in1", }, ports={ - "dc1:p3": "in2", - "dc1:p0": "in1", - "dc2:p2": "out2", - "dc2:p1": "out1", + "in0": "lft,in0", + "in1": "lft,in1", + "out0": "rgt,out0", + "out1": "rgt,out1", }, ) + +type(mzi) ``` -Simulating this is as simple as modifying the default parameters: + + + + function + + + +As you can see, the mzi we just created is just another component model function! To simulate it, call the mzi function with the (possibly nested) settings of its subcomponents. Global settings can be added to the 'root' of the circuit call and will be distributed over all subcomponents which have a parameter with the same name (e.g. 'wl'): ```python -params = sax.copy_params(mzi["default_params"]) -params["top"]["length"] = 2.5e-5 -params["btm"]["length"] = 1.5e-5 -mzi["in1", "out1"](params) +wl = jnp.linspace(1.53, 1.57, 1000) +result = mzi(wl=wl, lft={'coupling': 0.3}, top={'length': 200.0}, rgt={'coupling': 0.8}) + +plt.plot(1e3*wl, jnp.abs(result['in0', 'out0'])**2, label="in0->out0") +plt.plot(1e3*wl, jnp.abs(result['in0', 'out1'])**2, label="in0->out1", ls="--") +plt.xlabel("λ [nm]") +plt.ylabel("T") +plt.grid(True) +plt.figlegend(ncol=2, loc="upper center") +plt.show() ``` -``` -DeviceArray(-0.280701+0.10398856j, dtype=complex64) -``` + +![png](docs/images/output_10_0.png) + Those are the basics. For more info, check out the **full** -[SAX Quick Start page](https://sax.readthedocs.io/en/latest/examples/01_quick_start.html), -the [Examples](https://sax.readthedocs.io/en/latest/examples.html) -or the -[Documentation](https://sax.readthedocs.io/en/latest/index.html). +[SAX Quick Start page](https://flaport.github.io/sax/quick_start) or the rest of the [Documentation](https://flaport.github.io/sax). ## Installation ### Dependencies -- [JAX & JAXLIB](https://github.com/google/jax). Please read the JAX install - instructions [here](https://github.com/google/jax/#installation). Alternatively, you can - try running [jaxinstall.sh](jaxinstall.sh) to automatically pip-install the correct - `jax` and `jaxlib` package for your python and cuda version (if that exact combination - exists). +- [JAX & JAXLIB](https://github.com/google/jax). Please read the JAX install instructions [here](https://github.com/google/jax/#installation). ### Installation diff --git a/docs/Gemfile b/docs/Gemfile new file mode 100644 index 0000000..6bb45f4 --- /dev/null +++ b/docs/Gemfile @@ -0,0 +1,5 @@ +source "https://rubygems.org" +gem 'github-pages', "= 222", group: :jekyll_plugins +gem "jekyll", "= 3.9" +gem "kramdown", "= 2.3.1" +gem "jekyll-remote-theme", "= 0.4.3" diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index d0c3cbf..0000000 --- a/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source -BUILDDIR = build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_config.yml b/docs/_config.yml new file mode 100644 index 0000000..7f0a191 --- /dev/null +++ b/docs/_config.yml @@ -0,0 +1,66 @@ +repository: flaport/sax +output: web +topnav_title: sax +site_title: sax +company_name: Floris Laporte (Apache 2.0) +description: S + Autograd + XLA +# Set to false to disable KaTeX math +use_math: true +# Add Google analytics id if you have one and want to use it here +google_analytics: +# See http://nbdev.fast.ai/search for help with adding Search +google_search: + +host: 127.0.0.1 +# the preview server used. Leave as is. +port: 4000 +# the port where the preview is rendered. + +exclude: + - .idea/ + - .gitignore + - vendor + +exclude: [vendor] + +highlighter: rouge +markdown: kramdown +kramdown: + input: GFM + auto_ids: true + hard_wrap: false + syntax_highlighter: rouge + +collections: + tooltips: + output: false + +defaults: + - + scope: + path: "" + type: "pages" + values: + layout: "page" + comments: true + search: true + sidebar: home_sidebar + topnav: topnav + - + scope: + path: "" + type: "tooltips" + values: + layout: "page" + comments: true + search: true + tooltip: true + +sidebars: +- home_sidebar + +plugins: + - jekyll-remote-theme + +remote_theme: fastai/nbdev-jekyll-theme +baseurl: /sax/ \ No newline at end of file diff --git a/docs/_data/topnav.yml b/docs/_data/topnav.yml new file mode 100644 index 0000000..f0bd0e2 --- /dev/null +++ b/docs/_data/topnav.yml @@ -0,0 +1,10 @@ +topnav: +- title: Topnav + items: + - title: GitHub + external_url: https://github.com/flaport/sax/ + +#Topnav dropdowns +topnav_dropdowns: +- title: Topnav dropdowns + folders: diff --git a/docs/assets/css/theme-blue.css b/docs/assets/css/theme-blue.css new file mode 100755 index 0000000..b6f9230 --- /dev/null +++ b/docs/assets/css/theme-blue.css @@ -0,0 +1,121 @@ +.summary { + color: #808080; + border-left: 5px solid #5119ed; + font-size:16px; +} + + +h3 {color: #000000; } +h4 {color: #000000; } + +.nav-tabs > li.active > a, .nav-tabs > li.active > a:hover, .nav-tabs > li.active > a:focus { + background-color: #7070df; + color: white; +} + +.nav > li.active > a { + background-color: #8080ff; +} + +.nav > li > a:hover { + background-color: #7070df; +} + +div.navbar-collapse .dropdown-menu > li > a:hover { + background-color: #8080ff; +} + +.nav li.thirdlevel > a { + background-color: #FAFAFA !important; + color: #7070df; + font-weight: bold; +} + +a[data-toggle="tooltip"] { + color: #649345; + font-style: italic; + cursor: default; +} + +.navbar-inverse { + background-color: #8080ff; + border-color: #404080; +} +.navbar-inverse .navbar-nav>li>a, .navbar-inverse .navbar-brand { + color: white; +} + +.navbar-inverse .navbar-nav>li>a:hover, a.fa.fa-home.fa-lg.navbar-brand:hover { + color: #f0f0f0; +} + +a.navbar-brand:hover { + color: #f0f0f0; +} + +.navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus { + color: #404080; +} + +.navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus { + background-color: #404080; + color: #ffffff; +} + +.navbar-inverse .navbar-collapse, .navbar-inverse .navbar-form { + border-color: #7070df !important; +} + +.btn-primary { + color: #ffffff; + background-color: #8080ff; + border-color: #8080ff; +} + +.navbar-inverse .navbar-nav > .active > a, .navbar-inverse .navbar-nav > .active > a:hover, .navbar-inverse .navbar-nav > .active > a:focus { + background-color: #8080ff; +} + +.btn-primary:hover, +.btn-primary:focus, +.btn-primary:active, +.btn-primary.active, +.open .dropdown-toggle.btn-primary { + background-color: #7070df; + border-color: #8080ff; +} + +.printTitle { + color: #404080 !important; +} + +body.print h1 {color: #404080 !important; font-size:28px !important;} +body.print h2 {color: #595959 !important; font-size:20px !important;} +body.print h3 {color: #E50E51 !important; font-size:14px !important;} +body.print h4 {color: #679DCE !important; font-size:14px; font-style: italic !important;} + +.anchorjs-link:hover { + color: #404080; +} + +div.sidebarTitle { + color: #404080; +} + +li.sidebarTitle { + margin-top:20px; + font-weight:normal; + font-size:130%; + color: #5119ed; + margin-bottom:10px; + margin-left: 5px; + +} + +.navbar-inverse .navbar-toggle:focus, .navbar-inverse .navbar-toggle:hover { + background-color: #404080; +} + +.navbar-inverse .navbar-toggle { + border-color: #404080; +} \ No newline at end of file diff --git a/docs/assets/images/company_logo.png b/docs/assets/images/company_logo.png new file mode 100755 index 0000000..7b7abd5 Binary files /dev/null and b/docs/assets/images/company_logo.png differ diff --git a/docs/assets/images/favicon.ico b/docs/assets/images/favicon.ico new file mode 100755 index 0000000..908ac5b Binary files /dev/null and b/docs/assets/images/favicon.ico differ diff --git a/docs/images/output_10_0.png b/docs/images/output_10_0.png new file mode 100644 index 0000000..9b266bd Binary files /dev/null and b/docs/images/output_10_0.png differ diff --git a/docs/images/sax.svg b/docs/images/sax.svg new file mode 100644 index 0000000..f837e53 --- /dev/null +++ b/docs/images/sax.svg @@ -0,0 +1,372 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index 6247f7e..0000000 --- a/docs/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/docs/nbdev_showdoc.patch b/docs/nbdev_showdoc.patch new file mode 100644 index 0000000..440a6c2 --- /dev/null +++ b/docs/nbdev_showdoc.patch @@ -0,0 +1,20 @@ +224c224 +< return s.replace('_', r'\_') +--- +> return s.replace('_', '\_') +232,235c232 +< return_anno = f"->{type_repr(sig.return_annotation).replace('_empty', 'None')}" +< except: +< fmt_params = [] +< return_anno = '' +--- +> except: fmt_params = [] +239c236 +< return f'{f_name}',f'{name}{arg_str}{return_anno}' +--- +> return f'{f_name}',f'{name}{arg_str}' +305c302 +< except: display(Markdown(md)) +--- +> except: display(Markdown(md)) +\ No newline at end of file diff --git a/docs/sidebar.json b/docs/sidebar.json new file mode 100644 index 0000000..68c9b95 --- /dev/null +++ b/docs/sidebar.json @@ -0,0 +1,35 @@ +{ + "welcome": { + "welcome": "/" + }, + "examples": { + "1. quick start": "quick_start", + "2. all pass filter": "all_pass_filter", + "3. circuit from yaml": "circuit_from_yaml", + "4. multimode simulations": "multimode_simulations", + "5. thin film optimizatioon": "thinfilm", + "6. pathlengths and group delays": "additive_backend" + }, + "core api": { + "1. sax types": "typing", + "2. utilities": "utils", + "3. multimode": "multimode", + "4. models": "models", + "5. netlist": "netlist", + "6. circuit": "circuit", + "7. caching": "caching" + }, + "backends": { + "1. overview": "backends", + "2. default backend": "backends_default", + "3. klu backend": "backends_klu", + "4. additive backend": "backends_additive" + }, + "neural networks": { + "1. overview": "nn", + "2. loss functions": "nn_loss", + "3. utilities": "nn_utils", + "4. neural networks": "nn_core", + "5. i/o": "nn_io" + } +} diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index e04d959..0000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,70 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys -import shutil -# sys.path.insert(0, os.path.abspath('.')) - - -# -- Project information ----------------------------------------------------- - -project = 'SAX' -copyright = '2021, Floris Laporte' -author = 'Floris Laporte' - -# The full version, including alpha/beta/rc tags -release = '0.0.0' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.napoleon", - "sphinx.ext.mathjax", - "sphinx.ext.viewcode", - "nbsphinx", - "sphinx_rtd_theme", -] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# -- Examples Folder --------------------------------------------------------- - -sourcedir = os.path.dirname(__file__) -examples_src = os.path.abspath(os.path.join(sourcedir, "..", "..", "examples")) -examples_dst = os.path.abspath(os.path.join(sourcedir, "examples")) -shutil.rmtree(examples_dst, ignore_errors=True) -shutil.copytree(examples_src, examples_dst) - diff --git a/docs/source/examples.rst b/docs/source/examples.rst deleted file mode 100644 index c1d4984..0000000 --- a/docs/source/examples.rst +++ /dev/null @@ -1,10 +0,0 @@ -.. _examples: - -Examples -======== - -.. toctree:: - :maxdepth: 1 - - examples/01_quick_start.ipynb - examples/02_thinfilm.ipynb diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index 6c531a2..0000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,54 +0,0 @@ -SAX -==== - -:ref:`genindex` · :ref:`modindex` · :ref:`search` - -Autograd and XLA for S-parameters - a scatter parameter circuit simulator and -optimizer for the frequency domain based on `JAX `_ - -The simulator was developed for simulating Photonic Integrated Circuits but in fact is -able to perform any S-parameter based circuit simulation. The goal of SAX is to be a -thin wrapper around JAX with some basic tools for S-parameter based circuit simulation -and optimization. Therefore, SAX does not define any special datastructures and tries to -stay as close as possible to the functional nature of JAX. This makes it very easy to -get started with SAX as you only need functions and standard python dictionaries. Let's -dive in... - -Table Of Contents ------------------ - -.. toctree:: - :maxdepth: 2 - - examples - sax - - -Installation ------------- - - -Dependencies -~~~~~~~~~~~~ - -- `JAX & JAXLIB `__. Please read the JAX - install instructions - `here `__. - Alternatively, you can try running `jaxinstall.sh `__ - to automatically pip-install the correct ``jax`` and ``jaxlib`` - package for your python and cuda version (if that exact combination - exists). - - -Installation -~~~~~~~~~~~~ - -:: - - pip install sax - - -License -------- - -Copyright © 2021, Floris Laporte, `Apache-2.0 License `__ diff --git a/docs/source/sax.rst b/docs/source/sax.rst deleted file mode 100644 index 9dfb6e8..0000000 --- a/docs/source/sax.rst +++ /dev/null @@ -1,39 +0,0 @@ -SAX module docs -================ - -.. automodule:: sax - :members: - :undoc-members: - :show-inheritance: - -sax.models module ------------------- - -.. automodule:: sax.models - :members: - :undoc-members: - :show-inheritance: - -sax.core module ----------------- - -.. automodule:: sax.core - :members: - :undoc-members: - :show-inheritance: - -sax.utils module ------------------ - -.. automodule:: sax.utils - :members: - :undoc-members: - :show-inheritance: - -sax.constants module ---------------------- - -.. automodule:: sax.constants - :members: - :undoc-members: - :show-inheritance: diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..31e6f77 --- /dev/null +++ b/environment.yml @@ -0,0 +1,42 @@ +name: sax +channels: + - conda-forge + - fastai +dependencies: + - make==4.3 + - patch==2.7.6 + - suitesparse==5.10.1 + - python==3.9 + - pip==22.0.2 + - black==22.1.0 + - bokeh==2.4.2 + - build==0.7.0 + - bump2version==1.0.1 + - fastcore==1.3.27 + - h5py==3.6.0 + - ipykernel==6.7.0 + - ipympl==0.8.7 + - ipywidgets==7.6.5 + - jupyterlab==3.2.8 + - matplotlib==3.5.1 + - natsort==8.1.0 + - nbdev==1.1.23 + - nbstripout==0.5.0 + - networkx==2.6.3 + - numpy==1.22.1 + - pandas==1.4.0 + - papermill==2.3.4 + - pre-commit==2.17.0 + - pyright==0.0.13 + - pytables==3.7.0 + - pyyaml==6.0 + - scipy==1.7.3 + - shapely==1.8.0 + - tqdm==4.62.3 + - twine==0.0.0 + - pip: + - flax==0.4.0 + - jax==0.2.27 + - jaxlib==0.1.76 + - tmm==0.1.8 + - sax==0.6.0 diff --git a/examples/01_quick_start.ipynb b/examples/01_quick_start.ipynb index 4748a46..4bf0aa9 100644 --- a/examples/01_quick_start.ipynb +++ b/examples/01_quick_start.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "ab894499", "metadata": {}, "source": [ "# SAX Quick Start" @@ -9,6 +10,7 @@ }, { "cell_type": "markdown", + "id": "ac603874", "metadata": {}, "source": [ "Let's go over the core functionality of SAX." @@ -16,467 +18,380 @@ }, { "cell_type": "markdown", + "id": "3c91396d", "metadata": {}, "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import tqdm\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import jax.experimental.optimizers as opt\n", - "\n", - "# sax circuit simulator\n", - "import sax" + "## Environment variables" ] }, { "cell_type": "markdown", + "id": "79240f32", "metadata": {}, "source": [ - "## Models" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Models are simply defined by a single function per S-matrix combination. This function takes a dictionary of parameters as single argument. For example a directional coupler:" + "SAX is based on JAX... here are some useful environment variables for working with JAX:" ] }, { "cell_type": "code", "execution_count": null, + "id": "40ce30d2", "metadata": {}, "outputs": [], "source": [ - "def model_directional_coupler_coupling(params):\n", - " return 1j * params[\"coupling\"] ** 0.5\n", + "# select float32 or float64 as default dtype\n", + "%env JAX_ENABLE_X64=0\n", + "\n", + "# select cpu or gpu\n", + "%env JAX_PLATFORM_NAME=cpu\n", "\n", - "def model_directional_coupler_transmission(params):\n", - " return (1 - params[\"coupling\"]) ** 0.5" + "# set custom CUDA location for gpu:\n", + "%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda\n", + "\n", + "# Using GPU?\n", + "from jax.lib import xla_bridge\n", + "print(xla_bridge.get_backend().platform)" ] }, { "cell_type": "markdown", + "id": "4f9a078b", "metadata": {}, "source": [ - "These model functions can then be combined into a dictionary, which basically defines the full S-matrix for a directional coupler which is defined as follows:\n", - "\n", - "```\n", - " p3 p2\n", - " \\ /\n", - " ========\n", - " / \\\n", - " p0 p1\n", - "```" + "## Imports" ] }, { "cell_type": "code", "execution_count": null, + "id": "e0450fde", "metadata": {}, "outputs": [], "source": [ - "directional_coupler = {\n", - " (\"p0\", \"p1\"): model_directional_coupler_transmission,\n", - " (\"p1\", \"p0\"): model_directional_coupler_transmission,\n", - " (\"p2\", \"p3\"): model_directional_coupler_transmission,\n", - " (\"p3\", \"p2\"): model_directional_coupler_transmission,\n", - " (\"p0\", \"p2\"): model_directional_coupler_coupling,\n", - " (\"p2\", \"p0\"): model_directional_coupler_coupling,\n", - " (\"p1\", \"p3\"): model_directional_coupler_coupling,\n", - " (\"p3\", \"p1\"): model_directional_coupler_coupling,\n", - " \"default_params\": {\n", - " \"coupling\": 0.5\n", - " },\n", - "}" + "import jax\n", + "import jax.example_libraries.optimizers as opt\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt # plotting\n", + "import sax\n", + "import tqdm # progress bars" ] }, { "cell_type": "markdown", + "id": "8c4d4c4e", "metadata": {}, "source": [ - "Any non-existing S-matrix combination (for example `(\"p0\", \"p3\")`) is considered to be zero. Moreover, default parameters can be defined for the full component by specifying the `\"default_params\"` key in the dictionary. Also note that ALL parameters in the parameter dictionary should be floats!" + "## Scatter *dictionaries*\n", + "The core datastructure for specifying scatter parameters in SAX is a dictionary... more specifically a dictionary which maps a port combination (2-tuple) to a scatter parameter (or an array of scatter parameters when considering multiple wavelengths for example). Such a specific dictionary mapping is called ann `SDict` in SAX (`SDict ≈ Dict[Tuple[str,str], float]`).\n", + "\n", + "Dictionaries are in fact much better suited for characterizing S-parameters than, say, (jax-)numpy arrays due to the inherent sparse nature of scatter parameters. Moreover, dictonaries allow for string indexing, which makes them much more pleasant to use in this context. Let’s for example create an `SDict` for a 50/50 coupler:" ] }, { "cell_type": "markdown", + "id": "a7f90205", "metadata": {}, "source": [ - "We can do the same for a waveguide:\n", - "\n", "```\n", - " in -------- out\n", + "in1 out1\n", + " \\ /\n", + " ========\n", + " / \\\n", + "in0 out0\n", "```" ] }, { "cell_type": "code", "execution_count": null, + "id": "89f215e8", "metadata": {}, "outputs": [], "source": [ - "def model_waveguide_transmission(params):\n", - " neff = params[\"neff\"]\n", - " dwl = params[\"wl\"] - params[\"wl0\"]\n", - " dneff_dwl = (params[\"ng\"] - params[\"neff\"]) / params[\"wl0\"]\n", - " neff = neff - dwl * dneff_dwl\n", - " phase = jnp.exp(\n", - " jnp.log(2 * jnp.pi * neff * params[\"length\"]) - jnp.log(params[\"wl\"])\n", - " )\n", - " return 10 ** (-params[\"loss\"] * params[\"length\"] / 20) * jnp.exp(1j * phase)\n", - "\n", - "waveguide = {\n", - " (\"in\", \"out\"): model_waveguide_transmission,\n", - " (\"out\", \"in\"): model_waveguide_transmission,\n", - " \"default_params\": { # remember that ALL params should be floats!\n", - " \"length\": 25e-6,\n", - " \"wl\": 1.55e-6,\n", - " \"wl0\": 1.55e-6,\n", - " \"neff\": 2.34,\n", - " \"ng\": 3.4,\n", - " \"loss\": 0.0,\n", - " },\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "That's all you have to do to define a component! Also note that all ports of a component can be obtained with `sax.get_ports`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sax.get_ports(directional_coupler)" + "coupling = 0.5\n", + "kappa = coupling ** 0.5\n", + "tau = (1 - coupling) ** 0.5\n", + "coupler_dict = {\n", + " (\"in0\", \"out0\"): tau,\n", + " (\"out0\", \"in0\"): tau,\n", + " (\"in0\", \"out1\"): 1j * kappa,\n", + " (\"out1\", \"in0\"): 1j * kappa,\n", + " (\"in1\", \"out0\"): 1j * kappa,\n", + " (\"out0\", \"in1\"): 1j * kappa,\n", + " (\"in1\", \"out1\"): tau,\n", + " (\"out1\", \"in1\"): tau,\n", + "}\n", + "coupler_dict" ] }, { "cell_type": "markdown", + "id": "bb9cb2b6", "metadata": {}, "source": [ - "And ports can be renamed with `sax.rename_ports`:" + "Only the non-zero port combinations need to be specified. Any non-existent port-combination (for example `(\"in0\", \"in1\")`) is considered to be zero by SAX.\n", + "\n", + "Obviously, it can still be tedious to specify every port in the circuit manually. SAX therefore offers the `reciprocal` function, which auto-fills the reverse connection if the forward connection exist. For example:" ] }, { "cell_type": "code", "execution_count": null, + "id": "494f445d", "metadata": {}, "outputs": [], "source": [ - "directional_coupler2 = sax.rename_ports(\n", - " model=directional_coupler,\n", - " ports={\n", - " \"p0\": \"in1\", \n", - " \"p1\": \"out1\", \n", - " \"p2\": \"out2\", \n", - " \"p3\": \"in2\"\n", + "coupler_dict = sax.reciprocal(\n", + " {\n", + " (\"in0\", \"out0\"): tau,\n", + " (\"in0\", \"out1\"): 1j * kappa,\n", + " (\"in1\", \"out0\"): 1j * kappa,\n", + " (\"in1\", \"out1\"): tau,\n", " }\n", ")\n", - "directional_coupler2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that this NEVER changes anything inplace. The original directional coupler dictionary is still intact:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "directional_coupler" + "\n", + "coupler_dict" ] }, { "cell_type": "markdown", + "id": "67e730f6", "metadata": {}, "source": [ - "## Circuits" + "## Parametrized Models" ] }, { "cell_type": "markdown", + "id": "ec1e1214", "metadata": {}, "source": [ - "Circuits can be created with `sax.circuit`. This function takes three required arguments: `models`, `connections` and `ports`. These are all supposed to be dictionaries. The `models` dictionary describes the individual models and their name in the circuit. Note that a circuit is itself also a model, which allows you to define hierarchical circuits. The `connections` dictionary describes the connections between individual model ports. The model ports are defined as `\"{modelname}:{portname}\"`. Finally, the ports dictionary defines a mapping from the unused ports in the `\"{modelname}:{portname}\"` format back onto a single `\"{portname}\"`.\n", - "\n", - "```\n", - " top\n", - " in ----- out\n", - " in2 <- p3 p2 p3 p2 -> out2\n", - " \\ dc1 / \\ dc2 /\n", - " ======= =======\n", - " / \\ / \\\n", - " in1 <- p0 p1 btm p0 p1 -> out1\n", - " in ----- out\n", - "```" + "Constructing such an `SDict` is easy, however, usually we're more interested in having parametrized models for our components. To parametrize the coupler `SDict`, just wrap it in a function to obtain a SAX `Model`, which is a keyword-only function mapping to an `SDict`:" ] }, { "cell_type": "code", "execution_count": null, + "id": "33efb99f", "metadata": {}, "outputs": [], "source": [ - "mzi = sax.circuit(\n", - " models = {\n", - " \"dc1\": directional_coupler,\n", - " \"top\": waveguide,\n", - " \"dc2\": directional_coupler,\n", - " \"btm\": waveguide,\n", - " },\n", - " connections={\n", - " \"dc1:p2\": \"top:in\",\n", - " \"dc1:p1\": \"btm:in\",\n", - " \"top:out\": \"dc2:p3\",\n", - " \"btm:out\": \"dc2:p0\",\n", - " },\n", - " ports={\n", - " \"dc1:p3\": \"in2\",\n", - " \"dc1:p0\": \"in1\",\n", - " \"dc2:p2\": \"out2\",\n", - " \"dc2:p1\": \"out1\",\n", - " },\n", - ")" + "def coupler(coupling=0.5) -> sax.SDict:\n", + " kappa = coupling ** 0.5\n", + " tau = (1 - coupling) ** 0.5\n", + " coupler_dict = sax.reciprocal(\n", + " {\n", + " (\"in0\", \"out0\"): tau,\n", + " (\"in0\", \"out1\"): 1j * kappa,\n", + " (\"in1\", \"out0\"): 1j * kappa,\n", + " (\"in1\", \"out1\"): tau,\n", + " }\n", + " )\n", + " return coupler_dict\n", + "\n", + "\n", + "coupler(coupling=0.3)" ] }, { "cell_type": "markdown", + "id": "4ef93f65", "metadata": {}, "source": [ - "As you can see, the `mzi` circuit is just a dictionary of individual functions as well:" + "We can define a waveguide in the same way:" ] }, { "cell_type": "code", "execution_count": null, + "id": "5eac9176", "metadata": {}, "outputs": [], "source": [ - "mzi" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see, as for the individual components it's only defined for nonzero connections!" + "def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0) -> sax.SDict:\n", + " dwl = wl - wl0\n", + " dneff_dwl = (ng - neff) / wl0\n", + " neff = neff - dwl * dneff_dwl\n", + " phase = 2 * jnp.pi * neff * length / wl\n", + " transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)\n", + " sdict = sax.reciprocal(\n", + " {\n", + " (\"in0\", \"out0\"): transmission,\n", + " }\n", + " )\n", + " return sdict" ] }, { "cell_type": "markdown", + "id": "67713c18", "metadata": {}, "source": [ - "It also has default parameters for each of its subcomponents:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "params = mzi[\"default_params\"]\n", - "params" + "That's pretty straightforward. Let's now move on to parametrized circuits:" ] }, { "cell_type": "markdown", + "id": "579a9ce1", "metadata": {}, "source": [ - "## Simulating the MZI" + "## Circuit Models" ] }, { "cell_type": "markdown", + "id": "cc3c78a0", "metadata": {}, "source": [ - "To simulate the MZI, we first need to update the parameters. To do this, we first copy the params dictionary after which we can update it inplace:" + "Existing models can now be combined into a circuit using `sax.circuit`, which basically creates a new `Model` function:" ] }, { "cell_type": "code", "execution_count": null, + "id": "d6885e1d", "metadata": {}, "outputs": [], "source": [ - "params = sax.copy_params(params)\n", - "params[\"btm\"][\"length\"] = 1.5e-5 # make the bottom length shorter" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Moreover, we want to simulate over a range of wavelengths. To set the wavelength globally for all subcomponents of the circuit, we use `sax.set_global_params`:" + "mzi = sax.circuit(\n", + " instances={\n", + " \"lft\": coupler,\n", + " \"top\": waveguide,\n", + " \"btm\": waveguide,\n", + " \"rgt\": coupler,\n", + " },\n", + " connections={\n", + " \"lft,out0\": \"btm,in0\",\n", + " \"btm,out0\": \"rgt,in0\",\n", + " \"lft,out1\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"lft,in0\",\n", + " \"in1\": \"lft,in1\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + ")" ] }, { "cell_type": "code", "execution_count": null, + "id": "e95325c6", "metadata": {}, "outputs": [], "source": [ - "params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.51, 1.59, 500))" + "mzi?" ] }, { "cell_type": "markdown", + "id": "d2c74cff", "metadata": {}, "source": [ - "This sets the wavelength `wl` parameter for all subcomponents in the circuit." + "The `circuit` function just creates a similar function as we created for the waveguide and the coupler, but in stead of taking parameters directly it takes parameter *dictionaries* for each of the instances in the circuit. The keys in these parameter dictionaries should correspond to the keyword arguments of each individual subcomponent. " ] }, { "cell_type": "markdown", + "id": "ba1202d1", "metadata": {}, "source": [ - "Assume we're interested in simulating the `in1 -> out1` transmission. In this case our function of interest is given by the following:" + "Let's now do a simulation for the MZI we just constructed:" ] }, { "cell_type": "code", "execution_count": null, + "id": "f2108f6a", "metadata": {}, "outputs": [], "source": [ - "mzi_in1_out1 = mzi[\"in1\",\"out1\"]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can just-in-time (jit) compile this function for better performance:" + "%time mzi()" ] }, { "cell_type": "code", "execution_count": null, + "id": "cb43b835", "metadata": {}, "outputs": [], "source": [ - "mzi_in1_out1 = jax.jit(mzi[\"in1\", \"out1\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The first time you simulate, the function will be jitted and the simulation will be a bit slower:" + "mzi2 = jax.jit(mzi)" ] }, { "cell_type": "code", "execution_count": null, + "id": "61f2bb6e", "metadata": {}, "outputs": [], "source": [ - "%time detected = mzi_in1_out1(params)" + "%time mzi2()" ] }, { "cell_type": "markdown", + "id": "1e0d41ea", "metadata": {}, "source": [ - "The second time you simulate the simulation is really fast:" + "Or in the case we want an MZI with different arm lengths:" ] }, { "cell_type": "code", "execution_count": null, + "id": "4f53bdc3", "metadata": {}, "outputs": [], "source": [ - "%time detected = mzi_in1_out1(params)" + "mzi(top={\"length\": 25.0}, btm={\"length\": 15.0})" ] }, { "cell_type": "markdown", + "id": "a331252a", "metadata": {}, "source": [ - "Even if you change the parameters:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 500))\n", - "%time detected = mzi_in1_out1(params)" + "## Simulating the parametrized MZI" ] }, { "cell_type": "markdown", + "id": "4ba8626b", "metadata": {}, "source": [ - "**Unless the shape of one of the parameters changes**, then the model needs to be jit-compiled again" + "We can simulate the above mzi for multiple wavelengths as well by specifying the wavelength at the top level of the circuit call. Each setting specified at the top level of the circuit call will be propagated to all subcomponents of the circuit which have that setting:" ] }, { "cell_type": "code", "execution_count": null, + "id": "d6c8b5c9", "metadata": {}, "outputs": [], "source": [ - "params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 1000))\n", - "%time detected = mzi_in1_out1(params)" + "wl = jnp.linspace(1.51, 1.59, 1000)\n", + "%time S = mzi(wl=wl, top={\"length\": 25.0}, btm={\"length\": 15.0})" ] }, { "cell_type": "markdown", + "id": "28d8cc75", "metadata": {}, "source": [ - "Luckily, now both shapes yield fast computations (we don't lose the old jit-compiled model):" + "Let's see what this gives:" ] }, { "cell_type": "code", "execution_count": null, + "id": "f7de73dc", "metadata": {}, "outputs": [], "source": [ - "params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 500))\n", - "%time detected = mzi_in1_out1(params)\n", - "params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 1000))\n", - "%time detected = mzi_in1_out1(params)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Anyway, let's see what this gives:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.plot(params[\"top\"][\"wl\"], abs(detected)**2)\n", + "plt.plot(wl * 1e3, abs(S[\"in0\", \"out0\"]) ** 2)\n", "plt.ylim(-0.05, 1.05)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", @@ -486,6 +401,7 @@ }, { "cell_type": "markdown", + "id": "eb9f0074", "metadata": {}, "source": [ "## Optimization" @@ -493,6 +409,7 @@ }, { "cell_type": "markdown", + "id": "431aee51", "metadata": {}, "source": [ "We'd like to optimize an MZI such that one of the minima is at 1550nm. To do this, we need to define a loss function for the circuit at 1550nm. This function should take the parameters that you want to optimize as positional arguments:" @@ -501,29 +418,29 @@ { "cell_type": "code", "execution_count": null, + "id": "7dabe735", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def loss(delta_length):\n", - " params = sax.set_global_params(mzi[\"default_params\"], wl=1.55e-6)\n", - " params[\"top\"][\"length\"] = 1.5e-6 + delta_length\n", - " params[\"btm\"][\"length\"] = 1.5e-6\n", - " detected = mzi[\"in1\", \"out1\"](params)\n", - " return (abs(detected)**2).mean()" + " S = mzi(wl=1.55, top={\"length\": 15.0 + delta_length}, btm={\"length\": 15.0})\n", + " return (abs(S[\"in0\", \"out0\"]) ** 2).mean()" ] }, { "cell_type": "code", "execution_count": null, + "id": "a60a8eab", "metadata": {}, "outputs": [], "source": [ - "%time loss(10e-6)" + "%time loss(10.0)" ] }, { "cell_type": "markdown", + "id": "af1eb06b", "metadata": {}, "source": [ "We can use this loss function to define a grad function which works on the parameters of the loss function:" @@ -532,23 +449,21 @@ { "cell_type": "code", "execution_count": null, + "id": "16369cba", "metadata": {}, "outputs": [], "source": [ - "grad = jax.jit(jax.grad(loss))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%time grad(10e-6)" + "grad = jax.jit(\n", + " jax.grad(\n", + " loss,\n", + " argnums=0, # JAX gradient function for the first positional argument, jitted\n", + " )\n", + ")" ] }, { "cell_type": "markdown", + "id": "2cd2f2cf", "metadata": {}, "source": [ "Next, we need to define a JAX optimizer, which on its own is nothing more than three more functions: an initialization function with which to initialize the optimizer state, an update function which will update the optimizer state (and with it the model parameters). The third function that's being returned will give the model parameters given the optimizer state." @@ -557,16 +472,18 @@ { "cell_type": "code", "execution_count": null, + "id": "e553d69e", "metadata": {}, "outputs": [], "source": [ - "initial_delta_length = 10e-6\n", - "optim_init, optim_update, optim_params = opt.adam(step_size=1e-7)\n", + "initial_delta_length = 10.0\n", + "optim_init, optim_update, optim_params = opt.adam(step_size=0.1)\n", "optim_state = optim_init(initial_delta_length)" ] }, { "cell_type": "markdown", + "id": "e0a03fc2", "metadata": {}, "source": [ "Given all this, a single training step can be defined:" @@ -575,20 +492,21 @@ { "cell_type": "code", "execution_count": null, + "id": "eea4f7b8", "metadata": {}, "outputs": [], "source": [ - "@jax.jit\n", "def train_step(step, optim_state):\n", - " params = optim_params(optim_state)\n", - " lossvalue = loss(params)\n", - " gradvalue = grad(params)\n", + " settings = optim_params(optim_state)\n", + " lossvalue = loss(settings)\n", + " gradvalue = grad(settings)\n", " optim_state = optim_update(step, gradvalue, optim_state)\n", " return lossvalue, optim_state" ] }, { "cell_type": "markdown", + "id": "ba1612df", "metadata": {}, "source": [ "And we can use this step function to start the training of the MZI:" @@ -597,10 +515,11 @@ { "cell_type": "code", "execution_count": null, + "id": "61ef4479", "metadata": {}, "outputs": [], "source": [ - "range_ = tqdm.trange(1000)\n", + "range_ = tqdm.trange(300)\n", "for step in range_:\n", " lossvalue, optim_state = train_step(step, optim_state)\n", " range_.set_postfix(loss=f\"{lossvalue:.6f}\")" @@ -609,6 +528,7 @@ { "cell_type": "code", "execution_count": null, + "id": "eb9a1e20", "metadata": {}, "outputs": [], "source": [ @@ -618,6 +538,7 @@ }, { "cell_type": "markdown", + "id": "15d8754b", "metadata": {}, "source": [ "Let's see what we've got over a range of wavelengths:" @@ -626,23 +547,22 @@ { "cell_type": "code", "execution_count": null, + "id": "b36d10a3", "metadata": {}, "outputs": [], "source": [ - "params = sax.set_global_params(mzi[\"default_params\"], wl=1e-6*jnp.linspace(1.5, 1.6, 1000))\n", - "params[\"top\"][\"length\"] = 1.5e-5 + delta_length\n", - "params[\"btm\"][\"length\"] = 1.5e-5\n", - "detected = mzi[\"in1\", \"out1\"](params)\n", - "plt.plot(params[\"top\"][\"wl\"]*1e9, abs(detected)**2)\n", + "S = mzi(wl=wl, top={\"length\": 15.0 + delta_length}, btm={\"length\": 15.0})\n", + "plt.plot(wl * 1e3, abs(S[\"in1\", \"out1\"]) ** 2)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", "plt.ylim(-0.05, 1.05)\n", - "plt.plot([1550, 1550], [0,1])\n", + "plt.plot([1550, 1550], [0, 1])\n", "plt.show()" ] }, { "cell_type": "markdown", + "id": "a990da96", "metadata": {}, "source": [ "The minimum of the MZI is perfectly located at 1550nm." @@ -650,6 +570,7 @@ }, { "cell_type": "markdown", + "id": "7d93aa85", "metadata": {}, "source": [ "## MZI Chain" @@ -657,52 +578,53 @@ }, { "cell_type": "markdown", + "id": "9c312e13", "metadata": {}, "source": [ "Let's now create a chain of MZIs. For this, we first create a subcomponent: a directional coupler with arms:\n", "\n", "\n", "```\n", - " top\n", - " in ----- out -> out2\n", - " in2 <- p3 p2 \n", - " \\ dc / \n", - " ====== \n", - " / \\ \n", - " in1 <- p0 p1 btm \n", - " in ----- out -> out1\n", + " top\n", + " in0 ----- out0 -> out1\n", + " in1 <- in1 out1 \n", + " \\ dc / \n", + " ====== \n", + " / \\ \n", + " in0 <- in0 out0 btm \n", + " in0 ----- out0 -> out0\n", "```" ] }, { "cell_type": "code", "execution_count": null, + "id": "1ebe80c2", "metadata": {}, "outputs": [], "source": [ - "from sax.models.pic import directional_coupler, waveguide\n", - "\n", - "directional_coupler_with_arms = sax.circuit(\n", - " models = {\n", - " \"dc\": directional_coupler,\n", + "dc_with_arms = sax.circuit(\n", + " instances={\n", + " \"lft\": coupler,\n", " \"top\": waveguide,\n", " \"btm\": waveguide,\n", " },\n", " connections={\n", - " \"dc:p2\": \"top:in\",\n", - " \"dc:p1\": \"btm:in\",\n", + " \"lft,out0\": \"btm,in0\",\n", + " \"lft,out1\": \"top,in0\",\n", " },\n", " ports={\n", - " \"dc:p3\": \"in2\",\n", - " \"dc:p0\": \"in1\",\n", - " \"top:out\": \"out2\",\n", - " \"btm:out\": \"out1\",\n", + " \"in0\": \"lft,in0\",\n", + " \"in1\": \"lft,in1\",\n", + " \"out0\": \"btm,out0\",\n", + " \"out1\": \"top,out0\",\n", " },\n", ")" ] }, { "cell_type": "markdown", + "id": "1eda468f", "metadata": {}, "source": [ "An MZI chain can now be created by cascading these directional couplers with arms:\n", @@ -711,27 +633,30 @@ " _ _ _ _ _ _ \n", " \\/ \\/ \\/ \\/ ... \\/ \\/ \n", " /\\_ /\\_ /\\_ /\\_ /\\_ /\\_ \n", - "```" + "```\n", + "\n", + "Let's create a *model factory* (`ModelFactory`) for this. In SAX, a *model factory* is any keyword-only function that generates a `Model`:" ] }, { "cell_type": "code", "execution_count": null, + "id": "11c4ab07", "metadata": {}, "outputs": [], "source": [ - "def mzi_chain(num_mzis=1):\n", + "def mzi_chain(num_mzis=1) -> sax.Model:\n", " chain = sax.circuit(\n", - " models = {f\"dc{i}\": directional_coupler_with_arms for i in range(num_mzis+1)},\n", - " connections = {\n", - " **{f\"dc{i}:out1\":f\"dc{i+1}:in1\" for i in range(num_mzis)},\n", - " **{f\"dc{i}:out2\":f\"dc{i+1}:in2\" for i in range(num_mzis)},\n", + " instances={f\"dc{i}\": dc_with_arms for i in range(num_mzis + 1)},\n", + " connections={\n", + " **{f\"dc{i},out0\": f\"dc{i+1},in0\" for i in range(num_mzis)},\n", + " **{f\"dc{i},out1\": f\"dc{i+1},in1\" for i in range(num_mzis)},\n", " },\n", - " ports = {\n", - " \"dc0:in1\": \"in1\",\n", - " \"dc0:in2\": \"in2\",\n", - " f\"dc{num_mzis}:out1\": \"out1\",\n", - " f\"dc{num_mzis}:out2\": \"out2\",\n", + " ports={\n", + " \"in0\": f\"dc0,in0\",\n", + " \"in1\": f\"dc0,in1\",\n", + " \"out0\": f\"dc{num_mzis},out0\",\n", + " \"out1\": f\"dc{num_mzis},out1\",\n", " },\n", " )\n", " return chain" @@ -739,95 +664,66 @@ }, { "cell_type": "markdown", + "id": "03f225e9", "metadata": {}, "source": [ - "Let's for example create a chain with 15 MZIs:" + "Let's for example create a chain with 15 MZIs. We can also update the settings dictionary as follows:" ] }, { "cell_type": "code", "execution_count": null, + "id": "576d4de8", "metadata": {}, "outputs": [], "source": [ "chain = mzi_chain(num_mzis=15)\n", - "params = sax.copy_params(chain[\"default_params\"])\n", - "for dc in params:\n", - " params[dc][\"btm\"][\"length\"] = 1.5e-5\n", - "params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 1000))" + "settings = sax.get_settings(chain)\n", + "for dc in settings:\n", + " settings[dc][\"top\"][\"length\"] = 25.0\n", + " settings[dc][\"btm\"][\"length\"] = 15.0\n", + "settings = sax.update_settings(settings, wl=jnp.linspace(1.5, 1.6, 1000))" ] }, { "cell_type": "markdown", + "id": "11e3a10f", "metadata": {}, "source": [ - "We can simulate this again:" + "We can simulate this:" ] }, { "cell_type": "code", "execution_count": null, + "id": "b6e1155a", "metadata": {}, "outputs": [], "source": [ - "%time detected = chain[\"in1\", \"out1\"](params)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This takes a few seconds to simulate, so maybe it's worth jitting:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "chain_in1_out1 = jax.jit(chain[\"in1\", \"out1\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%time detected = chain_in1_out1(params)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Jit-compiling the function took even longer! However, after the jit-operation the simulation of the MZI chain becomes really fast:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%time detected = chain_in1_out1(params)" + "%time S = chain(**settings) # time to evaluate the MZI\n", + "func = jax.jit(chain)\n", + "%time S = func(**settings) # time to jit the MZI\n", + "%time S = func(**settings) # time to evaluate the MZI after jitting" ] }, { "cell_type": "markdown", + "id": "70db7179", "metadata": {}, "source": [ + "Where we see that the unjitted evaluation of the MZI chain takes about a second, while the jitting of the MZI chain takes about two minutes (on a CPU). However, after the MZI chain has been jitted, the evaluation is in the order of about a millisecond!\n", + "\n", "Anyway, let's see what this gives:" ] }, { "cell_type": "code", "execution_count": null, + "id": "1e4a2aab", "metadata": {}, "outputs": [], "source": [ - "plt.plot(1e9*params[\"dc0\"][\"top\"][\"wl\"], abs(detected)**2)\n", + "plt.plot(1e3 * settings[\"dc0\"][\"top\"][\"wl\"], jnp.abs(S[\"in0\", \"out0\"]) ** 2)\n", "plt.ylim(-0.05, 1.05)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", @@ -838,23 +734,11 @@ ], "metadata": { "kernelspec": { - "display_name": "", + "display_name": "sax", "language": "python", - "name": "" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.9" + "name": "sax" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/examples/02_all_pass_filter.ipynb b/examples/02_all_pass_filter.ipynb new file mode 100644 index 0000000..9ed42ea --- /dev/null +++ b/examples/02_all_pass_filter.ipynb @@ -0,0 +1,198 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a234aed9", + "metadata": {}, + "source": [ + "# Simulating an All-Pass Filter" + ] + }, + { + "cell_type": "markdown", + "id": "0b359fd3", + "metadata": {}, + "source": [ + "A simple comparison between an analytical evaluation of an all pass filter and using SAX." + ] + }, + { + "cell_type": "markdown", + "id": "4ca237fd", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8a6a944", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.example_libraries.optimizers as opt\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import sax\n", + "import tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "d9d87eef", + "metadata": {}, + "source": [ + "## Schematic\n", + "```\n", + "\n", + " in0---out0\n", + " in1 out1\n", + " \\ /\n", + " ========\n", + " / \\\n", + " in0 <- in0 out0 -> out0\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "3847a061", + "metadata": {}, + "source": [ + "## Simulation & Design Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03d546d8", + "metadata": {}, + "outputs": [], + "source": [ + "loss = 0.1 # [dB/μm] (alpha) waveguide loss\n", + "neff = 2.34 # Effective index of the waveguides\n", + "ng = 3.4 # Group index of the waveguides\n", + "wl0 = 1.55 # [μm] the wavelength at which neff and ng are defined\n", + "ring_length = 10.0 # [μm] Length of the ring\n", + "coupling = 0.5 # [] coupling of the coupler\n", + "wl = jnp.linspace(1.5, 1.6, 1000) # [μm] Wavelengths to sweep over" + ] + }, + { + "cell_type": "markdown", + "id": "8c463584", + "metadata": {}, + "source": [ + "## Frequency Domain Analytically" + ] + }, + { + "cell_type": "markdown", + "id": "9e171871", + "metadata": {}, + "source": [ + "As a comparison, we first calculate the frequency domain response for the all-pass filter analytically:\n", + "\\begin{align*}\n", + "o = \\frac{t-10^{-\\alpha L/20}\\exp(2\\pi j n_{\\rm eff}(\\lambda) L / \\lambda)}{1-t10^{-\\alpha L/20}\\exp(2\\pi j n_{\\rm eff}(\\lambda) L / \\lambda)}s\n", + "\\end{align*}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48031239", + "metadata": {}, + "outputs": [], + "source": [ + "def all_pass_analytical():\n", + " \"\"\" Analytic Frequency Domain Response of an all pass filter \"\"\"\n", + " detected = jnp.zeros_like(wl)\n", + " transmission = 1 - coupling\n", + " neff_wl = neff + (wl0 - wl) * (ng - neff) / wl0 # we expect a linear behavior with respect to wavelength\n", + " out = jnp.sqrt(transmission) - 10 ** (-loss * ring_length / 20.0) * jnp.exp(2j * jnp.pi * neff_wl * ring_length / wl)\n", + " out /= 1 - jnp.sqrt(transmission) * 10 ** (-loss * ring_length / 20.0) * jnp.exp(2j * jnp.pi * neff_wl * ring_length / wl)\n", + " detected = abs(out) ** 2\n", + " return detected" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb39b703", + "metadata": {}, + "outputs": [], + "source": [ + "%time detected = all_pass_analytical() # non-jitted evaluation time\n", + "all_pass_analytical_jitted = jax.jit(all_pass_analytical)\n", + "%time detected = all_pass_analytical_jitted() # time to jit\n", + "%time detected = all_pass_analytical_jitted() # evaluation time after jitting\n", + "\n", + "plt.plot(wl * 1e3, detected)\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"T\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "11219941", + "metadata": {}, + "source": [ + "## Scatter Dictionaries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5957b3f3", + "metadata": {}, + "outputs": [], + "source": [ + "all_pass_sax = sax.circuit(\n", + " instances={\n", + " \"dc\": sax.partial(sax.models.coupler, coupling=coupling),\n", + " \"top\": sax.partial(sax.models.straight, length=ring_length, loss=loss, neff=neff, ng=ng, wl0=wl0, wl=wl),\n", + " },\n", + " connections={\n", + " \"dc,out1\": \"top,in0\",\n", + " \"top,out0\": \"dc,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"dc,in0\",\n", + " \"out0\": \"dc,out0\",\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1702182a", + "metadata": {}, + "outputs": [], + "source": [ + "%time detected_sax = all_pass_sax() # non-jitted evaluation time\n", + "all_pass_sax_jitted = jax.jit(all_pass_analytical)\n", + "%time detected_sax = all_pass_sax_jitted() # time to jit\n", + "%time detected_sax = all_pass_sax_jitted() # time after jitting\n", + "\n", + "plt.plot(wl * 1e3, detected, label=\"analytical\")\n", + "plt.plot(wl * 1e3, detected_sax, label=\"sax\", ls=\"--\", lw=3)\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"T\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/02_thinfilm.ipynb b/examples/02_thinfilm.ipynb deleted file mode 100644 index e33dc63..0000000 --- a/examples/02_thinfilm.ipynb +++ /dev/null @@ -1,942 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Thin film optimization and wavelength-dependent parameters\n", - "\n", - "In this notebook, we apply SAX to thin-film optimization and show how it can be used for wavelength-dependent parameter optimization.\n", - "\n", - "The language of transfer/scatter matrices is commonly used to calculate optical properties of thin-films. Many specialized methods exist for their optimization. However, SAX can be useful to cut down on developer time by circumventing the need to manually take gradients of complicated or often-changed objective functions, and by generating efficient code from simple syntax. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import tqdm\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# GPU setup\n", - "#%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import jax.experimental.optimizers as opt\n", - "\n", - "# sax circuit simulator\n", - "import sax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Using GPU?\n", - "from jax.lib import xla_bridge \n", - "print(xla_bridge.get_backend().platform)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dielectric mirror Fabry-Pérot\n", - "\n", - "Consider a stack composed of only two materials, $n_A$ and $n_B$. Two types of transfer matrices characterize wave propagation in the system : interfaces described by Fresnel's equations, and propagation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Reflection at i-->j interface\n", - "def r_fresnel_ij(params):\n", - " \"\"\"\n", - " Normal incidence amplitude reflection from Fresnel's equations\n", - " ni : refractive index of the initial medium\n", - " nf : refractive index of the final\n", - " \"\"\"\n", - " return (params[\"ni\"] - params[\"nj\"]) / (params[\"ni\"] + params[\"nj\"])\n", - "\n", - "# Transmission at i-->j interface\n", - "def t_fresnel_ij(params):\n", - " \"\"\"\n", - " Normal incidence amplitude transmission from Fresnel's equations\n", - " \"\"\"\n", - " return 2 * params[\"ni\"] / (params[\"ni\"] + params[\"nj\"])\n", - "\n", - "# Propagation through medium A\n", - "def prop_i(params):\n", - " \"\"\"\n", - " Phase shift acquired as a wave propagates through medium A\n", - " wl : wavelength (arb. units)\n", - " ni : refractive index of medium (at wavelength wl)\n", - " di : thickness of layer (same arb. unit as wl)\n", - " \"\"\"\n", - " return jnp.exp(1j * 2*jnp.pi * params[\"ni\"] / params[\"wl\"] * params[\"di\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For the two-material stack, this leads to 4 scatter matrices coefficients. Through reciprocity they can be constructed out of two independent ones :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Arbitrary default parameters\n", - "fresnel_mirror_ij = {\n", - " (\"in\", \"in\"): r_fresnel_ij,\n", - " (\"in\", \"out\"): t_fresnel_ij,\n", - " (\"out\", \"in\"): lambda params: (1 - r_fresnel_ij(params)**2)/t_fresnel_ij(params), # t_ji,\n", - " (\"out\", \"out\"): lambda params: -1*r_fresnel_ij(params), # r_ji,\n", - " \"default_params\": {\n", - " \"ni\": 1.,\n", - " \"nj\": 1.,\n", - " \"wl\": 532.,\n", - " }\n", - "}\n", - "\n", - "propagation_i = {\n", - " (\"in\", \"out\"): prop_i,\n", - " (\"out\", \"in\"): prop_i,\n", - " \"default_params\": {\n", - " \"ni\": 1.,\n", - " \"di\": 500.,\n", - " \"wl\": 532.,\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A resonant cavity can be formed when a high index region is surrounded by low-index region :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dielectric_fabry_perot = sax.circuit(\n", - " models = {\n", - " \"air-B\": fresnel_mirror_ij,\n", - " \"B\": propagation_i,\n", - " \"B-air\": fresnel_mirror_ij,\n", - " },\n", - " connections={\n", - " \"air-B:out\": \"B:in\",\n", - " \"B:out\": \"B-air:in\",\n", - " },\n", - " ports={\n", - " \"air-B:in\": \"in\",\n", - " \"B-air:out\": \"out\",\n", - " },\n", - ")\n", - "\n", - "params = dielectric_fabry_perot[\"default_params\"]\n", - "params" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's choose $n_A = 1$, $n_B = 2$, $d_B = 1000$ nm, and compute over the visible spectrum :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "params = sax.copy_params(params)\n", - "params[\"air-B\"][\"nj\"] = 2.\n", - "params[\"B\"][\"ni\"] = 2.\n", - "params[\"B-air\"][\"ni\"] = 2.\n", - "\n", - "wls = jnp.linspace(380, 750, 200)\n", - "params = sax.set_global_params(params, wl=wls)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Compute transmission and reflection, and compare to another package's results (https://github.com/sbyrnes321/tmm) :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fabry_perot_transmitted = dielectric_fabry_perot[\"in\",\"out\"]\n", - "fabry_perot_transmitted = jax.jit(dielectric_fabry_perot[\"in\", \"out\"])\n", - "transmitted = fabry_perot_transmitted(params)\n", - "\n", - "fabry_perot_reflected = dielectric_fabry_perot[\"in\",\"in\"]\n", - "fabry_perot_reflected = jax.jit(dielectric_fabry_perot[\"in\", \"in\"])\n", - "reflected = fabry_perot_reflected(params)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# import sys\n", - "# !$sys.executable -m pip install tmm\n", - "from tmm import coh_tmm\n", - "\n", - "# tmm syntax (https://github.com/sbyrnes321/tmm)\n", - "d_list = [jnp.inf,500,jnp.inf]\n", - "n_list = [1,2,1]\n", - "# initialize lists of y-values to plot\n", - "rnorm=[]\n", - "tnorm = []\n", - "Tnorm = []\n", - "Rnorm = []\n", - "for l in wls:\n", - " rnorm.append(coh_tmm('s',n_list, d_list, 0, l)['r'])\n", - " tnorm.append(coh_tmm('s',n_list, d_list, 0, l)['t'])\n", - " Tnorm.append(coh_tmm('s',n_list, d_list, 0, l)['T'])\n", - " Rnorm.append(coh_tmm('s',n_list, d_list, 0, l)['R'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.scatter(wls, jnp.real(transmitted), label='t SAX')\n", - "plt.plot(wls, jnp.real(jnp.array(tnorm)), 'k', label='t tmm')\n", - "plt.scatter(wls, jnp.real(reflected), label='r SAX')\n", - "plt.plot(wls, jnp.real(jnp.array(rnorm)), 'k--', label='r tmm')\n", - "plt.xlabel(\"λ [nm]\")\n", - "plt.ylabel(\"Transmitted and reflected amplitude\")\n", - "plt.legend(loc=\"upper right\")\n", - "plt.title(\"Real part\")\n", - "plt.show()\n", - "\n", - "plt.scatter(wls, jnp.imag(transmitted), label='t SAX')\n", - "plt.plot(wls, jnp.imag(jnp.array(tnorm)), 'k', label='t tmm')\n", - "plt.scatter(wls, jnp.imag(reflected), label='r SAX')\n", - "plt.plot(wls, jnp.imag(jnp.array(rnorm)), 'k--', label='r tmm')\n", - "plt.xlabel(\"λ [nm]\")\n", - "plt.ylabel(\"Transmitted and reflected amplitude\")\n", - "plt.legend(loc=\"upper right\")\n", - "plt.title(\"Imaginary part\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In terms of powers, we get the following. Due to the reflections at the interfaces, resonant behaviour is observed, with evenly-spaced maxima/minima in wavevector space :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.scatter(2*jnp.pi/wls, jnp.abs(transmitted)**2, label='T SAX')\n", - "plt.plot(2*jnp.pi/wls, Tnorm, 'k', label='T tmm')\n", - "plt.scatter(2*jnp.pi/wls, jnp.abs(reflected)**2, label='R SAX')\n", - "plt.plot(2*jnp.pi/wls, Rnorm, 'k--', label='R tmm')\n", - "plt.vlines(jnp.arange(3,6)*jnp.pi/(2*500), ymin=0, ymax=1, color='k', linestyle='--', label='m$\\pi$/nd')\n", - "plt.xlabel(\"k = 2$\\pi$/λ [nm]\")\n", - "plt.ylabel(\"Transmitted and reflected intensities\")\n", - "plt.legend(loc=\"upper right\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Optimization test\n", - "\n", - "Let's attempt to minimize transmission at 500 nm by varying thickness." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def loss(thickness):\n", - " params = sax.set_global_params(dielectric_fabry_perot[\"default_params\"], wl=500.)\n", - " params[\"B\"][\"di\"] = thickness\n", - " params[\"air-B\"][\"nj\"] = 2.\n", - " params[\"B\"][\"ni\"] = 2.\n", - " params[\"B-air\"][\"ni\"] = 2.\n", - " detected = dielectric_fabry_perot[\"in\", \"out\"](params)\n", - " return abs(detected)**2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%time loss(500.)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "grad = jax.jit(jax.grad(loss))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%time grad(500.)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "initial_thickness = 500.\n", - "optim_init, optim_update, optim_params = opt.adam(step_size=10)\n", - "optim_state = optim_init(initial_thickness)\n", - "\n", - "@jax.jit\n", - "def train_step(step, optim_state):\n", - " params = optim_params(optim_state)\n", - " lossvalue = loss(params)\n", - " gradvalue = grad(params)\n", - " optim_state = optim_update(step, gradvalue, optim_state)\n", - " return lossvalue, optim_state" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "range_ = tqdm.trange(1000)\n", - "for step in range_:\n", - " lossvalue, optim_state = train_step(step, optim_state)\n", - " range_.set_postfix(loss=f\"{lossvalue:.6f}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "thickness = optim_params(optim_state)\n", - "thickness" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "params = sax.set_global_params(dielectric_fabry_perot[\"default_params\"], wl=wls)\n", - "params[\"B\"][\"di\"] = thickness\n", - "params[\"air-B\"][\"nj\"] = 2.\n", - "params[\"B\"][\"ni\"] = 2.\n", - "params[\"B-air\"][\"ni\"] = 2.\n", - "detected = dielectric_fabry_perot[\"in\", \"out\"](params)\n", - "\n", - "plt.plot(wls, jnp.abs(transmitted)**2, label='Before (500 nm)')\n", - "plt.plot(wls, jnp.abs(detected)**2, label=\"After ({} nm)\".format(thickness))\n", - "plt.vlines(500, 0.6, 1, 'k', linestyle='--')\n", - "plt.xlabel(\"λ [nm]\")\n", - "plt.ylabel(\"Transmitted intensity\")\n", - "plt.legend(loc=\"lower right\")\n", - "plt.title(\"Thickness optimization\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## General Fabry-Pérot étalon\n", - "\n", - "We reuse the propagation matrix above, and instead of simple interface matrices, model Fabry-Pérot mirrors as general lossless reciprocal scatter matrices :\n", - "\n", - "$$ \\left(\\begin{array}{c} \n", - "E_t \\\\\n", - "E_r\n", - "\\end{array}\\right) = E_{out} = SE_{in} = \\left(\\begin{array}{cc} \n", - "t & r \\\\\n", - "r & t\n", - "\\end{array}\\right) \\left(\\begin{array}{c} \n", - "E_0 \\\\\n", - "0\n", - "\\end{array}\\right) $$\n", - "\n", - "For lossless reciprocal systems, we further have the requirements\n", - "\n", - "$$ |t|^2 + |r|^2 = 1 $$\n", - "\n", - "and\n", - "\n", - "$$ \\angle t - \\angle r = \\pm \\pi/2 $$\n", - "\n", - "The general Fabry-Pérot cavity is analytically described by :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# General Fabry-Pérot transfer function (Airy formulas)\n", - "def airy_t13(t12, t23, r21, r23, wl, d=1000., n=1.):\n", - " '''\n", - " Assumptions \n", - " Each mirror lossless, reciprocal : tij = tji, rij = rji\n", - " \n", - " Inputs\n", - " t12 and r12 : S-parameters of the first mirror\n", - " t23 and r23 : S-parameters of the second mirror\n", - " wl : wavelength\n", - " d : gap between the two mirrors (in units of wavelength)\n", - " n : index of the gap between the two mirrors\n", - " \n", - " Returns\n", - " t13 : complex transmission amplitude of the mirror-gap-mirror system\n", - " '''\n", - " # Assume each mirror lossless, reciprocal : tij = tji, rij = rji\n", - " phi = n*2*jnp.pi/wl*d\n", - " return t12*t23*jnp.exp(-1j*phi)/( 1 - r21*r23*jnp.exp(-2j*phi) )\n", - "\n", - "def airy_r13(t12, t23, r21, r23, wl, d=1000., n=1.):\n", - " '''\n", - " Assumptions, inputs : see airy_t13\n", - " \n", - " Returns\n", - " r13 : complex reflection amplitude of the mirror-gap-mirror system\n", - " '''\n", - " phi = n*2*jnp.pi/wl*d\n", - " return r21 + t12*t12*r23*jnp.exp(-2j*phi)/( 1 - r21*r23*jnp.exp(-2j*phi) )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We need to implement the relationship between $t$ and $r$ for lossless reciprocal mirrors. The design parameter will be the amplitude and phase of the tranmission coefficient. The reflection coefficient is then fully determined :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def t_complex(params):\n", - " # Transmission coefficient (design parameter)\n", - " return params['t_amp']*jnp.exp(-1j*params['t_ang'])\n", - "\n", - "def r_complex(params):\n", - " # Reflection coefficient, derived from transmission coefficient\n", - " # Magnitude from |t|^2 + |r|^2 = 1\n", - " # Phase from phase(t) - phase(r) = pi/2\n", - " r_amp = jnp.sqrt( ( 1. - params['t_amp']**2 ) )\n", - " r_ang = params['t_ang'] - jnp.pi/2\n", - " return r_amp*jnp.exp(-1j*r_ang)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's see the expected result for half-mirrors :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t_initial = jnp.sqrt(0.5)\n", - "d_gap = 2000.\n", - "n_gap = 1.\n", - "params_analytical_test = {\"t_amp\": t_initial, \"t_ang\": 0.0}\n", - "r_initial = r_complex(params_analytical_test)\n", - "\n", - "wls = jnp.linspace(380, 780, 500)\n", - "\n", - "T_analytical_initial = jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))**2\n", - "R_analytical_initial = jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))**2\n", - "\n", - "plt.title(f't={t_initial:1.3f}, d={d_gap} nm, n={n_gap}')\n", - "plt.plot(2*jnp.pi/wls, T_analytical_initial, label='T')\n", - "plt.plot(2*jnp.pi/wls, R_analytical_initial, label='R')\n", - "plt.vlines(jnp.arange(6,11)*jnp.pi/2000, ymin=0, ymax=1, color='k', linestyle='--', label='m$\\pi$/nd')\n", - "plt.xlabel('k = 2$\\pi$/$\\lambda$ (/nm)')\n", - "plt.ylabel('Power (units of input)')\n", - "plt.legend()\n", - "plt.show()\n", - "\n", - "plt.title(f't={t_initial:1.3f}, d={d_gap} nm, n={n_gap}')\n", - "plt.plot(wls, T_analytical_initial, label='T')\n", - "plt.plot(wls, R_analytical_initial, label='R')\n", - "plt.xlabel('$\\lambda$ (nm)')\n", - "plt.ylabel('Power (units of input)')\n", - "plt.legend()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Is power conserved? (to within 0.1%)\n", - "assert jnp.isclose(R_analytical_initial + T_analytical_initial, 1, 0.001).all()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's do the same with SAX by defining new elements :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mirror = {\n", - " (\"in\", \"in\"): r_complex,\n", - " (\"in\", \"out\"): t_complex,\n", - " (\"out\", \"in\"): t_complex, # lambda params: (1 - r_complex(params)**2)/t_complex(params), # t_ji,\n", - " (\"out\", \"out\"): r_complex, # lambda params: -1*r_complex(params), # r_ji,\n", - " \"default_params\": {\n", - " \"t_amp\": jnp.sqrt(0.5),\n", - " \"t_ang\": 0.0,\n", - " }\n", - "}\n", - "\n", - "fabry_perot_tunable = sax.circuit(\n", - " models = {\n", - " \"mirror1\": mirror,\n", - " \"gap\": propagation_i,\n", - " \"mirror2\": mirror,\n", - " },\n", - " connections={\n", - " \"mirror1:out\": \"gap:in\",\n", - " \"gap:out\": \"mirror2:in\",\n", - " },\n", - " ports={\n", - " \"mirror1:in\": \"in\",\n", - " \"mirror2:out\": \"out\",\n", - " },\n", - ")\n", - "\n", - "params = fabry_perot_tunable[\"default_params\"]\n", - "params = sax.copy_params(params)\n", - "params" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fabry_perot_tunable = sax.circuit(\n", - " models = {\n", - " \"mirror1\": mirror,\n", - " \"gap\": propagation_i,\n", - " \"mirror2\": mirror,\n", - " },\n", - " connections={\n", - " \"mirror1:out\": \"gap:in\",\n", - " \"gap:out\": \"mirror2:in\",\n", - " },\n", - " ports={\n", - " \"mirror1:in\": \"in\",\n", - " \"mirror2:out\": \"out\",\n", - " },\n", - ")\n", - "\n", - "params = fabry_perot_tunable[\"default_params\"]\n", - "params = sax.copy_params(params)\n", - "params" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "N = 100\n", - "wls = jnp.linspace(380, 780, N)\n", - "params = sax.copy_params(fabry_perot_tunable[\"default_params\"])\n", - "params = sax.set_global_params(params, wl=wls)\n", - "params = sax.set_global_params(params, t_amp=jnp.sqrt(0.5))\n", - "params = sax.set_global_params(params, t_ang=0.0)\n", - "params[\"gap\"][\"ni\"] = 1.\n", - "params[\"gap\"][\"di\"] = 2000.\n", - "transmitted_initial = fabry_perot_tunable[\"in\",\"out\"](params)\n", - "reflected_initial = fabry_perot_tunable[\"out\",\"out\"](params)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "T_analytical_initial = jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))**2\n", - "R_analytical_initial = jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))**2\n", - "\n", - "plt.title(f't={t_initial:1.3f}, d={d_gap} nm, n={n_gap}')\n", - "plt.plot(wls, T_analytical_initial, label='T theory')\n", - "plt.scatter(wls, jnp.abs(transmitted_initial)**2, label='T SAX')\n", - "plt.plot(wls, R_analytical_initial, label='R theory')\n", - "plt.scatter(wls, jnp.abs(reflected_initial)**2, label='R SAX')\n", - "#plt.vlines(jnp.arange(6,11)*jnp.pi/2000, ymin=0, ymax=1, color='k', linestyle='--', label='m$\\pi$/nd')\n", - "plt.xlabel('k = 2$\\pi$/$\\lambda$ (/nm)')\n", - "plt.ylabel('Power (units of input)')\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Wavelength-dependent Fabry-Pérot étalon\n", - "\n", - "Let's repeat with a model where parameters can be wavelength-dependent. To comply with the optimizer object, we will stack all design parameters in a single array :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ts_initial = jnp.zeros(2*N)\n", - "ts_initial = jax.ops.index_update(ts_initial, jax.ops.index[0:N], jnp.sqrt(0.5))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will simply loop over all wavelengths, and use different $t$ parameters at each wavelength." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wls = jnp.linspace(380, 780, N)\n", - "transmitted = jnp.zeros_like(wls)\n", - "reflected = jnp.zeros_like(wls)\n", - "\n", - "for i in range(N):\n", - " # Update parameters\n", - " params = sax.copy_params(fabry_perot_tunable[\"default_params\"])\n", - " params = sax.set_global_params(params, wl=wls[i])\n", - " params = sax.set_global_params(params, t_amp=ts_initial[i])\n", - " params = sax.set_global_params(params, t_ang=ts_initial[N+i])\n", - " params[\"gap\"][\"ni\"] = 1.\n", - " params[\"gap\"][\"di\"] = 2000.\n", - " # Perform computation\n", - " transmission_i = fabry_perot_tunable[\"in\",\"out\"](params)\n", - " transmitted = jax.ops.index_update(transmitted, jax.ops.index[i], jnp.abs(transmission_i)**2)\n", - " reflected_i = fabry_perot_tunable[\"in\",\"in\"](params)\n", - " reflected = jax.ops.index_update(reflected, jax.ops.index[i], jnp.abs(reflected_i)**2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.plot(wls, T_analytical_initial, label='T theory')\n", - "plt.scatter(wls, transmitted, label='T SAX')\n", - "plt.plot(wls, R_analytical_initial, label='R theory')\n", - "plt.scatter(wls, reflected, label='R SAX')\n", - "plt.xlabel(\"λ [nm]\")\n", - "plt.ylabel(\"Transmitted and reflected intensities\")\n", - "plt.legend(loc=\"upper right\")\n", - "plt.title(f't={t_initial:1.3f}, d={d_gap} nm, n={n_gap}')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Since it seems to work, let's add a target and optimize some harmonics away :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def lorentzian(l0, dl, wl, A):\n", - " return A/((wl - l0)**2 + (0.5*dl)**2)\n", - "\n", - "target = lorentzian(533, 20, wls, 100)\n", - "\n", - "plt.scatter(wls, transmitted, label='T SAX')\n", - "plt.scatter(wls, reflected, label='R SAX')\n", - "plt.plot(wls, target, 'r', linewidth=2, label='target')\n", - "plt.xlabel(\"λ [nm]\")\n", - "plt.ylabel(\"Transmitted and reflected intensities\")\n", - "plt.legend(loc=\"upper right\")\n", - "plt.title(f't={t_initial:1.3f}, d={d_gap} nm, n={n_gap}')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Defining the loss as the mean squared error between transmission and target transmission. Note that we can use JAX's looping functions to great speedups after compilation here :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def loss(ts):\n", - "\n", - " N = len(ts[::2])\n", - " wls = jnp.linspace(380, 780, N)\n", - " transmitted = jnp.zeros_like(wls)\n", - " target = lorentzian(533, 20, wls, 100)\n", - " \n", - " def inner_loop(transmitted, i):\n", - " # Update parameters\n", - " params = sax.copy_params(fabry_perot_tunable[\"default_params\"])\n", - " params = sax.set_global_params(params, wl=wls[i])\n", - " params = sax.set_global_params(params, t_amp=ts[i])\n", - " params = sax.set_global_params(params, t_ang=ts[N+i])\n", - " params[\"gap\"][\"ni\"] = 1.\n", - " params[\"gap\"][\"di\"] = 2000.\n", - " # Perform computation\n", - " transmission_i = fabry_perot_tunable[\"in\",\"out\"](params)\n", - " transmitted = jax.ops.index_update(transmitted, jax.ops.index[i], jnp.abs(transmission_i)**2)\n", - " return transmitted, i\n", - "\n", - " transmitted, _ = jax.lax.scan(inner_loop, transmitted, jnp.arange(N, dtype=jnp.int32))\n", - " \n", - " return (jnp.abs(transmitted - target)**2).mean()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "grad = jax.jit(jax.grad(loss))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optim_init, optim_update, optim_params = opt.adam(step_size=0.001)\n", - "\n", - "@jax.jit\n", - "def train_step(step, optim_state):\n", - " params = optim_params(optim_state)\n", - " lossvalue = loss(params)\n", - " gradvalue = grad(params)\n", - " optim_state = optim_update(step, gradvalue, optim_state)\n", - " return lossvalue, optim_state" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "range_ = tqdm.trange(2000)\n", - "\n", - "optim_state = optim_init(ts_initial)\n", - "for step in range_:\n", - " lossvalue, optim_state = train_step(step, optim_state)\n", - " range_.set_postfix(loss=f\"{lossvalue:.6f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The optimized parameters are now wavelength-dependent :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ts_optimal = optim_params(optim_state)\n", - "\n", - "plt.scatter(wls, ts_initial[:N], label='t initial')\n", - "plt.scatter(wls, ts_optimal[:N], label='t optimal')\n", - "plt.xlabel(\"λ [nm]\")\n", - "plt.ylabel(\"|t| $(\\lambda)$\")\n", - "plt.legend(loc=\"best\")\n", - "plt.title(f'd={d_gap} nm, n={n_gap}')\n", - "plt.show()\n", - "\n", - "plt.scatter(wls, ts_initial[N:], label='t initial')\n", - "plt.scatter(wls, ts_optimal[N:], label='t optimal')\n", - "plt.xlabel(\"λ [nm]\")\n", - "plt.ylabel(\"angle $t (\\lambda)$ (rad)\")\n", - "plt.legend(loc=\"best\")\n", - "plt.title(f'd={d_gap} nm, n={n_gap}')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Visualizing the result :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wls = jnp.linspace(380, 780, N)\n", - "transmitted_optimal = jnp.zeros_like(wls)\n", - "reflected_optimal = jnp.zeros_like(wls)\n", - "\n", - "for i in range(N):\n", - " # Update parameters\n", - " params = sax.copy_params(fabry_perot_tunable[\"default_params\"])\n", - " params = sax.set_global_params(params, wl=wls[i])\n", - " params = sax.set_global_params(params, t_amp=ts_optimal[i])\n", - " params = sax.set_global_params(params, t_ang=ts_optimal[N+i])\n", - " params[\"gap\"][\"ni\"] = 1.\n", - " params[\"gap\"][\"di\"] = 2000.\n", - " # Perform computation\n", - " transmission_i = fabry_perot_tunable[\"in\",\"out\"](params)\n", - " transmitted_optimal = jax.ops.index_update(transmitted_optimal, jax.ops.index[i], jnp.abs(transmission_i)**2)\n", - " reflected_i = fabry_perot_tunable[\"in\",\"in\"](params)\n", - " reflected_optimal = jax.ops.index_update(reflected_optimal, jax.ops.index[i], jnp.abs(reflected_i)**2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.scatter(wls, transmitted_optimal, label='T')\n", - "plt.scatter(wls, reflected_optimal, label='R')\n", - "plt.plot(wls, lorentzian(533, 20, wls, 100), 'r', label='target')\n", - "plt.xlabel(\"λ [nm]\")\n", - "plt.ylabel(\"Transmitted and reflected intensities\")\n", - "plt.legend(loc=\"upper right\")\n", - "plt.title(f'Optimized t($\\lambda$), d={d_gap} nm, n={n_gap}')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The hard part is now to find physical stacks that physically implement $t(\\lambda)$. However, the ease with which we can modify and complexify the loss function opens opportunities for regularization and more complicated objective functions.\n", - "\n", - "The models above are available in models.thin_film, and can straightforwardly be extended to propagation at an angle, s and p polarizations, nonreciprocal systems, and systems with losses." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "", - "language": "python", - "name": "" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.9" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/03_circuit_from_yaml.ipynb b/examples/03_circuit_from_yaml.ipynb new file mode 100644 index 0000000..9093be3 --- /dev/null +++ b/examples/03_circuit_from_yaml.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "competitive-seating", + "metadata": {}, + "source": [ + "# Circuit from YAML\n", + "Sometimes it's useful to be able to define circuits from YAML definitions. To not re-invent the wheel, SAX uses [GDSFactory](https://gdsfactory.readthedocs.io/en/latest/yaml.html)'s YAML netlist spec to define its circuits. This makes it very easy to convert a GDSFactory layout to a SAX circuit model!" + ] + }, + { + "cell_type": "markdown", + "id": "narrative-poverty", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "personalized-recipe", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import sax" + ] + }, + { + "cell_type": "markdown", + "id": "featured-liberal", + "metadata": {}, + "source": [ + "## MZI" + ] + }, + { + "cell_type": "markdown", + "id": "collectible-feedback", + "metadata": {}, + "source": [ + "Let's first see how we can define a SAX circuit from YAML:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fixed-hurricane", + "metadata": {}, + "outputs": [], + "source": [ + "netlist = \"\"\"\n", + "instances:\n", + " lft:\n", + " component: coupler\n", + " settings:\n", + " coupling: 0.5\n", + " rgt:\n", + " component: coupler\n", + " settings:\n", + " coupling: 0.5\n", + " top:\n", + " component: straight\n", + " settings:\n", + " length: 25.0\n", + " btm:\n", + " component: straight\n", + " settings:\n", + " length: 15.0\n", + "\n", + "connections:\n", + " lft,out0: btm,in0\n", + " btm,out0: rgt,in0\n", + " lft,out1: top,in0\n", + " top,out0: rgt,in1\n", + " \n", + "ports:\n", + " in0: lft,in0\n", + " in1: lft,in1\n", + " out0: rgt,out0\n", + " out1: rgt,out1\n", + " \n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "imposed-treaty", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = sax.circuit_from_yaml(netlist)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dynamic-commonwealth", + "metadata": {}, + "outputs": [], + "source": [ + "wl = jnp.linspace(1.5, 1.6, 1000)\n", + "transmission = jnp.abs(mzi(wl=wl)[\"in0\", \"out0\"]) ** 2\n", + "\n", + "plt.plot(wl * 1e3, transmission)\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"T\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "alpha-married", + "metadata": {}, + "source": [ + "That was easy! However, during the above YAML conversion, only models available in `sax.models` were used. What if we want to map the YAML component names to custom models? Let's say we want to use a dispersionless waveguide for the above model for example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "important-association", + "metadata": {}, + "outputs": [], + "source": [ + "def waveguide_without_dispersion(wl=1.55, length=25.0, neff=2.34):\n", + " phase = 2 * jnp.pi * neff * length / wl\n", + " sdict = sax.reciprocal({(\"in0\", \"out0\"): jnp.exp(1j * phase)})\n", + " return sdict" + ] + }, + { + "cell_type": "markdown", + "id": "valuable-candidate", + "metadata": {}, + "source": [ + "We can regenerate the above circuit again, but this time we specify a models mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "former-retro", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = sax.circuit_from_yaml(netlist, models={\"straight\": waveguide_without_dispersion})" + ] + }, + { + "cell_type": "markdown", + "id": "focal-question", + "metadata": {}, + "source": [ + "> The `models=` keyword in `circuit_from_yaml` can be a dictionary **or** an imported python module (like for example `sax.models`). Or a list containing multiple of such dictionary mappings and imported modules." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "expired-panel", + "metadata": {}, + "outputs": [], + "source": [ + "wl = jnp.linspace(1.5, 1.6, 1000)\n", + "transmission = jnp.abs(mzi(wl=wl)[\"in0\", \"out0\"]) ** 2\n", + "\n", + "plt.plot(wl, transmission)\n", + "plt.xlabel(\"Wavelength [nm]\")\n", + "plt.ylabel(\"T\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/04_multimode_simulations.ipynb b/examples/04_multimode_simulations.ipynb new file mode 100644 index 0000000..dc06423 --- /dev/null +++ b/examples/04_multimode_simulations.ipynb @@ -0,0 +1,246 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "textile-panel", + "metadata": {}, + "source": [ + "# Multimode simulations" + ] + }, + { + "cell_type": "markdown", + "id": "latin-strike", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "prescribed-plant", + "metadata": {}, + "outputs": [], + "source": [ + "from itertools import combinations, combinations_with_replacement, product\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import sax" + ] + }, + { + "cell_type": "markdown", + "id": "requested-calvin", + "metadata": {}, + "source": [ + "## Ports and modes per port" + ] + }, + { + "cell_type": "markdown", + "id": "grateful-programming", + "metadata": {}, + "source": [ + "Let's denote a combination of a port and a mode by a string of the following format: `\"{port}@{mode}\"`. We can obtain all possible port-mode combinations with some magic itertools functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "going-entity", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "ports = [\"in0\", \"out0\"]\n", + "modes = [\"te\", \"tm\"]\n", + "portmodes = [\n", + " (f\"{p1}@{m1}\", f\"{p2}@{m2}\")\n", + " for (p1, m1), (p2, m2) in combinations_with_replacement(product(ports, modes), 2)\n", + "]\n", + "portmodes" + ] + }, + { + "cell_type": "markdown", + "id": "passive-selling", + "metadata": {}, + "source": [ + "If we would disregard any backreflection, this can be further simplified:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "numerical-novel", + "metadata": {}, + "outputs": [], + "source": [ + "portmodes_without_backreflection = [\n", + " (p1, p2) for p1, p2 in portmodes if p1.split(\"@\")[0] != p2.split(\"@\")[0]\n", + "]\n", + "portmodes_without_backreflection" + ] + }, + { + "cell_type": "markdown", + "id": "hindu-application", + "metadata": {}, + "source": [ + "Sometimes cross-polarization terms can also be ignored:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "encouraging-territory", + "metadata": {}, + "outputs": [], + "source": [ + "portmodes_without_crosspolarization = [\n", + " (p1, p2) for p1, p2 in portmodes if p1.split(\"@\")[1] == p2.split(\"@\")[1]\n", + "]\n", + "portmodes_without_crosspolarization" + ] + }, + { + "cell_type": "markdown", + "id": "mathematical-default", + "metadata": {}, + "source": [ + "## Multimode waveguide" + ] + }, + { + "cell_type": "markdown", + "id": "asian-noise", + "metadata": {}, + "source": [ + "Let's create a waveguide with two ports (`\"in\"`, `\"out\"`) and two modes (`\"te\"`, `\"tm\"`) without backreflection. Let's assume there is 5% cross-polarization and that the `\"tm\"`->`\"tm\"` transmission is 10% worse than the `\"te\"`->`\"te\"` transmission. Naturally in more realisic waveguide models these percentages will be length-dependent, but this is just a dummy model serving as an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "patient-wealth", + "metadata": {}, + "outputs": [], + "source": [ + "def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):\n", + " \"\"\"a simple straight waveguide model\n", + "\n", + " Args:\n", + " wl: wavelength\n", + " neff: waveguide effective index\n", + " ng: waveguide group index (used for linear neff dispersion)\n", + " wl0: center wavelength at which neff is defined\n", + " length: [m] wavelength length\n", + " loss: [dB/m] waveguide loss\n", + " \"\"\"\n", + " dwl = wl - wl0\n", + " dneff_dwl = (ng - neff) / wl0\n", + " neff = neff - dwl * dneff_dwl\n", + " phase = 2 * jnp.pi * neff * length / wl\n", + " transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)\n", + " sdict = sax.reciprocal(\n", + " {\n", + " (\"in0@te\", \"out0@te\"): 0.95 * transmission, # 5% lost to cross-polarization\n", + " (\"in0@te\", \"out0@tm\"): 0.05 * transmission, # 5% cross-polarization\n", + " (\"in0@tm\", \"out0@tm\"): 0.85 * transmission, # 10% worse tm->tm than te->te\n", + " (\"in0@tm\", \"out0@te\"): 0.05 * transmission, # 5% cross-polarization\n", + " }\n", + " )\n", + " return sdict\n", + "\n", + "\n", + "waveguide()" + ] + }, + { + "cell_type": "markdown", + "id": "accessory-sheep", + "metadata": {}, + "source": [ + "## Multimode MZI" + ] + }, + { + "cell_type": "markdown", + "id": "joint-semiconductor", + "metadata": {}, + "source": [ + "We can now combine these models into a circuit in much the same way as before. We just need to add the `modes=` keyword:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "suffering-judges", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = sax.circuit(\n", + " instances={\n", + " \"lft\": sax.models.coupler, # single mode models will be automatically converted to multimode models without cross polarization.\n", + " \"top\": sax.partial(waveguide, length=25.0),\n", + " \"btm\": sax.partial(waveguide, length=15.0),\n", + " \"rgt\": sax.models.coupler, # single mode models will be automatically converted to multimode models without cross polarization.\n", + " },\n", + " connections={\n", + " \"lft,out0\": \"btm,in0\",\n", + " \"btm,out0\": \"rgt,in0\",\n", + " \"lft,out1\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"lft,in0\",\n", + " \"in1\": \"lft,in1\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + " modes=(\"te\", \"tm\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bridal-insider", + "metadata": {}, + "outputs": [], + "source": [ + "mzi()" + ] + }, + { + "cell_type": "markdown", + "id": "cellular-contemporary", + "metadata": {}, + "source": [ + "we can convert this model back to a singlemode `SDict` as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aggregate-lemon", + "metadata": {}, + "outputs": [], + "source": [ + "mzi_te = sax.singlemode(mzi, mode=\"te\")\n", + "mzi_te()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/05_thinfilm.ipynb b/examples/05_thinfilm.ipynb new file mode 100644 index 0000000..7761a7e --- /dev/null +++ b/examples/05_thinfilm.ipynb @@ -0,0 +1,897 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "96f9e68f", + "metadata": {}, + "source": [ + "# Thin film optimization\n", + "\n", + "contributed by [simbilod](https://github.com/simbilod), adapted by [flaport](https://github.com/flaport)\n", + "\n", + "In this notebook, we apply SAX to thin-film optimization and show how it can be used for wavelength-dependent parameter optimization.\n", + "\n", + "The language of transfer/scatter matrices is commonly used to calculate optical properties of thin-films. Many specialized methods exist for their optimization. However, SAX can be useful to cut down on developer time by circumventing the need to manually take gradients of complicated or often-changed objective functions, and by generating efficient code from simple syntax. " + ] + }, + { + "cell_type": "markdown", + "id": "085ae33c", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1f8afd7", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.example_libraries.optimizers as opt\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import sax # sax circuit simulator\n", + "import tqdm.notebook as tqdm\n", + "from tmm import coh_tmm" + ] + }, + { + "cell_type": "markdown", + "id": "77e24e88", + "metadata": {}, + "source": [ + "## Dielectric mirror Fabry-Pérot\n", + "\n", + "Consider a stack composed of only two materials, $n_A$ and $n_B$. Two types of transfer matrices characterize wave propagation in the system : interfaces described by Fresnel's equations, and propagation." + ] + }, + { + "cell_type": "markdown", + "id": "fa0f57dc", + "metadata": {}, + "source": [ + "For the two-material stack, this leads to 4 scatter matrices coefficients. Through reciprocity they can be constructed out of two independent ones :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0d3a8eb", + "metadata": {}, + "outputs": [], + "source": [ + "def fresnel_mirror_ij(ni=1.0, nj=1.0):\n", + " \"\"\"Model a (fresnel) interface between twoo refractive indices\n", + "\n", + " Args:\n", + " ni: refractive index of the initial medium\n", + " nf: refractive index of the final\n", + " \"\"\"\n", + " r_fresnel_ij = (ni - nj) / (ni + nj) # i->j reflection\n", + " t_fresnel_ij = 2 * ni / (ni + nj) # i->j transmission\n", + " r_fresnel_ji = -r_fresnel_ij # j -> i reflection\n", + " t_fresnel_ji = (1 - r_fresnel_ij ** 2) / t_fresnel_ij # j -> i transmission\n", + " sdict = {\n", + " (\"in\", \"in\"): r_fresnel_ij,\n", + " (\"in\", \"out\"): t_fresnel_ij,\n", + " (\"out\", \"in\"): t_fresnel_ji,\n", + " (\"out\", \"out\"): r_fresnel_ji,\n", + " }\n", + " return sdict\n", + "\n", + "\n", + "def propagation_i(ni=1.0, di=0.5, wl=0.532):\n", + " \"\"\"Model the phase shift acquired as a wave propagates through medium A\n", + "\n", + " Args:\n", + " ni: refractive index of medium (at wavelength wl)\n", + " di: [μm] thickness of layer\n", + " wl: [μm] wavelength\n", + " \"\"\"\n", + " prop_i = jnp.exp(1j * 2 * jnp.pi * ni * di / wl)\n", + " sdict = {\n", + " (\"in\", \"out\"): prop_i,\n", + " (\"out\", \"in\"): prop_i,\n", + " }\n", + " return sdict" + ] + }, + { + "cell_type": "markdown", + "id": "007b1181", + "metadata": {}, + "source": [ + "A resonant cavity can be formed when a high index region is surrounded by low-index region :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14f86c1e", + "metadata": {}, + "outputs": [], + "source": [ + "dielectric_fabry_perot = sax.circuit(\n", + " instances={\n", + " \"air_B\": fresnel_mirror_ij,\n", + " \"B\": propagation_i,\n", + " \"B_air\": fresnel_mirror_ij,\n", + " },\n", + " connections={\n", + " \"air_B,out\": \"B,in\",\n", + " \"B,out\": \"B_air,in\",\n", + " },\n", + " ports={\n", + " \"in\": \"air_B,in\",\n", + " \"out\": \"B_air,out\",\n", + " },\n", + ")\n", + "\n", + "settings = sax.get_settings(dielectric_fabry_perot)\n", + "settings" + ] + }, + { + "cell_type": "markdown", + "id": "c9285d92", + "metadata": {}, + "source": [ + "Let's choose $n_A = 1$, $n_B = 2$, $d_B = 1000$ nm, and compute over the visible spectrum :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ee9d40d", + "metadata": {}, + "outputs": [], + "source": [ + "settings = sax.copy_settings(settings)\n", + "settings[\"air_B\"][\"nj\"] = 2.0\n", + "settings[\"B\"][\"ni\"] = 2.0\n", + "settings[\"B_air\"][\"ni\"] = 2.0\n", + "\n", + "wls = jnp.linspace(0.380, 0.750, 200)\n", + "settings = sax.update_settings(settings, wl=wls)" + ] + }, + { + "cell_type": "markdown", + "id": "1ace74e4", + "metadata": {}, + "source": [ + "Compute transmission and reflection, and compare to another package's results (https://github.com/sbyrnes321/tmm) :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "999d2eee", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "sdict = dielectric_fabry_perot(**settings)\n", + "\n", + "transmitted = sdict[\"in\", \"out\"]\n", + "reflected = sdict[\"in\", \"in\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e0d4550", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "# tmm syntax (https://github.com/sbyrnes321/tmm)\n", + "d_list = [jnp.inf, 0.500, jnp.inf]\n", + "n_list = [1, 2, 1]\n", + "# initialize lists of y-values to plot\n", + "rnorm = []\n", + "tnorm = []\n", + "Tnorm = []\n", + "Rnorm = []\n", + "for l in wls:\n", + " rnorm.append(coh_tmm(\"s\", n_list, d_list, 0, l)[\"r\"])\n", + " tnorm.append(coh_tmm(\"s\", n_list, d_list, 0, l)[\"t\"])\n", + " Tnorm.append(coh_tmm(\"s\", n_list, d_list, 0, l)[\"T\"])\n", + " Rnorm.append(coh_tmm(\"s\", n_list, d_list, 0, l)[\"R\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43b55a4d", + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(wls * 1e3, jnp.real(transmitted), label=\"t SAX\")\n", + "plt.plot(wls * 1e3, jnp.real(jnp.array(tnorm)), \"k\", label=\"t tmm\")\n", + "plt.scatter(wls * 1e3, jnp.real(reflected), label=\"r SAX\")\n", + "plt.plot(wls * 1e3, jnp.real(jnp.array(rnorm)), \"k--\", label=\"r tmm\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"Transmitted and reflected amplitude\")\n", + "plt.legend(loc=\"upper right\")\n", + "plt.title(\"Real part\")\n", + "plt.show()\n", + "\n", + "plt.scatter(wls * 1e3, jnp.imag(transmitted), label=\"t SAX\")\n", + "plt.plot(wls * 1e3, jnp.imag(jnp.array(tnorm)), \"k\", label=\"t tmm\")\n", + "plt.scatter(wls * 1e3, jnp.imag(reflected), label=\"r SAX\")\n", + "plt.plot(wls * 1e3, jnp.imag(jnp.array(rnorm)), \"k--\", label=\"r tmm\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"Transmitted and reflected amplitude\")\n", + "plt.legend(loc=\"upper right\")\n", + "plt.title(\"Imaginary part\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "ae720d84", + "metadata": {}, + "source": [ + "In terms of powers, we get the following. Due to the reflections at the interfaces, resonant behaviour is observed, with evenly-spaced maxima/minima in wavevector space :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3b974a5", + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(2 * jnp.pi / wls, jnp.abs(transmitted) ** 2, label=\"T SAX\")\n", + "plt.plot(2 * jnp.pi / wls, Tnorm, \"k\", label=\"T tmm\")\n", + "plt.scatter(2 * jnp.pi / wls, jnp.abs(reflected) ** 2, label=\"R SAX\")\n", + "plt.plot(2 * jnp.pi / wls, Rnorm, \"k--\", label=\"R tmm\")\n", + "plt.vlines(jnp.arange(3, 6) * jnp.pi / (2 * 0.5), ymin=0, ymax=1, color=\"k\", linestyle=\"--\", label=\"m$\\pi$/nd\")\n", + "plt.xlabel(\"k = 2$\\pi$/λ [1/nm]\")\n", + "plt.ylabel(\"Transmitted and reflected intensities\")\n", + "plt.legend(loc=\"upper right\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b83a4463", + "metadata": {}, + "source": [ + "### Optimization test\n", + "\n", + "Let's attempt to minimize transmission at 500 nm by varying thickness." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d497768f", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def loss(thickness):\n", + " settings = sax.update_settings(sax.get_settings(dielectric_fabry_perot), wl=0.5)\n", + " settings[\"B\"][\"di\"] = thickness\n", + " settings[\"air_B\"][\"nj\"] = 2.0\n", + " settings[\"B\"][\"ni\"] = 2.0\n", + " settings[\"B_air\"][\"ni\"] = 2.0\n", + " sdict = dielectric_fabry_perot(**settings)\n", + " return jnp.abs(sdict[\"in\", \"out\"]) ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e8d21cc", + "metadata": {}, + "outputs": [], + "source": [ + "grad = jax.jit(jax.grad(loss))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18af200b", + "metadata": {}, + "outputs": [], + "source": [ + "initial_thickness = 0.5\n", + "optim_init, optim_update, optim_params = opt.adam(step_size=0.01)\n", + "optim_state = optim_init(initial_thickness)\n", + "\n", + "\n", + "def train_step(step, optim_state):\n", + " thickness = optim_params(optim_state)\n", + " lossvalue = loss(thickness)\n", + " gradvalue = grad(thickness)\n", + " optim_state = optim_update(step, gradvalue, optim_state)\n", + " return lossvalue, optim_state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98bcb50b", + "metadata": {}, + "outputs": [], + "source": [ + "range_ = tqdm.trange(100)\n", + "for step in range_:\n", + " lossvalue, optim_state = train_step(step, optim_state)\n", + " range_.set_postfix(loss=f\"{lossvalue:.6f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13d16ace", + "metadata": {}, + "outputs": [], + "source": [ + "thickness = optim_params(optim_state)\n", + "thickness" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28073b3a", + "metadata": {}, + "outputs": [], + "source": [ + "settings = sax.update_settings(sax.get_settings(dielectric_fabry_perot), wl=wls)\n", + "settings[\"B\"][\"di\"] = thickness\n", + "settings[\"air_B\"][\"nj\"] = 2.0\n", + "settings[\"B\"][\"ni\"] = 2.0\n", + "settings[\"B_air\"][\"ni\"] = 2.0\n", + "sdict = dielectric_fabry_perot(**settings)\n", + "detected = sdict[\"in\", \"out\"]\n", + "\n", + "plt.plot(wls * 1e3, jnp.abs(transmitted) ** 2, label=\"Before (500 nm)\")\n", + "plt.plot(wls * 1e3, jnp.abs(detected) ** 2, label=f\"After ({thickness*1e3:.0f} nm)\")\n", + "plt.vlines(0.5 * 1e3, 0.6, 1, \"k\", linestyle=\"--\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"Transmitted intensity\")\n", + "plt.legend(loc=\"lower right\")\n", + "plt.title(\"Thickness optimization\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fa41c516", + "metadata": {}, + "source": [ + "## General Fabry-Pérot étalon\n", + "\n", + "We reuse the propagation matrix above, and instead of simple interface matrices, model Fabry-Pérot mirrors as general lossless reciprocal scatter matrices :\n", + "\n", + "$$ \\left(\\begin{array}{c} \n", + "E_t \\\\\n", + "E_r\n", + "\\end{array}\\right) = E_{out} = SE_{in} = \\left(\\begin{array}{cc} \n", + "t & r \\\\\n", + "r & t\n", + "\\end{array}\\right) \\left(\\begin{array}{c} \n", + "E_0 \\\\\n", + "0\n", + "\\end{array}\\right) $$\n", + "\n", + "For lossless reciprocal systems, we further have the requirements\n", + "\n", + "$$ |t|^2 + |r|^2 = 1 $$\n", + "\n", + "and\n", + "\n", + "$$ \\angle t - \\angle r = \\pm \\pi/2 $$\n", + "\n", + "The general Fabry-Pérot cavity is analytically described by :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b0a7441", + "metadata": {}, + "outputs": [], + "source": [ + "def airy_t13(t12, t23, r21, r23, wl, d=1.0, n=1.0):\n", + " \"\"\"General Fabry-Pérot transmission transfer function (Airy formula)\n", + "\n", + " Args:\n", + " t12 and r12 : S-parameters of the first mirror\n", + " t23 and r23 : S-parameters of the second mirror\n", + " wl : wavelength\n", + " d : gap between the two mirrors (in units of wavelength)\n", + " n : index of the gap between the two mirrors\n", + "\n", + " Returns:\n", + " t13 : complex transmission amplitude of the mirror-gap-mirror system\n", + "\n", + " Note:\n", + " Each mirror is assumed to be lossless and reciprocal : tij = tji, rij = rji\n", + " \"\"\"\n", + " phi = n * 2 * jnp.pi * d / wl\n", + " return t12 * t23 * jnp.exp(-1j * phi) / (1 - r21 * r23 * jnp.exp(-2j * phi))\n", + "\n", + "\n", + "def airy_r13(t12, t23, r21, r23, wl, d=1.0, n=1.0):\n", + " \"\"\"General Fabry-Pérot reflection transfer function (Airy formula)\n", + "\n", + " Args:\n", + " t12 and r12 : S-parameters of the first mirror\n", + " t23 and r23 : S-parameters of the second mirror\n", + " wl : wavelength\n", + " d : gap between the two mirrors (in units of wavelength)\n", + " n : index of the gap between the two mirrors\n", + "\n", + " Returns:\n", + " r13 : complex reflection amplitude of the mirror-gap-mirror system\n", + "\n", + " Note:\n", + " Each mirror is assumed to be lossless and reciprocal : tij = tji, rij = rji\n", + " \"\"\"\n", + " phi = n * 2 * jnp.pi * d / wl\n", + " return r21 + t12 * t12 * r23 * jnp.exp(-2j * phi) / (1 - r21 * r23 * jnp.exp(-2j * phi))" + ] + }, + { + "cell_type": "markdown", + "id": "9cc808ed", + "metadata": {}, + "source": [ + "We need to implement the relationship between $t$ and $r$ for lossless reciprocal mirrors. The design parameter will be the amplitude and phase of the tranmission coefficient. The reflection coefficient is then fully determined :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62c11bb6", + "metadata": {}, + "outputs": [], + "source": [ + "def t_complex(t_amp, t_ang):\n", + " return t_amp * jnp.exp(-1j * t_ang)\n", + "\n", + "\n", + "def r_complex(t_amp, t_ang):\n", + " r_amp = jnp.sqrt((1.0 - t_amp ** 2))\n", + " r_ang = t_ang - jnp.pi / 2\n", + " return r_amp * jnp.exp(-1j * r_ang)" + ] + }, + { + "cell_type": "markdown", + "id": "a865ea71", + "metadata": {}, + "source": [ + "Let's see the expected result for half-mirrors :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0a6b35d", + "metadata": {}, + "outputs": [], + "source": [ + "t_initial = jnp.sqrt(0.5)\n", + "d_gap = 2.0\n", + "n_gap = 1.0\n", + "r_initial = r_complex(t_initial, 0.0)\n", + "\n", + "wls = jnp.linspace(0.38, 0.78, 500)\n", + "\n", + "T_analytical_initial = jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap)) ** 2\n", + "R_analytical_initial = jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap)) ** 2 \n", + "\n", + "plt.title(f\"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}\")\n", + "plt.plot(2 * jnp.pi / wls, T_analytical_initial, label=\"T\")\n", + "plt.plot(2 * jnp.pi / wls, R_analytical_initial, label=\"R\")\n", + "plt.vlines(jnp.arange(6, 11) * jnp.pi / 2.0, ymin=0, ymax=1, color=\"k\", linestyle=\"--\", label=\"m$\\pi$/nd\")\n", + "plt.xlabel(\"k = 2$\\pi$/$\\lambda$ [1/nm]\")\n", + "plt.ylabel(\"Power (units of input)\")\n", + "plt.legend()\n", + "plt.show()\n", + "\n", + "plt.title(f\"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}\")\n", + "plt.plot(wls * 1e3, T_analytical_initial, label=\"T\")\n", + "plt.plot(wls * 1e3, R_analytical_initial, label=\"R\")\n", + "plt.xlabel(\"$\\lambda$ (nm)\")\n", + "plt.ylabel(\"Power (units of input)\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7c66218", + "metadata": {}, + "outputs": [], + "source": [ + "# Is power conserved? (to within 0.1%)\n", + "assert jnp.isclose(R_analytical_initial + T_analytical_initial, 1, 0.001).all()" + ] + }, + { + "cell_type": "markdown", + "id": "e91cd0dd", + "metadata": {}, + "source": [ + "Now let's do the same with SAX by defining new elements :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80c784a5", + "metadata": {}, + "outputs": [], + "source": [ + "def mirror(t_amp=0.5**0.5, t_ang=0.0):\n", + " r_complex_val = r_complex(t_amp, t_ang)\n", + " t_complex_val = t_complex(t_amp, t_ang)\n", + " sdict = {\n", + " (\"in\", \"in\"): r_complex_val,\n", + " (\"in\", \"out\"): t_complex_val,\n", + " (\"out\", \"in\"): t_complex_val, # (1 - r_complex_val**2)/t_complex_val, # t_ji\n", + " (\"out\", \"out\"): r_complex_val, # -r_complex_val, # r_ji\n", + " }\n", + " return sdict\n", + "\n", + "\n", + "fabry_perot_tunable = sax.circuit(\n", + " instances={\n", + " \"mirror1\": mirror,\n", + " \"gap\": propagation_i,\n", + " \"mirror2\": mirror,\n", + " },\n", + " connections={\n", + " \"mirror1,out\": \"gap,in\",\n", + " \"gap,out\": \"mirror2,in\",\n", + " },\n", + " ports={\n", + " \"in\": \"mirror1,in\",\n", + " \"out\": \"mirror2,out\",\n", + " },\n", + ")\n", + "\n", + "settings = sax.get_settings(fabry_perot_tunable)\n", + "settings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86e2f26e", + "metadata": {}, + "outputs": [], + "source": [ + "fabry_perot_tunable = sax.circuit(\n", + " instances={\n", + " \"mirror1\": mirror,\n", + " \"gap\": propagation_i,\n", + " \"mirror2\": mirror,\n", + " },\n", + " connections={\n", + " \"mirror1,out\": \"gap,in\",\n", + " \"gap,out\": \"mirror2,in\",\n", + " },\n", + " ports={\n", + " \"in\": \"mirror1,in\",\n", + " \"out\": \"mirror2,out\",\n", + " },\n", + ")\n", + "\n", + "settings = sax.get_settings(fabry_perot_tunable)\n", + "settings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3912513", + "metadata": {}, + "outputs": [], + "source": [ + "N = 100\n", + "wls = jnp.linspace(0.38, 0.78, N)\n", + "settings = sax.get_settings(fabry_perot_tunable)\n", + "settings = sax.update_settings(settings, wl=wls, t_amp=jnp.sqrt(0.5), t_ang=0.0)\n", + "settings[\"gap\"][\"ni\"] = 1.0\n", + "settings[\"gap\"][\"di\"] = 2.0\n", + "transmitted_initial = fabry_perot_tunable(**settings)[\"in\", \"out\"]\n", + "reflected_initial = fabry_perot_tunable(**settings)[\"out\", \"out\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0e93a4e", + "metadata": {}, + "outputs": [], + "source": [ + "T_analytical_initial = jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))**2\n", + "R_analytical_initial = jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))**2\n", + "plt.title(f\"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}\")\n", + "plt.plot(wls, T_analytical_initial, label=\"T theory\")\n", + "plt.scatter(wls, jnp.abs(transmitted_initial) ** 2, label=\"T SAX\")\n", + "plt.plot(wls, R_analytical_initial, label=\"R theory\")\n", + "plt.scatter(wls, jnp.abs(reflected_initial) ** 2, label=\"R SAX\")\n", + "plt.xlabel(\"k = 2$\\pi$/$\\lambda$ [1/nm]\")\n", + "plt.ylabel(\"Power (units of input)\")\n", + "plt.figlegend(framealpha=1.0)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "90a42e80", + "metadata": {}, + "source": [ + "## Wavelength-dependent Fabry-Pérot étalon\n", + "\n", + "Let's repeat with a model where parameters can be wavelength-dependent. To comply with the optimizer object, we will stack all design parameters in a single array :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e15a411", + "metadata": {}, + "outputs": [], + "source": [ + "ts_initial = jnp.zeros(2 * N)\n", + "ts_initial = jax.ops.index_update(ts_initial, jax.ops.index[0:N], jnp.sqrt(0.5))" + ] + }, + { + "cell_type": "markdown", + "id": "cbb87fa8", + "metadata": {}, + "source": [ + "We will simply loop over all wavelengths, and use different $t$ parameters at each wavelength." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7415ddb", + "metadata": {}, + "outputs": [], + "source": [ + "wls = jnp.linspace(0.38, 0.78, N)\n", + "transmitted = jnp.zeros_like(wls)\n", + "reflected = jnp.zeros_like(wls)\n", + "settings = sax.get_settings(fabry_perot_tunable)\n", + "settings = sax.update_settings(settings, wl=wls, t_amp=ts_initial[:N], t_ang=ts_initial[N:])\n", + "settings[\"gap\"][\"ni\"] = 1.0\n", + "settings[\"gap\"][\"di\"] = 2.0\n", + "# Perform computation\n", + "sdict = fabry_perot_tunable(**settings)\n", + "transmitted = jnp.abs(sdict[\"in\", \"out\"]) ** 2\n", + "reflected = jnp.abs(sdict[\"in\", \"in\"]) ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4149451", + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(wls * 1e3, T_analytical_initial, label=\"T theory\")\n", + "plt.scatter(wls * 1e3, transmitted, label=\"T SAX\")\n", + "plt.plot(wls * 1e3, R_analytical_initial, label=\"R theory\")\n", + "plt.scatter(wls * 1e3, reflected, label=\"R SAX\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"Transmitted and reflected intensities\")\n", + "plt.legend(loc=\"upper right\")\n", + "plt.title(f\"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6fbc2915", + "metadata": {}, + "source": [ + "Since it seems to work, let's add a target and optimize some harmonics away :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8d6108a", + "metadata": {}, + "outputs": [], + "source": [ + "def lorentzian(l0, dl, wl, A):\n", + " return A / ((wl - l0) ** 2 + (0.5 * dl) ** 2)\n", + "\n", + "\n", + "target = lorentzian(533.0, 20.0, wls * 1e3, 100.0)\n", + "\n", + "plt.scatter(wls * 1e3, transmitted, label=\"T SAX\")\n", + "plt.scatter(wls * 1e3, reflected, label=\"R SAX\")\n", + "plt.plot(wls * 1e3, target, \"r\", linewidth=2, label=\"target\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"Transmitted and reflected intensities\")\n", + "plt.legend(loc=\"upper right\")\n", + "plt.title(f\"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "490f2d79", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def loss(ts):\n", + " N = len(ts[::2])\n", + " wls = jnp.linspace(0.38, 0.78, N)\n", + " target = lorentzian(533.0, 20.0, wls * 1e3, 100.0)\n", + " settings = sax.get_settings(fabry_perot_tunable)\n", + " settings = sax.update_settings(settings, wl=wls, t_amp=ts[:N], t_ang=ts[N:])\n", + " settings[\"gap\"][\"ni\"] = 1.0\n", + " settings[\"gap\"][\"di\"] = 2.0\n", + " sdict = fabry_perot_tunable(**settings)\n", + " transmitted = jnp.abs(sdict[\"in\", \"out\"]) ** 2\n", + " return (jnp.abs(transmitted - target) ** 2).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed1b3bfe", + "metadata": {}, + "outputs": [], + "source": [ + "grad = jax.jit(jax.grad(loss))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "721379af", + "metadata": {}, + "outputs": [], + "source": [ + "optim_init, optim_update, optim_params = opt.adam(step_size=0.001)\n", + "\n", + "def train_step(step, optim_state):\n", + " ts = optim_params(optim_state)\n", + " lossvalue = loss(ts)\n", + " gradvalue = grad(ts)\n", + " optim_state = optim_update(step, gradvalue, optim_state)\n", + " return lossvalue, gradvalue, optim_state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "744e0990", + "metadata": {}, + "outputs": [], + "source": [ + "range_ = tqdm.trange(2000)\n", + "\n", + "optim_state = optim_init(ts_initial)\n", + "for step in range_:\n", + " lossvalue, gradvalue, optim_state = train_step(step, optim_state)\n", + " range_.set_postfix(loss=f\"{lossvalue:.6f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "2451287d", + "metadata": {}, + "source": [ + "The optimized parameters are now wavelength-dependent :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7893774", + "metadata": {}, + "outputs": [], + "source": [ + "ts_optimal = optim_params(optim_state)\n", + "\n", + "plt.scatter(wls * 1e3, ts_initial[:N], label=\"t initial\")\n", + "plt.scatter(wls * 1e3, ts_optimal[:N], label=\"t optimal\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"|t| $(\\lambda)$\")\n", + "plt.legend(loc=\"best\")\n", + "plt.title(f\"d={d_gap} nm, n={n_gap}\")\n", + "plt.show()\n", + "\n", + "plt.scatter(wls * 1e3, ts_initial[N:], label=\"t initial\")\n", + "plt.scatter(wls * 1e3, ts_optimal[N:], label=\"t optimal\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"angle $t (\\lambda)$ (rad)\")\n", + "plt.legend(loc=\"best\")\n", + "plt.title(f\"d={d_gap} nm, n={n_gap}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "e1b92e1e", + "metadata": {}, + "source": [ + "Visualizing the result :" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "223bd973", + "metadata": {}, + "outputs": [], + "source": [ + "wls = jnp.linspace(0.38, 0.78, N)\n", + "transmitted_optimal = jnp.zeros_like(wls)\n", + "reflected_optimal = jnp.zeros_like(wls)\n", + "\n", + "settings = sax.get_settings(fabry_perot_tunable)\n", + "settings = sax.update_settings(\n", + " settings, wl=wls, t_amp=ts_optimal[:N], t_ang=ts_optimal[N:]\n", + ")\n", + "settings[\"gap\"][\"ni\"] = 1.0\n", + "settings[\"gap\"][\"di\"] = 2.0\n", + "transmitted_optimal = jnp.abs(fabry_perot_tunable(**settings)[\"in\", \"out\"]) ** 2\n", + "reflected_optimal = jnp.abs(fabry_perot_tunable(**settings)[\"in\", \"in\"]) ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52ef7be4", + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(wls * 1e3, transmitted_optimal, label=\"T\")\n", + "plt.scatter(wls * 1e3, reflected_optimal, label=\"R\")\n", + "plt.plot(wls * 1e3, lorentzian(533, 20, wls * 1e3, 100), \"r\", label=\"target\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"Transmitted and reflected intensities\")\n", + "plt.legend(loc=\"upper right\")\n", + "plt.title(f\"Optimized t($\\lambda$), d={d_gap} nm, n={n_gap}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "cd20054c", + "metadata": {}, + "source": [ + "The hard part is now to find physical stacks that physically implement $t(\\lambda)$. However, the ease with which we can modify and complexify the loss function opens opportunities for regularization and more complicated objective functions.\n", + "\n", + "The models above are available in `sax.models.thinfilm`, and can straightforwardly be extended to propagation at an angle, s and p polarizations, nonreciprocal systems, and systems with losses." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/06_additive_backend.ipynb b/examples/06_additive_backend.ipynb new file mode 100644 index 0000000..1dd86e4 --- /dev/null +++ b/examples/06_additive_backend.ipynb @@ -0,0 +1,141 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6f097557", + "metadata": {}, + "source": [ + "# Additive Backend: pathlengths and group delays" + ] + }, + { + "cell_type": "markdown", + "id": "56868383", + "metadata": {}, + "source": [ + "Let's go over the core functionality of SAX." + ] + }, + { + "cell_type": "markdown", + "id": "e0d25c8e", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79c98bc0", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "import jax\n", + "import jax.example_libraries.optimizers as opt\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt # plotting\n", + "import sax\n", + "import tqdm # progress bars" + ] + }, + { + "cell_type": "markdown", + "id": "6a42b9d9", + "metadata": {}, + "source": [ + "## Parametrized Models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79e7e40e", + "metadata": {}, + "outputs": [], + "source": [ + "def coupler(length=50.0) -> sax.SDict:\n", + " sdict = {\n", + " (\"in0\", \"out0\"): length,\n", + " (\"in0\", \"out1\"): length,\n", + " (\"in1\", \"out0\"): length,\n", + " (\"in1\", \"out1\"): length,\n", + " }\n", + " return sax.reciprocal(sdict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3248838c", + "metadata": {}, + "outputs": [], + "source": [ + "def waveguide(length=100.0) -> sax.SDict:\n", + " sdict = {\n", + " (\"in0\", \"out0\"): length,\n", + " }\n", + " return sax.reciprocal(sdict)" + ] + }, + { + "cell_type": "markdown", + "id": "d650fb8d", + "metadata": {}, + "source": [ + "## Circuit with additive backend" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56e035d0", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = sax.circuit(\n", + " instances={\n", + " \"lft\": coupler,\n", + " \"top\": partial(waveguide, length=500),\n", + " \"btm\": partial(waveguide, length=100),\n", + " \"rgt\": coupler,\n", + " },\n", + " connections={\n", + " \"lft,out0\": \"btm,in0\",\n", + " \"btm,out0\": \"rgt,in0\",\n", + " \"lft,out1\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"lft,in0\",\n", + " \"in1\": \"lft,in1\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + " backend=\"additive\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1f245fd", + "metadata": {}, + "outputs": [], + "source": [ + "mzi()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/index.ipynb b/index.ipynb new file mode 100644 index 0000000..02cbb66 --- /dev/null +++ b/index.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7dfa345d", + "metadata": {}, + "source": [ + "# SAX\n", + "\n", + "> S + Autograd + XLA" + ] + }, + { + "cell_type": "markdown", + "id": "fef270ba", + "metadata": {}, + "source": [ + "![](docs/images/sax.svg)" + ] + }, + { + "cell_type": "markdown", + "id": "165e0fd9", + "metadata": {}, + "source": [ + "Autograd and XLA for S-parameters - a scatter parameter circuit simulator and\n", + "optimizer for the frequency domain based on [JAX](https://github.com/google/jax).\n", + "\n", + "The simulator was developed for simulating Photonic Integrated Circuits but in fact is\n", + "able to perform any S-parameter based circuit simulation. The goal of SAX is to be a\n", + "thin wrapper around JAX with some basic tools for S-parameter based circuit simulation\n", + "and optimization. Therefore, SAX does not define any special datastructures and tries to\n", + "stay as close as possible to the functional nature of JAX. This makes it very easy to\n", + "get started with SAX as you only need functions and standard python dictionaries. Let's\n", + "dive in...\n", + "\n", + "## Quick Start\n", + "\n", + "[Full Quick Start page](https://flaport.github.io/sax/quick_start) -\n", + "[Documentation](https://flaport.github.io/sax).\n", + "\n", + "Let's first import the SAX library, along with JAX and the JAX-version of numpy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "429e080c", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02092230", + "metadata": {}, + "outputs": [], + "source": [ + "import sax\n", + "import jax\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "markdown", + "id": "2e3a6634", + "metadata": {}, + "source": [ + "Define a model function for your component. A SAX model is just a function that returns\n", + "an 'S-dictionary'. For example a directional coupler:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da79c227", + "metadata": {}, + "outputs": [], + "source": [ + "def coupler(coupling=0.5):\n", + " kappa = coupling**0.5\n", + " tau = (1-coupling)**0.5\n", + " sdict = sax.reciprocal({\n", + " (\"in0\", \"out0\"): tau,\n", + " (\"in0\", \"out1\"): 1j*kappa,\n", + " (\"in1\", \"out0\"): 1j*kappa,\n", + " (\"in1\", \"out1\"): tau,\n", + " })\n", + " return sdict\n", + "\n", + "coupler(coupling=0.3)" + ] + }, + { + "cell_type": "markdown", + "id": "a7686e65", + "metadata": {}, + "source": [ + "Or a waveguide:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56718cf4", + "metadata": {}, + "outputs": [], + "source": [ + "def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):\n", + " dwl = wl - wl0\n", + " dneff_dwl = (ng - neff) / wl0\n", + " neff = neff - dwl * dneff_dwl\n", + " phase = 2 * jnp.pi * neff * length / wl\n", + " amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)\n", + " transmission = amplitude * jnp.exp(1j * phase)\n", + " sdict = sax.reciprocal({(\"in0\", \"out0\"): transmission})\n", + " return sdict\n", + "\n", + "waveguide(length=100.0)" + ] + }, + { + "cell_type": "markdown", + "id": "5837bd61", + "metadata": {}, + "source": [ + "These component models can then be combined into a circuit:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac89c549", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = sax.circuit(\n", + " instances = {\n", + " \"lft\": coupler,\n", + " \"top\": waveguide,\n", + " \"rgt\": coupler,\n", + " },\n", + " connections={\n", + " \"lft,out0\": \"rgt,in0\",\n", + " \"lft,out1\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"lft,in0\",\n", + " \"in1\": \"lft,in1\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + ")\n", + "\n", + "type(mzi)" + ] + }, + { + "cell_type": "markdown", + "id": "a0f204ce", + "metadata": {}, + "source": [ + "As you can see, the mzi we just created is just another component model function! To simulate it, call the mzi function with the (possibly nested) settings of its subcomponents. Global settings can be added to the 'root' of the circuit call and will be distributed over all subcomponents which have a parameter with the same name (e.g. 'wl'):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9bc7f6f", + "metadata": {}, + "outputs": [], + "source": [ + "wl = jnp.linspace(1.53, 1.57, 1000)\n", + "result = mzi(wl=wl, lft={'coupling': 0.3}, top={'length': 200.0}, rgt={'coupling': 0.8})\n", + "\n", + "plt.plot(1e3*wl, jnp.abs(result['in0', 'out0'])**2, label=\"in0->out0\")\n", + "plt.plot(1e3*wl, jnp.abs(result['in0', 'out1'])**2, label=\"in0->out1\", ls=\"--\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"T\")\n", + "plt.grid(True)\n", + "plt.figlegend(ncol=2, loc=\"upper center\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "9bf49bc7", + "metadata": {}, + "source": [ + "Those are the basics. For more info, check out the **full**\n", + "[SAX Quick Start page](https://flaport.github.io/sax/quick_start) or the rest of the [Documentation](https://flaport.github.io/sax).\n", + "\n", + "## Installation\n", + "\n", + "### Dependencies\n", + "\n", + "- [JAX & JAXLIB](https://github.com/google/jax). Please read the JAX install instructions [here](https://github.com/google/jax/#installation).\n", + "\n", + "### Installation\n", + "\n", + "```\n", + "pip install sax\n", + "```\n", + "\n", + "## License\n", + "\n", + "Copyright © 2021, Floris Laporte, [Apache-2.0 License](LICENSE)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/jaxinstall.sh b/jaxinstall.sh deleted file mode 100755 index f94a05f..0000000 --- a/jaxinstall.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/sh - -[ -z "$PLATFORM_VERSION" ] && PLATFORM_VERSION="manylinux2010_x86_64" -[ -z "$PYTHON_VERSION" ] && PYTHON_VERSION="cp$(python --version | sed 's/^.*[ ]\([0-9]\)\.\([0-9]\).*/\1\2/g')" -[ -z "$CUDA_VERSION" ] && CUDA_VERSION="nocuda" -[ -z "$JAXLIB_VERSION" ] && JAXLIB_VERSION="$(grep "jaxlib" requirements.txt | sed 's/==/-/g')" -JAXLIB_URL=$JAXLIB_VERSION -if which nvcc > /dev/null 2>&1; then - [ "$CUDA_VERSION" == "nocuda" ] && CUDA_VERSION="cuda$(nvcc --version | grep release | sed 's/^.*release[ ]\([0-9]*\)\.\([0-9]*\).*$/\1\2/g')" - JAXLIB_URL="https://storage.googleapis.com/jax-releases/$CUDA_VERSION/$JAXLIB_VERSION+$CUDA_VERSION-$PYTHON_VERSION-none-$PLATFORM_VERSION.whl" -fi -PYTHON=$(which python) - -echo "python version: $PYTHON_VERSION @ $PYTHON" -echo "platform version: $PLATFORM_VERSION" -echo "cuda version: $CUDA_VERSION" -echo "jaxlib version: $JAXLIB_VERSION" -echo -echo "pip install --upgrade jax" -echo "pip install --upgrade $JAXLIB_URL" -echo -echo - -$PYTHON -m pip install --upgrade jax -$PYTHON -m pip install --upgrade $JAXLIB_URL diff --git a/nbs/00_typing.ipynb b/nbs/00_typing.ipynb new file mode 100644 index 0000000..c01d574 --- /dev/null +++ b/nbs/00_typing.ipynb @@ -0,0 +1,1608 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "513b2287", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp typing_" + ] + }, + { + "cell_type": "markdown", + "id": "13a3464b", + "metadata": {}, + "source": [ + "# Typing\n", + "\n", + "> SAX types" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9974f734", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec6a3f56", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "import functools\n", + "import inspect\n", + "from collections.abc import Callable as CallableABC\n", + "from typing import Any, Callable, Dict, Tuple, TypedDict, Union, cast, overload\n", + "\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "from natsort import natsorted" + ] + }, + { + "cell_type": "markdown", + "id": "41e26ab2", + "metadata": {}, + "source": [ + "## Core Types" + ] + }, + { + "cell_type": "markdown", + "id": "ca29df5c", + "metadata": {}, + "source": [ + "#### Array" + ] + }, + { + "cell_type": "markdown", + "id": "82c3c344", + "metadata": {}, + "source": [ + "an `Array` is either a jax array or a numpy array:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d301cd9", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "Array = Union[jnp.ndarray, np.ndarray]" + ] + }, + { + "cell_type": "markdown", + "id": "6b86fa37", + "metadata": {}, + "source": [ + "#### Int" + ] + }, + { + "cell_type": "markdown", + "id": "9fa467c9", + "metadata": {}, + "source": [ + "An `Int` is either a built-in `int` or an `Array` [of dtype `int`]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7096ab39", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "Int = Union[int, Array]" + ] + }, + { + "cell_type": "markdown", + "id": "c486cd2c", + "metadata": {}, + "source": [ + "#### Float" + ] + }, + { + "cell_type": "markdown", + "id": "3b59bba6", + "metadata": {}, + "source": [ + "A `Float` is eiter a built-in `float` or an `Array` [of dtype `float`]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7acd4a7a", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "Float = Union[float, Array]" + ] + }, + { + "cell_type": "markdown", + "id": "bdfd6200", + "metadata": {}, + "source": [ + "#### ComplexFloat" + ] + }, + { + "cell_type": "markdown", + "id": "b1a7d255", + "metadata": {}, + "source": [ + "A `ComplexFloat` is either a build-in `complex` or an Array [of dtype `complex`]:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93dacae2", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "ComplexFloat = Union[complex, Float]" + ] + }, + { + "cell_type": "markdown", + "id": "1f1aa099", + "metadata": {}, + "source": [ + "#### Settings" + ] + }, + { + "cell_type": "markdown", + "id": "69839acd", + "metadata": {}, + "source": [ + "A `Settings` dictionary is a nested mapping between setting names [`str`] to either `ComplexFloat` values OR to another lower level `Settings` dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a7fb49e", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "Settings = Union[Dict[str, ComplexFloat], Dict[str, \"Settings\"]]" + ] + }, + { + "cell_type": "markdown", + "id": "cf202d06", + "metadata": {}, + "source": [ + "Settings dictionaries are used to parametrize a SAX `Model` or a `circuit`. The settings dictionary should have the same hierarchy levels as the circuit:\n", + " \n", + " > Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b26c4ab", + "metadata": {}, + "outputs": [], + "source": [ + "mzi_settings = {\n", + " \"wl\": 1.5, # global settings\n", + " \"lft\": {\"coupling\": 0.5}, # settings for the left coupler\n", + " \"top\": {\"neff\": 3.4}, # settings for the top waveguide\n", + " \"rgt\": {\"coupling\": 0.3}, # settings for the right coupler\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "45043912", + "metadata": {}, + "source": [ + "#### SDict" + ] + }, + { + "cell_type": "markdown", + "id": "c5a7facf", + "metadata": {}, + "source": [ + "An `SDict` is a sparse dictionary based representation of an S-matrix, mapping port name tuples such as `('in0', 'out0')` to `ComplexFloat`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64f1293e", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "SDict = Dict[Tuple[str, str], ComplexFloat]" + ] + }, + { + "cell_type": "markdown", + "id": "07dc470a", + "metadata": {}, + "source": [ + "> Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad65a70c", + "metadata": {}, + "outputs": [], + "source": [ + "_sdict: SDict = {\n", + " (\"in0\", \"out0\"): 3.0,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "0e2c7430", + "metadata": {}, + "source": [ + "#### SCoo" + ] + }, + { + "cell_type": "markdown", + "id": "ef614e1d", + "metadata": {}, + "source": [ + "An `SCoo` is a sparse matrix based representation of an S-matrix consisting of three arrays and a port map. The three arrays represent the input port indices [`int`], output port indices [`int`] and the S-matrix values [`ComplexFloat`] of the sparse matrix. The port map maps a port name [`str`] to a port index [`int`]. Only these four arrays **together** and in this specific **order** are considered a valid `SCoo` representation!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4541349", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "SCoo = Tuple[Array, Array, ComplexFloat, Dict[str, int]]" + ] + }, + { + "cell_type": "markdown", + "id": "ff9719c8", + "metadata": {}, + "source": [ + "> Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "972f8060", + "metadata": {}, + "outputs": [], + "source": [ + "Si = jnp.arange(3, dtype=int)\n", + "Sj = jnp.array([0, 1, 0], dtype=int)\n", + "Sx = jnp.array([3.0, 4.0, 1.0])\n", + "port_map = {\"in0\": 0, \"in1\": 2, \"out0\": 1}\n", + "_scoo: SCoo = Si, Sj, Sx, port_map" + ] + }, + { + "cell_type": "markdown", + "id": "2cba9246", + "metadata": {}, + "source": [ + "#### SDense" + ] + }, + { + "cell_type": "markdown", + "id": "fd28795b", + "metadata": {}, + "source": [ + "an `SDense` is a dense matrix representation of an S-matrix. It's represented by an NxN `ComplexFloat` array and a port map (mapping port names onto port indices)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6aee3af", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "SDense = Tuple[Array, Dict[str, int]]" + ] + }, + { + "cell_type": "markdown", + "id": "2c4789f7", + "metadata": {}, + "source": [ + "> Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02def7ed", + "metadata": {}, + "outputs": [], + "source": [ + "Sd = jnp.arange(9, dtype=float).reshape(3, 3)\n", + "port_map = {\"in0\": 0, \"in1\": 2, \"out0\": 1}\n", + "_sdense = Sd, port_map" + ] + }, + { + "cell_type": "markdown", + "id": "b0d2e36f", + "metadata": {}, + "source": [ + "#### SType" + ] + }, + { + "cell_type": "markdown", + "id": "b441f21d", + "metadata": {}, + "source": [ + "an `SType` is either an `SDict` **OR** an `SCoo` **OR** an `SDense`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7db692df", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "SType = Union[SDict, SCoo, SDense]" + ] + }, + { + "cell_type": "markdown", + "id": "e0561b07", + "metadata": {}, + "source": [ + "> Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b52a5f2", + "metadata": {}, + "outputs": [], + "source": [ + "obj: SType = _sdict\n", + "obj: SType = _scoo\n", + "obj: SType = _sdense" + ] + }, + { + "cell_type": "markdown", + "id": "afabe91b", + "metadata": {}, + "source": [ + "#### Model" + ] + }, + { + "cell_type": "markdown", + "id": "13808fb1", + "metadata": {}, + "source": [ + "A `Model` is any keyword-only function that returns an `SType`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "130eb3e8", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "Model = Callable[..., SType]" + ] + }, + { + "cell_type": "markdown", + "id": "fad534fb", + "metadata": {}, + "source": [ + "#### ModelFactory" + ] + }, + { + "cell_type": "markdown", + "id": "85781647", + "metadata": {}, + "source": [ + "A `ModelFactory` is any keyword-only function that returns a `Model`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67523a7c", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "ModelFactory = Callable[..., Model]" + ] + }, + { + "cell_type": "markdown", + "id": "353d8506", + "metadata": {}, + "source": [ + "> Note: SAX sometimes needs to figure out the difference between a `ModelFactory` and a normal `Model` *before* running the function. To do this, SAX will check the return annotation of the function. Any function with a `-> Model` or `-> Callable` annotation will be considered a `ModelFactory`. Any function without this annotation will be considered a normal Model: **don't forget the return annotation of your Model Factory!** To ensure a correct annotation and to ensure forward compatibility, it's recommended to decorate your `ModelFactory` with the `modelfactory` decorator." + ] + }, + { + "cell_type": "markdown", + "id": "e42d3f17", + "metadata": {}, + "source": [ + "#### GeneralModel" + ] + }, + { + "cell_type": "markdown", + "id": "0c4356fc", + "metadata": {}, + "source": [ + "a `GeneralModel` is either a `Model` or a `LogicalNetlist` (will be defined below):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91897d92", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "GeneralModel = Union[Model, \"LogicalNetlist\"]" + ] + }, + { + "cell_type": "markdown", + "id": "b32b5a31", + "metadata": {}, + "source": [ + "#### Models" + ] + }, + { + "cell_type": "markdown", + "id": "a796dcb2", + "metadata": {}, + "source": [ + "`Models` is a mapping between model names [`str`] and `GeneralModel`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5ad83b3", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "Models = Dict[str, GeneralModel]" + ] + }, + { + "cell_type": "markdown", + "id": "ed1af61c", + "metadata": {}, + "source": [ + "> Note: sometimes 'component' is used to refer to a a `Model` or `GeneralModel`. This is because other tools (such as for example GDSFactory) prefer that terminology." + ] + }, + { + "cell_type": "markdown", + "id": "5d276f63", + "metadata": {}, + "source": [ + "## Netlist Types" + ] + }, + { + "cell_type": "markdown", + "id": "2c946dc2", + "metadata": {}, + "source": [ + "#### Instance" + ] + }, + { + "cell_type": "markdown", + "id": "24694709", + "metadata": {}, + "source": [ + "A netlist `Instance` is a mapping with two keys: `\"component\"`, which should map to a key in a `Models` dictionary and `\"settings\"`, which are all the necessary settings to instanciate a component:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6df6a74", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "Instance = TypedDict(\n", + " \"Instance\",\n", + " {\n", + " \"component\": str,\n", + " \"settings\": Settings,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ff82e535", + "metadata": {}, + "source": [ + "> Note: in SAX, a better name for `\"component\"` in the instance definition would probably be `\"model\"` or `\"model_name\"`. However we chose `\"component\"` here to have a 1-to-1 map between SAX netlists and GDSFactory netlists." + ] + }, + { + "cell_type": "markdown", + "id": "365559c3", + "metadata": {}, + "source": [ + "#### GeneralInstance" + ] + }, + { + "cell_type": "markdown", + "id": "5249fc0f", + "metadata": {}, + "source": [ + "A general instance can be any of the following (`LogicalNetlist` and `Netlist` will be defined below):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84243bb1", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "GeneralInstance = Union[str, Instance, \"LogicalNetlist\", \"Netlist\"]" + ] + }, + { + "cell_type": "markdown", + "id": "04a5992a", + "metadata": {}, + "source": [ + "> For example, this is allowed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ef6d5ee", + "metadata": {}, + "outputs": [], + "source": [ + "inst: GeneralInstance = \"my_component_model\"\n", + "inst: GeneralInstance = {\n", + " \"component\": \"my_component_model\",\n", + " \"settings\": {},\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "8af40281", + "metadata": {}, + "source": [ + "> ... and this is not (will be flagged by a static type checker like pyright or mypy):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c13fc4d7", + "metadata": {}, + "outputs": [], + "source": [ + "inst: GeneralInstance = {\n", + " \"component\": \"my_component_model\",\n", + " \"settings\": {},\n", + " \"extra_arg\": \"invalid\",\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "cf9fd876", + "metadata": {}, + "source": [ + "#### Instances" + ] + }, + { + "cell_type": "markdown", + "id": "a9aad796", + "metadata": {}, + "source": [ + "`Instances` is a mapping from instance names [`str`] to a `GeneralInstance`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c86b8426", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "Instances = Union[Dict[str, str], Dict[str, GeneralInstance]]" + ] + }, + { + "cell_type": "markdown", + "id": "0f292b1f", + "metadata": {}, + "source": [ + "#### Netlist" + ] + }, + { + "cell_type": "markdown", + "id": "7b85ab37", + "metadata": {}, + "source": [ + "a `Netlist` is a collection of `\"instances\"`, `\"connections\"` and `\"ports\"`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fb48e7f", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "Netlist = TypedDict(\n", + " \"Netlist\",\n", + " {\n", + " \"instances\": Instances,\n", + " \"connections\": Dict[str, str],\n", + " \"ports\": Dict[str, str],\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d12f61a8", + "metadata": {}, + "source": [ + "> Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "877d1151", + "metadata": {}, + "outputs": [], + "source": [ + "mzi_netlist: Netlist = {\n", + " \"instances\": {\n", + " \"lft\": \"mmi1x2\", # shorthand if no settings need to be given\n", + " \"top\": { # full instance definition\n", + " \"component\": \"waveguide\",\n", + " \"settings\": {\n", + " \"length\": 100.0,\n", + " },\n", + " },\n", + " \"rgt\": \"mmi2x2\", # shorthand if no settings need to be given\n", + " },\n", + " \"connections\": {\n", + " \"lft,out0\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in0\",\n", + " \"top,out1\": \"rgt,in1\",\n", + " },\n", + " \"ports\": {\n", + " \"in0\": \"lft,in0\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "613b0f22", + "metadata": {}, + "source": [ + "#### LogicalNetlist" + ] + }, + { + "cell_type": "markdown", + "id": "f8cb7e75", + "metadata": {}, + "source": [ + "a `LogicalNetlist` is a subset of the more general `Netlist`. It only contains the logical connections and instance names. Not the actual instances. This data structure is mostly used for internal use only." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61556940", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "LogicalNetlist = TypedDict(\n", + " \"LogicalNetlist\",\n", + " {\n", + " \"instances\": Dict[str, str],\n", + " \"connections\": Dict[str, str],\n", + " \"ports\": Dict[str, str],\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3468d6a1", + "metadata": {}, + "source": [ + "> Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab67ad41", + "metadata": {}, + "outputs": [], + "source": [ + "mzi_logical_netlist: Netlist = {\n", + " \"instances\": {\n", + " \"lft\": \"mmi1x2\",\n", + " \"top\": \"waveguide\",\n", + " \"rgt\": \"mmi2x2\",\n", + " },\n", + " \"connections\": {\n", + " \"lft,out0\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in0\",\n", + " \"top,out1\": \"rgt,in1\",\n", + " },\n", + " \"ports\": {\n", + " \"in0\": \"lft,in0\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "dbcb4447", + "metadata": {}, + "source": [ + "## Validation and runtime type-checking:" + ] + }, + { + "cell_type": "markdown", + "id": "6e4c44aa", + "metadata": {}, + "source": [ + "> Note: the type-checking functions below are **NOT** very tight and hence should be used within the right context!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfd6027b", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_float(x: Any) -> bool:\n", + " \"\"\"Check if an object is a `Float`\"\"\"\n", + " if isinstance(x, float):\n", + " return True\n", + " if isinstance(x, np.ndarray):\n", + " return x.dtype in (np.float16, np.float32, np.float64, np.float128)\n", + " if isinstance(x, jnp.ndarray):\n", + " return x.dtype in (jnp.float16, jnp.float32, jnp.float64)\n", + " return False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28374448", + "metadata": {}, + "outputs": [], + "source": [ + "assert is_float(3.0)\n", + "assert not is_float(3)\n", + "assert not is_float(3.0 + 2j)\n", + "assert not is_float(jnp.array(3.0, dtype=complex))\n", + "assert not is_float(jnp.array(3, dtype=int))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "734d6c66", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_complex(x: Any) -> bool:\n", + " \"\"\"check if an object is a `ComplexFloat`\"\"\"\n", + " if isinstance(x, complex):\n", + " return True\n", + " if isinstance(x, np.ndarray):\n", + " return x.dtype in (np.complex64, np.complex128)\n", + " if isinstance(x, jnp.ndarray):\n", + " return x.dtype in (jnp.complex64, jnp.complex128)\n", + " return False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d81a703d", + "metadata": {}, + "outputs": [], + "source": [ + "assert not is_complex(3.0)\n", + "assert not is_complex(3)\n", + "assert is_complex(3.0 + 2j)\n", + "assert is_complex(jnp.array(3.0, dtype=complex))\n", + "assert not is_complex(jnp.array(3, dtype=int))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13d572a4", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_complex_float(x: Any) -> bool:\n", + " \"\"\"check if an object is either a `ComplexFloat` or a `Float`\"\"\"\n", + " return is_float(x) or is_complex(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c533ff0d", + "metadata": {}, + "outputs": [], + "source": [ + "assert is_complex_float(3.0)\n", + "assert not is_complex_float(3)\n", + "assert is_complex_float(3.0 + 2j)\n", + "assert is_complex_float(jnp.array(3.0, dtype=complex))\n", + "assert not is_complex_float(jnp.array(3, dtype=int))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "481faad1", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_sdict(x: Any) -> bool:\n", + " \"\"\"check if an object is an `SDict` (a SAX S-dictionary)\"\"\"\n", + " return isinstance(x, dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb0365dc", + "metadata": {}, + "outputs": [], + "source": [ + "assert not is_sdict(object())\n", + "assert is_sdict(_sdict)\n", + "assert not is_sdict(_scoo)\n", + "assert not is_sdict(_sdense)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e47ea884", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_scoo(x: Any) -> bool:\n", + " \"\"\"check if an object is an `SCoo` (a SAX sparse S-matrix representation in COO-format)\"\"\"\n", + " return isinstance(x, (tuple, list)) and len(x) == 4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ea0239a", + "metadata": {}, + "outputs": [], + "source": [ + "assert not is_scoo(object)\n", + "assert not is_scoo(_sdict)\n", + "assert is_scoo(_scoo)\n", + "assert not is_scoo(_sdense)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b88caf0", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_sdense(x: Any) -> bool:\n", + " \"\"\"check if an object is an `SDense` (a SAX dense S-matrix representation)\"\"\"\n", + " return isinstance(x, (tuple, list)) and len(x) == 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2eddd63a", + "metadata": {}, + "outputs": [], + "source": [ + "assert not is_sdense(object)\n", + "assert not is_sdense(_sdict)\n", + "assert not is_sdense(_scoo)\n", + "assert is_sdense(_sdense)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32503ed6", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_model(model: Any) -> bool:\n", + " \"\"\"check if a callable is a `Model` (a callable returning an `SType`)\"\"\"\n", + " if not callable(model):\n", + " return False\n", + " try:\n", + " sig = inspect.signature(model)\n", + " except ValueError:\n", + " return False\n", + " for param in sig.parameters.values():\n", + " if param.default == inspect.Parameter.empty:\n", + " return False # a proper SAX model does not have any positional arguments.\n", + " if _is_callable_annotation(sig.return_annotation): # model factory\n", + " return False\n", + " return True\n", + "\n", + "def _is_callable_annotation(annotation: Any) -> bool:\n", + " \"\"\"check if an annotation is `Callable`-like\"\"\"\n", + " if isinstance(annotation, str):\n", + " # happens when\n", + " # from __future__ import annotations\n", + " # was imported at the top of the file...\n", + " return annotation.startswith(\"Callable\") or annotation.endswith(\"Model\")\n", + " # TODO: this is not a very robust check...\n", + " try:\n", + " return annotation.__origin__ == CallableABC\n", + " except AttributeError:\n", + " return False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "396caa20", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "assert _is_callable_annotation(Callable)\n", + "assert not _is_callable_annotation(SDict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89bdb647", + "metadata": {}, + "outputs": [], + "source": [ + "def good_model(x=jnp.array(3.0), y=jnp.array(4.0)) -> SDict:\n", + " return {(\"in0\", \"out0\"): jnp.array(3.0)}\n", + "assert is_model(good_model)\n", + "\n", + "def bad_model(positional_argument, x=jnp.array(3.0), y=jnp.array(4.0)) -> SDict:\n", + " return {(\"in0\", \"out0\"): jnp.array(3.0)}\n", + "assert not is_model(bad_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a83edef", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_model_factory(model: Any) -> bool:\n", + " \"\"\"check if a callable is a model function.\"\"\"\n", + " if not callable(model):\n", + " return False\n", + " sig = inspect.signature(model)\n", + " if _is_callable_annotation(sig.return_annotation): # model factory\n", + " return True\n", + " return False" + ] + }, + { + "cell_type": "markdown", + "id": "1f9715ed", + "metadata": {}, + "source": [ + "> Note: For a `Callable` to be considered a `ModelFactory` in SAX, it **MUST** have a `Callable` or `Model` return annotation. Otherwise SAX will view it as a `Model` and things might break!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe9cdff0", + "metadata": {}, + "outputs": [], + "source": [ + "def func() -> Model:\n", + " ...\n", + " \n", + "assert is_model_factory(func) # yes, we only check the annotation for now...\n", + "\n", + "def func():\n", + " ...\n", + " \n", + "assert not is_model_factory(func) # yes, we only check the annotation for now..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6bcb294", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def validate_model(model: Callable):\n", + " \"\"\"Validate the parameters of a model\"\"\"\n", + " positional_arguments = []\n", + " for param in inspect.signature(model).parameters.values():\n", + " if param.default is inspect.Parameter.empty:\n", + " positional_arguments.append(param.name)\n", + " if positional_arguments:\n", + " raise ValueError(\n", + " f\"model '{model}' takes positional arguments {', '.join(positional_arguments)} \"\n", + " \"and hence is not a valid SAX Model! A SAX model should ONLY take keyword arguments (or no arguments at all).\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "754399d5", + "metadata": {}, + "outputs": [], + "source": [ + "def good_model(x=jnp.array(3.0), y=jnp.array(4.0)) -> SDict:\n", + " return {(\"in0\", \"out0\"): jnp.array(3.0)}\n", + "\n", + "\n", + "assert validate_model(good_model) is None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "181c72fa", + "metadata": {}, + "outputs": [], + "source": [ + "def bad_model(positional_argument, x=jnp.array(3.0), y=jnp.array(4.0)) -> SDict:\n", + " return {(\"in0\", \"out0\"): jnp.array(3.0)}\n", + "\n", + "\n", + "with raises(ValueError):\n", + " validate_model(bad_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0c1c97f", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_instance(instance: Any) -> bool:\n", + " \"\"\"check if a dictionary is an instance\"\"\"\n", + " if not isinstance(instance, dict):\n", + " return False\n", + " return \"component\" in instance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0706db8", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_netlist(netlist: Any) -> bool:\n", + " \"\"\"check if a dictionary is a netlist\"\"\"\n", + " if not isinstance(netlist, dict):\n", + " return False\n", + " if not \"instances\" in netlist:\n", + " return False\n", + " if not \"connections\" in netlist:\n", + " return False\n", + " if not \"ports\" in netlist:\n", + " return False\n", + " return True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e2a0f53", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_stype(stype: Any) -> bool:\n", + " \"\"\"check if an object is an SDict, SCoo or SDense\"\"\"\n", + " return is_sdict(stype) or is_scoo(stype) or is_sdense(stype)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "203dc194", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_singlemode(S: Any) -> bool:\n", + " \"\"\"check if an stype is single mode\"\"\"\n", + " if not is_stype(S):\n", + " return False\n", + " ports = _get_ports(S)\n", + " return not any((\"@\" in p) for p in ports)\n", + "\n", + "def _get_ports(S: SType):\n", + " if is_sdict(S):\n", + " S = cast(SDict, S)\n", + " ports_set = {p1 for p1, _ in S} | {p2 for _, p2 in S}\n", + " return tuple(natsorted(ports_set))\n", + " else:\n", + " *_, ports_map = S\n", + " assert isinstance(ports_map, dict)\n", + " return tuple(natsorted(ports_map.keys()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d94e1b5d", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_multimode(S: Any) -> bool:\n", + " \"\"\"check if an stype is single mode\"\"\"\n", + " if not is_stype(S):\n", + " return False\n", + " \n", + " ports = _get_ports(S)\n", + " return all((\"@\" in p) for p in ports)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "657c524a", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def is_mixedmode(S: Any) -> bool:\n", + " \"\"\"check if an stype is neither single mode nor multimode (hence invalid)\"\"\"\n", + " return not is_singlemode(S) and not is_multimode(S)" + ] + }, + { + "cell_type": "markdown", + "id": "3afe685c", + "metadata": {}, + "source": [ + "## SAX return type helpers\n", + "\n", + "> a.k.a SDict, SDense, SCoo helpers" + ] + }, + { + "cell_type": "markdown", + "id": "a875f149", + "metadata": {}, + "source": [ + "Convert an `SDict`, `SCoo` or `SDense` into an `SDict` (or convert a model generating any of these types into a model generating an `SDict`):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac8c1d92", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "@overload\n", + "def sdict(S: Model) -> Model:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def sdict(S: SType) -> SDict:\n", + " ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3facb9f", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def sdict(S: Union[Model, SType]) -> Union[Model, SType]:\n", + " \"\"\"Convert an `SCoo` or `SDense` to `SDict`\"\"\"\n", + "\n", + " if is_model(S):\n", + " model = cast(Model, S)\n", + "\n", + " @functools.wraps(model)\n", + " def wrapper(**kwargs):\n", + " return sdict(model(**kwargs))\n", + "\n", + " return wrapper\n", + "\n", + " elif is_scoo(S):\n", + " x_dict = _scoo_to_sdict(*cast(SCoo, S))\n", + " elif is_sdense(S):\n", + " x_dict = _sdense_to_sdict(*cast(SDense, S))\n", + " elif is_sdict(S):\n", + " x_dict = cast(SDict, S)\n", + " else:\n", + " raise ValueError(\"Could not convert arguments to sdict.\")\n", + "\n", + " return x_dict\n", + "\n", + "\n", + "def _scoo_to_sdict(Si: Array, Sj: Array, Sx: Array, ports_map: Dict[str, int]) -> SDict:\n", + " sdict = {}\n", + " inverse_ports_map = {int(i): p for p, i in ports_map.items()}\n", + " for i, (si, sj) in enumerate(zip(Si, Sj)):\n", + " sdict[\n", + " inverse_ports_map.get(int(si), \"\"), inverse_ports_map.get(int(sj), \"\")\n", + " ] = Sx[..., i]\n", + " sdict = {(p1, p2): v for (p1, p2), v in sdict.items() if p1 and p2}\n", + " return sdict\n", + "\n", + "\n", + "def _sdense_to_sdict(S: Array, ports_map: Dict[str, int]) -> SDict:\n", + " sdict = {}\n", + " for p1, i in ports_map.items():\n", + " for p2, j in ports_map.items():\n", + " sdict[p1, p2] = S[..., i, j]\n", + " return sdict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a372fbf", + "metadata": {}, + "outputs": [], + "source": [ + "assert sdict(_sdict) is _sdict\n", + "assert sdict(_scoo) == {\n", + " (\"in0\", \"in0\"): 3.0,\n", + " (\"in1\", \"in0\"): 1.0,\n", + " (\"out0\", \"out0\"): 4.0,\n", + "}\n", + "assert sdict(_sdense) == {\n", + " (\"in0\", \"in0\"): 0.0,\n", + " (\"in0\", \"out0\"): 1.0,\n", + " (\"in0\", \"in1\"): 2.0,\n", + " (\"out0\", \"in0\"): 3.0,\n", + " (\"out0\", \"out0\"): 4.0,\n", + " (\"out0\", \"in1\"): 5.0,\n", + " (\"in1\", \"in0\"): 6.0,\n", + " (\"in1\", \"out0\"): 7.0,\n", + " (\"in1\", \"in1\"): 8.0,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "492c5cdd", + "metadata": {}, + "source": [ + "Convert an `SDict`, `SCoo` or `SDense` into an `SCoo` (or convert a model generating any of these types into a model generating an `SCoo`):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1089b251", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "@overload\n", + "def scoo(S: Callable) -> Callable:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def scoo(S: SType) -> SCoo:\n", + " ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "017d3328", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def scoo(S: Union[Callable, SType]) -> Union[Callable, SCoo]:\n", + " \"\"\"Convert an `SDict` or `SDense` to `SCoo`\"\"\"\n", + "\n", + " if is_model(S):\n", + " model = cast(Model, S)\n", + "\n", + " @functools.wraps(model)\n", + " def wrapper(**kwargs):\n", + " return scoo(model(**kwargs))\n", + "\n", + " return wrapper\n", + "\n", + " elif is_scoo(S):\n", + " S = cast(SCoo, S)\n", + " elif is_sdense(S):\n", + " S = _sdense_to_scoo(*cast(SDense, S))\n", + " elif is_sdict(S):\n", + " S = _sdict_to_scoo(cast(SDict, S))\n", + " else:\n", + " raise ValueError(\"Could not convert arguments to scoo.\")\n", + "\n", + " return S\n", + "\n", + "\n", + "def _sdense_to_scoo(S: Array, ports_map: Dict[str, int]) -> SCoo:\n", + " Sj, Si = jnp.meshgrid(jnp.arange(S.shape[-1]), jnp.arange(S.shape[-2]))\n", + " return Si.ravel(), Sj.ravel(), S.reshape(*S.shape[:-2], -1), ports_map\n", + "\n", + "\n", + "def _sdict_to_scoo(sdict: SDict) -> SCoo:\n", + " all_ports = {}\n", + " for p1, p2 in sdict:\n", + " all_ports[p1] = None\n", + " all_ports[p2] = None\n", + " ports_map = {p: i for i, p in enumerate(all_ports)}\n", + " Sx = jnp.stack(jnp.broadcast_arrays(*sdict.values()), -1)\n", + " Si = jnp.array([ports_map[p] for p, _ in sdict])\n", + " Sj = jnp.array([ports_map[p] for _, p in sdict])\n", + " return Si, Sj, Sx, ports_map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e97b31c", + "metadata": {}, + "outputs": [], + "source": [ + "assert scoo(_scoo) is _scoo\n", + "assert scoo(_sdict) == (0, 1, 3.0, {\"in0\": 0, \"out0\": 1})\n", + "Si, Sj, Sx, port_map = scoo(_sdense) # type: ignore\n", + "np.testing.assert_array_equal(Si, jnp.array([0, 0, 0, 1, 1, 1, 2, 2, 2]))\n", + "np.testing.assert_array_equal(Sj, jnp.array([0, 1, 2, 0, 1, 2, 0, 1, 2]))\n", + "np.testing.assert_array_almost_equal(Sx, jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))\n", + "assert port_map == {\"in0\": 0, \"in1\": 2, \"out0\": 1}" + ] + }, + { + "cell_type": "markdown", + "id": "5e58325b", + "metadata": {}, + "source": [ + "Convert an `SDict`, `SCoo` or `SDense` into an `SDense` (or convert a model generating any of these types into a model generating an `SDense`):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59c66af9", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "\n", + "@overload\n", + "def sdense(S: Callable) -> Callable:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def sdense(S: SType) -> SDense:\n", + " ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18a7d28d", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def sdense(S: Union[Callable, SType]) -> Union[Callable, SDense]:\n", + " \"\"\"Convert an `SDict` or `SCoo` to `SDense`\"\"\"\n", + "\n", + " if is_model(S):\n", + " model = cast(Model, S)\n", + "\n", + " @functools.wraps(model)\n", + " def wrapper(**kwargs):\n", + " return sdense(model(**kwargs))\n", + "\n", + " return wrapper\n", + "\n", + " if is_sdict(S):\n", + " S = _sdict_to_sdense(cast(SDict, S))\n", + " elif is_scoo(S):\n", + " S = _scoo_to_sdense(*cast(SCoo, S))\n", + " elif is_sdense(S):\n", + " S = cast(SDense, S)\n", + " else:\n", + " raise ValueError(\"Could not convert arguments to sdense.\")\n", + "\n", + " return S\n", + "\n", + "\n", + "def _scoo_to_sdense(\n", + " Si: Array, Sj: Array, Sx: Array, ports_map: Dict[str, int]\n", + ") -> SDense:\n", + " n_col = len(ports_map)\n", + " S = jnp.zeros((*Sx.shape[:-1], n_col, n_col), dtype=complex)\n", + " S = S.at[..., Si, Sj].add(Sx)\n", + " return S, ports_map\n", + "\n", + "\n", + "def _sdict_to_sdense(sdict: SDict) -> SDense:\n", + " Si, Sj, Sx, ports_map = _sdict_to_scoo(sdict)\n", + " return _scoo_to_sdense(Si, Sj, Sx, ports_map)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "084b7ddb", + "metadata": {}, + "outputs": [], + "source": [ + "assert sdense(_sdense) is _sdense\n", + "Sd, port_map = sdense(_scoo) # type: ignore\n", + "Sd_ = jnp.array([[3.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n", + " [0.0 + 0.0j, 4.0 + 0.0j, 0.0 + 0.0j],\n", + " [1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]])\n", + "\n", + "np.testing.assert_array_almost_equal(Sd, Sd_)\n", + "assert port_map == {\"in0\": 0, \"in1\": 2, \"out0\": 1}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24e25816", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def modelfactory(func):\n", + " \"\"\"Decorator that marks a function as `ModelFactory`\"\"\"\n", + " sig = inspect.signature(func)\n", + " if _is_callable_annotation(sig.return_annotation): # already model factory\n", + " return func\n", + " func.__signature__ = sig.replace(return_annotation=Model)\n", + " return func" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/01_patched.ipynb b/nbs/01_patched.ipynb new file mode 100644 index 0000000..b8fbd5b --- /dev/null +++ b/nbs/01_patched.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "2f9e0b2c-92e0-4b32-9d57-a7c81e7e49bc", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp patched" + ] + }, + { + "cell_type": "markdown", + "id": "62619887-4d1d-4218-8259-9dbe991cfded", + "metadata": {}, + "source": [ + "# Patched\n", + "\n", + "> We patch some library and objects that don't belong to SAX. Don't worry, it's nothing substantial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2bc5876-855c-4c0e-b8ae-2142bbd64925", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "import jax.numpy as jnp\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "514f0ad3-8503-4982-9ae2-dbbbb2e2a3c7", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "import re\n", + "from fastcore.basics import patch_to\n", + "from flax.core import FrozenDict\n", + "from jaxlib.xla_extension import DeviceArray\n", + "\n", + "from sax.typing_ import is_complex_float, is_float\n", + "from textwrap import dedent" + ] + }, + { + "cell_type": "markdown", + "id": "e37bb974-621b-4fc2-b4a8-d3f51d4f72ad", + "metadata": {}, + "source": [ + "Paching `FrozenDict` to have the same repr as a normal dict:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c4421fc-3f65-4f9a-9e88-9ed1650f9107", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "@patch_to(FrozenDict)\n", + "def __repr__(self): # type: ignore\n", + " _dict = lambda d: dict(\n", + " {k: (v if not isinstance(v, self.__class__) else dict(v)) for k, v in d.items()}\n", + " )\n", + " return f\"{self.__class__.__name__}({dict.__repr__(_dict(self))})\"" + ] + }, + { + "cell_type": "markdown", + "id": "4484b223-c22e-4537-8f90-df47ef31d086", + "metadata": {}, + "source": [ + "Patching `DeviceArray` to have less verbose reprs for 0-D arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6eb58dbc-119f-4f98-a2f7-d6f1b1b7488e", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "@patch_to(DeviceArray)\n", + "def __repr__(self): # type: ignore\n", + " if self.ndim == 0 and is_float(self):\n", + " v = float(self)\n", + " return repr(round(v, 5)) if abs(v) > 1e-4 else repr(v)\n", + " elif self.ndim == 0 and is_complex_float(self):\n", + " r, i = float(self.real), float(self.imag)\n", + " r = round(r, 5) if abs(r) > 1e-4 else r\n", + " i = round(i, 5) if abs(i) > 1e-4 else i\n", + " s = repr(r + 1j * i)\n", + " if s[0] == \"(\" and s[-1] == \")\":\n", + " s = s[1:-1]\n", + " return s\n", + " else:\n", + " s = super(self.__class__, self).__repr__()\n", + " s = s.replace(\"DeviceArray(\", \" array(\")\n", + " s = re.sub(r\", dtype=.*[,)]\", \"\", s)\n", + " s = re.sub(r\" weak_type=.*[,)]\", \"\", s)\n", + " return dedent(s)+\")\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f6c14e4-723e-4ad9-9b20-fde3d7b7f86f", + "metadata": {}, + "outputs": [], + "source": [ + "jnp.array(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02676641-cddd-4f70-8138-b4530dd318cd", + "metadata": {}, + "outputs": [], + "source": [ + "jnp.array([3, 4, 5])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/02_utils.ipynb b/nbs/02_utils.ipynb new file mode 100644 index 0000000..dd9fde3 --- /dev/null +++ b/nbs/02_utils.ipynb @@ -0,0 +1,1184 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "afb5f4e9", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp utils" + ] + }, + { + "cell_type": "markdown", + "id": "6ecfd145", + "metadata": {}, + "source": [ + "# Utils\n", + "\n", + "> General SAX utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ab524cf", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d1ed3a9", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "import inspect\n", + "import re\n", + "from functools import lru_cache, partial, wraps\n", + "from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union, cast, overload\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.scipy as jsp\n", + "from natsort import natsorted\n", + "from sax.typing_ import (\n", + " Array,\n", + " ComplexFloat,\n", + " Float,\n", + " Model,\n", + " ModelFactory,\n", + " SCoo,\n", + " SDense,\n", + " SDict,\n", + " Settings,\n", + " SType,\n", + " is_mixedmode,\n", + " is_model,\n", + " is_model_factory,\n", + " is_scoo,\n", + " is_sdense,\n", + " is_sdict,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23181084", + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "def block_diag(*arrs: Array) -> Array:\n", + " \"\"\"create block diagonal matrix with arbitrary batch dimensions \"\"\"\n", + " batch_shape = arrs[0].shape[:-2]\n", + " N = 0\n", + " for arr in arrs:\n", + " if batch_shape != arr.shape[:-2]:\n", + " raise ValueError(\"batch dimensions for given arrays don't match.\")\n", + " m, n = arr.shape[-2:]\n", + " if m != n:\n", + " raise ValueError(\"given arrays are not square.\")\n", + " N += n\n", + "\n", + " block_diag = jax.vmap(jsp.linalg.block_diag, in_axes=0, out_axes=0)(\n", + " *(arr.reshape(-1, arr.shape[-2], arr.shape[-1]) for arr in arrs)\n", + " ).reshape(*batch_shape, N, N)\n", + "\n", + " return block_diag" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e9b15b3", + "metadata": {}, + "outputs": [], + "source": [ + "arr1 = 1 * jnp.ones((1, 2, 2))\n", + "arr2 = 2 * jnp.ones((1, 3, 3))\n", + "\n", + "test_eq(\n", + " block_diag(arr1, arr2),\n", + " [[[1.0, 1.0, 0.0, 0.0, 0.0],\n", + " [1.0, 1.0, 0.0, 0.0, 0.0],\n", + " [0.0, 0.0, 2.0, 2.0, 2.0],\n", + " [0.0, 0.0, 2.0, 2.0, 2.0],\n", + " [0.0, 0.0, 2.0, 2.0, 2.0]]]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d56e112d", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def clean_string(s: str) -> str:\n", + " \"\"\"clean a string such that it is a valid python identifier\"\"\"\n", + " s = s.strip()\n", + " s = s.replace(\".\", \"p\") # point\n", + " s = s.replace(\"-\", \"m\") # minus\n", + " s = re.sub(\"[^0-9a-zA-Z]\", \"_\", s)\n", + " if s[0] in \"0123456789\":\n", + " s = \"_\" + s\n", + " return s" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfe8bf00", + "metadata": {}, + "outputs": [], + "source": [ + "assert clean_string(\"Hello, string 1.0\") == \"Hello__string_1p0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5f62f77", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def copy_settings(settings: Settings) -> Settings:\n", + " \"\"\"copy a parameter dictionary\"\"\"\n", + " return validate_settings(settings) # validation also copies\n", + "\n", + "def validate_settings(settings: Settings) -> Settings:\n", + " \"\"\"Validate a parameter dictionary\"\"\"\n", + " _settings = {}\n", + " for k, v in settings.items():\n", + " if isinstance(v, dict):\n", + " _settings[k] = validate_settings(v)\n", + " else:\n", + " _settings[k] = try_float(v)\n", + " return _settings\n", + "\n", + "def try_float(f: Any) -> Any:\n", + " \"\"\"try converting an object to float, return unchanged object on fail\"\"\"\n", + " try:\n", + " return jnp.asarray(f, dtype=float)\n", + " except (ValueError, TypeError):\n", + " return f" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c7a72e4", + "metadata": {}, + "outputs": [], + "source": [ + "orig_settings = {\"a\": 3, \"c\": jnp.array([9.0, 10.0, 11.0])}\n", + "new_settings = copy_settings(orig_settings)\n", + "\n", + "assert orig_settings[\"a\"] == new_settings[\"a\"]\n", + "assert jnp.all(orig_settings[\"c\"] == new_settings[\"c\"])\n", + "new_settings[\"a\"] = jnp.array(5.0)\n", + "assert orig_settings[\"a\"] == 3\n", + "assert new_settings[\"a\"] == 5\n", + "assert orig_settings[\"c\"] is new_settings[\"c\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea47394a", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def flatten_dict(dic: Dict[str, Any], sep: str = \",\") -> Dict[str, Any]:\n", + " \"\"\"flatten a nested dictionary\"\"\"\n", + " return _flatten_dict(dic, sep=sep)\n", + "\n", + "\n", + "def _flatten_dict(\n", + " dic: Dict[str, Any], sep: str = \",\", frozen: bool = False, parent_key: str = \"\"\n", + ") -> Dict[str, Any]:\n", + " items = []\n", + " for k, v in dic.items():\n", + " new_key = parent_key + sep + k if parent_key else k\n", + " if isinstance(v, dict):\n", + " items.extend(\n", + " _flatten_dict(v, sep=sep, frozen=frozen, parent_key=new_key).items()\n", + " )\n", + " else:\n", + " items.append((new_key, v))\n", + "\n", + " return dict(items)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc454e7d", + "metadata": {}, + "outputs": [], + "source": [ + "nested_dict = {\n", + " \"a\": 3.0,\n", + " \"b\": {\"c\": 4.0},\n", + "}\n", + "\n", + "flat_dict = flatten_dict(nested_dict, sep=\",\")\n", + "assert flat_dict == {\"a\": 3.0, \"b,c\": 4.0}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c03047f", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def unflatten_dict(dic, sep=\",\"):\n", + " \"\"\"unflatten a flattened dictionary \"\"\"\n", + " \n", + " # from: https://gist.github.com/fmder/494aaa2dd6f8c428cede\n", + " items = dict()\n", + "\n", + " for k, v in dic.items():\n", + " keys = k.split(sep)\n", + " sub_items = items\n", + " for ki in keys[:-1]:\n", + " if ki in sub_items:\n", + " sub_items = sub_items[ki]\n", + " else:\n", + " sub_items[ki] = dict()\n", + " sub_items = sub_items[ki]\n", + "\n", + " sub_items[keys[-1]] = v\n", + "\n", + " return items" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd8b221b", + "metadata": {}, + "outputs": [], + "source": [ + "assert unflatten_dict(flat_dict, sep=\",\") == nested_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "101fe3dc", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def get_ports(S: Union[Model, SType]) -> Tuple[str, ...]:\n", + " \"\"\"get port names of a model or an stype\"\"\"\n", + " if is_model(S):\n", + " return _get_ports_from_model(cast(Model, S))\n", + " elif is_sdict(S):\n", + " ports_set = {p1 for p1, _ in S} | {p2 for _, p2 in S}\n", + " return tuple(natsorted(ports_set))\n", + " elif is_scoo(S) or is_sdense(S):\n", + " *_, ports_map = S\n", + " return tuple(natsorted(ports_map.keys()))\n", + " else:\n", + " raise ValueError(\"Could not extract ports for given S\")\n", + " \n", + "@lru_cache(maxsize=4096) # cache to prevent future tracing\n", + "def _get_ports_from_model(model: Model) -> Tuple[str, ...]:\n", + " S: SType = jax.eval_shape(model)\n", + " return get_ports(S)" + ] + }, + { + "cell_type": "markdown", + "id": "4c23ef24", + "metadata": {}, + "source": [ + "> Note: if a `Model` function is given in stead of an `SDict`, the function will be traced by JAX to obtain the ports of the resulting `SType`. Although this tracing of the function is 'cheap' in comparison to evaluating the model/circuit. It is not for free! Use this function sparingly on your large `Model` or `circuit`!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c70af810", + "metadata": {}, + "outputs": [], + "source": [ + "from sax.typing_ import scoo\n", + "scoo({(\"in0\", \"out0\"): 1.0})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc2eecbe", + "metadata": {}, + "outputs": [], + "source": [ + "def coupler(coupling=0.5):\n", + " return {\n", + " (\"in0\", \"out0\"): coupling**0.5,\n", + " (\"in0\", \"out1\"): 1j*coupling**0.5,\n", + " (\"in1\", \"out0\"): 1j*coupling**0.5,\n", + " (\"in1\", \"out1\"): coupling**0.5,\n", + " }\n", + "\n", + "model = coupler\n", + "assert get_ports(model) == ('in0', 'in1', 'out0', 'out1')\n", + "\n", + "sdict_ = coupler()\n", + "assert get_ports(sdict_) == ('in0', 'in1', 'out0', 'out1')\n", + "\n", + "from sax.typing_ import scoo\n", + "scoo_ = scoo(sdict_)\n", + "assert get_ports(scoo_) == ('in0', 'in1', 'out0', 'out1')\n", + "\n", + "from sax.typing_ import sdense\n", + "sdense_ = sdense(sdict_)\n", + "assert get_ports(sdense_) == ('in0', 'in1', 'out0', 'out1')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99e8799d", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def get_port_combinations(S: Union[Model, SType]) -> Tuple[Tuple[str, str], ...]:\n", + " \"\"\"get port combinations of a model or an stype\"\"\"\n", + " \n", + " if is_model(S):\n", + " S = cast(Model, S)\n", + " return _get_port_combinations_from_model(S)\n", + " elif is_sdict(S):\n", + " S = cast(SDict, S)\n", + " return tuple(S.keys())\n", + " elif is_scoo(S):\n", + " Si, Sj, _, pm = cast(SCoo, S)\n", + " rpm = {int(i): str(p) for p, i in pm.items()}\n", + " return tuple(natsorted((rpm[int(i)], rpm[int(j)]) for i, j in zip(Si, Sj)))\n", + " elif is_sdense(S):\n", + " _, pm = cast(SDense, S)\n", + " return tuple(natsorted((p1, p2) for p1 in pm for p2 in pm))\n", + " else:\n", + " raise ValueError(\"Could not extract ports for given S\")\n", + " \n", + "@lru_cache(maxsize=4096) # cache to prevent future tracing\n", + "def _get_port_combinations_from_model(model: Model) -> Tuple[Tuple[str, str], ...]:\n", + " S: SType = jax.eval_shape(model)\n", + " return get_port_combinations(S)" + ] + }, + { + "cell_type": "markdown", + "id": "a43d6849", + "metadata": {}, + "source": [ + "> Note: if a `Model` function is given in stead of an `SDict`, the function will be traced by JAX to obtain the port combinations of the resulting `SType`. Although this tracing of the function is 'cheap' in comparison to evaluating the model/circuit. It is not for free! Use this function sparingly on your large `Model` or `circuit`!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf00a0c0", + "metadata": {}, + "outputs": [], + "source": [ + "model = coupler\n", + "assert get_port_combinations(model) == ((\"in0\", \"out0\"), (\"in0\", \"out1\"), (\"in1\", \"out0\"), (\"in1\", \"out1\"))\n", + "\n", + "sdict_ = coupler()\n", + "assert get_port_combinations(sdict_) == ((\"in0\", \"out0\"), (\"in0\", \"out1\"), (\"in1\", \"out0\"), (\"in1\", \"out1\"))\n", + "\n", + "from sax.typing_ import scoo\n", + "scoo_ = scoo(sdict_)\n", + "assert get_port_combinations(scoo_) == ((\"in0\", \"out0\"), (\"in0\", \"out1\"), (\"in1\", \"out0\"), (\"in1\", \"out1\"))\n", + "\n", + "from sax.typing_ import sdense\n", + "sdense_ = sdense(sdict_)\n", + "assert get_port_combinations(sdense_) == ((\"in0\", \"in0\"), (\"in0\", \"in1\"), (\"in0\", \"out0\"), (\"in0\", \"out1\"), \n", + " (\"in1\", \"in0\"), (\"in1\", \"in1\"), (\"in1\", \"out0\"), (\"in1\", \"out1\"), (\"out0\", \"in0\"), (\"out0\", \"in1\"), \n", + " (\"out0\", \"out0\"), (\"out0\", \"out1\"), (\"out1\", \"in0\"), (\"out1\", \"in1\"), (\"out1\", \"out0\"), (\"out1\", \"out1\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1aff65c", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def get_settings(model: Union[Model, ModelFactory]) -> Settings:\n", + " \"\"\"Get the parameters of a SAX model function\"\"\"\n", + " \n", + " signature = inspect.signature(model)\n", + "\n", + " settings: Settings = {\n", + " k: (v.default if not isinstance(v, dict) else v)\n", + " for k, v in signature.parameters.items()\n", + " if v.default is not inspect.Parameter.empty\n", + " }\n", + "\n", + " # make sure an inplace operation of resulting dict does not change the\n", + " # circuit parameters themselves\n", + " return copy_settings(settings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2deefdc8", + "metadata": {}, + "outputs": [], + "source": [ + "assert get_settings(coupler) == {'coupling': 0.5}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "453ecdaa", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def grouped_interp(wl: Float, wls: Float, phis: Float) -> Float:\n", + " \"\"\"Grouped phase interpolation\"\"\"\n", + " wl = cast(Array, jnp.asarray(wl))\n", + " wls = cast(Array, jnp.asarray(wls))\n", + " # make sure values between -pi and pi\n", + " phis = cast(Array, jnp.asarray(phis)) % (2 * jnp.pi)\n", + " phis = jnp.where(phis > jnp.pi, phis - 2 * jnp.pi, phis) \n", + " if not wls.ndim == 1:\n", + " raise ValueError(\"grouped_interp: wls should be a 1D array\")\n", + " if not phis.ndim == 1:\n", + " raise ValueError(\"grouped_interp: wls should be a 1D array\")\n", + " if not wls.shape == phis.shape:\n", + " raise ValueError(\"grouped_interp: wls and phis shape does not match\")\n", + " return _grouped_interp(wl.reshape(-1), wls, phis).reshape(*wl.shape)\n", + "\n", + "\n", + "@partial(jax.vmap, in_axes=(0, None, None), out_axes=0)\n", + "@jax.jit\n", + "def _grouped_interp(\n", + " wl: Array, # 0D array (not-vmapped) ; 1D array (vmapped)\n", + " wls: Array, # 1D array\n", + " phis: Array, # 1D array\n", + ") -> Array:\n", + " dphi_dwl = (phis[1::2] - phis[::2]) / (wls[1::2] - wls[::2])\n", + " phis = phis[::2]\n", + " wls = wls[::2]\n", + " dwl = (wls[1:] - wls[:-1]).mean(0, keepdims=True)\n", + "\n", + " t = (wl - wls + 1e-5 * dwl) / dwl # small offset to ensure no values are zero\n", + " t = jnp.where(jnp.abs(t) < 1, t, 0)\n", + " m0 = jnp.where(t > 0, size=1)[0]\n", + " m1 = jnp.where(t < 0, size=1)[0]\n", + " t = t[m0]\n", + " wl0 = wls[m0]\n", + " wl1 = wls[m1]\n", + " phi0 = phis[m0]\n", + " phi1 = phis[m1]\n", + " dphi_dwl0 = dphi_dwl[m0]\n", + " dphi_dwl1 = dphi_dwl[m1]\n", + " _phi0 = phi0 - 0.5 * (wl1 - wl0) * (\n", + " dphi_dwl0 * (t ** 2 - 2 * t) - dphi_dwl1 * t ** 2\n", + " )\n", + " _phi1 = phi1 - 0.5 * (wl1 - wl0) * (\n", + " dphi_dwl0 * (t - 1) ** 2 - dphi_dwl1 * (t ** 2 - 1)\n", + " )\n", + " phis = jnp.arctan2(\n", + " (1 - t) * jnp.sin(_phi0) + t * jnp.sin(_phi1),\n", + " (1 - t) * jnp.cos(_phi0) + t * jnp.cos(_phi1),\n", + " )\n", + " return phis" + ] + }, + { + "cell_type": "markdown", + "id": "79a7f1a8", + "metadata": {}, + "source": [ + "Grouped interpolation is useful to interpolate phase values where each datapoint is doubled (very close together) to give an indication of the phase variation at that point.\n", + "\n", + "> Note: this interpolation is only accurate in the range `[wls[0], wls[-2])` (`wls[-2]` not included). Any extrapolation outside these bounds can yield unexpected results!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e06bfe4", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "\n", + "wls = jnp.array([2.19999, 2.20001, 2.22499, 2.22501, 2.24999, 2.25001, 2.27499, 2.27501, 2.29999, 2.30001, 2.32499, 2.32501, 2.34999, 2.35001, 2.37499, 2.37501, 2.39999, 2.40001, 2.42499, 2.42501, 2.44999, 2.45001])\n", + "phis = jnp.array([5.17317336, 5.1219654, 4.71259842, 4.66252492, 5.65699608, 5.60817922, 2.03697377, 1.98936119, 6.010146, 5.96358061, 4.96336733, 4.91777933, 5.13912198, 5.09451137, 0.22347545, 0.17979684, 2.74501894, 2.70224092, 0.10403192, 0.06214664, 4.83328794, 4.79225525])\n", + "wl = jnp.array([2.21, 2.27, 1.31, 2.424])\n", + "phi = jnp.array(grouped_interp(wl, wls, phis))\n", + "phi_ref = jnp.array([-1.4901831, 1.3595749, -1.110012 , 2.1775336])\n", + "\n", + "assert ((phi-phi_ref)**2 < 1e-5).all()" + ] + }, + { + "cell_type": "markdown", + "id": "ad7ec9d1", + "metadata": {}, + "source": [ + "> Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03b7a938", + "metadata": {}, + "outputs": [], + "source": [ + "wl = jnp.linspace(wls.min(), wls.max(), 10000)\n", + "\n", + "_, ax = plt.subplots(2, 1, sharex=True, figsize=(14, 6))\n", + "plt.sca(ax[0])\n", + "plt.plot(1e3*wls, jnp.arange(wls.shape[0]), marker=\"o\", ls=\"none\")\n", + "plt.grid(True)\n", + "plt.ylabel(\"index\")\n", + "plt.sca(ax[1])\n", + "plt.grid(True)\n", + "plt.plot(1e3*wls, phis, marker=\"o\", c=\"C1\")\n", + "plt.plot(1e3*wl, grouped_interp(wl, wls, phis), c=\"C2\")\n", + "plt.xlabel(\"λ [nm]\")\n", + "plt.ylabel(\"φ\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd2e89a1", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def merge_dicts(*dicts: Dict) -> Dict:\n", + " \"\"\"merge (possibly deeply nested) dictionaries\"\"\"\n", + " if len(dicts) == 1:\n", + " return dict(_generate_merged_dict(dicts[0], {}))\n", + " elif len(dicts) == 2:\n", + " return dict(_generate_merged_dict(dicts[0], dicts[1]))\n", + " else:\n", + " return merge_dicts(dicts[0], merge_dicts(*dicts[1:]))\n", + " \n", + "\n", + "def _generate_merged_dict(dict1: Dict, dict2: Dict) -> Iterator[Tuple[Any, Any]]:\n", + " # inspired by https://stackoverflow.com/questions/7204805/how-to-merge-dictionaries-of-dictionaries\n", + " keys = {**{k: None for k in dict1}, **{k: None for k in dict2}} # keep key order, values irrelevant\n", + " for k in keys:\n", + " if k in dict1 and k in dict2:\n", + " v1, v2 = dict1[k], dict2[k]\n", + " if isinstance(v1, dict) and isinstance(v2, dict):\n", + " v = dict(_generate_merged_dict(v1, v2))\n", + " else:\n", + " # If one of the values is not a dict, you can't continue merging it.\n", + " # Value from second dict overrides one in first and we move on.\n", + " v = v2\n", + " elif k in dict1:\n", + " v = dict1[k]\n", + " else: # k in dict2:\n", + " v = dict2[k]\n", + " \n", + " if isinstance(v, dict):\n", + " yield (k, {**v}) # shallow copy of dict\n", + " else:\n", + " yield (k, v)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27cf74d7", + "metadata": {}, + "outputs": [], + "source": [ + "d = merge_dicts({\"a\": 3}, {\"b\": 4})\n", + "assert d[\"a\"] == 3\n", + "assert d[\"b\"] == 4\n", + "assert tuple(sorted(d)) == (\"a\", \"b\")\n", + "\n", + "d = merge_dicts({\"a\": 3}, {\"a\": 4})\n", + "assert d[\"a\"] == 4\n", + "assert tuple(d) == (\"a\",)\n", + "\n", + "d = merge_dicts({\"a\": 3}, {\"a\": {\"b\": 5}})\n", + "assert d[\"a\"][\"b\"] == 5\n", + "assert tuple(d) == (\"a\",)\n", + "\n", + "d = merge_dicts({\"a\": {\"b\": 5}}, {\"a\": 3})\n", + "assert d[\"a\"] == 3\n", + "assert tuple(d) == (\"a\",)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "604b6b59", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def mode_combinations(\n", + " modes: Iterable[str], cross: bool = False\n", + ") -> Tuple[Tuple[str, str], ...]:\n", + " \"\"\"create mode combinations for a collection of given modes\"\"\"\n", + " if cross:\n", + " mode_combinations = natsorted((m1, m2) for m1 in modes for m2 in modes)\n", + " else:\n", + " mode_combinations = natsorted((m, m) for m in modes)\n", + " return tuple(mode_combinations)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5010a982", + "metadata": {}, + "outputs": [], + "source": [ + "assert mode_combinations(modes=[\"te\", \"tm\"]) == (('te', 'te'), ('tm', 'tm'))\n", + "assert mode_combinations(modes=[\"te\", \"tm\"], cross=True) == (('te', 'te'), ('te', 'tm'), ('tm', 'te'), ('tm', 'tm'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9c11225", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def reciprocal(sdict: SDict) -> SDict:\n", + " \"\"\"Make an SDict reciprocal\"\"\"\n", + " if is_sdict(sdict):\n", + " return {\n", + " **{(p1, p2): v for (p1, p2), v in sdict.items()},\n", + " **{(p2, p1): v for (p1, p2), v in sdict.items()},\n", + " }\n", + " else:\n", + " raise ValueError(\"sax.reciprocal is only valid for SDict types\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30a304a7", + "metadata": {}, + "outputs": [], + "source": [ + "sdict_ = {(\"in0\", \"out0\"): 1.0}\n", + "assert reciprocal(sdict_) == {(\"in0\", \"out0\"): 1.0, (\"out0\", \"in0\"): 1.0}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d5da94d", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "@overload\n", + "def rename_params(model: ModelFactory, renamings: Dict[str, str]) -> ModelFactory:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def rename_params(model: Model, renamings: Dict[str, str]) -> Model:\n", + " ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d719c153", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def rename_params(\n", + " model: Union[Model, ModelFactory], renamings: Dict[str, str]\n", + ") -> Union[Model, ModelFactory]:\n", + " \"\"\"rename the parameters of a `Model` or `ModelFactory` given a renamings mapping old parameter names to new.\"\"\"\n", + " \n", + " reversed_renamings = {v: k for k, v in renamings.items()}\n", + " if len(reversed_renamings) < len(renamings):\n", + " raise ValueError(\"Multiple old names point to the same new name!\")\n", + "\n", + " if is_model_factory(model):\n", + " old_model_factory = cast(ModelFactory, model)\n", + " old_settings = get_settings(model)\n", + "\n", + " @wraps(old_model_factory)\n", + " def new_model_factory(**settings):\n", + " old_settings = {\n", + " reversed_renamings.get(k, k): v for k, v in settings.items()\n", + " }\n", + " model = old_model_factory(**old_settings)\n", + " return rename_params(model, renamings)\n", + "\n", + " new_settings = {renamings.get(k, k): v for k, v in old_settings.items()}\n", + " _replace_kwargs(new_model_factory, **new_settings)\n", + "\n", + " return new_model_factory\n", + "\n", + " elif is_model(model):\n", + " old_model = cast(Model, model)\n", + " old_settings = get_settings(model)\n", + "\n", + " @wraps(old_model)\n", + " def new_model(**settings):\n", + " old_settings = {\n", + " reversed_renamings.get(k, k): v for k, v in settings.items()\n", + " }\n", + " return old_model(**old_settings)\n", + "\n", + " new_settings = {renamings.get(k, k): v for k, v in old_settings.items()}\n", + " _replace_kwargs(new_model, **new_settings)\n", + "\n", + " return new_model\n", + "\n", + " else:\n", + " raise ValueError(\n", + " \"rename_params should be used to decorate a Model or ModelFactory.\"\n", + " )\n", + " \n", + "def _replace_kwargs(func: Callable, **kwargs: ComplexFloat):\n", + " \"\"\"Change the kwargs signature of a function\"\"\"\n", + " sig = inspect.signature(func)\n", + " settings = [\n", + " inspect.Parameter(k, inspect.Parameter.KEYWORD_ONLY, default=v)\n", + " for k, v in kwargs.items()\n", + " ]\n", + " func.__signature__ = sig.replace(parameters=settings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b041b23", + "metadata": {}, + "outputs": [], + "source": [ + "def model(x=jnp.array(3.0), y=jnp.array(4.0), z=jnp.array([3.0, 4.0])) -> SDict:\n", + " return {(\"in0\", \"out0\"): jnp.array(3.0)}\n", + "\n", + "renamings = {\"x\": \"a\", \"y\": \"z\", \"z\": \"y\"}\n", + "new_model = rename_params(model, renamings)\n", + "settings = get_settings(new_model)\n", + "assert settings[\"a\"] == 3.0\n", + "assert settings[\"z\"] == 4.0\n", + "assert jnp.all(settings[\"y\"] == jnp.array([3.0, 4.0]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00885806", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "@overload\n", + "def rename_ports(S: SDict, renamings: Dict[str, str]) -> SDict:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def rename_ports(S: SCoo, renamings: Dict[str, str]) -> SCoo:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def rename_ports(S: SDense, renamings: Dict[str, str]) -> SDense:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def rename_ports(S: Model, renamings: Dict[str, str]) -> Model:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def rename_ports(S: ModelFactory, renamings: Dict[str, str]) -> ModelFactory:\n", + " ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fd79786", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def rename_ports(\n", + " S: Union[SType, Model, ModelFactory], renamings: Dict[str, str]\n", + ") -> Union[SType, Model, ModelFactory]:\n", + " \"\"\"rename the ports of an `SDict`, `Model` or `ModelFactory` given a renamings mapping old port names to new.\"\"\"\n", + " if is_scoo(S):\n", + " Si, Sj, Sx, ports_map = cast(SCoo, S)\n", + " ports_map = {renamings[p]: i for p, i in ports_map.items()}\n", + " return Si, Sj, Sx, ports_map\n", + " elif is_sdense(S):\n", + " Sx, ports_map = cast(SDense, S)\n", + " ports_map = {renamings[p]: i for p, i in ports_map.items()}\n", + " return Sx, ports_map\n", + " elif is_sdict(S):\n", + " sdict = cast(SDict, S)\n", + " original_ports = get_ports(sdict)\n", + " assert len(renamings) == len(original_ports)\n", + " return {(renamings[p1], renamings[p2]): v for (p1, p2), v in sdict.items()}\n", + " elif is_model(S):\n", + " old_model = cast(Model, S)\n", + "\n", + " @wraps(old_model)\n", + " def new_model(**settings) -> SType:\n", + " return rename_ports(old_model(**settings), renamings)\n", + "\n", + " return new_model\n", + " elif is_model_factory(S):\n", + " old_model_factory = cast(ModelFactory, S)\n", + "\n", + " @wraps(old_model_factory)\n", + " def new_model_factory(**settings) -> Callable[..., SType]:\n", + " return rename_ports(old_model_factory(**settings), renamings)\n", + "\n", + " return new_model_factory\n", + " else:\n", + " raise ValueError(\"Cannot rename ports for type {type(S)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd9c42ac", + "metadata": {}, + "outputs": [], + "source": [ + "d = reciprocal({(\"p0\", \"p1\"): 0.1, (\"p1\", \"p2\"): 0.2})\n", + "origports = get_ports(d)\n", + "renamings = {\"p0\": \"in0\", \"p1\": \"out0\", \"p2\": \"in1\"}\n", + "d_ = rename_ports(d, renamings)\n", + "assert tuple(sorted(get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))\n", + "d_ = rename_ports(scoo(d), renamings)\n", + "assert tuple(sorted(get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))\n", + "d_ = rename_ports(sdense(d), renamings)\n", + "assert tuple(sorted(get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e0f4e87", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def update_settings(\n", + " settings: Settings, *compnames: str, **kwargs: ComplexFloat\n", + ") -> Settings:\n", + " \"\"\"update a nested settings dictionary\"\"\"\n", + " _settings = {}\n", + " if not compnames:\n", + " for k, v in settings.items():\n", + " if isinstance(v, dict):\n", + " _settings[k] = update_settings(v, **kwargs)\n", + " else:\n", + " if k in kwargs:\n", + " _settings[k] = try_float(kwargs[k])\n", + " else:\n", + " _settings[k] = try_float(v)\n", + " else:\n", + " for k, v in settings.items():\n", + " if isinstance(v, dict):\n", + " if k == compnames[0]:\n", + " _settings[k] = update_settings(v, *compnames[1:], **kwargs)\n", + " else:\n", + " _settings[k] = v\n", + " else:\n", + " _settings[k] = try_float(v)\n", + " return _settings" + ] + }, + { + "cell_type": "markdown", + "id": "78e2388a", + "metadata": {}, + "source": [ + "> Note: (1) Even though it's possible to update parameter dictionaries in place, this function is convenient to apply certain parameters (e.g. wavelength 'wl' or temperature 'T') globally. (2) This operation never updates the given settings dictionary inplace. (3) Any non-float keyword arguments will be silently ignored." + ] + }, + { + "cell_type": "markdown", + "id": "a7642618", + "metadata": {}, + "source": [ + "Assuming you have a settings dictionary for a `circuit` containing a directional coupler `\"dc\"` and a waveguide `\"wg\"`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85293a23", + "metadata": {}, + "outputs": [], + "source": [ + "settings = {\"wl\": 1.55, \"dc\": {\"coupling\": 0.5}, \"wg\": {\"wl\": 1.56, \"neff\": 2.33}}" + ] + }, + { + "cell_type": "markdown", + "id": "68c1cd3b", + "metadata": {}, + "source": [ + "You can update this settings dictionary with some global settings as follows. When updating settings globally like this, each subdictionary of the settings dictionary will be updated with these values (if the key exists in the subdictionary):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a1aaf3c", + "metadata": {}, + "outputs": [], + "source": [ + "settings = update_settings(settings, wl=1.3, coupling=0.3, neff=3.0)\n", + "assert settings == {\"wl\": 1.3, \"dc\": {\"coupling\": 0.3}, \"wg\": {\"wl\": 1.3, \"neff\": 3.0}}" + ] + }, + { + "cell_type": "markdown", + "id": "73737b1f", + "metadata": {}, + "source": [ + "Alternatively, you can set certain settings for a specific component (e.g. 'wg' in this case) as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87c32103", + "metadata": {}, + "outputs": [], + "source": [ + "settings = update_settings(settings, \"wg\", wl=2.0)\n", + "assert settings == {\"wl\": 1.3, \"dc\": {\"coupling\": 0.3}, \"wg\": {\"wl\": 2.0, \"neff\": 3.0}}" + ] + }, + { + "cell_type": "markdown", + "id": "47d3b775", + "metadata": {}, + "source": [ + "note that only the `\"wl\"` belonging to `\"wg\"` has changed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb755065", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def validate_not_mixedmode(S: SType):\n", + " \"\"\"validate that an stype is not 'mixed mode' (i.e. invalid)\n", + "\n", + " Args:\n", + " S: the stype to validate\n", + " \"\"\"\n", + "\n", + " if is_mixedmode(S): # mixed mode\n", + " raise ValueError(\n", + " \"Given SType is neither multimode or singlemode. Please check the port \"\n", + " \"names: they should either ALL contain the '@' separator (multimode) \"\n", + " \"or NONE should contain the '@' separator (singlemode).\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b47e00d4", + "metadata": {}, + "outputs": [], + "source": [ + "sdict = {(\"in0\", \"out0\"): 1.0, (\"out0\", \"in0\"): 1.0}\n", + "validate_not_mixedmode(sdict)\n", + "\n", + "sdict = {(\"in0@te\", \"out0@te\"): 1.0, (\"out0@tm\", \"in0@tm\"): 1.0}\n", + "validate_not_mixedmode(sdict)\n", + "\n", + "sdict = {(\"in0@te\", \"out0@te\"): 1.0, (\"out0\", \"in0@tm\"): 1.0}\n", + "with raises(ValueError):\n", + " validate_not_mixedmode(sdict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af966dfb", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def validate_multimode(S: SType, modes=(\"te\", \"tm\")) -> None:\n", + " \"\"\"validate that an stype is multimode and that the given modes are present.\"\"\"\n", + " try:\n", + " current_modes = set(p.split(\"@\")[1] for p in get_ports(S))\n", + " except IndexError:\n", + " raise ValueError(\"The given stype is not multimode.\")\n", + " for mode in modes:\n", + " if mode not in current_modes:\n", + " raise ValueError(\n", + " f\"Could not find mode '{mode}' in one of the multimode models.\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41fb0142", + "metadata": {}, + "outputs": [], + "source": [ + "sdict = {(\"in0\", \"out0\"): 1.0, (\"out0\", \"in0\"): 1.0}\n", + "with raises(ValueError):\n", + " validate_multimode(sdict)\n", + "\n", + "sdict = {(\"in0@te\", \"out0@te\"): 1.0, (\"out0@tm\", \"in0@tm\"): 1.0}\n", + "validate_multimode(sdict)\n", + "\n", + "sdict = {(\"in0@te\", \"out0@te\"): 1.0, (\"out0\", \"in0@tm\"): 1.0}\n", + "with raises(ValueError):\n", + " validate_multimode(sdict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f51db101", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def validate_sdict(sdict: Any) -> None:\n", + " \"\"\"Validate an `SDict`\"\"\"\n", + " \n", + " if not isinstance(sdict, dict):\n", + " raise ValueError(\"An SDict should be a dictionary.\")\n", + " for ports in sdict:\n", + " if not isinstance(ports, tuple) and not len(ports) == 2:\n", + " raise ValueError(f\"SDict keys should be length-2 tuples. Got {ports}\")\n", + " p1, p2 = ports\n", + " if not isinstance(p1, str) or not isinstance(p2, str):\n", + " raise ValueError(\n", + " f\"SDict ports should be strings. Got {ports} \"\n", + " f\"({type(ports[0])}, {type(ports[1])})\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe25c362", + "metadata": {}, + "outputs": [], + "source": [ + "good_sdict = reciprocal({(\"p0\", \"p1\"): 0.1, \n", + " (\"p1\", \"p2\"): 0.2})\n", + "assert validate_sdict(good_sdict) is None\n", + "\n", + "bad_sdict = {\n", + " \"p0,p1\": 0.1,\n", + " (\"p1\", \"p2\"): 0.2,\n", + "}\n", + "with raises(ValueError):\n", + " validate_sdict(bad_sdict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd24b099", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def get_inputs_outputs(ports: Tuple[str, ...]):\n", + " inputs = tuple(p for p in ports if p.lower().startswith(\"in\"))\n", + " outputs = tuple(p for p in ports if not p.lower().startswith(\"in\"))\n", + " if not inputs:\n", + " inputs = tuple(p for p in ports if not p.lower().startswith(\"out\"))\n", + " outputs = tuple(p for p in ports if p.lower().startswith(\"out\"))\n", + " return inputs, outputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9ba032d", + "metadata": {}, + "outputs": [], + "source": [ + "assert get_inputs_outputs([\"in0\", \"out0\"]) == (('in0',), ('out0',))\n", + "assert get_inputs_outputs([\"in0\", \"in1\"]) == (('in0', 'in1'), ())\n", + "assert get_inputs_outputs([\"out0\", \"out1\"]) == ((), ('out0', 'out1'))\n", + "assert get_inputs_outputs([\"out0\", \"dc0\"]) == (('dc0',), ('out0',))\n", + "assert get_inputs_outputs([\"dc0\", \"in0\"]) == (('in0',), ('dc0',))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/03_caching.ipynb b/nbs/03_caching.ipynb new file mode 100644 index 0000000..1315818 --- /dev/null +++ b/nbs/03_caching.ipynb @@ -0,0 +1,123 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "10028cb4", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp caching" + ] + }, + { + "cell_type": "markdown", + "id": "faeaa526", + "metadata": {}, + "source": [ + "# Caching\n", + "\n", + "> SAX Caching" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5c17f54", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a22a0cd6", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "import gc\n", + "from functools import _lru_cache_wrapper, lru_cache, partial, wraps\n", + "from typing import Callable, Optional" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a703105", + "metadata": {}, + "outputs": [], + "source": [ + "#exporti\n", + "\n", + "_cached_functions = []" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ab8a8d5", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def cache(func: Optional[Callable] = None, /, *, maxsize: Optional[int] = None) -> Callable:\n", + " \"\"\"cache a function\"\"\"\n", + " if func is None:\n", + " return partial(cache, maxsize=maxsize)\n", + "\n", + " cached_func = lru_cache(maxsize=maxsize)(func)\n", + "\n", + " @wraps(func)\n", + " def new_func(*args, **kwargs):\n", + " return cached_func(*args, **kwargs)\n", + "\n", + " new_func.cache_clear = cached_func.cache_clear\n", + "\n", + " _cached_functions.append(new_func)\n", + "\n", + " return new_func" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b81ee26", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def cache_clear(*, force: bool=False):\n", + " \"\"\"clear all function caches\"\"\"\n", + " if not force:\n", + " for func in _cached_functions:\n", + " func.cache_clear()\n", + " else:\n", + " gc.collect()\n", + " funcs = [a for a in gc.get_objects() if isinstance(a, _lru_cache_wrapper)]\n", + "\n", + " for func in funcs:\n", + " func.cache_clear()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/04_multimode.ipynb b/nbs/04_multimode.ipynb new file mode 100644 index 0000000..9d95b65 --- /dev/null +++ b/nbs/04_multimode.ipynb @@ -0,0 +1,398 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1a5b7bd8", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp multimode" + ] + }, + { + "cell_type": "markdown", + "id": "60e8c162", + "metadata": {}, + "source": [ + "# Multimode\n", + "\n", + "> SAX Multimode utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d733ad45", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1471ce2a", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from functools import wraps\n", + "from typing import Dict, Tuple, Union, cast, overload\n", + "\n", + "import jax.numpy as jnp\n", + "from sax.typing_ import (\n", + " Model,\n", + " SCoo,\n", + " SDense,\n", + " SDict,\n", + " SType,\n", + " is_model,\n", + " is_multimode,\n", + " is_scoo,\n", + " is_sdense,\n", + " is_sdict,\n", + " is_singlemode,\n", + ")\n", + "from sax.utils import (\n", + " block_diag,\n", + " mode_combinations,\n", + " validate_multimode,\n", + " validate_not_mixedmode,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ed2f214", + "metadata": {}, + "outputs": [], + "source": [ + "#exporti\n", + "\n", + "@overload\n", + "def multimode(S: Model, modes: Tuple[str, ...] = (\"te\", \"tm\")) -> Model:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def multimode(S: SDict, modes: Tuple[str, ...] = (\"te\", \"tm\")) -> SDict:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def multimode(S: SCoo, modes: Tuple[str, ...] = (\"te\", \"tm\")) -> SCoo:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def multimode(S: SDense, modes: Tuple[str, ...] = (\"te\", \"tm\")) -> SDense:\n", + " ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16b104cf", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def multimode(\n", + " S: Union[SType, Model], modes: Tuple[str, ...] = (\"te\", \"tm\")\n", + ") -> Union[SType, Model]:\n", + " \"\"\"Convert a single mode model to a multimode model\"\"\"\n", + " if is_model(S):\n", + " model = cast(Model, S)\n", + "\n", + " @wraps(model)\n", + " def new_model(**params):\n", + " return multimode(model(**params), modes=modes)\n", + "\n", + " return cast(Model, new_model)\n", + "\n", + " S = cast(SType, S)\n", + "\n", + " validate_not_mixedmode(S)\n", + " if is_multimode(S):\n", + " validate_multimode(S, modes=modes)\n", + " return S\n", + "\n", + " if is_sdict(S):\n", + " return _multimode_sdict(cast(SDict, S), modes=modes)\n", + " elif is_scoo(S):\n", + " return _multimode_scoo(cast(SCoo, S), modes=modes)\n", + " elif is_sdense(S):\n", + " return _multimode_sdense(cast(SDense, S), modes=modes)\n", + " else:\n", + " raise ValueError(\"cannot convert to multimode. Unknown stype.\")\n", + "\n", + "\n", + "def _multimode_sdict(sdict: SDict, modes: Tuple[str, ...] = (\"te\", \"tm\")) -> SDict:\n", + " multimode_sdict = {}\n", + " _mode_combinations = mode_combinations(modes)\n", + " for (p1, p2), value in sdict.items():\n", + " for (m1, m2) in _mode_combinations:\n", + " multimode_sdict[f\"{p1}@{m1}\", f\"{p2}@{m2}\"] = value\n", + " return multimode_sdict\n", + "\n", + "\n", + "def _multimode_scoo(scoo: SCoo, modes: Tuple[str, ...] = (\"te\", \"tm\")) -> SCoo:\n", + "\n", + " Si, Sj, Sx, port_map = scoo\n", + " num_ports = len(port_map)\n", + " mode_map = (\n", + " {mode: i for i, mode in enumerate(modes)}\n", + " if not isinstance(modes, dict)\n", + " else cast(Dict, modes)\n", + " )\n", + "\n", + " _mode_combinations = mode_combinations(modes)\n", + "\n", + " Si_m = jnp.concatenate(\n", + " [Si + mode_map[m] * num_ports for m, _ in _mode_combinations], -1\n", + " )\n", + " Sj_m = jnp.concatenate(\n", + " [Sj + mode_map[m] * num_ports for _, m in _mode_combinations], -1\n", + " )\n", + " Sx_m = jnp.concatenate([Sx for _ in _mode_combinations], -1)\n", + " port_map_m = {\n", + " f\"{port}@{mode}\": idx + mode_map[mode] * num_ports\n", + " for mode in modes\n", + " for port, idx in port_map.items()\n", + " }\n", + "\n", + " return Si_m, Sj_m, Sx_m, port_map_m\n", + "\n", + "\n", + "def _multimode_sdense(sdense, modes=(\"te\", \"tm\")):\n", + "\n", + " Sx, port_map = sdense\n", + " num_ports = len(port_map)\n", + " mode_map = (\n", + " {mode: i for i, mode in enumerate(modes)}\n", + " if not isinstance(modes, dict)\n", + " else modes\n", + " )\n", + "\n", + " Sx_m = block_diag(*(Sx for _ in modes))\n", + "\n", + " port_map_m = {\n", + " f\"{port}@{mode}\": idx + mode_map[mode] * num_ports\n", + " for mode in modes\n", + " for port, idx in port_map.items()\n", + " }\n", + "\n", + " return Sx_m, port_map_m" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4547870", + "metadata": {}, + "outputs": [], + "source": [ + "sdict_s = {(\"in0\", \"out0\"): 1.0}\n", + "sdict_m = multimode(sdict_s)\n", + "assert sdict_m == {(\"in0@te\", \"out0@te\"): 1.0, (\"in0@tm\", \"out0@tm\"): 1.0}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95dab5a3", + "metadata": {}, + "outputs": [], + "source": [ + "from sax.typing_ import scoo\n", + "scoo_s = scoo(sdict_s)\n", + "scoo_m = multimode(scoo_s)\n", + "test_eq(scoo_m[0], jnp.array([0, 2], dtype=int))\n", + "test_eq(scoo_m[1], jnp.array([1, 3], dtype=int))\n", + "test_eq(scoo_m[2], jnp.array([1.0, 1.0], dtype=float))\n", + "test_eq(scoo_m[3], {\"in0@te\": 0, \"out0@te\": 1, \"in0@tm\": 2, \"out0@tm\": 3})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e5474e1", + "metadata": {}, + "outputs": [], + "source": [ + "from sax.typing_ import sdense\n", + "sdense_s = sdense(sdict_s)\n", + "sdense_m = multimode(sdense_s)\n", + "test_eq(\n", + " sdense_m[0],\n", + " [[0.0 + 0.0j, 1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n", + " [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n", + " [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 1.0 + 0.0j],\n", + " [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]],\n", + ")\n", + "test_eq(sdense_m[1], {\"in0@te\": 0, \"out0@te\": 1, \"in0@tm\": 2, \"out0@tm\": 3})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4ec42c7", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "@overload\n", + "def singlemode(S: Model, mode: str = \"te\") -> Model:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def singlemode(S: SDict, mode: str = \"te\") -> SDict:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def singlemode(S: SCoo, mode: str = \"te\") -> SCoo:\n", + " ...\n", + "\n", + "\n", + "@overload\n", + "def singlemode(S: SDense, mode: str = \"te\") -> SDense:\n", + " ..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "960cfc31", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def singlemode(S: Union[SType, Model], mode: str = \"te\") -> Union[SType, Model]:\n", + " \"\"\"Convert multimode model to a singlemode model\"\"\"\n", + " if is_model(S):\n", + " model = cast(Model, S)\n", + "\n", + " @wraps(model)\n", + " def new_model(**params):\n", + " return singlemode(model(**params), mode=mode)\n", + "\n", + " return cast(Model, new_model)\n", + "\n", + " S = cast(SType, S)\n", + "\n", + " validate_not_mixedmode(S)\n", + " if is_singlemode(S):\n", + " return S\n", + " if is_sdict(S):\n", + " return _singlemode_sdict(cast(SDict, S), mode=mode)\n", + " elif is_scoo(S):\n", + " return _singlemode_scoo(cast(SCoo, S), mode=mode)\n", + " elif is_sdense(S):\n", + " return _singlemode_sdense(cast(SDense, S), mode=mode)\n", + " else:\n", + " raise ValueError(\"cannot convert to multimode. Unknown stype.\")\n", + "\n", + "\n", + "def _singlemode_sdict(sdict: SDict, mode: str = \"te\") -> SDict:\n", + " singlemode_sdict = {}\n", + " for (p1, p2), value in sdict.items():\n", + " if p1.endswith(f\"@{mode}\") and p2.endswith(f\"@{mode}\"):\n", + " p1, _ = p1.split(\"@\")\n", + " p2, _ = p2.split(\"@\")\n", + " singlemode_sdict[p1, p2] = value\n", + " return singlemode_sdict\n", + "\n", + "\n", + "def _singlemode_scoo(scoo: SCoo, mode: str = \"te\") -> SCoo:\n", + " Si, Sj, Sx, port_map = scoo\n", + " # no need to touch the data...\n", + " # just removing some ports from the port map should be enough\n", + " port_map = {\n", + " port.split(\"@\")[0]: idx\n", + " for port, idx in port_map.items()\n", + " if port.endswith(f\"@{mode}\")\n", + " }\n", + " return Si, Sj, Sx, port_map\n", + "\n", + "\n", + "def _singlemode_sdense(sdense: SDense, mode: str = \"te\") -> SDense:\n", + " Sx, port_map = sdense\n", + " # no need to touch the data...\n", + " # just removing some ports from the port map should be enough\n", + " port_map = {\n", + " port.split(\"@\")[0]: idx\n", + " for port, idx in port_map.items()\n", + " if port.endswith(f\"@{mode}\")\n", + " }\n", + " return Sx, port_map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2215f7c9", + "metadata": {}, + "outputs": [], + "source": [ + "sdict_s = singlemode(sdict_m)\n", + "assert sdict_s == {(\"in0\", \"out0\"): 1.0}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f57f37b", + "metadata": {}, + "outputs": [], + "source": [ + "scoo_s = singlemode(scoo_s)\n", + "test_eq(scoo_s[0], jnp.array([0], dtype=int))\n", + "test_eq(scoo_s[1], jnp.array([1], dtype=int))\n", + "test_eq(scoo_s[2], jnp.array([1.0], dtype=float))\n", + "test_eq(scoo_s[3], {'in0': 0, 'out0': 1})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94bc44fb", + "metadata": {}, + "outputs": [], + "source": [ + "sdense_s = singlemode(sdense_m)\n", + "test_eq(\n", + " sdense_s[0],\n", + " [[0.0 + 0.0j, 1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n", + " [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n", + " [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 1.0 + 0.0j],\n", + " [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]],\n", + ")\n", + "test_eq(sdense_s[1], {'in0': 0, 'out0': 1})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/05_models.ipynb b/nbs/05_models.ipynb new file mode 100644 index 0000000..3bafac7 --- /dev/null +++ b/nbs/05_models.ipynb @@ -0,0 +1,672 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "8c16afd7", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp models" + ] + }, + { + "cell_type": "markdown", + "id": "986cb864", + "metadata": {}, + "source": [ + "# Models\n", + "\n", + "> Default SAX Models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00e62ae1", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b9be8b2", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from typing import Optional, Tuple\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import sax\n", + "from sax.typing_ import Model, SCoo, SDict\n", + "from sax.utils import get_inputs_outputs, reciprocal" + ] + }, + { + "cell_type": "markdown", + "id": "efa9b619", + "metadata": {}, + "source": [ + "## Standard Models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29f9ef13", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def straight(\n", + " *,\n", + " wl: float = 1.55,\n", + " wl0: float = 1.55,\n", + " neff: float = 2.34,\n", + " ng: float = 3.4,\n", + " length: float = 10.0,\n", + " loss: float = 0.0\n", + ") -> SDict:\n", + " \"\"\"a simple straight waveguide model\"\"\"\n", + " dwl = wl - wl0\n", + " dneff_dwl = (ng - neff) / wl0\n", + " neff = neff - dwl * dneff_dwl\n", + " phase = 2 * jnp.pi * neff * length / wl\n", + " amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)\n", + " transmission = amplitude * jnp.exp(1j * phase)\n", + " sdict = reciprocal(\n", + " {\n", + " (\"in0\", \"out0\"): transmission,\n", + " }\n", + " )\n", + " return sdict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a99cec57", + "metadata": {}, + "outputs": [], + "source": [ + "straight()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce350ff4", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def coupler(*, coupling: float = 0.5) -> SDict:\n", + " \"\"\"a simple coupler model\"\"\"\n", + " kappa = coupling ** 0.5\n", + " tau = (1 - coupling) ** 0.5\n", + " sdict = reciprocal(\n", + " {\n", + " (\"in0\", \"out0\"): tau,\n", + " (\"in0\", \"out1\"): 1j * kappa,\n", + " (\"in1\", \"out0\"): 1j * kappa,\n", + " (\"in1\", \"out1\"): tau,\n", + " }\n", + " )\n", + " return sdict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fd87a22", + "metadata": {}, + "outputs": [], + "source": [ + "coupler()" + ] + }, + { + "cell_type": "markdown", + "id": "ca06c12f", + "metadata": {}, + "source": [ + "## Model Factories" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7f79da4-04c9-4284-9ef9-88938c05707d", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "def _validate_ports(ports, num_inputs, num_outputs, diagonal) -> Tuple[Tuple[str,...], Tuple[str,...], int, int]:\n", + " if ports is None:\n", + " if num_inputs is None or num_outputs is None:\n", + " raise ValueError(\n", + " \"if not ports given, you must specify how many input ports \"\n", + " \"and how many output ports a model has.\"\n", + " )\n", + " input_ports = [f\"in{i}\" for i in range(num_inputs)]\n", + " output_ports = [f\"out{i}\" for i in range(num_outputs)]\n", + " else:\n", + " if num_inputs is not None:\n", + " if num_outputs is None:\n", + " raise ValueError(\n", + " \"if num_inputs is given, num_outputs should be given as well.\"\n", + " )\n", + " if num_outputs is not None:\n", + " if num_inputs is None:\n", + " raise ValueError(\n", + " \"if num_outputs is given, num_inputs should be given as well.\"\n", + " )\n", + " if num_inputs is not None and num_outputs is not None:\n", + " if num_inputs + num_outputs != len(ports):\n", + " raise ValueError(\"num_inputs + num_outputs != len(ports)\")\n", + " input_ports = ports[:num_inputs]\n", + " output_ports = ports[num_inputs:]\n", + " else:\n", + " input_ports, output_ports = get_inputs_outputs(ports)\n", + " num_inputs = len(input_ports)\n", + " num_outputs = len(output_ports)\n", + " \n", + " if diagonal:\n", + " if num_inputs != num_outputs:\n", + " raise ValueError(\n", + " \"Can only have a diagonal passthru if number of input ports equals the number of output ports!\"\n", + " )\n", + " return input_ports, output_ports, num_inputs, num_outputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90193d56", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "@sax.cache\n", + "def unitary(\n", + " num_inputs: Optional[int] = None,\n", + " num_outputs: Optional[int] = None,\n", + " ports: Optional[Tuple[str, ...]] = None,\n", + " *,\n", + " jit=True,\n", + " reciprocal=True,\n", + " diagonal=False,\n", + ") -> Model:\n", + " input_ports, output_ports, num_inputs, num_outputs = _validate_ports(ports, num_inputs, num_outputs, diagonal)\n", + " assert num_inputs is not None and num_outputs is not None\n", + " \n", + " # let's create the squared S-matrix:\n", + " N = max(num_inputs, num_outputs)\n", + " S = jnp.zeros((2*N, 2*N), dtype=float)\n", + "\n", + " if not diagonal:\n", + " S = S.at[:N, N:].set(1)\n", + " else:\n", + " r = jnp.arange(N, dtype=int) # reciprocal only works if num_inputs == num_outputs!\n", + " S = S.at[r, N+r].set(1)\n", + "\n", + " if reciprocal:\n", + " if not diagonal:\n", + " S = S.at[N:, :N].set(1)\n", + " else:\n", + " r = jnp.arange(N, dtype=int) # reciprocal only works if num_inputs == num_outputs!\n", + " S = S.at[N+r, r].set(1)\n", + "\n", + " # Now we need to normalize the squared S-matrix\n", + " U, s, V = jnp.linalg.svd(S, full_matrices=False)\n", + " S = jnp.sqrt(U@jnp.diag(jnp.where(s > 1e-12, 1, 0))@V)\n", + " \n", + " # Now create subset of this matrix we're interested in:\n", + " r = jnp.concatenate([jnp.arange(num_inputs, dtype=int), N+jnp.arange(num_outputs, dtype=int)], 0)\n", + " S = S[r, :][:, r]\n", + "\n", + " # let's convert it in SCOO format:\n", + " Si, Sj = jnp.where(S > 1e-6)\n", + " Sx = S[Si, Sj]\n", + " \n", + " # the last missing piece is a port map:\n", + " pm = {\n", + " **{p: i for i, p in enumerate(input_ports)},\n", + " **{p: i + num_inputs for i, p in enumerate(output_ports)},\n", + " }\n", + " \n", + " def func(wl: float = 1.5) -> SCoo:\n", + " wl_ = jnp.asarray(wl)\n", + " Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape))\n", + " return Si, Sj, Sx_, pm\n", + "\n", + " func.__name__ = f\"unitary_{num_inputs}_{num_outputs}\"\n", + " func.__qualname__ = f\"unitary_{num_inputs}_{num_outputs}\"\n", + " if jit:\n", + " return jax.jit(func)\n", + " return func" + ] + }, + { + "cell_type": "markdown", + "id": "327febf2-70d5-48ed-8b39-bac2d9727db5", + "metadata": {}, + "source": [ + "A unitary model returns an `SCoo` by default:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e95bacdd", + "metadata": {}, + "outputs": [], + "source": [ + "unitary_model = unitary(2, 2)\n", + "unitary_model() # a unitary model returns an SCoo by default" + ] + }, + { + "cell_type": "markdown", + "id": "01ff2567-9b0e-4686-acb4-47cfb11b06ce", + "metadata": {}, + "source": [ + "As you probably already know, it's very easy to convert a model returning any `Stype` into a model returning an `SDict` as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "260dda8c-7e44-4f3f-b96c-e742daaf2885", + "metadata": {}, + "outputs": [], + "source": [ + "unitary_sdict_model = sax.sdict(unitary_model)\n", + "unitary_sdict_model()" + ] + }, + { + "cell_type": "markdown", + "id": "e9a1a46b-b10a-471b-b835-7b6d8d49175a", + "metadata": {}, + "source": [ + "If we need custom port names, we can also just specify them explicitly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66abe9b0", + "metadata": {}, + "outputs": [], + "source": [ + "unitary_model = unitary(ports=(\"in0\", \"in1\", \"out0\", \"out1\"))\n", + "unitary_model()" + ] + }, + { + "cell_type": "markdown", + "id": "ebf14c1c-ccfb-44e6-8572-18b366ffab91", + "metadata": {}, + "source": [ + "A unitary model will by default split a signal at an input port equally over all output ports. However, if there are an equal number of input ports as output ports we can in stead create a passthru by setting the `diagonal` flag to `True`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c1d61a3", + "metadata": {}, + "outputs": [], + "source": [ + "passthru_model = unitary(2, 2, diagonal=True)\n", + "sax.sdict(passthru_model())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f93b145", + "metadata": {}, + "outputs": [], + "source": [ + "ports_in=['in0']\n", + "ports_out=['out0', 'out1', 'out2', 'out3', 'out4']\n", + "model = unitary(\n", + " ports=tuple(ports_in+ports_out), jit=True, reciprocal=True\n", + ")\n", + "model = sax.sdict(model)\n", + "model()" + ] + }, + { + "cell_type": "markdown", + "id": "8b149d44-5487-4e50-a2f2-1ad0aca99dac", + "metadata": {}, + "source": [ + "Because this is a pretty common usecase we have a dedicated model factory for this as well. This passthru component just takes the number of links (`'in{i}' -> 'out{i]'`) as input. Alternatively, as before, one can also specify the port names directly but one needs to ensure that `len(ports) == 2*num_links`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4e9f36e-bdfb-419b-9e8d-786e4bf773ca", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "@sax.cache\n", + "def passthru(\n", + " num_links: Optional[int] = None,\n", + " ports: Optional[Tuple[str, ...]] = None,\n", + " *,\n", + " jit=True,\n", + " reciprocal=True,\n", + ") -> Model:\n", + " passthru = unitary(num_links, num_links, ports, jit=jit, reciprocal=reciprocal, diagonal=True)\n", + " passthru.__name__ = f\"passthru_{num_links}_{num_links}\"\n", + " passthru.__qualname__ = f\"passthru_{num_links}_{num_links}\"\n", + " if jit:\n", + " return jax.jit(passthru)\n", + " return passthru" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e8b6cd7-56d4-4697-9a6e-33c929f3d853", + "metadata": {}, + "outputs": [], + "source": [ + "passthru_model = passthru(3)\n", + "passthru_sdict_model = sax.sdict(passthru_model)\n", + "passthru_sdict_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "690e718c-2d84-4177-b8ad-ff39cf8c4691", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = sax.circuit(\n", + " instances={\n", + " \"lft\": unitary(1, 2),\n", + " \"top\": unitary(1, 1),\n", + " \"rgt\": unitary(1, 2),\n", + " },\n", + " connections={\n", + " \"lft,out0\": \"rgt,out0\",\n", + " \"lft,out1\": \"top,in0\",\n", + " \"top,out0\": \"rgt,out1\",\n", + " },\n", + " ports = {\n", + " \"in0\": \"lft,in0\",\n", + " \"out0\": \"rgt,in0\",\n", + " }\n", + ")\n", + "mzi()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c68a8993-53e8-4b7c-ac34-62383bae4e2f", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "@sax.cache\n", + "def copier(\n", + " num_inputs: Optional[int] = None,\n", + " num_outputs: Optional[int] = None,\n", + " ports: Optional[Tuple[str, ...]] = None,\n", + " *,\n", + " jit=True,\n", + " reciprocal=True,\n", + " diagonal=False,\n", + ") -> Model:\n", + " input_ports, output_ports, num_inputs, num_outputs = _validate_ports(ports, num_inputs, num_outputs, diagonal)\n", + " assert num_inputs is not None and num_outputs is not None\n", + " \n", + " # let's create the squared S-matrix:\n", + " S = jnp.zeros((num_inputs+num_outputs, num_inputs+num_outputs), dtype=float)\n", + "\n", + " if not diagonal:\n", + " S = S.at[:num_inputs, num_inputs:].set(1)\n", + " else:\n", + " r = jnp.arange(num_inputs, dtype=int) # == range(num_outputs) # reciprocal only works if num_inputs == num_outputs!\n", + " S = S.at[r, num_inputs+r].set(1)\n", + "\n", + " if reciprocal:\n", + " if not diagonal:\n", + " S = S.at[num_inputs:, :num_inputs].set(1)\n", + " else:\n", + " r = jnp.arange(num_inputs, dtype=int) # == range(num_outputs) # reciprocal only works if num_inputs == num_outputs!\n", + " S = S.at[num_inputs+r, r].set(1)\n", + "\n", + " # let's convert it in SCOO format:\n", + " Si, Sj = jnp.where(S > 1e-6)\n", + " Sx = S[Si, Sj]\n", + " \n", + " # the last missing piece is a port map:\n", + " pm = {\n", + " **{p: i for i, p in enumerate(input_ports)},\n", + " **{p: i + num_inputs for i, p in enumerate(output_ports)},\n", + " }\n", + " \n", + " def func(wl: float = 1.5) -> SCoo:\n", + " wl_ = jnp.asarray(wl)\n", + " Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape))\n", + " return Si, Sj, Sx_, pm\n", + "\n", + " func.__name__ = f\"unitary_{num_inputs}_{num_outputs}\"\n", + " func.__qualname__ = f\"unitary_{num_inputs}_{num_outputs}\"\n", + " if jit:\n", + " return jax.jit(func)\n", + " return func" + ] + }, + { + "cell_type": "markdown", + "id": "77014d0f-1a08-4b4b-a011-cf11cd9684b8", + "metadata": {}, + "source": [ + "A copier model is like a unitary model, but copies the input signal over all output signals. Hence, if the model has multiple output ports, this model can be considered to introduce gain. That said, it can sometimes be a useful component." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2a61a9f-51ea-455f-a66f-94719b43520e", + "metadata": {}, + "outputs": [], + "source": [ + "copier_model = copier(2, 2)\n", + "copier_model() # a copier model returns an SCoo by default" + ] + }, + { + "cell_type": "markdown", + "id": "fbd51c6c-c059-4717-9dae-447becc7e8d5", + "metadata": {}, + "source": [ + "As you probably already know, it's very easy to convert a model returning any `Stype` into a model returning an `SDict` as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "becda099-b5d9-494c-a6b2-9828be10c4a5", + "metadata": {}, + "outputs": [], + "source": [ + "copier_sdict_model = sax.sdict(copier_model)\n", + "copier_sdict_model()" + ] + }, + { + "cell_type": "markdown", + "id": "1ad16d1b-f685-4ab9-a8c4-41efbbdae4fa", + "metadata": {}, + "source": [ + "If we need custom port names, we can also just specify them explicitly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "297d1e45-b881-440b-b887-b38c5517596e", + "metadata": {}, + "outputs": [], + "source": [ + "copier_model = copier(ports=(\"in0\", \"in1\", \"out0\", \"out1\"))\n", + "copier_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "608d7114-03ac-4f5f-9790-cf65379d81dd", + "metadata": {}, + "outputs": [], + "source": [ + "ports_in=['in0']\n", + "ports_out=['out0', 'out1', 'out2', 'out3', 'out4']\n", + "model = unitary(\n", + " ports=tuple(ports_in+ports_out), jit=True, reciprocal=True\n", + ")\n", + "model = sax.sdict(model)\n", + "model()" + ] + }, + { + "cell_type": "markdown", + "id": "6d024e5b-95a9-44e5-9dbc-4980249e73a7", + "metadata": {}, + "source": [ + "Because this is a pretty common usecase we have a dedicated model factory for this as well. This passthru component just takes the number of links (`'in{i}' -> 'out{i]'`) as input. Alternatively, as before, one can also specify the port names directly but one needs to ensure that `len(ports) == 2*num_links`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa72677c-b346-43a1-85b9-8171fa3e6cef", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "@sax.cache\n", + "def passthru(\n", + " num_links: Optional[int] = None,\n", + " ports: Optional[Tuple[str, ...]] = None,\n", + " *,\n", + " jit=True,\n", + " reciprocal=True,\n", + ") -> Model:\n", + " passthru = unitary(num_links, num_links, ports, jit=jit, reciprocal=reciprocal, diagonal=True)\n", + " passthru.__name__ = f\"passthru_{num_links}_{num_links}\"\n", + " passthru.__qualname__ = f\"passthru_{num_links}_{num_links}\"\n", + " if jit:\n", + " return jax.jit(passthru)\n", + " return passthru" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a063810-224b-446c-a59d-12a20b56e82e", + "metadata": {}, + "outputs": [], + "source": [ + "passthru_model = passthru(3)\n", + "passthru_sdict_model = sax.sdict(passthru_model)\n", + "passthru_sdict_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b2be9a8-1b06-4b60-ac44-853d49b3e9f6", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = sax.circuit(\n", + " instances={\n", + " \"lft\": unitary(1, 2),\n", + " \"top\": unitary(1, 1),\n", + " \"rgt\": unitary(1, 2),\n", + " },\n", + " connections={\n", + " \"lft,out0\": \"rgt,out0\",\n", + " \"lft,out1\": \"top,in0\",\n", + " \"top,out0\": \"rgt,out1\",\n", + " },\n", + " ports = {\n", + " \"in0\": \"lft,in0\",\n", + " \"out0\": \"rgt,in0\",\n", + " }\n", + ")\n", + "mzi()" + ] + }, + { + "cell_type": "markdown", + "id": "bb6e94aa", + "metadata": {}, + "source": [ + "## All Models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ddef1cde", + "metadata": {}, + "outputs": [], + "source": [ + "#exports\n", + "\n", + "models = {\n", + " \"copier\": copier,\n", + " \"coupler\": coupler,\n", + " \"passthru\": passthru,\n", + " \"straight\": straight,\n", + " \"unitary\": unitary,\n", + "}\n", + "\n", + "def get_models(copy: bool=True):\n", + " if copy:\n", + " return {**models}\n", + " return models" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/06_netlist.ipynb b/nbs/06_netlist.ipynb new file mode 100644 index 0000000..a06d144 --- /dev/null +++ b/nbs/06_netlist.ipynb @@ -0,0 +1,771 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3150cab0", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp netlist" + ] + }, + { + "cell_type": "markdown", + "id": "0bec8f0c", + "metadata": {}, + "source": [ + "# Netlist\n", + "\n", + "> SAX Netlist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24ebf73c", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c0927c5", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "import os\n", + "import re\n", + "from functools import partial\n", + "from typing import Callable, Dict, Iterable, Optional, Tuple, Union, cast\n", + "\n", + "from flax.core import FrozenDict\n", + "from natsort import natsorted\n", + "from sax.models import models as default_sax_models\n", + "from sax.typing_ import (\n", + " ComplexFloat,\n", + " Instance,\n", + " Instances,\n", + " LogicalNetlist,\n", + " Model,\n", + " ModelFactory,\n", + " Models,\n", + " Netlist,\n", + " Settings,\n", + " is_instance,\n", + " is_model_factory,\n", + " is_netlist,\n", + ")\n", + "from sax.utils import (\n", + " clean_string,\n", + " copy_settings,\n", + " get_settings,\n", + " merge_dicts,\n", + " rename_params,\n", + " rename_ports,\n", + " try_float,\n", + ")\n", + "from yaml import Loader\n", + "from yaml import load as load_yaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4cae003", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "def _clean_component_names(models: Optional[Models]) -> Models:\n", + " if models is None:\n", + " models = {}\n", + " return {clean_string(comp): model for comp, model in models.items()}\n", + "\n", + "\n", + "def _clean_instance_names(instances: Optional[Instances]) -> Instances:\n", + " _instances = {}\n", + " if instances is None:\n", + " instances = {}\n", + " for name, inst in instances.items():\n", + " if \",\" in name:\n", + " raise ValueError(\n", + " f\"Instance name '{name}' is invalid. It contains the port separator ','.\"\n", + " )\n", + " if \":\" in name:\n", + " raise ValueError(\n", + " f\"Instance name '{name}' is invalid. It contains the port slice symbol ':'.\"\n", + " )\n", + " name = clean_string(name)\n", + " _instances[name] = inst\n", + " return _instances\n", + "\n", + "\n", + "def _component_from_callable(f: Callable, models: Models) -> Tuple[str, Models]:\n", + " _reverse_models = {id(model): name for name, model in models.items()}\n", + " if id(f) in _reverse_models:\n", + " component = _reverse_models[id(f)]\n", + " else:\n", + " component = _funcname(f)\n", + " models[component] = f\n", + " return component, models\n", + "\n", + "\n", + "def _funcname(p: Callable) -> str:\n", + " name = \"\"\n", + " f: Callable = p\n", + " while isinstance(f, partial):\n", + " name = \"{name}p_\"\n", + " if f.args:\n", + " try:\n", + " name = f\"{name[:-1]}{hash(f.args)}_\"\n", + " except TypeError:\n", + " raise TypeError(\n", + " \"when using partials as SAX models, positional arguments of the partial should be hashable.\"\n", + " )\n", + " f = f.func\n", + " return f\"{name}{f.__name__}_{id(f)}\"\n", + "\n", + "\n", + "def _instance_from_callable(f: Callable, models: Models) -> Tuple[Instance, Models]:\n", + " f, settings = _maybe_parse_partial(f)\n", + " component, models = _component_from_callable(f, models)\n", + " instance = Instance(component=component, settings=settings)\n", + " return instance, models\n", + "\n", + "\n", + "def _instance_from_instance(\n", + " name: str,\n", + " instance: Instance,\n", + " models: Models,\n", + " override_settings: Settings,\n", + " global_settings: Dict[str, ComplexFloat],\n", + " default_models=None,\n", + ") -> Tuple[Instance, Models]:\n", + " default_models = default_sax_models if default_models is None else default_models\n", + " component = clean_string(instance[\"component\"])\n", + " if component not in models:\n", + " if component not in default_models:\n", + " raise ValueError(\n", + " f\"Error constructing netlist. Component '{component}' not found.\"\n", + " )\n", + " model = default_models[component]\n", + " else:\n", + " model = models[component]\n", + " if isinstance(model, str):\n", + " if model not in default_models:\n", + " raise ValueError(\n", + " f\"Error constructing netlist. Component '{model}' not found.\"\n", + " )\n", + " model = default_models[model]\n", + " if not callable(model): # Model or ModelFactory\n", + " raise ValueError(\n", + " f\"Error constructing netlist. Model for component '{component}' is not callable.\"\n", + " )\n", + " # fmt: off\n", + " _default_settings = get_settings(model)\n", + " _instance_settings = {k: v for k, v in instance.get('settings', {}).items() if k in _default_settings}\n", + " _override_settings = cast(Dict[str, ComplexFloat], override_settings.get(name, {}))\n", + " _override_settings = {k: v for k, v in _override_settings.items() if k in _default_settings}\n", + " _global_settings = {k: v for k, v in global_settings.items() if k in _default_settings}\n", + " settings = merge_dicts(_default_settings, _instance_settings, _override_settings, _global_settings)\n", + " # fmt: on\n", + "\n", + " instance = Instance(component=component, settings=settings)\n", + " models[instance[\"component\"]] = model\n", + " return instance, models\n", + "\n", + "\n", + "def _instance_from_string(\n", + " s: str, models: Models, default_models=None\n", + ") -> Tuple[Instance, Models]:\n", + " default_models = default_sax_models if default_models is None else default_models\n", + " if s not in models:\n", + " if s not in default_models:\n", + " raise ValueError(f\"Error constructing netlist. Component '{s}' not found.\")\n", + " models[s] = default_models[s]\n", + " instance = Instance(component=s, settings={})\n", + " return instance, models\n", + "\n", + "\n", + "def _maybe_parse_partial(p: Callable) -> Tuple[Callable, Dict[str, ComplexFloat]]:\n", + " _settings = {}\n", + " while isinstance(p, partial):\n", + " _settings = merge_dicts(_settings, p.keywords)\n", + " if not p.args:\n", + " p = p.func\n", + " else:\n", + " p = partial(p.func, *p.args)\n", + " break\n", + " return p, _settings\n", + "\n", + "\n", + "def _model_operations(\n", + " models: Optional[Models],\n", + " ops: Optional[Dict[str, Callable]] = None,\n", + " default_models=None,\n", + ") -> Models:\n", + " default_models = default_sax_models if default_models is None else default_models\n", + " if ops is None:\n", + " ops = {}\n", + " if models is None:\n", + " models = {}\n", + "\n", + " _models = {}\n", + " for component, model in models.items():\n", + " if isinstance(model, str):\n", + " if not model in default_models:\n", + " raise ValueError(f\"Could not find model {model}.\")\n", + " model = default_models[model]\n", + "\n", + " _models[component] = model\n", + "\n", + " if not isinstance(model, dict):\n", + " continue\n", + "\n", + " if is_netlist(model):\n", + " continue # TODO: This case should actually be handled...\n", + "\n", + " model = {**model}\n", + "\n", + " if \"model\" not in model:\n", + " raise ValueError(\n", + " \"Invalid model dict for '{component}'. Key 'model' not found.\"\n", + " )\n", + "\n", + " if isinstance(model[\"model\"], str):\n", + " if not model[\"model\"] in default_models:\n", + " raise ValueError(f\"Could not find model {model['model']}.\")\n", + " model[\"model\"] = default_models[cast(str, model[\"model\"])]\n", + "\n", + " for op_name, op_func in ops.items():\n", + " assert isinstance(model, dict)\n", + " if op_name not in model:\n", + " continue\n", + " op_args = model[op_name]\n", + " model[\"model\"] = op_func(model[\"model\"], op_args)\n", + "\n", + " _models[component] = model[\"model\"]\n", + " return _models\n", + "\n", + "\n", + "def _split_global_settings(\n", + " settings: Optional[dict], instance_names: Iterable[str]\n", + ") -> Tuple[Settings, Dict[str, ComplexFloat]]:\n", + " if settings:\n", + " override_settings = cast(Dict[str, Settings], copy_settings(settings))\n", + " global_settings: Dict[str, ComplexFloat] = {}\n", + " for k in list(override_settings.keys()):\n", + " if k in instance_names:\n", + " continue\n", + " global_settings[k] = cast(ComplexFloat, try_float(override_settings.pop(k)))\n", + " else:\n", + " override_settings: Dict[str, Settings] = {}\n", + " global_settings: Dict[str, ComplexFloat] = {}\n", + " return override_settings, global_settings\n", + "\n", + "\n", + "def _enumerate_portrange(s):\n", + " if not \":\" in s:\n", + " return [s]\n", + " idx1, idx2 = s.split(\":\")\n", + " s1 = re.sub(\"[0-9]*\", \"\", idx1)\n", + " idx1 = int(re.sub(\"[^0-9]*\", \"\", idx1))\n", + " s2 = re.sub(\"[0-9]*\", \"\", idx2)\n", + " idx2 = int(re.sub(\"[^0-9]*\", \"\", idx2))\n", + " if s1 != s2 and s2 != \"\":\n", + " raise ValueError(\n", + " \"Cannot enumerate portrange {s}, string portion of port differs.\"\n", + " )\n", + " return [f\"{s1}{i}\" for i in range(idx1, idx2)]\n", + "\n", + "\n", + "def _validate_connections(connections):\n", + " # todo: check if instance names are available in instances\n", + " # todo: check if instance ports are used in output ports\n", + " _ports = set()\n", + " old_connections, connections = connections, {}\n", + " for conn1, conn2 in old_connections.items():\n", + " if conn1.count(\",\") != 1 or conn2.count(\",\") != 1:\n", + " raise ValueError(\n", + " \"Connections ports should have format '{instance_name},{port}'. \"\n", + " f\"Got '{conn1}'.\"\n", + " )\n", + " name1, port1 = conn1.split(\",\")\n", + " name2, port2 = conn2.split(\",\")\n", + " ports1 = _enumerate_portrange(port1)\n", + " ports2 = _enumerate_portrange(port2)\n", + "\n", + " if len(ports1) != len(ports2):\n", + " if len(ports1) == 1:\n", + " ports1 = [ports1[0] for _ in ports2]\n", + " elif len(ports2) == 1:\n", + " ports2 = [ports2[0] for _ in ports1]\n", + " else:\n", + " raise ValueError(\n", + " f\"Cannot enumerate connection {conn1} -> {conn2}, slice lengths on both sides differ.\"\n", + " )\n", + "\n", + " for port1, port2 in zip(ports1, ports2):\n", + " name1 = clean_string(name1)\n", + " name2 = clean_string(name2)\n", + " port1 = clean_string(port1)\n", + " port2 = clean_string(port2)\n", + " (name1, port1), (name2, port2) = natsorted([(name1, port1), (name2, port2)])\n", + " conn1 = f\"{name1},{port1}\"\n", + " conn2 = f\"{name2},{port2}\"\n", + " if conn1 in _ports:\n", + " raise ValueError(f\"duplicate connection port: '{conn1}'\")\n", + " if conn2 in _ports:\n", + " raise ValueError(f\"duplicate connection port: '{conn2}'\")\n", + " connections[conn1] = conn2\n", + " _ports.add(conn1)\n", + " _ports.add(conn2)\n", + "\n", + " return dict(natsorted([natsorted([k, v]) for k, v in connections.items()]))\n", + "\n", + "\n", + "def _validate_ports(ports):\n", + " # todo: check if instance names are available in instances\n", + " # todo: check if instance ports are used in connections\n", + " _ports = set()\n", + " old_ports, ports = ports, {}\n", + " for port1, conn2 in old_ports.items():\n", + "\n", + " if port1.count(\",\") == 1:\n", + " if conn2.count(\",\") != 0:\n", + " raise ValueError(\n", + " \"Netlist output port '{conn2}' should not contain any ','.\"\n", + " )\n", + " port1, conn2 = conn2, port1\n", + " elif conn2.count(\",\") == 1:\n", + " if port1.count(\",\") != 0:\n", + " raise ValueError(\n", + " \"Netlist output port '{port1}' should not contain any ','.\"\n", + " )\n", + " else:\n", + " raise ValueError(\n", + " \"Netlist output ports should be mapped onto an instance port \"\n", + " \"using the format: '{output_port}': '{instance_name},{port}'. \"\n", + " f\"Got: '{port1}': '{conn2}.'\"\n", + " )\n", + "\n", + " name2, port2 = conn2.split(\",\")\n", + " ports1 = _enumerate_portrange(port1)\n", + " ports2 = _enumerate_portrange(port2)\n", + "\n", + " if len(ports1) != len(ports2):\n", + " if len(ports1) == 1:\n", + " ports1 = [ports1[0] for _ in ports2]\n", + " elif len(ports2) == 1:\n", + " ports2 = [ports2[0] for _ in ports1]\n", + " else:\n", + " raise ValueError(\n", + " f\"Cannot enumerate output ports {port1} -> {conn2}, slice lengths on both sides differ.\"\n", + " )\n", + "\n", + " for port1, port2 in zip(ports1, ports2):\n", + " port1 = clean_string(port1) # output_port\n", + " name2 = clean_string(name2)\n", + " port2 = clean_string(port2)\n", + " conn2 = f\"{name2},{port2}\"\n", + " if port1 in _ports:\n", + " raise ValueError(f\"duplicate output port: '{port1}'\")\n", + " if conn2 in _ports:\n", + " raise ValueError(f\"duplicate instance output port: '{conn2}'\")\n", + " ports[port1] = conn2\n", + " _ports.add(port1)\n", + " _ports.add(conn2)\n", + " return dict(natsorted(ports.items()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cbe7c50", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def netlist(\n", + " *,\n", + " instances: Instances,\n", + " connections: Dict[str, str],\n", + " ports: Dict[str, str],\n", + " models: Optional[Models] = None,\n", + " settings: Optional[Settings] = None,\n", + " default_models=None,\n", + ") -> Tuple[Netlist, Models]:\n", + " \"\"\"Create a `Netlist` and `Models` dictionary\"\"\"\n", + " default_models = default_sax_models if default_models is None else default_models\n", + " models = _clean_component_names(models)\n", + " models = _model_operations(\n", + " models,\n", + " ops={\n", + " \"rename_params\": rename_params,\n", + " \"rename_ports\": rename_ports,\n", + " },\n", + " default_models=default_models,\n", + " )\n", + "\n", + " instances = _clean_instance_names(instances)\n", + " override_settings, global_settings = _split_global_settings(settings, instances)\n", + "\n", + " _instances: Dict[str, Union[Instance, Netlist]] = {}\n", + " for name, instance in instances.items():\n", + " if callable(instance):\n", + " instance, models = _instance_from_callable(instance, models)\n", + " elif isinstance(instance, str):\n", + " instance, models = _instance_from_string(instance, models, default_models)\n", + " if not isinstance(instance, dict):\n", + " raise ValueError(\n", + " f\"invalid instance '{name}': expected str, dict or callable.\"\n", + " )\n", + " if is_instance(instance):\n", + " instance, models = _instance_from_instance(\n", + " name=name,\n", + " instance=cast(Instance, instance),\n", + " models=models,\n", + " override_settings=override_settings,\n", + " global_settings=global_settings,\n", + " default_models=default_models,\n", + " )\n", + " elif is_netlist(instance):\n", + " instance = cast(Netlist, instance)\n", + " instance, models = netlist(\n", + " instances=instance[\"instances\"],\n", + " connections=instance[\"connections\"],\n", + " ports=instance[\"ports\"],\n", + " models=models,\n", + " settings=merge_dicts(global_settings, override_settings),\n", + " default_models=default_models,\n", + " )\n", + " else:\n", + " raise ValueError(\n", + " f\"Instance {name} cannot be interpreted as an Instance or a Netlist.\"\n", + " )\n", + " _instances[name] = instance\n", + "\n", + " _instances = {k: _instances[k] for k in natsorted(_instances.keys())}\n", + " _connections = _validate_connections(connections)\n", + " _ports = _validate_ports(ports)\n", + "\n", + " _netlist = Netlist(\n", + " instances=cast(Instances, _instances),\n", + " connections=_connections,\n", + " ports=_ports,\n", + " )\n", + " return _netlist, models" + ] + }, + { + "cell_type": "markdown", + "id": "0ae88733", + "metadata": {}, + "source": [ + "> Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9aca3e95", + "metadata": {}, + "outputs": [], + "source": [ + "from sax.models import straight, coupler\n", + "mzi_netlist, models = netlist(\n", + " instances={\n", + " \"lft\": \"mmi1x2\", # shorthand if no settings need to be given\n", + " \"top\": { # full instance definition\n", + " \"component\": \"waveguide\",\n", + " \"settings\": {\n", + " \"length\": 100.0,\n", + " },\n", + " },\n", + " \"rgt\": \"mmi2x2\", # shorthand if no settings need to be given\n", + " },\n", + " connections={\n", + " \"lft,out0\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in0\",\n", + " \"top,out1\": \"rgt,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"lft,in0\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + " models={\n", + " \"mmi1x2\": coupler,\n", + " \"mmi2x2\": coupler,\n", + " \"waveguide\": straight,\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb63486d", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def netlist_from_yaml(\n", + " yaml: str,\n", + " *,\n", + " models: Optional[Models] = None,\n", + " settings: Optional[Settings] = None,\n", + " default_models=None,\n", + ") -> Tuple[Netlist, Models]:\n", + " \"\"\"Load a sax `Netlist` from yaml definition and a `Models` dictionary\"\"\"\n", + " \n", + " default_models = default_sax_models if default_models is None else default_models\n", + " ext = None\n", + " directory = None\n", + " yaml_path = os.path.abspath(os.path.expanduser(yaml))\n", + " if os.path.isdir(yaml_path):\n", + " raise IsADirectoryError(\n", + " \"Cannot read from yaml path '{yaml_path}'. Path is a directory.\"\n", + " )\n", + " elif os.path.exists(yaml_path):\n", + " if ext is None:\n", + " _, *ext_list = os.path.basename(yaml_path).split(\".\")\n", + " ext = f\".{'.'.join(ext_list)}\"\n", + " if directory is None:\n", + " directory = os.path.dirname(yaml_path)\n", + " yaml = open(yaml_path, \"r\").read()\n", + " else:\n", + " yaml_path = None\n", + "\n", + " subnetlists = {}\n", + " if directory is not None and ext is not None:\n", + " subnetlists = {\n", + " re.sub(f\"{ext}$\", \"\", os.path.basename(file)): os.path.join(root, file)\n", + " for root, _, files in os.walk(os.path.abspath(directory))\n", + " for file in files\n", + " if file.endswith(ext)\n", + " }\n", + "\n", + " raw_netlist = load_yaml(yaml, Loader)\n", + "\n", + " for section in [\"instances\", \"connections\", \"ports\"]:\n", + " if section not in raw_netlist:\n", + " raise ValueError(f\"Can not load from yaml: '{section}' not found.\")\n", + "\n", + " raw_instances = raw_netlist[\"instances\"]\n", + " connections = raw_netlist[\"connections\"]\n", + " ports = raw_netlist[\"ports\"]\n", + " override_settings, global_settings = _split_global_settings(settings, raw_instances)\n", + "\n", + " instances = {}\n", + " for name, instance in raw_instances.items():\n", + " if isinstance(instance, str):\n", + " instance = {\"component\": instance, \"settings\": {}}\n", + " elif isinstance(instance, dict):\n", + " if \"component\" not in instance:\n", + " raise ValueError(\n", + " f\"Can not load from yaml: 'component' not found in instance '{name}'.\"\n", + " )\n", + " component = instance[\"component\"]\n", + " component = re.sub(r\"\\.ba$\", \"\", component) # for compat with pb\n", + " if component in subnetlists:\n", + " _override_settings = cast(Settings, override_settings.get(name, {}))\n", + " _settings = merge_dicts(global_settings, _override_settings)\n", + " instance, models = netlist_from_yaml(\n", + " yaml=subnetlists[component],\n", + " models=models,\n", + " settings=_settings,\n", + " default_models=default_models,\n", + " )\n", + " instances[name] = instance\n", + "\n", + " _netlist, _models = netlist(\n", + " instances=instances,\n", + " connections=connections,\n", + " ports=ports,\n", + " models=models,\n", + " settings=settings,\n", + " default_models=default_models,\n", + " )\n", + " return _netlist, _models" + ] + }, + { + "cell_type": "markdown", + "id": "a1e660a8", + "metadata": {}, + "source": [ + "> Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3358777a", + "metadata": {}, + "outputs": [], + "source": [ + "from sax.models import straight, coupler\n", + "\n", + "mzi_netlist_from_yaml, models = netlist_from_yaml(\"\"\"\n", + "instances:\n", + " \"lft\": \"mmi1x2\"\n", + " \"top\": \n", + " \"component\": \"waveguide\"\n", + " \"settings\": \n", + " \"length\": 100.0\n", + " \"rgt\": \"mmi2x2\"\n", + "connections:\n", + " \"lft,out0\": \"top,in0\"\n", + " \"top,out0\": \"rgt,in0\"\n", + " \"top,out1\": \"rgt,in1\"\n", + "ports:\n", + " \"in0\": \"lft,in0\"\n", + " \"out0\": \"rgt,out0\"\n", + " \"out1\": \"rgt,out1\"\n", + "\"\"\", \n", + "models=models,\n", + ")\n", + "\n", + "assert mzi_netlist_from_yaml == mzi_netlist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b10020dd", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def logical_netlist(\n", + " *,\n", + " instances: Instances,\n", + " connections: Dict[str, str],\n", + " ports: Dict[str, str],\n", + " models: Optional[Models] = None,\n", + " settings: Optional[Settings] = None,\n", + " default_models=None,\n", + ") -> Tuple[LogicalNetlist, Settings, Models]:\n", + " \"\"\"Create a `LogicalNetlist` with separated `Settings` and `Models` dictionary\"\"\"\n", + " \n", + " default_models = default_sax_models if default_models is None else default_models\n", + " _netlist, models = netlist(\n", + " instances=instances,\n", + " connections=connections,\n", + " ports=ports,\n", + " models=models,\n", + " settings=settings,\n", + " default_models=default_models,\n", + " )\n", + " _instances: Dict[str, str] = {}\n", + " _settings: Settings = {}\n", + " _models = models\n", + "\n", + " for name, instance in _netlist[\"instances\"].items():\n", + " if is_netlist(instance):\n", + " instance = cast(Netlist, instance)\n", + " model, _settings[name], _models = logical_netlist(\n", + " instances=instance[\"instances\"],\n", + " connections=instance[\"connections\"],\n", + " ports=instance[\"ports\"],\n", + " models=_models,\n", + " default_models=default_models,\n", + " )\n", + " model_hash = hex(abs(hash(FrozenDict(model))))[2:]\n", + " component = f\"logical_netlist_{model_hash}\"\n", + " _instances[name] = component\n", + " _models[component] = model\n", + " elif is_instance(instance):\n", + " instance = cast(Instance, instance)\n", + " component = instance[\"component\"]\n", + " _instance_settings = instance.get(\"settings\", {})\n", + " _instance_model = cast(Model, _models[component])\n", + "\n", + " if is_model_factory(_instance_model):\n", + " model_factory = cast(ModelFactory, _instance_model)\n", + " _instance_model = model_factory(**_instance_settings)\n", + " instance, _models = _instance_from_callable(_instance_model, _models)\n", + " _instance_settings = instance.get(\"settings\", {})\n", + " component = instance[\"component\"]\n", + "\n", + " _instances[name] = component\n", + " _model_settings = get_settings(_instance_model)\n", + " _instance_settings = {\n", + " k: try_float(\n", + " v\n", + " if (k not in _instance_settings or _instance_settings[k] is None)\n", + " else _instance_settings[k]\n", + " )\n", + " for k, v in _model_settings.items()\n", + " }\n", + " _settings[name] = cast(Settings, _instance_settings)\n", + " else:\n", + " raise ValueError(f\"instance '{name}' is not an Instance or a Netlist.\")\n", + "\n", + " _instances = dict(natsorted(_instances.items()))\n", + "\n", + " _logical_netlist = LogicalNetlist(\n", + " instances=_instances,\n", + " connections=_netlist[\"connections\"],\n", + " ports=_netlist[\"ports\"],\n", + " )\n", + " return (\n", + " _logical_netlist,\n", + " _settings,\n", + " _models,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f77f87eb", + "metadata": {}, + "outputs": [], + "source": [ + "mzi_logical_netlist, mzi_settings, models = logical_netlist(\n", + " instances=mzi_netlist[\"instances\"],\n", + " connections=mzi_netlist[\"connections\"],\n", + " ports=mzi_netlist[\"ports\"],\n", + " models=models,\n", + ")\n", + "\n", + "assert mzi_logical_netlist == {\n", + " \"instances\": {\"lft\": \"mmi1x2\", \"rgt\": \"mmi2x2\", \"top\": \"waveguide\"},\n", + " \"connections\": {\n", + " \"lft,out0\": \"top,in0\",\n", + " \"rgt,in0\": \"top,out0\",\n", + " \"rgt,in1\": \"top,out1\",\n", + " },\n", + " \"ports\": {\"in0\": \"lft,in0\", \"out0\": \"rgt,out0\", \"out1\": \"rgt,out1\"},\n", + "}" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/07_backends.ipynb b/nbs/07_backends.ipynb new file mode 100644 index 0000000..7509803 --- /dev/null +++ b/nbs/07_backends.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "04d466c4", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp backends.__init__" + ] + }, + { + "cell_type": "markdown", + "id": "079604be", + "metadata": {}, + "source": [ + "# Backend\n", + "\n", + "> SAX Backends" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5382439e", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from nbdev import show_doc\n", + "from pytest import approx, raises\n", + "from sax.typing_ import SDense, SDict\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c359d6e", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "from sax.backends.default import evaluate_circuit\n", + "from sax.backends.klu import evaluate_circuit_klu\n", + "from sax.backends.additive import evaluate_circuit_additive" + ] + }, + { + "cell_type": "markdown", + "id": "a39b2084", + "metadata": {}, + "source": [ + "#### circuit_backends" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dca841f0", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "circuit_backends = {\n", + " \"default\": evaluate_circuit,\n", + " \"klu\": evaluate_circuit_klu,\n", + " \"additive\": evaluate_circuit_additive,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "b23a8349", + "metadata": {}, + "source": [ + "SAX allows to easily interchange the backend of a circuit. A SAX backend needs to have the following signature:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5153759", + "metadata": {}, + "outputs": [], + "source": [ + "# hide_input\n", + "from sax.backends.default import evaluate_circuit\n", + "show_doc(evaluate_circuit, doc_string=False)" + ] + }, + { + "cell_type": "markdown", + "id": "df66e846", + "metadata": {}, + "source": [ + "i.e. it takes a dictionary of instance names pointing to `SType`s (usually `SDict`s), a connection dictionary and an (output) ports dictionary. Internally it must construct the output `SType` (usually output `SDict`)." + ] + }, + { + "cell_type": "markdown", + "id": "835a521a", + "metadata": {}, + "source": [ + "> Example\n", + "\n", + "Let's create an MZI `SDict` using the default backend's `evaluate_circuit`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a22955d", + "metadata": {}, + "outputs": [], + "source": [ + "wg_sdict: SDict = {\n", + " (\"in0\", \"out0\"): 0.5 + 0.86603j,\n", + " (\"out0\", \"in0\"): 0.5 + 0.86603j,\n", + "}\n", + "\n", + "τ, κ = 0.5 ** 0.5, 1j * 0.5 ** 0.5\n", + "dc_sdense: SDense = (\n", + " jnp.array([[0, 0, τ, κ], \n", + " [0, 0, κ, τ], \n", + " [τ, κ, 0, 0], \n", + " [κ, τ, 0, 0]]),\n", + " {\"in0\": 0, \"in1\": 1, \"out0\": 2, \"out1\": 3},\n", + ")\n", + "\n", + "mzi_sdict: SDict = evaluate_circuit(\n", + " instances={\n", + " \"dc1\": dc_sdense,\n", + " \"wg\": wg_sdict,\n", + " \"dc2\": dc_sdense,\n", + " },\n", + " connections={\n", + " \"dc1,out0\": \"wg,in0\",\n", + " \"wg,out0\": \"dc2,in0\",\n", + " \"dc1,out1\": \"dc2,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"dc1,in0\",\n", + " \"in1\": \"dc1,in1\",\n", + " \"out0\": \"dc2,out0\",\n", + " \"out1\": \"dc2,out1\",\n", + " }\n", + ")\n", + "\n", + "mzi_sdict" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/07a_backends_default.ipynb b/nbs/07a_backends_default.ipynb new file mode 100644 index 0000000..df34e19 --- /dev/null +++ b/nbs/07a_backends_default.ipynb @@ -0,0 +1,725 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "4680bd0e", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp backends.default" + ] + }, + { + "cell_type": "markdown", + "id": "41e0d311", + "metadata": {}, + "source": [ + "# Backend - default\n", + "\n", + "> Default SAX Backend" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e134758", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "import jax.numpy as jnp\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3948228", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from typing import Dict\n", + "\n", + "import jax\n", + "from sax.typing_ import SType, SDict, sdense, sdict" + ] + }, + { + "cell_type": "markdown", + "id": "d8369d2f-9109-4375-8cb4-6d7ccabbd356", + "metadata": {}, + "source": [ + "## Citation\n", + "The default SAX backend is based on the following paper:\n", + "\n", + "> Filipsson, Gunnar. \"*A new general computer algorithm for S-matrix calculation of interconnected multiports.*\" 11th European Microwave Conference. IEEE, 1981." + ] + }, + { + "cell_type": "markdown", + "id": "feea9c4f-0528-44e0-8e87-91234c6d29cc", + "metadata": {}, + "source": [ + "## Circuit Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "635246ee", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def evaluate_circuit(\n", + " instances: Dict[str, SType],\n", + " connections: Dict[str, str],\n", + " ports: Dict[str, str],\n", + ") -> SDict:\n", + " \"\"\"evaluate a circuit for the given sdicts.\"\"\"\n", + " \n", + " # it's actually easier working w reverse:\n", + " reversed_ports = {v: k for k, v in ports.items()}\n", + "\n", + " block_diag = {}\n", + " for name, S in instances.items():\n", + " block_diag.update(\n", + " {(f\"{name},{p1}\", f\"{name},{p2}\"): v for (p1, p2), v in sdict(S).items()}\n", + " )\n", + "\n", + " sorted_connections = sorted(connections.items(), key=_connections_sort_key)\n", + " all_connected_instances = {k: {k} for k in instances}\n", + "\n", + " for k, l in sorted_connections:\n", + " name1, _ = k.split(\",\")\n", + " name2, _ = l.split(\",\")\n", + "\n", + " connected_instances = (\n", + " all_connected_instances[name1] | all_connected_instances[name2]\n", + " )\n", + " for name in connected_instances:\n", + " all_connected_instances[name] = connected_instances\n", + "\n", + " current_ports = tuple(\n", + " p\n", + " for instance in connected_instances\n", + " for p in set([p for p, _ in block_diag] + [p for _, p in block_diag])\n", + " if p.startswith(f\"{instance},\")\n", + " )\n", + "\n", + " block_diag.update(_interconnect_ports(block_diag, current_ports, k, l))\n", + "\n", + " for i, j in list(block_diag.keys()):\n", + " is_connected = i == k or i == l or j == k or j == l\n", + " is_in_output_ports = i in reversed_ports and j in reversed_ports\n", + " if is_connected and not is_in_output_ports:\n", + " del block_diag[i, j] # we're no longer interested in these port combinations\n", + "\n", + " circuit_sdict: SDict = {\n", + " (reversed_ports[i], reversed_ports[j]): v\n", + " for (i, j), v in block_diag.items()\n", + " if i in reversed_ports and j in reversed_ports\n", + " }\n", + " return circuit_sdict\n", + "\n", + "\n", + "def _connections_sort_key(connection):\n", + " \"\"\"sort key for sorting a connection dictionary \"\"\"\n", + " part1, part2 = connection\n", + " name1, _ = part1.split(\",\")\n", + " name2, _ = part2.split(\",\")\n", + " return (min(name1, name2), max(name1, name2))\n", + "\n", + "\n", + "def _interconnect_ports(block_diag, current_ports, k, l):\n", + " \"\"\"interconnect two ports in a given model\n", + "\n", + " > Note: the interconnect algorithm is based on equation 6 of 'Filipsson, Gunnar. \n", + " \"A new general computer algorithm for S-matrix calculation of interconnected \n", + " multiports.\" 11th European Microwave Conference. IEEE, 1981.'\n", + " \"\"\"\n", + " current_block_diag = {}\n", + " for i in current_ports:\n", + " for j in current_ports:\n", + " vij = _calculate_interconnected_value(\n", + " vij=block_diag.get((i, j), 0.0),\n", + " vik=block_diag.get((i, k), 0.0),\n", + " vil=block_diag.get((i, l), 0.0),\n", + " vkj=block_diag.get((k, j), 0.0),\n", + " vkk=block_diag.get((k, k), 0.0),\n", + " vkl=block_diag.get((k, l), 0.0),\n", + " vlj=block_diag.get((l, j), 0.0),\n", + " vlk=block_diag.get((l, k), 0.0),\n", + " vll=block_diag.get((l, l), 0.0),\n", + " )\n", + " current_block_diag[i, j] = vij\n", + " return current_block_diag\n", + "\n", + "\n", + "@jax.jit\n", + "def _calculate_interconnected_value(vij, vik, vil, vkj, vkk, vkl, vlj, vlk, vll):\n", + " \"\"\"Calculate an interconnected S-parameter value\n", + "\n", + " Note:\n", + " The interconnect algorithm is based on equation 6 in the paper below::\n", + "\n", + " Filipsson, Gunnar. \"A new general computer algorithm for S-matrix calculation\n", + " of interconnected multiports.\" 11th European Microwave Conference. IEEE, 1981.\n", + " \"\"\"\n", + " result = vij + (\n", + " vkj * vil * (1 - vlk)\n", + " + vlj * vik * (1 - vkl)\n", + " + vkj * vll * vik\n", + " + vlj * vkk * vil\n", + " ) / ((1 - vkl) * (1 - vlk) - vkk * vll)\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "id": "9dabee38", + "metadata": {}, + "source": [ + "## Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ab7ea6c", + "metadata": {}, + "outputs": [], + "source": [ + "wg_sdict: SDict = {\n", + " (\"in0\", \"out0\"): 0.5 + 0.86603j,\n", + " (\"out0\", \"in0\"): 0.5 + 0.86603j,\n", + "}\n", + "\n", + "τ, κ = 0.5 ** 0.5, 1j * 0.5 ** 0.5\n", + "dc_sdense: SDense = (\n", + " jnp.array([[0, 0, τ, κ], \n", + " [0, 0, κ, τ], \n", + " [τ, κ, 0, 0], \n", + " [κ, τ, 0, 0]]),\n", + " {\"in0\": 0, \"in1\": 1, \"out0\": 2, \"out1\": 3},\n", + ")\n", + "\n", + "mzi_sdict: SDict = evaluate_circuit(\n", + " instances={\n", + " \"dc1\": dc_sdense,\n", + " \"wg\": wg_sdict,\n", + " \"dc2\": dc_sdense,\n", + " },\n", + " connections={\n", + " \"dc1,out0\": \"wg,in0\",\n", + " \"wg,out0\": \"dc2,in0\",\n", + " \"dc1,out1\": \"dc2,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"dc1,in0\",\n", + " \"in1\": \"dc1,in1\",\n", + " \"out0\": \"dc2,out0\",\n", + " \"out1\": \"dc2,out1\",\n", + " }\n", + ")\n", + "\n", + "mzi_sdict" + ] + }, + { + "cell_type": "markdown", + "id": "48950a3d", + "metadata": {}, + "source": [ + "## Algorithm Walkthrough\n", + "\n", + "> Note: This algorithm gets pretty slow for large circuits. I'd be [very interested in any improvements](#Algorithm-Improvements) that can be made here, especially because - as opposed to the currently faster [KLU backend](backends_klu.html) - the algorithm discussed here is jittable, differentiable and can be used on GPUs." + ] + }, + { + "cell_type": "markdown", + "id": "81e28766", + "metadata": {}, + "source": [ + "Let's walk through all the steps of this algorithm. We'll do this for a simple MZI circuit, given by two directional couplers characterised by `dc_sdense` with a phase shifting waveguide in between `wg_sdict`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2bdfbf4", + "metadata": {}, + "outputs": [], + "source": [ + "instances={\n", + " \"dc1\": dc_sdense,\n", + " \"wg\": wg_sdict,\n", + " \"dc2\": dc_sdense,\n", + "}\n", + "connections={\n", + " \"dc1,out0\": \"wg,in0\",\n", + " \"wg,out0\": \"dc2,in0\",\n", + " \"dc1,out1\": \"dc2,in1\",\n", + "}\n", + "ports={\n", + " \"in0\": \"dc1,in0\",\n", + " \"in1\": \"dc1,in1\",\n", + " \"out0\": \"dc2,out0\",\n", + " \"out1\": \"dc2,out1\",\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "6ef72161", + "metadata": {}, + "source": [ + "as a first step, we construct the `reversed_ports`, it's actually easier to work with `reversed_ports` (we chose the opposite convention in the netlist definition to adhere to the GDSFactory netlist convention):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70302612", + "metadata": {}, + "outputs": [], + "source": [ + "reversed_ports = {v: k for k, v in ports.items()}" + ] + }, + { + "cell_type": "markdown", + "id": "78676045", + "metadata": {}, + "source": [ + "The first real step of the algorithm is to create the 'block diagonal sdict`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ff6149f", + "metadata": {}, + "outputs": [], + "source": [ + "block_diag = {}\n", + "for name, S in instances.items():\n", + " block_diag.update(\n", + " {(f\"{name},{p1}\", f\"{name},{p2}\"): v for (p1, p2), v in sdict(S).items()}\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "id": "c4721548", + "metadata": {}, + "source": [ + "we can optionally filter out zeros from the resulting block_diag representation. Just note that this will make the resuling function unjittable (the resulting 'shape' (i.e. keys) of the dictionary would depend on the data itself, which is not allowed in JAX jit). We're doing it here to avoid printing zeros but **internally this is not done by default**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfce53c1", + "metadata": {}, + "outputs": [], + "source": [ + "block_diag = {k: v for k, v in block_diag.items() if jnp.abs(v) > 1e-10}\n", + "print(len(block_diag))\n", + "block_diag" + ] + }, + { + "cell_type": "markdown", + "id": "1963694a", + "metadata": {}, + "source": [ + "next, we sort the connections such that similar components are grouped together:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9859c2d4", + "metadata": {}, + "outputs": [], + "source": [ + "sorted_connections = sorted(connections.items(), key=_connections_sort_key)\n", + "sorted_connections" + ] + }, + { + "cell_type": "markdown", + "id": "c75fbf65", + "metadata": {}, + "source": [ + "Now we iterate over the sorted connections and connect components as they come in. Connected components take over the name of the first component in the connection, but we keep a set of components belonging to that key in `all_connected_instances`.\n", + "\n", + "This is how this `all_connected_instances` dictionary looks initially." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf30668c", + "metadata": {}, + "outputs": [], + "source": [ + "all_connected_instances = {k: {k} for k in instances}\n", + "all_connected_instances" + ] + }, + { + "cell_type": "markdown", + "id": "1152056b", + "metadata": {}, + "source": [ + "Normally we would loop over every connection in `sorted_connections` now, but let's just go through it once at first:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "faebbe5f", + "metadata": {}, + "outputs": [], + "source": [ + "# for k, l in sorted_connections:\n", + "k, l = sorted_connections[0]\n", + "k, l" + ] + }, + { + "cell_type": "markdown", + "id": "16a2266c", + "metadata": {}, + "source": [ + "`k` and `l` are the S-matrix indices we're trying to connect. Note that in our sparse `SDict` notation these S-matrix indices are in fact equivalent with the port names `('dc1,out1', 'dc2,in1')`!" + ] + }, + { + "cell_type": "markdown", + "id": "0eaf4add", + "metadata": {}, + "source": [ + "first we split the connection string into an instance name and a port name (we don't use the port name yet):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "667d1a3c", + "metadata": {}, + "outputs": [], + "source": [ + "name1, _ = k.split(\",\")\n", + "name2, _ = l.split(\",\")" + ] + }, + { + "cell_type": "markdown", + "id": "16e8bbe8", + "metadata": {}, + "source": [ + "We then obtain the new set of connected instances." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51bf8615", + "metadata": {}, + "outputs": [], + "source": [ + "connected_instances = all_connected_instances[name1] | all_connected_instances[name2]\n", + "connected_instances" + ] + }, + { + "cell_type": "markdown", + "id": "676c8d61", + "metadata": {}, + "source": [ + "We then iterate over each of the components in this set and make sure each of the component names in that set maps to that set (yes, I know... confusing). We do this to be able to keep track with which components each of the components in the circuit is currently already connected to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d322e86c", + "metadata": {}, + "outputs": [], + "source": [ + "for name in connected_instances:\n", + " all_connected_instances[name] = connected_instances\n", + " \n", + "all_connected_instances" + ] + }, + { + "cell_type": "markdown", + "id": "b79f49f8", + "metadata": {}, + "source": [ + "now we need to obtain all the ports of the currently connected instances." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61b3cfab", + "metadata": {}, + "outputs": [], + "source": [ + "current_ports = tuple(\n", + " p\n", + " for instance in connected_instances\n", + " for p in set([p for p, _ in block_diag] + [p for _, p in block_diag])\n", + " if p.startswith(f\"{instance},\")\n", + ")\n", + "\n", + "current_ports" + ] + }, + { + "cell_type": "markdown", + "id": "d0e87449", + "metadata": {}, + "source": [ + "Now the [Gunnar Algorithm](#citation) is used. Given a (block-diagonal) 'S-matrix' `block_diag` and a 'connection matrix' `current_ports` we can interconnect port `k` and `l` as follows:\n", + "\n", + "> Note: some creative freedom is used here. In SAX, the matrices we're talking about are in fact represented by a sparse dictionary (an `SDict`), i.e. similar to a COO sparse matrix for which the indices are the port names." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "431649e5", + "metadata": {}, + "outputs": [], + "source": [ + "def _interconnect_ports(block_diag, current_ports, k, l):\n", + " current_block_diag = {}\n", + " for i in current_ports:\n", + " for j in current_ports:\n", + " vij = _calculate_interconnected_value(\n", + " vij=block_diag.get((i, j), 0.0),\n", + " vik=block_diag.get((i, k), 0.0),\n", + " vil=block_diag.get((i, l), 0.0),\n", + " vkj=block_diag.get((k, j), 0.0),\n", + " vkk=block_diag.get((k, k), 0.0),\n", + " vkl=block_diag.get((k, l), 0.0),\n", + " vlj=block_diag.get((l, j), 0.0),\n", + " vlk=block_diag.get((l, k), 0.0),\n", + " vll=block_diag.get((l, l), 0.0),\n", + " )\n", + " current_block_diag[i, j] = vij\n", + " return current_block_diag\n", + "\n", + "@jax.jit\n", + "def _calculate_interconnected_value(vij, vik, vil, vkj, vkk, vkl, vlj, vlk, vll):\n", + " result = vij + (\n", + " vkj * vil * (1 - vlk)\n", + " + vlj * vik * (1 - vkl)\n", + " + vkj * vll * vik\n", + " + vlj * vkk * vil\n", + " ) / ((1 - vkl) * (1 - vlk) - vkk * vll)\n", + " return result\n", + "\n", + "block_diag.update(_interconnect_ports(block_diag, current_ports, k, l))" + ] + }, + { + "cell_type": "markdown", + "id": "988c08d9", + "metadata": {}, + "source": [ + "Just as before, we're filtering the zeros from the sparse representation (remember, internally this is **not done by default**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9ac8164", + "metadata": {}, + "outputs": [], + "source": [ + "block_diag = {k: v for k, v in block_diag.items() if jnp.abs(v) > 1e-10}\n", + "print(len(block_diag))\n", + "block_diag" + ] + }, + { + "cell_type": "markdown", + "id": "db0cd151", + "metadata": {}, + "source": [ + "This is the resulting block-diagonal matrix after interconnecting two ports (i.e. basically saying that those two ports are the same port). Because these ports are now connected we should actually remove them from the S-matrix representation (they are integrated into the S-parameters of the other connections):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "311b034f", + "metadata": {}, + "outputs": [], + "source": [ + "for i, j in list(block_diag.keys()):\n", + " is_connected = i == k or i == l or j == k or j == l\n", + " is_in_output_ports = i in reversed_ports and j in reversed_ports\n", + " if is_connected and not is_in_output_ports:\n", + " del block_diag[i, j] # we're no longer interested in these port combinations\n", + " \n", + "print(len(block_diag))\n", + "block_diag" + ] + }, + { + "cell_type": "markdown", + "id": "b6badcb7", + "metadata": {}, + "source": [ + "Note that this deletion of values **does NOT** make this operation un-jittable. The deletion depends on the ports of the dictionary (i.e. on the dictionary 'shape'), not on the values." + ] + }, + { + "cell_type": "markdown", + "id": "89e775cd", + "metadata": {}, + "source": [ + "We now basically have to do those steps again for all other connections:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25eac933", + "metadata": {}, + "outputs": [], + "source": [ + "#for k, l in sorted_connections: \n", + "for k, l in sorted_connections[1:]: # we just did the first iteration of this loop above...\n", + " name1, _ = k.split(\",\")\n", + " name2, _ = l.split(\",\")\n", + " connected_instances = all_connected_instances[name1] | all_connected_instances[name2]\n", + " for name in connected_instances:\n", + " all_connected_instances[name] = connected_instances\n", + " current_ports = tuple(\n", + " p\n", + " for instance in connected_instances\n", + " for p in set([p for p, _ in block_diag] + [p for _, p in block_diag])\n", + " if p.startswith(f\"{instance},\")\n", + " )\n", + " block_diag.update(_interconnect_ports(block_diag, current_ports, k, l))\n", + " for i, j in list(block_diag.keys()):\n", + " is_connected = i == k or i == l or j == k or j == l\n", + " is_in_output_ports = i in reversed_ports and j in reversed_ports\n", + " if is_connected and not is_in_output_ports:\n", + " del block_diag[i, j] # we're no longer interested in these port combinations" + ] + }, + { + "cell_type": "markdown", + "id": "36eb82fc", + "metadata": {}, + "source": [ + "This is the final MZI matrix we're getting:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1c0367e", + "metadata": {}, + "outputs": [], + "source": [ + "block_diag" + ] + }, + { + "cell_type": "markdown", + "id": "03f67ec6", + "metadata": {}, + "source": [ + "All that's left is to rename these internal ports of the format `{instance},{port}` into output ports of the resulting circuit:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf5ea4af", + "metadata": {}, + "outputs": [], + "source": [ + "circuit_sdict: SDict = {\n", + " (reversed_ports[i], reversed_ports[j]): v\n", + " for (i, j), v in block_diag.items()\n", + " if i in reversed_ports and j in reversed_ports\n", + "}\n", + "circuit_sdict" + ] + }, + { + "cell_type": "markdown", + "id": "7dc6505e", + "metadata": {}, + "source": [ + "And that's it. We evaluated the `SDict` of the full circuit." + ] + }, + { + "cell_type": "markdown", + "id": "170db10b", + "metadata": {}, + "source": [ + "## Algorithm Improvements" + ] + }, + { + "cell_type": "markdown", + "id": "4ed0e69a", + "metadata": {}, + "source": [ + "This algorithm is \n", + "\n", + "* pretty fast for small circuits 🙂\n", + "* jittable 🙂\n", + "* differentiable 🙂\n", + "* GPU-compatible 🙂\n", + "\n", + "This algorithm is however:\n", + "\n", + "* **really slow** for large circuits 😥\n", + "* **pretty slow** to jit the resulting circuit function 😥\n", + "* **pretty slow** to differentiate the resulting circuit function 😥\n", + "\n", + "There are probably still plenty of improvements possible for this algorithm:\n", + "\n", + "* **¿** Network analysis (ft. NetworkX ?) to obtain which ports of the block diagonal representation are relevant to obtain the output connection **?**\n", + "* **¿** Smarter ordering of connections to always have the minimum amount of ports in the intermediate block-diagonal representation **?**\n", + "* **¿** Using `jax.lax.scan` in stead of python native for-loops in `_interconnect_ports` **?**\n", + "* **¿** ... **?**\n", + "\n", + "Bottom line is... Do you know how to improve this algorithm or how to implement the above suggestions? Please open a Merge Request!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/07b_backends_klu.ipynb b/nbs/07b_backends_klu.ipynb new file mode 100644 index 0000000..809f05d --- /dev/null +++ b/nbs/07b_backends_klu.ipynb @@ -0,0 +1,1221 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1db95e7a", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp backends.klu" + ] + }, + { + "cell_type": "markdown", + "id": "27c3489e", + "metadata": {}, + "source": [ + "# Backend - KLU\n", + "\n", + "> SAX KLU Backend" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19542d05", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import sax\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "from nbdev import show_doc\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12cf5743", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from typing import Dict\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from sax.typing_ import SDense, SDict, SType, scoo\n", + "\n", + "try:\n", + " import klujax\n", + "except ImportError:\n", + " klujax = None" + ] + }, + { + "cell_type": "markdown", + "id": "64e42767", + "metadata": {}, + "source": [ + "## Citation\n", + "The KLU backend is using `klujax`, which uses the [SuiteSparse](https://github.com/DrTimothyAldenDavis/SuiteSparse) C++ libraries for sparse matrix evaluations to evaluate the circuit insanely fast on a CPU. The specific algorith being used in question is the KLU algorithm:\n", + "\n", + "> Ekanathan Palamadai Natariajan. \"*KLU - A high performance sparse linear solver for circuit simulation problems.*\"" + ] + }, + { + "cell_type": "markdown", + "id": "2db5821c", + "metadata": {}, + "source": [ + "## Theoretical Background" + ] + }, + { + "cell_type": "markdown", + "id": "3d4d4a4c", + "metadata": {}, + "source": [ + "The core of the KLU algorithm is supported by `klujax`, which internally uses the Suitesparse libraries to solve the sparse system `Ax = b`, in which A is a sparse matrix." + ] + }, + { + "cell_type": "markdown", + "id": "a1985f58", + "metadata": {}, + "source": [ + "Now it only comes down to shoehorn our circuit evaluation into a sparse linear system of equations $Ax=b$ where we need to solve for $x$ using `klujax`. \n", + "Consider the block diagonal matrix $S_{bd}$ of all components in the circuit acting on the fields $x_{in}$ at each of the individual ports of each of the component integrated in $S^{bd}$. The output fields $x^{out}$ at each of those ports is then given by:\n", + "\n", + "$$\n", + "x^{out} = S_{bd} x^{in}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "9c26c04a", + "metadata": {}, + "source": [ + "However, $S_{bd}$ is not the S-matrix of the circuit as it does not encode any connectivity *between* the components. Connecting two component ports basically comes down to enforcing equality between the output fields at one port of a component with the input fields at another port of another (or maybe even the same) component. This equality can be enforced by creating an internal connection matrix, connecting all internal ports of the circuit:\n", + "\n", + "$$\n", + "x^{in} = C_{int} x^{out}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "16b3fdad", + "metadata": {}, + "source": [ + "We can thus write the following combined equation:\n", + "\n", + "$$\n", + "x^{in} = C_{int} S_{bd} x^{in}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "f5ceb9a2", + "metadata": {}, + "source": [ + "But this is not the complete story... Some component ports will *not* be *interconnected* with other ports: they will become the new *external ports* (or output ports) of the combined circuit. We can include those external ports into the above equation as follows:\n", + "\n", + "$$\n", + "\\begin{pmatrix} x^{in} \\\\ x^{out}_{ext} \\end{pmatrix} = \\begin{pmatrix} C_{int} & C_{ext} \\\\ C_{ext}^T & 0 \\end{pmatrix} \\begin{pmatrix} S_{bd} x^{in} \\\\ x_{ext}^{in} \\end{pmatrix} \n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "6382a242", + "metadata": {}, + "source": [ + "Note that $C_{ext}$ is obviously **not** a square matrix. Eliminating $x^{in}$ from the equation above finally yields:\n", + "\n", + "$$\n", + "x^{out}_{ext} = C^T_{ext} S_{bd} (\\mathbb{1} - C_{int}S_{bd})^{-1} C_{ext}x_{ext}^{in}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "5f15bdcd", + "metadata": {}, + "source": [ + "We basically found a representation of the circuit S-matrix:\n", + "\n", + "$$\n", + "S = C^T_{ext} S_{bd} (\\mathbb{1} - C_{int}S_{bd})^{-1} C_{ext}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "40ad83b8", + "metadata": {}, + "source": [ + "Obviously, we won't want to calculate the inverse $(\\mathbb{1} - C_{int}S_{bd})^{-1}$, which is the inverse of a very sparse matrix (a connection matrix only has a single 1 per line), which very often is not even sparse itself. In stead we'll use the `solve_klu` function:\n", + "\n", + "$$\n", + "S = C^T_{ext} S_{bd} \\texttt{solve}\\_\\texttt{klu}\\left((\\mathbb{1} - C_{int}S_{bd}), C_{ext}\\right)\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "018fab3d", + "metadata": {}, + "source": [ + "Moreover, $C_{ext}^TS_{bd}$ is also a sparse matrix, therefore we'll also need a `mul_coo` routine:\n", + "\n", + "$$\n", + "S = C^T_{ext} \\texttt{mul}\\_\\texttt{coo}\\left(S_{bd},~~\\texttt{solve}\\_\\texttt{klu}\\left((\\mathbb{1} - C_{int}S_{bd}),~C_{ext}\\right)\\right)\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "d9b78852", + "metadata": {}, + "source": [ + "## Sparse Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "521638fe", + "metadata": {}, + "outputs": [], + "source": [ + "# hide_input\n", + "show_doc(klujax.solve, doc_string=False, name=\"klujax.solve\")" + ] + }, + { + "cell_type": "markdown", + "id": "01299d5d", + "metadata": {}, + "source": [ + "`klujax.solve` solves the sparse system of equations `Ax=b` for `x`. Where `A` is represented by in [COO-format](https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_(COO)) as (`Ai`, `Aj`, `Ax`).\n", + "\n", + "> Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63e7977a", + "metadata": {}, + "outputs": [], + "source": [ + "Ai = jnp.array([0, 1, 2, 3, 4])\n", + "Aj = jnp.array([1, 3, 4, 0, 2])\n", + "Ax = jnp.array([5, 6, 1, 1, 2])\n", + "b = jnp.array([5, 3, 2, 6, 1])\n", + "\n", + "x = klujax.solve(Ai, Aj, Ax, b)\n", + "x" + ] + }, + { + "cell_type": "markdown", + "id": "3233b04c", + "metadata": {}, + "source": [ + "This result is indeed correct:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "629bbc71", + "metadata": {}, + "outputs": [], + "source": [ + "A = jnp.zeros((5, 5)).at[Ai, Aj].set(Ax)\n", + "print(A)\n", + "print(A@x)" + ] + }, + { + "cell_type": "markdown", + "id": "c7fe6547", + "metadata": {}, + "source": [ + "However, to use this function effectively, we probably need an extra dimension for `Ax`. Indeed, we would like to solve this equation for multiple wavelengths (or more general, for multiple circuit configurations) at once. For this we can use `jax.vmap` to expose `klujax.solve` to more dimensions for `Ax`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f9686a0", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "solve_klu = None\n", + "if klujax is not None:\n", + " solve_klu = jax.vmap(klujax.solve, (None, None, 0, None), 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0fcc3e1", + "metadata": {}, + "outputs": [], + "source": [ + "# hide_input\n", + "show_doc(solve_klu, doc_string=False, name=\"solve_klu\")" + ] + }, + { + "cell_type": "markdown", + "id": "6707049d", + "metadata": {}, + "source": [ + "Let's now redefine `Ax` and see what it gives:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4d7b6f9", + "metadata": {}, + "outputs": [], + "source": [ + "Ai = jnp.array([0, 1, 2, 3, 4])\n", + "Aj = jnp.array([1, 3, 4, 0, 2])\n", + "Ax = jnp.array([[5, 6, 1, 1, 2], \n", + " [5, 4, 3, 2, 1], \n", + " [1, 2, 3, 4, 5]])\n", + "b = jnp.array([5, 3, 2, 6, 1])\n", + "\n", + "x = solve_klu(Ai, Aj, Ax, b)\n", + "x" + ] + }, + { + "cell_type": "markdown", + "id": "5f342c6f", + "metadata": {}, + "source": [ + "This result is indeed correct:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0afefaa6", + "metadata": {}, + "outputs": [], + "source": [ + "A = jnp.zeros((3, 5, 5)).at[:, Ai, Aj].set(Ax)\n", + "jnp.einsum(\"ijk,ik->ij\", A, x)" + ] + }, + { + "cell_type": "markdown", + "id": "7ed8f62d", + "metadata": {}, + "source": [ + "Additionally, we need a way to multiply a sparse COO-matrix with a dense vector. This can be done with `klujax.coo_mul_vec`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "606b3e20", + "metadata": {}, + "outputs": [], + "source": [ + "# hide_input\n", + "\n", + "show_doc(klujax.coo_mul_vec, doc_string=False, name=\"klujax.coo_mul_vec\")" + ] + }, + { + "cell_type": "markdown", + "id": "7f57c443", + "metadata": {}, + "source": [ + "However, it's useful to allow a batch dimension, this time *both* in `Ax` and in `b`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f11c4df", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "\n", + "# @jax.jit # TODO: make this available to autograd\n", + "# def mul_coo(Ai, Aj, Ax, b):\n", + "# result = jnp.zeros_like(b).at[..., Ai, :].add(Ax[..., :, None] * b[..., Aj, :])\n", + "# return result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2f56966", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "mul_coo = None \n", + "if klujax is not None:\n", + " mul_coo = jax.vmap(klujax.coo_mul_vec, (None, None, 0, 0), 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3e5e645", + "metadata": {}, + "outputs": [], + "source": [ + "# hide_input\n", + "show_doc(mul_coo, doc_string=False, name=\"mul_coo\")" + ] + }, + { + "cell_type": "markdown", + "id": "d26cc3d9", + "metadata": {}, + "source": [ + "Let's confirm this does the right thing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd868973", + "metadata": {}, + "outputs": [], + "source": [ + "mul_coo(Ai, Aj, Ax, x)" + ] + }, + { + "cell_type": "markdown", + "id": "d71e89fa", + "metadata": {}, + "source": [ + "## Circuit Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e77e2026", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def evaluate_circuit_klu(\n", + " instances: Dict[str, SType],\n", + " connections: Dict[str, str],\n", + " ports: Dict[str, str],\n", + "):\n", + " \"\"\"evaluate a circuit using KLU for the given sdicts. \"\"\"\n", + "\n", + " if klujax is None:\n", + " raise ImportError(\n", + " \"Could not import 'klujax'. \"\n", + " \"Please install it first before using backend method 'klu'\"\n", + " )\n", + "\n", + " assert solve_klu is not None\n", + " assert mul_coo is not None\n", + "\n", + " connections = {**connections, **{v: k for k, v in connections.items()}}\n", + " inverse_ports = {v: k for k, v in ports.items()}\n", + " port_map = {k: i for i, k in enumerate(ports)}\n", + "\n", + " idx, Si, Sj, Sx, instance_ports = 0, [], [], [], {}\n", + " batch_shape = ()\n", + " for name, instance in instances.items():\n", + " si, sj, sx, ports_map = scoo(instance)\n", + " Si.append(si + idx)\n", + " Sj.append(sj + idx)\n", + " Sx.append(sx)\n", + " if len(sx.shape[:-1]) > len(batch_shape):\n", + " batch_shape = sx.shape[:-1]\n", + " instance_ports.update({f\"{name},{p}\": i + idx for p, i in ports_map.items()})\n", + " idx += len(ports_map)\n", + "\n", + " Si = jnp.concatenate(Si, -1)\n", + " Sj = jnp.concatenate(Sj, -1)\n", + " Sx = jnp.concatenate(\n", + " [jnp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1\n", + " )\n", + "\n", + " n_col = idx\n", + " n_rhs = len(port_map)\n", + "\n", + " Cmap = {\n", + " int(instance_ports[k]): int(instance_ports[v]) for k, v in connections.items()\n", + " }\n", + " Ci = jnp.array(list(Cmap.keys()), dtype=jnp.int32)\n", + " Cj = jnp.array(list(Cmap.values()), dtype=jnp.int32)\n", + "\n", + " Cextmap = {int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items()}\n", + " Cexti = jnp.stack(list(Cextmap.keys()), 0)\n", + " Cextj = jnp.stack(list(Cextmap.values()), 0)\n", + " Cext = jnp.zeros((n_col, n_rhs), dtype=complex).at[Cexti, Cextj].set(1.0)\n", + "\n", + " # TODO: make this block jittable...\n", + " Ix = jnp.ones((*batch_shape, n_col))\n", + " Ii = Ij = jnp.arange(n_col)\n", + " mask = Cj[None,:] == Si[:, None]\n", + " CSi = jnp.broadcast_to(Ci[None, :], mask.shape)[mask]\n", + "\n", + " # CSi = jnp.where(Cj[None, :] == Si[:, None], Ci[None, :], 0).sum(1)\n", + " mask = (Cj[:, None] == Si[None, :]).any(0)\n", + " CSj = Sj[mask]\n", + " \n", + " if Sx.ndim > 1: # bug in JAX... see https://github.com/google/jax/issues/9050\n", + " CSx = Sx[..., mask]\n", + " else:\n", + " CSx = Sx[mask]\n", + " \n", + " # CSj = jnp.where(mask, Sj, 0)\n", + " # CSx = jnp.where(mask, Sx, 0.0)\n", + "\n", + " I_CSi = jnp.concatenate([CSi, Ii], -1)\n", + " I_CSj = jnp.concatenate([CSj, Ij], -1)\n", + " I_CSx = jnp.concatenate([-CSx, Ix], -1)\n", + "\n", + " n_col, n_rhs = Cext.shape\n", + " n_lhs = jnp.prod(jnp.array(batch_shape, dtype=jnp.int32))\n", + " Sx = Sx.reshape(n_lhs, -1)\n", + " I_CSx = I_CSx.reshape(n_lhs, -1)\n", + "\n", + " inv_I_CS_Cext = solve_klu(I_CSi, I_CSj, I_CSx, Cext)\n", + " S_inv_I_CS_Cext = mul_coo(Si, Sj, Sx, inv_I_CS_Cext)\n", + "\n", + " CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj]\n", + " \n", + " _, n, _ = CextT_S_inv_I_CS_Cext.shape\n", + " S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n)\n", + "\n", + " return S, port_map" + ] + }, + { + "cell_type": "markdown", + "id": "dce9d8ca", + "metadata": {}, + "source": [ + "## Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f577522a", + "metadata": {}, + "outputs": [], + "source": [ + "wg_sdict: SDict = {\n", + " (\"in0\", \"out0\"): 0.5 + 0.86603j,\n", + " (\"out0\", \"in0\"): 0.5 + 0.86603j,\n", + "}\n", + "\n", + "τ, κ = 0.5 ** 0.5, 1j * 0.5 ** 0.5\n", + "dc_sdense: SDense = (\n", + " jnp.array([[0, 0, τ, κ], \n", + " [0, 0, κ, τ], \n", + " [τ, κ, 0, 0], \n", + " [κ, τ, 0, 0]]),\n", + " {\"in0\": 0, \"in1\": 1, \"out0\": 2, \"out1\": 3},\n", + ")\n", + "\n", + "mzi_sdense: SDense = evaluate_circuit_klu(\n", + " instances={\n", + " \"dc1\": dc_sdense,\n", + " \"wg\": wg_sdict,\n", + " \"dc2\": dc_sdense,\n", + " },\n", + " connections={\n", + " \"dc1,out0\": \"wg,in0\",\n", + " \"wg,out0\": \"dc2,in0\",\n", + " \"dc1,out1\": \"dc2,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"dc1,in0\",\n", + " \"in1\": \"dc1,in1\",\n", + " \"out0\": \"dc2,out0\",\n", + " \"out1\": \"dc2,out1\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "60b6d1e6", + "metadata": {}, + "source": [ + "the KLU backend yields `SDense` results by default:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ecc3115e", + "metadata": {}, + "outputs": [], + "source": [ + "mzi_sdense" + ] + }, + { + "cell_type": "markdown", + "id": "4b8e87f3", + "metadata": {}, + "source": [ + "An `SDense` is returned for perfomance reasons. By returning an `SDense` by default we prevent any internal `SDict -> SDense` conversions in deeply hierarchical circuits. It's however very easy to convert `SDense` to `SDict` as a final step. To do this, wrap the result (or the function generating the result) with `sdict`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7202ab1e", + "metadata": {}, + "outputs": [], + "source": [ + "sax.sdict(mzi_sdense)" + ] + }, + { + "cell_type": "markdown", + "id": "9f8cd5ea", + "metadata": {}, + "source": [ + "## Algorithm Walkthrough" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "665c856d", + "metadata": {}, + "outputs": [], + "source": [ + "instances={\n", + " \"dc1\": dc_sdense,\n", + " \"wg\": wg_sdict,\n", + " \"dc2\": dc_sdense,\n", + "}\n", + "connections={\n", + " \"dc1,out0\": \"wg,in0\",\n", + " \"wg,out0\": \"dc2,in0\",\n", + " \"dc1,out1\": \"dc2,in1\",\n", + "}\n", + "ports={\n", + " \"in0\": \"dc1,in0\",\n", + " \"in1\": \"dc1,in1\",\n", + " \"out0\": \"dc2,out0\",\n", + " \"out1\": \"dc2,out1\",\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "b5bf5188", + "metadata": {}, + "source": [ + "Let's first enforce $C^T = C$:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d7d4c39", + "metadata": {}, + "outputs": [], + "source": [ + "connections = {**connections, **{v: k for k, v in connections.items()}}\n", + "connections" + ] + }, + { + "cell_type": "markdown", + "id": "a41250ec", + "metadata": {}, + "source": [ + "We'll also need the reversed ports:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a70463f", + "metadata": {}, + "outputs": [], + "source": [ + "inverse_ports = {v: k for k, v in ports.items()}\n", + "inverse_ports" + ] + }, + { + "cell_type": "markdown", + "id": "1e96443b", + "metadata": {}, + "source": [ + "An the port indices" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e18d182", + "metadata": {}, + "outputs": [], + "source": [ + "port_map = {k: i for i, k in enumerate(ports)}\n", + "port_map" + ] + }, + { + "cell_type": "markdown", + "id": "70924074", + "metadata": {}, + "source": [ + "Let's now create the COO-representation of our block diagonal S-matrix $S_{bd}$:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "835deb8f", + "metadata": {}, + "outputs": [], + "source": [ + "idx, Si, Sj, Sx, instance_ports = 0, [], [], [], {}\n", + "batch_shape = ()\n", + "for name, instance in instances.items():\n", + " si, sj, sx, ports_map = scoo(instance)\n", + " Si.append(si + idx)\n", + " Sj.append(sj + idx)\n", + " Sx.append(sx)\n", + " if len(sx.shape[:-1]) > len(batch_shape):\n", + " batch_shape = sx.shape[:-1]\n", + " instance_ports.update({f\"{name},{p}\": i + idx for p, i in ports_map.items()})\n", + " idx += len(ports_map)\n", + "Si = jnp.concatenate(Si, -1)\n", + "Sj = jnp.concatenate(Sj, -1)\n", + "Sx = jnp.concatenate([jnp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1)\n", + "\n", + "print(Si)\n", + "print(Sj)\n", + "print(Sx)" + ] + }, + { + "cell_type": "markdown", + "id": "a8edefa1", + "metadata": {}, + "source": [ + "note that we also kept track of the `batch_shape`, i.e. the number of independent simulations (usually number of wavelengths). In the example being used here we don't have a batch dimension (all elements of the `SDict` are `0D`):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7513759a", + "metadata": {}, + "outputs": [], + "source": [ + "batch_shape" + ] + }, + { + "cell_type": "markdown", + "id": "c63b14d7", + "metadata": {}, + "source": [ + "We'll also keep track of the number of columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89000432", + "metadata": {}, + "outputs": [], + "source": [ + "n_col = idx\n", + "n_col" + ] + }, + { + "cell_type": "markdown", + "id": "7b0f95da", + "metadata": {}, + "source": [ + "And we'll need to solve the circuit for each output port, i.e. we need to solve `n_rhs` number of equations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27902354", + "metadata": {}, + "outputs": [], + "source": [ + "n_rhs = len(port_map)\n", + "n_rhs" + ] + }, + { + "cell_type": "markdown", + "id": "05056b46", + "metadata": {}, + "source": [ + "We can represent the internal connection matrix $C_{int}$ as a mapping between port indices:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3ff1fa7", + "metadata": {}, + "outputs": [], + "source": [ + "Cmap = {int(instance_ports[k]): int(instance_ports[v]) for k, v in connections.items()}\n", + "Cmap" + ] + }, + { + "cell_type": "markdown", + "id": "999164af", + "metadata": {}, + "source": [ + "Therefore, the COO-representation of this connection matrix can be obtained as follows (note that an array of values Cx is not necessary, all non-zero elements in a connection matrix are 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e4f1728", + "metadata": {}, + "outputs": [], + "source": [ + "Ci = jnp.array(list(Cmap.keys()), dtype=jnp.int32)\n", + "Cj = jnp.array(list(Cmap.values()), dtype=jnp.int32)\n", + "print(Ci)\n", + "print(Cj)" + ] + }, + { + "cell_type": "markdown", + "id": "47d5ed60", + "metadata": {}, + "source": [ + "We can represent the external connection matrix $C_{ext}$ as a map between internal port indices and external port indices:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f56ae166", + "metadata": {}, + "outputs": [], + "source": [ + "Cextmap = {int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items()}\n", + "Cextmap" + ] + }, + { + "cell_type": "markdown", + "id": "2485c01d", + "metadata": {}, + "source": [ + "Just as for the internal matrix we can represent this external connection matrix in COO-format:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfb419ac", + "metadata": {}, + "outputs": [], + "source": [ + "Cexti = jnp.stack(list(Cextmap.keys()), 0)\n", + "Cextj = jnp.stack(list(Cextmap.values()), 0)\n", + "print(Cexti)\n", + "print(Cextj)" + ] + }, + { + "cell_type": "markdown", + "id": "e46d609c", + "metadata": {}, + "source": [ + "However, we actually need it as a dense representation:\n", + "\n", + "> help needed: can we find a way later on to keep this sparse?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25c4ec6e", + "metadata": {}, + "outputs": [], + "source": [ + "Cext = jnp.zeros((n_col, n_rhs), dtype=complex).at[Cexti, Cextj].set(1.0)\n", + "Cext" + ] + }, + { + "cell_type": "markdown", + "id": "b31cc279", + "metadata": {}, + "source": [ + "We'll now calculate the row index `CSi` of $C_{int}S_{bd}$ in COO-format:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a7e87d1", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: make this block jittable...\n", + "Ix = jnp.ones((*batch_shape, n_col))\n", + "Ii = Ij = jnp.arange(n_col)\n", + "mask = Cj[None,:] == Si[:, None]\n", + "CSi = jnp.broadcast_to(Ci[None, :], mask.shape)[mask]\n", + "CSi" + ] + }, + { + "cell_type": "markdown", + "id": "a2657582", + "metadata": {}, + "source": [ + "> `CSi`: possible jittable alternative? how do we remove the zeros?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4ea2cb5", + "metadata": {}, + "outputs": [], + "source": [ + "CSi_ = jnp.where(Cj[None, :] == Si[:, None], Ci[None, :], 0).sum(1) # not used\n", + "CSi_ # not used" + ] + }, + { + "cell_type": "markdown", + "id": "44fc620e", + "metadata": {}, + "source": [ + "The column index `CSj` of $C_{int}S_{bd}$ can more easily be obtained:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f71fdd7c", + "metadata": {}, + "outputs": [], + "source": [ + "mask = (Cj[:, None] == Si[None, :]).any(0)\n", + "CSj = Sj[mask]\n", + "CSj" + ] + }, + { + "cell_type": "markdown", + "id": "99bf7684", + "metadata": {}, + "source": [ + "> `CSj`: possible jittable alternative? how do we remove the zeros?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffa7bc82", + "metadata": {}, + "outputs": [], + "source": [ + "CSj_ = jnp.where(mask, Sj, 0) # not used\n", + "CSj_ # not used" + ] + }, + { + "cell_type": "markdown", + "id": "e0fc6d02", + "metadata": {}, + "source": [ + "Finally, the values `CSx` of $C_{int}S_{bd}$ can be obtained as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59d2e1b2", + "metadata": {}, + "outputs": [], + "source": [ + "if Sx.ndim > 1:\n", + " CSx = Sx[..., mask] # normally this should be enough\n", + "else:\n", + " CSx = Sx[mask] # need separate case bc bug in JAX... see https://github.com/google/jax/issues/9050\n", + " \n", + "CSx" + ] + }, + { + "cell_type": "markdown", + "id": "c02035bd", + "metadata": {}, + "source": [ + "> `CSx`: possible jittable alternative? how do we remove the zeros?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "107e9e80", + "metadata": {}, + "outputs": [], + "source": [ + "CSx_ = jnp.where(mask, Sx, 0.0) # not used\n", + "CSx_ # not used" + ] + }, + { + "cell_type": "markdown", + "id": "9411edf5", + "metadata": {}, + "source": [ + "Now we calculate $\\mathbb{1} - C_{int}S_{bd}$ in an *uncoalesced* way (we might have duplicate indices on the diagonal):\n", + "\n", + "> **uncoalesced**: having duplicate index combinations (i, j) in the representation possibly with different corresponding values. This is usually not a problem as in linear operations these values will end up to be summed, usually the behavior you want:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53c3f4ab", + "metadata": {}, + "outputs": [], + "source": [ + "I_CSi = jnp.concatenate([CSi, Ii], -1)\n", + "I_CSj = jnp.concatenate([CSj, Ij], -1)\n", + "I_CSx = jnp.concatenate([-CSx, Ix], -1)\n", + "print(I_CSi)\n", + "print(I_CSj)\n", + "print(I_CSx)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "188e4606", + "metadata": {}, + "outputs": [], + "source": [ + "n_col, n_rhs = Cext.shape\n", + "print(n_col, n_rhs)" + ] + }, + { + "cell_type": "markdown", + "id": "e018c515", + "metadata": {}, + "source": [ + "The batch shape dimension can generally speaking be anything (in the example here 0D). We need to do the necessary reshapings to make the batch shape 1D:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dad259e7", + "metadata": {}, + "outputs": [], + "source": [ + "n_lhs = jnp.prod(jnp.array(batch_shape, dtype=jnp.int32))\n", + "print(n_lhs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cf69da7", + "metadata": {}, + "outputs": [], + "source": [ + "Sx = Sx.reshape(n_lhs, -1)\n", + "Sx.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "788294fa", + "metadata": {}, + "outputs": [], + "source": [ + "I_CSx = I_CSx.reshape(n_lhs, -1)\n", + "I_CSx.shape" + ] + }, + { + "cell_type": "markdown", + "id": "ec741031", + "metadata": {}, + "source": [ + "We're finally ready to do the most important part of the calculation, which we conveniently leave to `klujax` and `SuiteSparse`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c1d8644", + "metadata": {}, + "outputs": [], + "source": [ + "inv_I_CS_Cext = solve_klu(I_CSi, I_CSj, I_CSx, Cext)" + ] + }, + { + "cell_type": "markdown", + "id": "e51487df", + "metadata": {}, + "source": [ + "one more sparse multiplication:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cccda072", + "metadata": {}, + "outputs": [], + "source": [ + "S_inv_I_CS_Cext = mul_coo(Si, Sj, Sx, inv_I_CS_Cext)" + ] + }, + { + "cell_type": "markdown", + "id": "627a98f7", + "metadata": {}, + "source": [ + "And one more $C_{ext}$ multiplication which we do by clever indexing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12ff877c", + "metadata": {}, + "outputs": [], + "source": [ + "CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj]\n", + "CextT_S_inv_I_CS_Cext" + ] + }, + { + "cell_type": "markdown", + "id": "3e1c75b2", + "metadata": {}, + "source": [ + "That's it! We found the S-matrix of the circuit. We just need to reshape the batch dimension back into the matrix:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbf962af", + "metadata": {}, + "outputs": [], + "source": [ + "_, n, _ = CextT_S_inv_I_CS_Cext.shape\n", + "S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n)\n", + "S" + ] + }, + { + "cell_type": "markdown", + "id": "1f6dea4c", + "metadata": {}, + "source": [ + "Oh and to complete the `SDense` representation we need to specify the port map as well:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21098a45", + "metadata": {}, + "outputs": [], + "source": [ + "port_map" + ] + }, + { + "cell_type": "markdown", + "id": "0cfee3cf", + "metadata": {}, + "source": [ + "## Algorithm Improvements" + ] + }, + { + "cell_type": "markdown", + "id": "26b70270", + "metadata": {}, + "source": [ + "This algorithm is \n", + "\n", + "* very fast for large circuits 🙂\n", + "\n", + "This algorithm is however:\n", + "\n", + "* **not** jittable 😥\n", + "* **not** differentiable 😥\n", + "* **not** GPU-compatible 🙂\n", + "\n", + "There are probably still plenty of improvements possible for this algorithm:\n", + "\n", + "* **¿** make it jittable **?**\n", + "* **¿** make it differentiable (requires making klujax differentiable first) **?**\n", + "* **¿** make it GPU compatible (requires making suitesparse GPU compatible... probably not gonna happen)**?**\n", + "\n", + "Bottom line is... Do you know how to improve this algorithm or how to implement the above suggestions? Please open a Merge Request!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/07c_backends_additive.ipynb b/nbs/07c_backends_additive.ipynb new file mode 100644 index 0000000..a69d1d1 --- /dev/null +++ b/nbs/07c_backends_additive.ipynb @@ -0,0 +1,391 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "51511eb3", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp backends.additive" + ] + }, + { + "cell_type": "markdown", + "id": "ca1899a9", + "metadata": {}, + "source": [ + "# Backend - additive\n", + "\n", + "> Additive SAX Backend\n", + "\n", + "Sometimes we would like to calculate circuit path lengths or time delays within a circuit. We could obviously simulate these things with a time domain simulator, but in many cases a simple additive backend (as opposed to the default multiplicative backend) can suffice." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "033a8953", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9b1b10b", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from typing import Dict, Tuple\n", + "\n", + "import jax.numpy as jnp\n", + "import networkx as nx\n", + "from sax.typing_ import SDict, SType, SDense, sdict\n", + "from sax.netlist import netlist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0f28218", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def split_port(port: str) -> Tuple[str, str]:\n", + " try:\n", + " instance, port = port.split(\",\")\n", + " except ValueError:\n", + " (port,) = port.split(\",\")\n", + " instance = \"\"\n", + " return instance, port" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "345aaa5b-cefe-49ec-9213-70c73b8c741b", + "metadata": {}, + "outputs": [], + "source": [ + "assert split_port(\"wg,in0\") == ('wg', 'in0') # internal circuit port\n", + "assert split_port(\"out0\") == ('', 'out0') # external circuit port" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da1af226-47d3-4134-aaeb-ac37f8cf5ed7", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def graph_edges(\n", + " instances: Dict[str, SType], connections: Dict[str, str], ports: Dict[str, str]\n", + "):\n", + " zero = jnp.array([0.0], dtype=float)\n", + " edges = {}\n", + " edges.update({split_port(k): split_port(v) for k, v in connections.items()})\n", + " edges.update({split_port(v): split_port(k) for k, v in connections.items()})\n", + " edges.update({split_port(k): split_port(v) for k, v in ports.items()})\n", + " edges.update({split_port(v): split_port(k) for k, v in ports.items()})\n", + " edges = [(n1, n2, {\"type\": \"C\", \"length\": zero}) for n1, n2 in edges.items()]\n", + "\n", + " _instances = {\n", + " **{i1: None for (i1, _), (_, _), _ in edges},\n", + " **{i2: None for (_, _), (i2, _), _ in edges},\n", + " }\n", + " del _instances[\"\"] # external ports don't belong to an instance\n", + "\n", + " for instance in _instances:\n", + " s = instances[instance]\n", + " edges += [\n", + " ((instance, p1), (instance, p2), {\"type\": \"S\", \"length\": jnp.asarray(length, dtype=float).ravel()})\n", + " for (p1, p2), length in sdict(s).items()\n", + " ]\n", + "\n", + " return edges" + ] + }, + { + "cell_type": "markdown", + "id": "bc6e5237-5517-4776-b23f-9bac3200cdfa", + "metadata": {}, + "source": [ + "> Example\n", + "\n", + "> Note: in stead of S-parameters the stypes need to contain *additive* parameters, such as length or time delay." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9a7bed0-c4b2-42b9-aac6-f8371cad3e6b", + "metadata": {}, + "outputs": [], + "source": [ + "wg_sdict = {\n", + " (\"in0\", \"out0\"): jnp.array([100.0, 200.0, 300.0]), # assume for now there are three possible paths between these two ports.\n", + " (\"out0\", \"in0\"): jnp.array([100.0, 200.0, 300.0]), # assume for now there are three possible paths between these two ports.\n", + "}\n", + "\n", + "dc_sdict = {\n", + " (\"in0\", \"out0\"): jnp.array([10.0, 20.0]), # assume for now there are two possible paths between these two ports.\n", + " (\"in0\", \"out1\"): 15.0,\n", + " (\"in1\", \"out0\"): 15.0,\n", + " (\"in1\", \"out1\"): jnp.array([10.0, 20.0]), # assume for now there are two possible paths between these two ports.\n", + "}\n", + "\n", + "instances= {\n", + " \"dc1\": dc_sdict,\n", + " \"wg\": wg_sdict,\n", + " \"dc2\": dc_sdict,\n", + "}\n", + "connections= {\n", + " \"dc1,out0\": \"wg,in0\",\n", + " \"wg,out0\": \"dc2,in0\",\n", + " \"dc1,out1\": \"dc2,in1\",\n", + "}\n", + "ports= {\n", + " \"in0\": \"dc1,in0\",\n", + " \"in1\": \"dc1,in1\",\n", + " \"out0\": \"dc2,out0\",\n", + " \"out1\": \"dc2,out1\",\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "7aaff527-a216-472b-891d-9c56901af914", + "metadata": {}, + "source": [ + "> Note: it is recommended to **not** use an `SDense` representation for the additive backend. Very often an `SDense` representation will introduce **zeros** which will be interpreted as an **existing connection with zero length**. Conversely, in a sparse representation like `SDict` or `SCoo`, non-existing elements will be just that: they will not be present in the internal graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a8ed72a-bb73-46aa-8a78-5dea087d7fc7", + "metadata": {}, + "outputs": [], + "source": [ + "edges = graph_edges(instances, connections, ports)\n", + "edges" + ] + }, + { + "cell_type": "markdown", + "id": "d87a25be-b4bb-458c-9bd2-0f5630524c82", + "metadata": {}, + "source": [ + "We made a difference here between edges of 'S'-type (connections through the S-matrix) and edges of 'C'-type (connections through the connection matrix). Connections of 'C'-type obviously always have length zero as they signify per definition the equality of two ports." + ] + }, + { + "cell_type": "markdown", + "id": "12c3c744-7fdc-400c-9612-c50fb8214819", + "metadata": {}, + "source": [ + "We can create a NetworkX graph from these edges:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bef4c710-10c4-480c-b28e-897e2672136f", + "metadata": {}, + "outputs": [], + "source": [ + "graph = nx.Graph()\n", + "graph.add_edges_from(edges)\n", + "nx.draw_kamada_kawai(graph, with_labels=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "876b683e", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def prune_internal_output_nodes(graph):\n", + " broken = True\n", + " while broken:\n", + " broken = False\n", + " for (i, p), dic in list(graph.adjacency()):\n", + " if (\n", + " i != \"\"\n", + " and len(dic) == 2\n", + " and all(prop.get(\"type\", \"C\") == \"C\" for prop in dic.values())\n", + " ):\n", + " graph.remove_node((i, p))\n", + " graph.add_edge(*dic.keys(), type=\"C\", length=0.0)\n", + " broken = True\n", + " break\n", + " return graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34163540-3f11-467b-8f39-dba25e469da9", + "metadata": {}, + "outputs": [], + "source": [ + "graph = prune_internal_output_nodes(graph)\n", + "nx.draw_kamada_kawai(graph, with_labels=True)" + ] + }, + { + "cell_type": "markdown", + "id": "ae4c42c4-2500-4b47-b2ed-701ab271950e", + "metadata": {}, + "source": [ + "We can now get a list of all possible paths in the network. Note that these paths **must** alternate between an S-edge and a C-edge:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e695146", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def get_possible_paths(graph, source, target):\n", + " paths = []\n", + " default_props = {\"type\": \"C\", \"length\": 0.0}\n", + " for path in nx.all_simple_edge_paths(graph, source, target):\n", + " prevtype = \"C\"\n", + " for n1, n2 in path:\n", + " curtype = graph.get_edge_data(n1, n2, default_props)[\"type\"]\n", + " if curtype == prevtype == \"S\":\n", + " break\n", + " else:\n", + " prevtype = curtype\n", + " else:\n", + " paths.append(path)\n", + " return paths" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e769550-89a9-4f6b-88c1-916ec40c7a9f", + "metadata": {}, + "outputs": [], + "source": [ + "paths = get_possible_paths(graph, (\"\", \"in0\"), (\"\", \"out0\"))\n", + "paths" + ] + }, + { + "cell_type": "markdown", + "id": "01f7a5f4-8ea4-4578-b561-b12506b79dd3", + "metadata": {}, + "source": [ + "And the path lengths of those paths can be calculated as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d14050f", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def path_lengths(graph, paths):\n", + " lengths = []\n", + " for path in paths:\n", + " length = zero = jnp.array([0.0], dtype=float)\n", + " default_edge_data = {\"type\": \"C\", \"length\": zero}\n", + " for edge in path:\n", + " edge_data = graph.get_edge_data(*edge, default_edge_data)\n", + " length = (length[None,:] + edge_data.get(\"length\", zero)[:, None]).ravel()\n", + " lengths.append(length)\n", + " return lengths" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a445e1c2-af3f-4953-824d-6fbc89df1499", + "metadata": {}, + "outputs": [], + "source": [ + "path_lengths(graph, paths)" + ] + }, + { + "cell_type": "markdown", + "id": "1c963377-0e97-4396-acc6-469517e1afcf", + "metadata": {}, + "source": [ + "This is all brought together in the additive KLU backend:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40042509", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def evaluate_circuit_additive(\n", + " instances: Dict[str, SDict],\n", + " connections: Dict[str, str],\n", + " ports: Dict[str, str],\n", + "):\n", + " \"\"\"evaluate a circuit for the given sdicts.\"\"\"\n", + " edges = graph_edges(instances, connections, ports)\n", + "\n", + " graph = nx.Graph()\n", + " graph.add_edges_from(edges)\n", + " prune_internal_output_nodes(graph)\n", + "\n", + " sdict = {}\n", + " for source in ports:\n", + " for target in ports:\n", + " paths = get_possible_paths(graph, source=(\"\", source), target=(\"\", target))\n", + " if not paths:\n", + " continue\n", + " sdict[source, target] = path_lengths(graph, paths)\n", + "\n", + " return sdict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ec2aaf0-1545-49a8-9203-1b86d36b3e84", + "metadata": {}, + "outputs": [], + "source": [ + "evaluate_circuit_additive(instances, connections, ports)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/08_circuit.ipynb b/nbs/08_circuit.ipynb new file mode 100644 index 0000000..8df82f8 --- /dev/null +++ b/nbs/08_circuit.ipynb @@ -0,0 +1,367 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5b79c062", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp circuit" + ] + }, + { + "cell_type": "markdown", + "id": "2c9c99fe", + "metadata": {}, + "source": [ + "# Circuit\n", + "\n", + "> SAX Circuits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97203ca8", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "import jax.numpy as jnp\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26a50167", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from functools import partial\n", + "from typing import Dict, Optional, Tuple, Union, cast\n", + "\n", + "from sax.backends import circuit_backends\n", + "from sax.models import coupler, straight\n", + "from sax.multimode import multimode, singlemode\n", + "from sax.netlist import LogicalNetlist, Netlist, logical_netlist, netlist_from_yaml\n", + "from sax.typing_ import Instances, Model, Models, Netlist, Settings, SType, is_netlist\n", + "from sax.utils import _replace_kwargs, get_settings, merge_dicts, update_settings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ca49845", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def circuit(\n", + " *,\n", + " instances: Instances,\n", + " connections: Dict[str, str],\n", + " ports: Dict[str, str],\n", + " models: Optional[Models] = None,\n", + " modes: Optional[Tuple[str, ...]] = None,\n", + " settings: Optional[Settings] = None,\n", + " backend: str = \"default\",\n", + " default_models=None,\n", + ") -> Model:\n", + " # assert valid circuit_backend\n", + " if backend not in circuit_backends:\n", + " raise KeyError(\n", + " f\"circuit backend {backend} not found. Allowed circuit backends: \"\n", + " f\"{', '.join(circuit_backends.keys())}.\"\n", + " )\n", + "\n", + " evaluate_circuit = circuit_backends[backend]\n", + "\n", + " _netlist, _settings, _models = logical_netlist(\n", + " instances=instances,\n", + " connections=connections,\n", + " ports=ports,\n", + " models=models,\n", + " settings=settings,\n", + " default_models=default_models,\n", + " )\n", + "\n", + " for name in list(_models.keys()):\n", + " if is_netlist(_models[name]):\n", + " netlist_model = cast(LogicalNetlist, _models.pop(name))\n", + " instance_model_names = set(netlist_model[\"instances\"].values())\n", + " instance_models = {k: _models[k] for k in instance_model_names}\n", + " netlist_func = circuit_from_netlist(\n", + " netlist=netlist_model,\n", + " models=instance_models,\n", + " backend=backend,\n", + " modes=modes,\n", + " settings=None, # settings are already integrated in netlist by now.\n", + " default_models=default_models,\n", + " )\n", + " _models[name] = netlist_func\n", + "\n", + " if modes is not None:\n", + " maybe_multimode = partial(multimode, modes=modes)\n", + " connections = {\n", + " f\"{p1}@{mode}\": f\"{p2}@{mode}\"\n", + " for p1, p2 in _netlist[\"connections\"].items()\n", + " for mode in modes\n", + " }\n", + " ports = {\n", + " f\"{p1}@{mode}\": f\"{p2}@{mode}\"\n", + " for p1, p2 in _netlist[\"ports\"].items()\n", + " for mode in modes\n", + " }\n", + " else:\n", + " maybe_multimode = partial(singlemode, mode=\"te\")\n", + " connections = _netlist[\"connections\"]\n", + " ports = _netlist[\"ports\"]\n", + "\n", + " def _circuit(**settings: Settings) -> SType:\n", + " settings = merge_dicts(_settings, settings)\n", + " global_settings = {}\n", + " for k in list(settings.keys()):\n", + " if k in _netlist[\"instances\"]:\n", + " continue\n", + " global_settings[k] = settings.pop(k)\n", + " if global_settings:\n", + " settings = cast(\n", + " Dict[str, Settings], update_settings(settings, **global_settings)\n", + " )\n", + " instances: Dict[str, SType] = {}\n", + " for name, model_name in _netlist[\"instances\"].items():\n", + " model = cast(Model, _models[model_name])\n", + " instances[name] = cast(\n", + " SType, maybe_multimode(model(**settings.get(name, {})))\n", + " )\n", + " S = evaluate_circuit(instances, connections, ports)\n", + " return S\n", + "\n", + " settings = {\n", + " name: get_settings(cast(Model, _models[model]))\n", + " for name, model in _netlist[\"instances\"].items()\n", + " }\n", + " settings = merge_dicts(settings, _settings)\n", + " _replace_kwargs(_circuit, **settings)\n", + "\n", + " return _circuit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c6be0ea", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = circuit(\n", + " instances={\n", + " \"lft\": \"coupler\",\n", + " \"top\": \"straight\",\n", + " \"btm\": \"straight\",\n", + " \"rgt\": \"coupler\",\n", + " },\n", + " connections={\n", + " \"lft,out0\": \"btm,in0\",\n", + " \"btm,out0\": \"rgt,in0\",\n", + " \"lft,out1\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in1\",\n", + " },\n", + " ports={\n", + " \"in0\": \"lft,in0\",\n", + " \"in1\": \"lft,in1\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + " models={\n", + " \"straight\": straight,\n", + " \"coupler\": coupler,\n", + " }\n", + ")\n", + "\n", + "mzi?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bca34a72", + "metadata": {}, + "outputs": [], + "source": [ + "result = mzi(top={\"length\": 25.0}, btm={\"length\": 15.0})\n", + "result = {k: approx(jnp.abs(v)) for k, v in result.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f6a301f", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def circuit_from_netlist(\n", + " netlist: Union[LogicalNetlist, Netlist],\n", + " *,\n", + " models: Optional[Models] = None,\n", + " modes: Optional[Tuple[str, ...]] = None,\n", + " settings: Optional[Settings] = None,\n", + " backend: str = \"default\",\n", + " default_models=None,\n", + ") -> Model:\n", + " \"\"\"create a circuit model function from a netlist \"\"\"\n", + " instances = netlist[\"instances\"]\n", + " connections = netlist[\"connections\"]\n", + " ports = netlist[\"ports\"]\n", + " _circuit = circuit(\n", + " instances=instances,\n", + " connections=connections,\n", + " ports=ports,\n", + " models=models,\n", + " modes=modes,\n", + " settings=settings,\n", + " backend=backend,\n", + " default_models=default_models,\n", + " )\n", + " return _circuit" + ] + }, + { + "cell_type": "markdown", + "id": "34d7295c", + "metadata": {}, + "source": [ + "> Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d29f070a", + "metadata": {}, + "outputs": [], + "source": [ + "mzi = circuit_from_netlist(\n", + " netlist = {\n", + " \"instances\": {\n", + " \"lft\": \"coupler\",\n", + " \"top\": \"straight\",\n", + " \"btm\": \"straight\",\n", + " \"rgt\": \"coupler\",\n", + " },\n", + " \"connections\": {\n", + " \"lft,out0\": \"btm,in0\",\n", + " \"btm,out0\": \"rgt,in0\",\n", + " \"lft,out1\": \"top,in0\",\n", + " \"top,out0\": \"rgt,in1\",\n", + " },\n", + " \"ports\": {\n", + " \"in0\": \"lft,in0\",\n", + " \"in1\": \"lft,in1\",\n", + " \"out0\": \"rgt,out0\",\n", + " \"out1\": \"rgt,out1\",\n", + " },\n", + " },\n", + " models={\n", + " \"straight\": straight,\n", + " \"coupler\": coupler,\n", + " }\n", + ")\n", + "\n", + "mzi?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91deb9a6", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def circuit_from_yaml(\n", + " yaml: str,\n", + " *,\n", + " models: Optional[Models] = None,\n", + " modes: Optional[Tuple[str, ...]] = None,\n", + " settings: Optional[Settings] = None,\n", + " backend: str = \"default\",\n", + " default_models=None,\n", + ") -> Model:\n", + " \"\"\"Load a sax circuit from yaml definition\n", + "\n", + " Args:\n", + " yaml: the yaml string to load\n", + " models: a dictionary which maps component names to model functions\n", + " modes: the modes of the simulation (if not given, single mode\n", + " operation is assumed).\n", + " settings: override netlist instance settings. Use this setting to set\n", + " global settings like for example the wavelength 'wl'.\n", + " backend: \"default\" or \"klu\". How the circuit S-parameters are\n", + " calculated. \"klu\" is a CPU-only method which generally speaking is\n", + " much faster for large circuits but cannot be jitted or used for autograd.\n", + " \"\"\"\n", + " netlist, models = netlist_from_yaml(yaml=yaml, models=models, settings=settings)\n", + " circuit = circuit_from_netlist(\n", + " netlist=netlist,\n", + " models=models,\n", + " modes=modes,\n", + " settings=None, # settings are already integrated in the netlist by now\n", + " backend=backend,\n", + " default_models=default_models,\n", + " )\n", + " return circuit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "beaf04ed", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def circuit_from_gdsfactory(\n", + " component,\n", + " *,\n", + " models: Optional[Models] = None,\n", + " modes: Optional[Tuple[str, ...]] = None,\n", + " settings: Optional[Settings] = None,\n", + " backend: str = \"default\",\n", + " default_models=None,\n", + ") -> Model:\n", + " \"\"\"Load a sax circuit from a GDSFactory component\"\"\"\n", + " circuit = circuit_from_netlist(\n", + " component.get_netlist(),\n", + " models=models,\n", + " modes=modes,\n", + " settings=settings,\n", + " backend=backend,\n", + " default_models=default_models,\n", + " )\n", + " return circuit" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/09_nn.ipynb b/nbs/09_nn.ipynb new file mode 100644 index 0000000..a2661ad --- /dev/null +++ b/nbs/09_nn.ipynb @@ -0,0 +1,148 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "96bd487d", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp nn.__init__" + ] + }, + { + "cell_type": "markdown", + "id": "36903b9f", + "metadata": {}, + "source": [ + "# NN - Neural Networks\n", + "\n", + "> Utilitites for creating advanced neural network SAX Models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "266f1980", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "markdown", + "id": "75fbba73", + "metadata": {}, + "source": [ + "## Loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ccb0f09", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "from __future__ import annotations\n", + "\n", + "from sax.nn.loss import huber_loss as huber_loss\n", + "from sax.nn.loss import l2_reg as l2_reg\n", + "from sax.nn.loss import mse as mse" + ] + }, + { + "cell_type": "markdown", + "id": "1a18faa2", + "metadata": {}, + "source": [ + "## Utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dbbfdef3", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax.nn.utils import (\n", + " cartesian_product as cartesian_product,\n", + " denormalize as denormalize,\n", + " get_normalization as get_normalization,\n", + " get_df_columns as get_df_columns,\n", + " normalize as normalize,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ef010b3c", + "metadata": {}, + "source": [ + "## Core" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42a738dd", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax.nn.core import (\n", + " preprocess as preprocess,\n", + " dense as dense,\n", + " generate_dense_weights as generate_dense_weights,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5e96c647", + "metadata": {}, + "source": [ + "## IO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8277e2b1", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax.nn.io import (\n", + " load_nn_weights_json as load_nn_weights_json,\n", + " save_nn_weights_json as save_nn_weights_json,\n", + " get_available_sizes as get_available_sizes,\n", + " get_dense_weights_path as get_dense_weights_path,\n", + " get_norm_path as get_norm_path,\n", + " load_nn_dense as load_nn_dense,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/09a_nn_loss.ipynb b/nbs/09a_nn_loss.ipynb new file mode 100644 index 0000000..fa60126 --- /dev/null +++ b/nbs/09a_nn_loss.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0c9f92fd", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp nn.loss" + ] + }, + { + "cell_type": "markdown", + "id": "3cf6720e", + "metadata": {}, + "source": [ + "# NN - Loss\n", + "\n", + "> loss functions and utilitites for SAX neural networks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20cf7e5a", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d04b4ec7", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from typing import Dict\n", + "\n", + "import jax.numpy as jnp\n", + "from sax.typing_ import ComplexFloat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67d829bc", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def mse(x: ComplexFloat, y: ComplexFloat) -> float:\n", + " \"\"\"mean squared error\"\"\"\n", + " return ((x - y) ** 2).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ecabd990", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def huber_loss(x: ComplexFloat, y: ComplexFloat, delta: float=0.5) -> float:\n", + " \"\"\"huber loss\"\"\"\n", + " return ((delta ** 2) * ((1.0 + ((x - y) / delta) ** 2) ** 0.5 - 1.0)).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "98d84861", + "metadata": {}, + "source": [ + "The huber loss is like the mean squared error close to zero and mean\n", + "absolute error for outliers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c75a2f9", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def l2_reg(weights: Dict[str, ComplexFloat]) -> float:\n", + " \"\"\"L2 regularization loss\"\"\"\n", + " numel = 0\n", + " loss = 0.0\n", + " for w in (v for k, v in weights.items() if k[0] in (\"w\", \"b\")):\n", + " numel = numel + w.size\n", + " loss = loss + (jnp.abs(w) ** 2).sum()\n", + " return loss / numel" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/09b_nn_utils.ipynb b/nbs/09b_nn_utils.ipynb new file mode 100644 index 0000000..e268eaa --- /dev/null +++ b/nbs/09b_nn_utils.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6aa811a", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp nn.utils" + ] + }, + { + "cell_type": "markdown", + "id": "3ac7a21d", + "metadata": {}, + "source": [ + "# NN - Utils\n", + "\n", + "> loss functions and utilitites for SAX neural networks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26e72661", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "220cf534", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from collections import namedtuple\n", + "from typing import Tuple\n", + "\n", + "import jax.numpy as jnp\n", + "import pandas as pd\n", + "from sax.typing_ import ComplexFloat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61a03469", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def cartesian_product(*arrays: ComplexFloat) -> ComplexFloat:\n", + " \"\"\"calculate the n-dimensional cartesian product of an arbitrary number of arrays\"\"\"\n", + " ixarrays = jnp.ix_(*arrays)\n", + " barrays = jnp.broadcast_arrays(*ixarrays)\n", + " sarrays = jnp.stack(barrays, -1)\n", + " assert isinstance(sarrays, jnp.ndarray)\n", + " product = sarrays.reshape(-1, sarrays.shape[-1])\n", + " assert isinstance(product, jnp.ndarray)\n", + " return product" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2be83fa5", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def denormalize(x: ComplexFloat, mean: ComplexFloat = 0.0, std: ComplexFloat = 1.0) -> ComplexFloat:\n", + " \"\"\"denormalize an array with a given mean and standard deviation\"\"\"\n", + " return x * std + mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28a67bb4", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "norm = namedtuple(\"norm\", (\"mean\", \"std\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab609c9b", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def get_normalization(x: ComplexFloat):\n", + " \"\"\"Get mean and standard deviation for a given array\"\"\"\n", + " if isinstance(x, (complex, float)):\n", + " return x, 0.0\n", + " return norm(x.mean(0), x.std(0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d23ada4", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def get_df_columns(df: pd.DataFrame, *names: str) -> Tuple[ComplexFloat, ...]:\n", + " \"\"\"Get certain columns from a pandas DataFrame as jax.numpy arrays\"\"\"\n", + " tup = namedtuple(\"params\", names)\n", + " params_list = []\n", + " for name in names:\n", + " column_np = df[name].values\n", + " column_jnp = jnp.array(column_np)\n", + " assert isinstance(column_jnp, jnp.ndarray)\n", + " params_list.append(column_jnp.ravel())\n", + " return tup(*params_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b93be0b2", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def normalize(\n", + " x: ComplexFloat, mean: ComplexFloat = 0.0, std: ComplexFloat = 1.0\n", + ") -> ComplexFloat:\n", + " \"\"\"normalize an array with a given mean and standard deviation\"\"\"\n", + " return (x - mean) / std" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/09c_nn_core.ipynb b/nbs/09c_nn_core.ipynb new file mode 100644 index 0000000..87dea63 --- /dev/null +++ b/nbs/09c_nn_core.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b11a4aa9", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp nn.core" + ] + }, + { + "cell_type": "markdown", + "id": "ad9dbb7f", + "metadata": {}, + "source": [ + "# NN - Core\n", + "\n", + "> Core for SAX neural networks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3512bc49", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05f67381", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "from typing import Callable, Dict, Optional, Tuple, Union\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from sax.nn.utils import denormalize, normalize\n", + "from sax.typing_ import Array, ComplexFloat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5409f7bf", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def preprocess(*params: ComplexFloat) -> ComplexFloat:\n", + " \"\"\"preprocess parameters\n", + "\n", + " > Note: (1) all arguments are first casted into the same shape. (2) then pairs \n", + " of arguments are divided into each other to create relative arguments. (3) all \n", + " arguments are then stacked into one big tensor\n", + " \"\"\"\n", + " x = jnp.stack(jnp.broadcast_arrays(*params), -1)\n", + " assert isinstance(x, jnp.ndarray)\n", + " to_concatenate = [x]\n", + " for i in range(1, x.shape[-1]):\n", + " _x = jnp.roll(x, shift=i, axis=-1)\n", + " to_concatenate.append(x / _x)\n", + " to_concatenate.append(_x / x)\n", + " x = jnp.concatenate(to_concatenate, -1)\n", + " assert isinstance(x, jnp.ndarray)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e11602bb", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def dense(\n", + " weights: Dict[str, Array],\n", + " *params: ComplexFloat,\n", + " x_norm: Tuple[float, float] = (0.0, 1.0),\n", + " y_norm: Tuple[float, float] = (0.0, 1.0),\n", + " preprocess: Callable = preprocess,\n", + " activation: Callable = jax.nn.leaky_relu,\n", + ") -> ComplexFloat:\n", + " \"\"\"simple dense neural network\"\"\"\n", + " x_mean, x_std = x_norm\n", + " y_mean, y_std = y_norm\n", + " x = preprocess(*params)\n", + " x = normalize(x, mean=x_mean, std=x_std)\n", + " for i in range(len([w for w in weights if w.startswith(\"w\")])):\n", + " x = activation(x @ weights[f\"w{i}\"] + weights.get(f\"b{i}\", 0.0))\n", + " y = denormalize(x, mean=y_mean, std=y_std)\n", + " return y" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f22a32c", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def generate_dense_weights(\n", + " key: Union[int, Array],\n", + " sizes: Tuple[int, ...],\n", + " input_names: Optional[Tuple[str, ...]] = None,\n", + " output_names: Optional[Tuple[str, ...]] = None,\n", + " preprocess=preprocess,\n", + ") -> Dict[str, ComplexFloat]:\n", + " \"\"\"Generate the weights for a dense neural network\"\"\"\n", + "\n", + " if isinstance(key, int):\n", + " key = jax.random.PRNGKey(key)\n", + " assert isinstance(key, jnp.ndarray)\n", + "\n", + " sizes = tuple(s for s in sizes)\n", + " if input_names:\n", + " arr = preprocess(*jnp.ones(len(input_names)))\n", + " assert isinstance(arr, jnp.ndarray)\n", + " sizes = (arr.shape[-1],) + sizes\n", + " if output_names:\n", + " sizes = sizes + (len(output_names),)\n", + "\n", + " keys = jax.random.split(key, 2 * len(sizes))\n", + " rand = jax.nn.initializers.lecun_normal()\n", + " weights = {}\n", + " for i, (m, n) in enumerate(zip(sizes[:-1], sizes[1:])):\n", + " weights[f\"w{i}\"] = rand(keys[2 * i], (m, n))\n", + " weights[f\"b{i}\"] = rand(keys[2 * i + 1], (1, n)).ravel()\n", + "\n", + " return weights" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/09d_nn_io.ipynb b/nbs/09d_nn_io.ipynb new file mode 100644 index 0000000..c6b3a76 --- /dev/null +++ b/nbs/09d_nn_io.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ddc8ac6c", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp nn.io" + ] + }, + { + "cell_type": "markdown", + "id": "a0489f95", + "metadata": {}, + "source": [ + "# NN - IO\n", + "\n", + "> IO Utilitites for SAX neural networks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ede352c", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01041a0d", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "from __future__ import annotations\n", + "\n", + "import json\n", + "import os\n", + "import re\n", + "from typing import Callable, Dict, List, Optional, Tuple\n", + "\n", + "import jax.numpy as jnp\n", + "from sax.nn.core import dense, preprocess\n", + "from sax.nn.utils import norm\n", + "from sax.typing_ import ComplexFloat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "626189c9", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def load_nn_weights_json(path: str) -> Dict[str, ComplexFloat]:\n", + " \"\"\"Load json weights from given path\"\"\"\n", + " path = os.path.abspath(os.path.expanduser(path))\n", + " weights = {}\n", + " if os.path.exists(path):\n", + " with open(path, \"r\") as file:\n", + " for k, v in json.load(file).items():\n", + " _v = jnp.array(v, dtype=float)\n", + " assert isinstance(_v, jnp.ndarray)\n", + " weights[k] = _v\n", + " return weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5aa53e90", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def save_nn_weights_json(weights: Dict[str, ComplexFloat], path: str):\n", + " \"\"\"Save json weights to given path\"\"\"\n", + " path = os.path.abspath(os.path.expanduser(path))\n", + " os.makedirs(os.path.dirname(path), exist_ok=True)\n", + " with open(path, \"w\") as file:\n", + " _weights = {}\n", + " for k, v in weights.items():\n", + " v = jnp.atleast_1d(jnp.array(v))\n", + " assert isinstance(v, jnp.ndarray)\n", + " _weights[k] = v.tolist()\n", + " json.dump(_weights, file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dad70875", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def get_available_sizes(\n", + " dirpath: str,\n", + " prefix: str,\n", + " input_names: Tuple[str, ...],\n", + " output_names: Tuple[str, ...],\n", + ") -> List[Tuple[int, ...]]:\n", + " \"\"\"Get all available json weight hidden sizes given filename parameters\n", + "\n", + " > Note: this function does NOT return the input size and the output size \n", + " of the neural network. ONLY the hidden sizes are reported. The input \n", + " and output sizes can easily be derived from `input_names` (after \n", + " preprocessing) and `output_names`.\n", + " \"\"\"\n", + " all_weightfiles = os.listdir(dirpath)\n", + " possible_weightfiles = (\n", + " s for s in all_weightfiles if s.endswith(f\"-{'-'.join(output_names)}.json\")\n", + " )\n", + " possible_weightfiles = (\n", + " s\n", + " for s in possible_weightfiles\n", + " if s.startswith(f\"{prefix}-{'-'.join(input_names)}\")\n", + " )\n", + " possible_weightfiles = (re.sub(\"[^0-9x]\", \"\", s) for s in possible_weightfiles)\n", + " possible_weightfiles = (re.sub(\"^x*\", \"\", s) for s in possible_weightfiles)\n", + " possible_weightfiles = (re.sub(\"x[^0-9]*$\", \"\", s) for s in possible_weightfiles)\n", + " possible_hidden_sizes = (s.strip() for s in possible_weightfiles if s.strip())\n", + " possible_hidden_sizes = (\n", + " tuple(hs.strip() for hs in s.split(\"x\") if hs.strip())\n", + " for s in possible_hidden_sizes\n", + " )\n", + " possible_hidden_sizes = (\n", + " tuple(int(hs) for hs in s[1:-1]) for s in possible_hidden_sizes if len(s) > 2\n", + " )\n", + " possible_hidden_sizes = sorted(\n", + " possible_hidden_sizes, key=lambda hs: (len(hs), max(hs))\n", + " )\n", + " return possible_hidden_sizes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d619994", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def get_dense_weights_path(\n", + " *sizes: int,\n", + " input_names: Optional[Tuple[str, ...]] = None,\n", + " output_names: Optional[Tuple[str, ...]] = None,\n", + " dirpath: str = \"weights\",\n", + " prefix: str = \"dense\",\n", + " preprocess=preprocess,\n", + "):\n", + " \"\"\"Create the SAX conventional path for a given weight dictionary\"\"\"\n", + " if input_names:\n", + " num_inputs = preprocess(*jnp.ones(len(input_names))).shape[0]\n", + " sizes = (num_inputs,) + sizes\n", + " if output_names:\n", + " sizes = sizes + (len(output_names),)\n", + " path = os.path.abspath(os.path.join(dirpath, prefix))\n", + " if input_names:\n", + " path = f\"{path}-{'-'.join(input_names)}\"\n", + " if sizes:\n", + " path = f\"{path}-{'x'.join(str(s) for s in sizes)}\"\n", + " if output_names:\n", + " path = f\"{path}-{'-'.join(output_names)}\"\n", + " return f\"{path}.json\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cab40f3", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "\n", + "def get_norm_path(\n", + " *shape: int,\n", + " input_names: Optional[Tuple[str, ...]] = None,\n", + " output_names: Optional[Tuple[str, ...]] = None,\n", + " dirpath: str = \"norms\",\n", + " prefix: str = \"norm\",\n", + " preprocess=preprocess,\n", + "):\n", + " \"\"\"Create the SAX conventional path for the normalization constants\"\"\"\n", + " if input_names and output_names:\n", + " raise ValueError(\n", + " \"To get the norm name, one can only specify `input_names` OR `output_names`.\"\n", + " )\n", + " if input_names:\n", + " num_inputs = preprocess(*jnp.ones(len(input_names))).shape[0]\n", + " shape = (num_inputs,) + shape\n", + " if output_names:\n", + " shape = shape + (len(output_names),)\n", + " path = os.path.abspath(os.path.join(dirpath, prefix))\n", + " if input_names:\n", + " path = f\"{path}-{'-'.join(input_names)}\"\n", + " if shape:\n", + " path = f\"{path}-{'x'.join(str(s) for s in shape)}\"\n", + " if output_names:\n", + " path = f\"{path}-{'-'.join(output_names)}\"\n", + " return f\"{path}.json\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd6d5bf8", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "class _PartialDense:\n", + " def __init__(self, weights, x_norm, y_norm, input_names, output_names):\n", + " self.weights = weights\n", + " self.x_norm = x_norm\n", + " self.y_norm = y_norm\n", + " self.input_names = input_names\n", + " self.output_names = output_names\n", + "\n", + " def __call__(self, *params: ComplexFloat) -> ComplexFloat:\n", + " return dense(self.weights, *params, x_norm=self.x_norm, y_norm=self.y_norm)\n", + "\n", + " def __repr__(self):\n", + " return f\"{self.__class__.__name__}{repr(self.input_names)}->{repr(self.output_names)}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b63025ba", + "metadata": {}, + "outputs": [], + "source": [ + "# export\n", + "def load_nn_dense(\n", + " *sizes: int,\n", + " input_names: Optional[Tuple[str, ...]] = None,\n", + " output_names: Optional[Tuple[str, ...]] = None,\n", + " weightprefix=\"dense\",\n", + " weightdirpath=\"weights\",\n", + " normdirpath=\"norms\",\n", + " normprefix=\"norm\",\n", + " preprocess=preprocess,\n", + ") -> Callable:\n", + " \"\"\"Load a pre-trained dense model\"\"\"\n", + " weights_path = get_dense_weights_path(\n", + " *sizes,\n", + " input_names=input_names,\n", + " output_names=output_names,\n", + " prefix=weightprefix,\n", + " dirpath=weightdirpath,\n", + " preprocess=preprocess,\n", + " )\n", + " if not os.path.exists(weights_path):\n", + " raise ValueError(\"Cannot find weights path for given parameters\")\n", + " x_norm_path = get_norm_path(\n", + " input_names=input_names,\n", + " prefix=normprefix,\n", + " dirpath=normdirpath,\n", + " preprocess=preprocess,\n", + " )\n", + " if not os.path.exists(x_norm_path):\n", + " raise ValueError(\"Cannot find normalization for input parameters\")\n", + " y_norm_path = get_norm_path(\n", + " output_names=output_names,\n", + " prefix=normprefix,\n", + " dirpath=normdirpath,\n", + " preprocess=preprocess,\n", + " )\n", + " if not os.path.exists(x_norm_path):\n", + " raise ValueError(\"Cannot find normalization for output parameters\")\n", + " weights = load_nn_weights_json(weights_path)\n", + " x_norm_dict = load_nn_weights_json(x_norm_path)\n", + " y_norm_dict = load_nn_weights_json(y_norm_path)\n", + " x_norm = norm(x_norm_dict[\"mean\"], x_norm_dict[\"std\"])\n", + " y_norm = norm(y_norm_dict[\"mean\"], y_norm_dict[\"std\"])\n", + " partial_dense = _PartialDense(weights, x_norm, y_norm, input_names, output_names)\n", + " return partial_dense" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/99_init.ipynb b/nbs/99_init.ipynb new file mode 100644 index 0000000..8442a9b --- /dev/null +++ b/nbs/99_init.ipynb @@ -0,0 +1,375 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6c5b8ffd", + "metadata": {}, + "outputs": [], + "source": [ + "# default_exp __init__" + ] + }, + { + "cell_type": "markdown", + "id": "b4250b31", + "metadata": {}, + "source": [ + "# SAX init\n", + "\n", + "> Import everything into the `sax` namespace:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e82c62a", + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "import matplotlib.pyplot as plt\n", + "from fastcore.test import test_eq\n", + "from pytest import approx, raises\n", + "\n", + "import os, sys; sys.stderr = open(os.devnull, \"w\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d33aee88", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "__author__ = \"Floris Laporte\"\n", + "__version__ = \"0.6.0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70247f53", + "metadata": {}, + "outputs": [], + "source": [ + "# exporti\n", + "from __future__ import annotations" + ] + }, + { + "cell_type": "markdown", + "id": "fe77e5e3", + "metadata": {}, + "source": [ + "## External\n", + "\n", + "utils from other packages available in SAX for convenience:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0f1afc3", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from functools import partial as partial\n", + "from math import pi as pi\n", + "\n", + "from flax.core.frozen_dict import FrozenDict as FrozenDict\n", + "from scipy.constants import c as c" + ] + }, + { + "cell_type": "markdown", + "id": "3d3bde86", + "metadata": {}, + "source": [ + "## Typing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72cfc443", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import typing_ as typing\n", + "from sax.typing_ import (\n", + " Array as Array,\n", + " ComplexFloat as ComplexFloat,\n", + " Float as Float,\n", + " Instance as Instance,\n", + " Instances as Instances,\n", + " LogicalNetlist as LogicalNetlist,\n", + " Model as Model,\n", + " ModelFactory as ModelFactory,\n", + " Models as Models,\n", + " Netlist as Netlist,\n", + " SCoo as SCoo,\n", + " SDense as SDense,\n", + " SDict as SDict,\n", + " Settings as Settings,\n", + " SType as SType,\n", + " is_complex as is_complex,\n", + " is_complex_float as is_complex_float,\n", + " is_float as is_float,\n", + " is_instance as is_instance,\n", + " is_mixedmode as is_mixedmode,\n", + " is_model as is_model,\n", + " is_model_factory as is_model_factory,\n", + " is_multimode as is_multimode,\n", + " is_netlist as is_netlist,\n", + " is_scoo as is_scoo,\n", + " is_sdense as is_sdense,\n", + " is_sdict as is_sdict,\n", + " is_singlemode as is_singlemode,\n", + " modelfactory as modelfactory,\n", + " scoo as scoo,\n", + " sdense as sdense,\n", + " sdict as sdict,\n", + " validate_model as validate_model,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "345808d3", + "metadata": {}, + "source": [ + "## Utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8db0fa9e", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import utils as utils\n", + "from sax.utils import (\n", + " block_diag as block_diag,\n", + " clean_string as clean_string,\n", + " copy_settings as copy_settings,\n", + " flatten_dict as flatten_dict,\n", + " get_inputs_outputs as get_inputs_outputs,\n", + " get_port_combinations as get_port_combinations,\n", + " get_ports as get_ports,\n", + " get_settings as get_settings,\n", + " grouped_interp as grouped_interp,\n", + " merge_dicts as merge_dicts,\n", + " mode_combinations as mode_combinations,\n", + " reciprocal as reciprocal,\n", + " rename_params as rename_params,\n", + " rename_ports as rename_ports,\n", + " try_float as try_float,\n", + " unflatten_dict as unflatten_dict,\n", + " update_settings as update_settings,\n", + " validate_multimode as validate_multimode,\n", + " validate_not_mixedmode as validate_not_mixedmode,\n", + " validate_sdict as validate_sdict,\n", + " validate_settings as validate_settings,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d80b6784", + "metadata": {}, + "source": [ + "## Caching" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98e7f4f0", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import caching as caching\n", + "from sax.caching import (\n", + " cache as cache, \n", + " cache_clear as cache_clear,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "23287a09", + "metadata": {}, + "source": [ + "## Multimode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e796047f", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import multimode as multimode\n", + "from sax.multimode import (\n", + " multimode as multimode,\n", + " singlemode as singlemode,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c794abf0", + "metadata": {}, + "source": [ + "## Models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1a070d5", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import models as models\n", + "from sax.models import get_models as get_models, passthru as passthru" + ] + }, + { + "cell_type": "markdown", + "id": "b94feb62", + "metadata": {}, + "source": [ + "## Netlist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb6b8b12", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import netlist as netlist\n", + "from sax.netlist import (\n", + " logical_netlist as logical_netlist,\n", + " netlist as netlist,\n", + " netlist_from_yaml as netlist_from_yaml,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5792b5e9", + "metadata": {}, + "source": [ + "## Circuit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e019e31e", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import circuit as circuit\n", + "from sax.circuit import (\n", + " circuit as circuit,\n", + " circuit_from_gdsfactory as circuit_from_gdsfactory,\n", + " circuit_from_netlist as circuit_from_netlist,\n", + " circuit_from_yaml as circuit_from_yaml,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e2d7c9fc", + "metadata": {}, + "source": [ + "## Backend" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3486a62e", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import backends as backends" + ] + }, + { + "cell_type": "markdown", + "id": "754759b7", + "metadata": {}, + "source": [ + "## Neural Networks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89955a98", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import nn as nn" + ] + }, + { + "cell_type": "markdown", + "id": "3ecf363e", + "metadata": {}, + "source": [ + "## Patches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a6af34c", + "metadata": {}, + "outputs": [], + "source": [ + "# exports\n", + "\n", + "from sax import patched as _patched" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sax", + "language": "python", + "name": "sax" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1eb44d0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,11 @@ +[build-system] +requires = ["setuptools", "pip", "build", "wheel"] +build_backend = "setuptools.build_meta" + +[tool.black] +line-length = 88 +target-version = ['py38'] +include = '\.pyi?$' + +[tool.pyright] +reportPrivateImportUsage = false diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 6331d6e..0000000 --- a/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -jax==0.2.0 -jaxlib==0.1.55 -sax -tqdm -ipykernel -matplotlib -sphinx -nbsphinx -sphinx-rtd-theme -tmm diff --git a/sax/__init__.py b/sax/__init__.py index 55f3b37..5de8431 100644 --- a/sax/__init__.py +++ b/sax/__init__.py @@ -1,21 +1,140 @@ -""" SAX """ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/99_init.ipynb (unless otherwise specified). + +from __future__ import annotations + + +__all__ = [] + +# Internal Cell __author__ = "Floris Laporte" -__version__ = "0.0.1" +__version__ = "0.6.0" +# Internal Cell +#nbdev_comment from __future__ import annotations -from . import core -from . import utils -from . import models -from . import constants +# Cell -from .core import modelgenerator, circuit +from functools import partial as partial +from math import pi as pi + +from flax.core.frozen_dict import FrozenDict as FrozenDict +from scipy.constants import c as c + +# Cell + +from sax import typing_ as typing +from .typing_ import ( + Array as Array, + ComplexFloat as ComplexFloat, + Float as Float, + Instance as Instance, + Instances as Instances, + LogicalNetlist as LogicalNetlist, + Model as Model, + ModelFactory as ModelFactory, + Models as Models, + Netlist as Netlist, + SCoo as SCoo, + SDense as SDense, + SDict as SDict, + Settings as Settings, + SType as SType, + is_complex as is_complex, + is_complex_float as is_complex_float, + is_float as is_float, + is_instance as is_instance, + is_mixedmode as is_mixedmode, + is_model as is_model, + is_model_factory as is_model_factory, + is_multimode as is_multimode, + is_netlist as is_netlist, + is_scoo as is_scoo, + is_sdense as is_sdense, + is_sdict as is_sdict, + is_singlemode as is_singlemode, + modelfactory as modelfactory, + scoo as scoo, + sdense as sdense, + sdict as sdict, + validate_model as validate_model, +) + +# Cell + +from sax import utils as utils from .utils import ( - load, - save, - set_global_params, - rename_ports, - get_ports, - copy_params, - validate_params, + block_diag as block_diag, + clean_string as clean_string, + copy_settings as copy_settings, + flatten_dict as flatten_dict, + get_inputs_outputs as get_inputs_outputs, + get_port_combinations as get_port_combinations, + get_ports as get_ports, + get_settings as get_settings, + grouped_interp as grouped_interp, + merge_dicts as merge_dicts, + mode_combinations as mode_combinations, + reciprocal as reciprocal, + rename_params as rename_params, + rename_ports as rename_ports, + try_float as try_float, + unflatten_dict as unflatten_dict, + update_settings as update_settings, + validate_multimode as validate_multimode, + validate_not_mixedmode as validate_not_mixedmode, + validate_sdict as validate_sdict, + validate_settings as validate_settings, +) + +# Cell + +from sax import caching as caching +from .caching import ( + cache as cache, + cache_clear as cache_clear, +) + +# Cell + +from sax import multimode as multimode +from .multimode import ( + multimode as multimode, + singlemode as singlemode, +) + +# Cell + +from sax import models as models +from .models import get_models as get_models, passthru as passthru + +# Cell + +from sax import netlist as netlist +from .netlist import ( + logical_netlist as logical_netlist, + netlist as netlist, + netlist_from_yaml as netlist_from_yaml, ) + +# Cell + +from sax import circuit as circuit +from .circuit import ( + circuit as circuit, + circuit_from_gdsfactory as circuit_from_gdsfactory, + circuit_from_netlist as circuit_from_netlist, + circuit_from_yaml as circuit_from_yaml, +) + +# Cell + +from sax import backends as backends + +# Cell + +from sax import nn as nn + +# Cell + +from sax import patched as _patched \ No newline at end of file diff --git a/sax/_nbdev.py b/sax/_nbdev.py new file mode 100644 index 0000000..5abcdc1 --- /dev/null +++ b/sax/_nbdev.py @@ -0,0 +1,137 @@ +# AUTOGENERATED BY NBDEV! DO NOT EDIT! + +__all__ = ["index", "modules", "custom_doc_links", "git_url"] + +index = {"Array": "00_typing.ipynb", + "Int": "00_typing.ipynb", + "Float": "00_typing.ipynb", + "ComplexFloat": "00_typing.ipynb", + "Settings": "00_typing.ipynb", + "SDict": "00_typing.ipynb", + "SCoo": "00_typing.ipynb", + "SDense": "00_typing.ipynb", + "SType": "00_typing.ipynb", + "Model": "00_typing.ipynb", + "ModelFactory": "00_typing.ipynb", + "GeneralModel": "00_typing.ipynb", + "Models": "00_typing.ipynb", + "Instance": "00_typing.ipynb", + "GeneralInstance": "00_typing.ipynb", + "Instances": "00_typing.ipynb", + "Netlist": "00_typing.ipynb", + "LogicalNetlist": "00_typing.ipynb", + "is_float": "00_typing.ipynb", + "is_complex": "00_typing.ipynb", + "is_complex_float": "00_typing.ipynb", + "is_sdict": "00_typing.ipynb", + "is_scoo": "00_typing.ipynb", + "is_sdense": "00_typing.ipynb", + "is_model": "00_typing.ipynb", + "is_model_factory": "00_typing.ipynb", + "validate_model": "00_typing.ipynb", + "is_instance": "00_typing.ipynb", + "is_netlist": "00_typing.ipynb", + "is_stype": "00_typing.ipynb", + "is_singlemode": "00_typing.ipynb", + "is_multimode": "00_typing.ipynb", + "is_mixedmode": "00_typing.ipynb", + "sdict": "00_typing.ipynb", + "scoo": "00_typing.ipynb", + "sdense": "00_typing.ipynb", + "modelfactory": "00_typing.ipynb", + "__repr__": "01_patched.ipynb", + "block_diag": "02_utils.ipynb", + "clean_string": "02_utils.ipynb", + "copy_settings": "02_utils.ipynb", + "validate_settings": "02_utils.ipynb", + "try_float": "02_utils.ipynb", + "flatten_dict": "02_utils.ipynb", + "unflatten_dict": "02_utils.ipynb", + "get_ports": "02_utils.ipynb", + "get_port_combinations": "02_utils.ipynb", + "get_settings": "02_utils.ipynb", + "grouped_interp": "02_utils.ipynb", + "merge_dicts": "02_utils.ipynb", + "mode_combinations": "02_utils.ipynb", + "reciprocal": "02_utils.ipynb", + "rename_params": "02_utils.ipynb", + "rename_ports": "02_utils.ipynb", + "update_settings": "02_utils.ipynb", + "validate_not_mixedmode": "02_utils.ipynb", + "validate_multimode": "02_utils.ipynb", + "validate_sdict": "02_utils.ipynb", + "get_inputs_outputs": "02_utils.ipynb", + "cache": "03_caching.ipynb", + "cache_clear": "03_caching.ipynb", + "multimode": "04_multimode.ipynb", + "singlemode": "04_multimode.ipynb", + "straight": "05_models.ipynb", + "coupler": "05_models.ipynb", + "unitary": "05_models.ipynb", + "passthru": "05_models.ipynb", + "copier": "05_models.ipynb", + "get_models": "05_models.ipynb", + "models": "05_models.ipynb", + "netlist": "06_netlist.ipynb", + "netlist_from_yaml": "06_netlist.ipynb", + "logical_netlist": "06_netlist.ipynb", + "circuit_backends": "07_backends.ipynb", + "evaluate_circuit": "07a_backends_default.ipynb", + "solve_klu": "07b_backends_klu.ipynb", + "mul_coo": "07b_backends_klu.ipynb", + "evaluate_circuit_klu": "07b_backends_klu.ipynb", + "split_port": "07c_backends_additive.ipynb", + "graph_edges": "07c_backends_additive.ipynb", + "prune_internal_output_nodes": "07c_backends_additive.ipynb", + "get_possible_paths": "07c_backends_additive.ipynb", + "path_lengths": "07c_backends_additive.ipynb", + "evaluate_circuit_additive": "07c_backends_additive.ipynb", + "circuit": "08_circuit.ipynb", + "circuit_from_netlist": "08_circuit.ipynb", + "circuit_from_yaml": "08_circuit.ipynb", + "circuit_from_gdsfactory": "08_circuit.ipynb", + "mse": "09a_nn_loss.ipynb", + "huber_loss": "09a_nn_loss.ipynb", + "l2_reg": "09a_nn_loss.ipynb", + "cartesian_product": "09b_nn_utils.ipynb", + "denormalize": "09b_nn_utils.ipynb", + "norm": "09b_nn_utils.ipynb", + "get_normalization": "09b_nn_utils.ipynb", + "get_df_columns": "09b_nn_utils.ipynb", + "normalize": "09b_nn_utils.ipynb", + "preprocess": "09c_nn_core.ipynb", + "dense": "09c_nn_core.ipynb", + "generate_dense_weights": "09c_nn_core.ipynb", + "load_nn_weights_json": "09d_nn_io.ipynb", + "save_nn_weights_json": "09d_nn_io.ipynb", + "get_available_sizes": "09d_nn_io.ipynb", + "get_dense_weights_path": "09d_nn_io.ipynb", + "get_norm_path": "09d_nn_io.ipynb", + "load_nn_dense": "09d_nn_io.ipynb", + "__author__": "99_init.ipynb", + "__version__": "99_init.ipynb"} + +modules = ["typing_.py", + "patched.py", + "utils.py", + "caching.py", + "multimode.py", + "models.py", + "netlist.py", + "backends/__init__.py", + "backends/default.py", + "backends/klu.py", + "backends/additive.py", + "circuit.py", + "nn/__init__.py", + "nn/loss.py", + "nn/utils.py", + "nn/core.py", + "nn/io.py", + "__init__.py"] + +doc_url = "https://flaport.github.io/sax/" + +git_url = "https://github.com/flaport/sax/" + +def custom_doc_links(name): return None diff --git a/sax/backends/__init__.py b/sax/backends/__init__.py new file mode 100644 index 0000000..39a7c08 --- /dev/null +++ b/sax/backends/__init__.py @@ -0,0 +1,17 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/07_backends.ipynb (unless otherwise specified). + +__all__ = ['circuit_backends'] + +# Internal Cell + +from .default import evaluate_circuit +from .klu import evaluate_circuit_klu +from .additive import evaluate_circuit_additive + +# Cell + +circuit_backends = { + "default": evaluate_circuit, + "klu": evaluate_circuit_klu, + "additive": evaluate_circuit_additive, +} \ No newline at end of file diff --git a/sax/backends/additive.py b/sax/backends/additive.py new file mode 100644 index 0000000..dd95590 --- /dev/null +++ b/sax/backends/additive.py @@ -0,0 +1,122 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/07c_backends_additive.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['split_port', 'graph_edges', 'prune_internal_output_nodes', 'get_possible_paths', 'path_lengths', + 'evaluate_circuit_additive'] + +# Cell +#nbdev_comment from __future__ import annotations + +from typing import Dict, Tuple + +import jax.numpy as jnp +import networkx as nx +from ..typing_ import SDict, SType, SDense, sdict +from ..netlist import netlist + +# Cell +def split_port(port: str) -> Tuple[str, str]: + try: + instance, port = port.split(",") + except ValueError: + (port,) = port.split(",") + instance = "" + return instance, port + +# Cell +def graph_edges( + instances: Dict[str, SType], connections: Dict[str, str], ports: Dict[str, str] +): + zero = jnp.array([0.0], dtype=float) + edges = {} + edges.update({split_port(k): split_port(v) for k, v in connections.items()}) + edges.update({split_port(v): split_port(k) for k, v in connections.items()}) + edges.update({split_port(k): split_port(v) for k, v in ports.items()}) + edges.update({split_port(v): split_port(k) for k, v in ports.items()}) + edges = [(n1, n2, {"type": "C", "length": zero}) for n1, n2 in edges.items()] + + _instances = { + **{i1: None for (i1, _), (_, _), _ in edges}, + **{i2: None for (_, _), (i2, _), _ in edges}, + } + del _instances[""] # external ports don't belong to an instance + + for instance in _instances: + s = instances[instance] + edges += [ + ((instance, p1), (instance, p2), {"type": "S", "length": jnp.asarray(length, dtype=float).ravel()}) + for (p1, p2), length in sdict(s).items() + ] + + return edges + +# Cell +def prune_internal_output_nodes(graph): + broken = True + while broken: + broken = False + for (i, p), dic in list(graph.adjacency()): + if ( + i != "" + and len(dic) == 2 + and all(prop.get("type", "C") == "C" for prop in dic.values()) + ): + graph.remove_node((i, p)) + graph.add_edge(*dic.keys(), type="C", length=0.0) + broken = True + break + return graph + +# Cell +def get_possible_paths(graph, source, target): + paths = [] + default_props = {"type": "C", "length": 0.0} + for path in nx.all_simple_edge_paths(graph, source, target): + prevtype = "C" + for n1, n2 in path: + curtype = graph.get_edge_data(n1, n2, default_props)["type"] + if curtype == prevtype == "S": + break + else: + prevtype = curtype + else: + paths.append(path) + return paths + +# Cell +def path_lengths(graph, paths): + lengths = [] + for path in paths: + length = zero = jnp.array([0.0], dtype=float) + default_edge_data = {"type": "C", "length": zero} + for edge in path: + edge_data = graph.get_edge_data(*edge, default_edge_data) + length = (length[None,:] + edge_data.get("length", zero)[:, None]).ravel() + lengths.append(length) + return lengths + +# Cell +def evaluate_circuit_additive( + instances: Dict[str, SDict], + connections: Dict[str, str], + ports: Dict[str, str], +): + """evaluate a circuit for the given sdicts.""" + edges = graph_edges(instances, connections, ports) + + graph = nx.Graph() + graph.add_edges_from(edges) + prune_internal_output_nodes(graph) + + sdict = {} + for source in ports: + for target in ports: + paths = get_possible_paths(graph, source=("", source), target=("", target)) + if not paths: + continue + sdict[source, target] = path_lengths(graph, paths) + + return sdict \ No newline at end of file diff --git a/sax/backends/default.py b/sax/backends/default.py new file mode 100644 index 0000000..106a09e --- /dev/null +++ b/sax/backends/default.py @@ -0,0 +1,120 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/07a_backends_default.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['evaluate_circuit'] + +# Cell +#nbdev_comment from __future__ import annotations + +from typing import Dict + +import jax +from ..typing_ import SType, SDict, sdense, sdict + +# Cell + +def evaluate_circuit( + instances: Dict[str, SType], + connections: Dict[str, str], + ports: Dict[str, str], +) -> SDict: + """evaluate a circuit for the given sdicts.""" + + # it's actually easier working w reverse: + reversed_ports = {v: k for k, v in ports.items()} + + block_diag = {} + for name, S in instances.items(): + block_diag.update( + {(f"{name},{p1}", f"{name},{p2}"): v for (p1, p2), v in sdict(S).items()} + ) + + sorted_connections = sorted(connections.items(), key=_connections_sort_key) + all_connected_instances = {k: {k} for k in instances} + + for k, l in sorted_connections: + name1, _ = k.split(",") + name2, _ = l.split(",") + + connected_instances = ( + all_connected_instances[name1] | all_connected_instances[name2] + ) + for name in connected_instances: + all_connected_instances[name] = connected_instances + + current_ports = tuple( + p + for instance in connected_instances + for p in set([p for p, _ in block_diag] + [p for _, p in block_diag]) + if p.startswith(f"{instance},") + ) + + block_diag.update(_interconnect_ports(block_diag, current_ports, k, l)) + + for i, j in list(block_diag.keys()): + is_connected = i == k or i == l or j == k or j == l + is_in_output_ports = i in reversed_ports and j in reversed_ports + if is_connected and not is_in_output_ports: + del block_diag[i, j] # we're no longer interested in these port combinations + + circuit_sdict: SDict = { + (reversed_ports[i], reversed_ports[j]): v + for (i, j), v in block_diag.items() + if i in reversed_ports and j in reversed_ports + } + return circuit_sdict + + +def _connections_sort_key(connection): + """sort key for sorting a connection dictionary """ + part1, part2 = connection + name1, _ = part1.split(",") + name2, _ = part2.split(",") + return (min(name1, name2), max(name1, name2)) + + +def _interconnect_ports(block_diag, current_ports, k, l): + """interconnect two ports in a given model + + > Note: the interconnect algorithm is based on equation 6 of 'Filipsson, Gunnar. + "A new general computer algorithm for S-matrix calculation of interconnected + multiports." 11th European Microwave Conference. IEEE, 1981.' + """ + current_block_diag = {} + for i in current_ports: + for j in current_ports: + vij = _calculate_interconnected_value( + vij=block_diag.get((i, j), 0.0), + vik=block_diag.get((i, k), 0.0), + vil=block_diag.get((i, l), 0.0), + vkj=block_diag.get((k, j), 0.0), + vkk=block_diag.get((k, k), 0.0), + vkl=block_diag.get((k, l), 0.0), + vlj=block_diag.get((l, j), 0.0), + vlk=block_diag.get((l, k), 0.0), + vll=block_diag.get((l, l), 0.0), + ) + current_block_diag[i, j] = vij + return current_block_diag + + +@jax.jit +def _calculate_interconnected_value(vij, vik, vil, vkj, vkk, vkl, vlj, vlk, vll): + """Calculate an interconnected S-parameter value + + Note: + The interconnect algorithm is based on equation 6 in the paper below:: + + Filipsson, Gunnar. "A new general computer algorithm for S-matrix calculation + of interconnected multiports." 11th European Microwave Conference. IEEE, 1981. + """ + result = vij + ( + vkj * vil * (1 - vlk) + + vlj * vik * (1 - vkl) + + vkj * vll * vik + + vlj * vkk * vil + ) / ((1 - vkl) * (1 - vlk) - vkk * vll) + return result \ No newline at end of file diff --git a/sax/backends/klu.py b/sax/backends/klu.py new file mode 100644 index 0000000..083d9ac --- /dev/null +++ b/sax/backends/klu.py @@ -0,0 +1,128 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/07b_backends_klu.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['solve_klu', 'mul_coo', 'evaluate_circuit_klu'] + +# Cell +#nbdev_comment from __future__ import annotations + +from typing import Dict + +import jax +import jax.numpy as jnp +from ..typing_ import SDense, SDict, SType, scoo + +try: + import klujax +except ImportError: + klujax = None + +# Cell +solve_klu = None +if klujax is not None: + solve_klu = jax.vmap(klujax.solve, (None, None, 0, None), 0) + +# Internal Cell + +# @jax.jit # TODO: make this available to autograd +# def mul_coo(Ai, Aj, Ax, b): +# result = jnp.zeros_like(b).at[..., Ai, :].add(Ax[..., :, None] * b[..., Aj, :]) +# return result + +# Cell +mul_coo = None +if klujax is not None: + mul_coo = jax.vmap(klujax.coo_mul_vec, (None, None, 0, 0), 0) + +# Cell +def evaluate_circuit_klu( + instances: Dict[str, SType], + connections: Dict[str, str], + ports: Dict[str, str], +): + """evaluate a circuit using KLU for the given sdicts. """ + + if klujax is None: + raise ImportError( + "Could not import 'klujax'. " + "Please install it first before using backend method 'klu'" + ) + + assert solve_klu is not None + assert mul_coo is not None + + connections = {**connections, **{v: k for k, v in connections.items()}} + inverse_ports = {v: k for k, v in ports.items()} + port_map = {k: i for i, k in enumerate(ports)} + + idx, Si, Sj, Sx, instance_ports = 0, [], [], [], {} + batch_shape = () + for name, instance in instances.items(): + si, sj, sx, ports_map = scoo(instance) + Si.append(si + idx) + Sj.append(sj + idx) + Sx.append(sx) + if len(sx.shape[:-1]) > len(batch_shape): + batch_shape = sx.shape[:-1] + instance_ports.update({f"{name},{p}": i + idx for p, i in ports_map.items()}) + idx += len(ports_map) + + Si = jnp.concatenate(Si, -1) + Sj = jnp.concatenate(Sj, -1) + Sx = jnp.concatenate( + [jnp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1 + ) + + n_col = idx + n_rhs = len(port_map) + + Cmap = { + int(instance_ports[k]): int(instance_ports[v]) for k, v in connections.items() + } + Ci = jnp.array(list(Cmap.keys()), dtype=jnp.int32) + Cj = jnp.array(list(Cmap.values()), dtype=jnp.int32) + + Cextmap = {int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items()} + Cexti = jnp.stack(list(Cextmap.keys()), 0) + Cextj = jnp.stack(list(Cextmap.values()), 0) + Cext = jnp.zeros((n_col, n_rhs), dtype=complex).at[Cexti, Cextj].set(1.0) + + # TODO: make this block jittable... + Ix = jnp.ones((*batch_shape, n_col)) + Ii = Ij = jnp.arange(n_col) + mask = Cj[None,:] == Si[:, None] + CSi = jnp.broadcast_to(Ci[None, :], mask.shape)[mask] + + # CSi = jnp.where(Cj[None, :] == Si[:, None], Ci[None, :], 0).sum(1) + mask = (Cj[:, None] == Si[None, :]).any(0) + CSj = Sj[mask] + + if Sx.ndim > 1: # bug in JAX... see https://github.com/google/jax/issues/9050 + CSx = Sx[..., mask] + else: + CSx = Sx[mask] + + # CSj = jnp.where(mask, Sj, 0) + # CSx = jnp.where(mask, Sx, 0.0) + + I_CSi = jnp.concatenate([CSi, Ii], -1) + I_CSj = jnp.concatenate([CSj, Ij], -1) + I_CSx = jnp.concatenate([-CSx, Ix], -1) + + n_col, n_rhs = Cext.shape + n_lhs = jnp.prod(jnp.array(batch_shape, dtype=jnp.int32)) + Sx = Sx.reshape(n_lhs, -1) + I_CSx = I_CSx.reshape(n_lhs, -1) + + inv_I_CS_Cext = solve_klu(I_CSi, I_CSj, I_CSx, Cext) + S_inv_I_CS_Cext = mul_coo(Si, Sj, Sx, inv_I_CS_Cext) + + CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj] + + _, n, _ = CextT_S_inv_I_CS_Cext.shape + S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n) + + return S, port_map \ No newline at end of file diff --git a/sax/caching.py b/sax/caching.py new file mode 100644 index 0000000..1216326 --- /dev/null +++ b/sax/caching.py @@ -0,0 +1,50 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03_caching.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['cache', 'cache_clear'] + +# Cell +#nbdev_comment from __future__ import annotations + +import gc +from functools import _lru_cache_wrapper, lru_cache, partial, wraps +from typing import Callable, Optional + +# Internal Cell + +_cached_functions = [] + +# Cell + +def cache(func: Optional[Callable] = None, /, *, maxsize: Optional[int] = None) -> Callable: + """cache a function""" + if func is None: + return partial(cache, maxsize=maxsize) + + cached_func = lru_cache(maxsize=maxsize)(func) + + @wraps(func) + def new_func(*args, **kwargs): + return cached_func(*args, **kwargs) + + new_func.cache_clear = cached_func.cache_clear + + _cached_functions.append(new_func) + + return new_func + +# Cell +def cache_clear(*, force: bool=False): + """clear all function caches""" + if not force: + for func in _cached_functions: + func.cache_clear() + else: + gc.collect() + funcs = [a for a in gc.get_objects() if isinstance(a, _lru_cache_wrapper)] + + for func in funcs: + func.cache_clear() \ No newline at end of file diff --git a/sax/circuit.py b/sax/circuit.py new file mode 100644 index 0000000..e9cefcf --- /dev/null +++ b/sax/circuit.py @@ -0,0 +1,194 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/08_circuit.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['circuit', 'circuit_from_netlist', 'circuit_from_yaml', 'circuit_from_gdsfactory'] + +# Cell +#nbdev_comment from __future__ import annotations + +from functools import partial +from typing import Dict, Optional, Tuple, Union, cast + +from .backends import circuit_backends +from .models import coupler, straight +from .multimode import multimode, singlemode +from .netlist import LogicalNetlist, Netlist, logical_netlist, netlist_from_yaml +from .typing_ import Instances, Model, Models, Netlist, Settings, SType, is_netlist +from .utils import _replace_kwargs, get_settings, merge_dicts, update_settings + +# Cell + +def circuit( + *, + instances: Instances, + connections: Dict[str, str], + ports: Dict[str, str], + models: Optional[Models] = None, + modes: Optional[Tuple[str, ...]] = None, + settings: Optional[Settings] = None, + backend: str = "default", + default_models=None, +) -> Model: + # assert valid circuit_backend + if backend not in circuit_backends: + raise KeyError( + f"circuit backend {backend} not found. Allowed circuit backends: " + f"{', '.join(circuit_backends.keys())}." + ) + + evaluate_circuit = circuit_backends[backend] + + _netlist, _settings, _models = logical_netlist( + instances=instances, + connections=connections, + ports=ports, + models=models, + settings=settings, + default_models=default_models, + ) + + for name in list(_models.keys()): + if is_netlist(_models[name]): + netlist_model = cast(LogicalNetlist, _models.pop(name)) + instance_model_names = set(netlist_model["instances"].values()) + instance_models = {k: _models[k] for k in instance_model_names} + netlist_func = circuit_from_netlist( + netlist=netlist_model, + models=instance_models, + backend=backend, + modes=modes, + settings=None, # settings are already integrated in netlist by now. + default_models=default_models, + ) + _models[name] = netlist_func + + if modes is not None: + maybe_multimode = partial(multimode, modes=modes) + connections = { + f"{p1}@{mode}": f"{p2}@{mode}" + for p1, p2 in _netlist["connections"].items() + for mode in modes + } + ports = { + f"{p1}@{mode}": f"{p2}@{mode}" + for p1, p2 in _netlist["ports"].items() + for mode in modes + } + else: + maybe_multimode = partial(singlemode, mode="te") + connections = _netlist["connections"] + ports = _netlist["ports"] + + def _circuit(**settings: Settings) -> SType: + settings = merge_dicts(_settings, settings) + global_settings = {} + for k in list(settings.keys()): + if k in _netlist["instances"]: + continue + global_settings[k] = settings.pop(k) + if global_settings: + settings = cast( + Dict[str, Settings], update_settings(settings, **global_settings) + ) + instances: Dict[str, SType] = {} + for name, model_name in _netlist["instances"].items(): + model = cast(Model, _models[model_name]) + instances[name] = cast( + SType, maybe_multimode(model(**settings.get(name, {}))) + ) + S = evaluate_circuit(instances, connections, ports) + return S + + settings = { + name: get_settings(cast(Model, _models[model])) + for name, model in _netlist["instances"].items() + } + settings = merge_dicts(settings, _settings) + _replace_kwargs(_circuit, **settings) + + return _circuit + +# Cell + +def circuit_from_netlist( + netlist: Union[LogicalNetlist, Netlist], + *, + models: Optional[Models] = None, + modes: Optional[Tuple[str, ...]] = None, + settings: Optional[Settings] = None, + backend: str = "default", + default_models=None, +) -> Model: + """create a circuit model function from a netlist """ + instances = netlist["instances"] + connections = netlist["connections"] + ports = netlist["ports"] + _circuit = circuit( + instances=instances, + connections=connections, + ports=ports, + models=models, + modes=modes, + settings=settings, + backend=backend, + default_models=default_models, + ) + return _circuit + +# Cell +def circuit_from_yaml( + yaml: str, + *, + models: Optional[Models] = None, + modes: Optional[Tuple[str, ...]] = None, + settings: Optional[Settings] = None, + backend: str = "default", + default_models=None, +) -> Model: + """Load a sax circuit from yaml definition + + Args: + yaml: the yaml string to load + models: a dictionary which maps component names to model functions + modes: the modes of the simulation (if not given, single mode + operation is assumed). + settings: override netlist instance settings. Use this setting to set + global settings like for example the wavelength 'wl'. + backend: "default" or "klu". How the circuit S-parameters are + calculated. "klu" is a CPU-only method which generally speaking is + much faster for large circuits but cannot be jitted or used for autograd. + """ + netlist, models = netlist_from_yaml(yaml=yaml, models=models, settings=settings) + circuit = circuit_from_netlist( + netlist=netlist, + models=models, + modes=modes, + settings=None, # settings are already integrated in the netlist by now + backend=backend, + default_models=default_models, + ) + return circuit + +# Cell +def circuit_from_gdsfactory( + component, + *, + models: Optional[Models] = None, + modes: Optional[Tuple[str, ...]] = None, + settings: Optional[Settings] = None, + backend: str = "default", + default_models=None, +) -> Model: + """Load a sax circuit from a GDSFactory component""" + circuit = circuit_from_netlist( + component.get_netlist(), + models=models, + modes=modes, + settings=settings, + backend=backend, + default_models=default_models, + ) + return circuit \ No newline at end of file diff --git a/sax/constants.py b/sax/constants.py deleted file mode 100644 index c9a5a5f..0000000 --- a/sax/constants.py +++ /dev/null @@ -1,5 +0,0 @@ -""" A collection of useful constants for SAX simulations. """ - -from math import pi - -c = 299792458.0 diff --git a/sax/core.py b/sax/core.py deleted file mode 100644 index 6ef8a7a..0000000 --- a/sax/core.py +++ /dev/null @@ -1,386 +0,0 @@ -""" SAX core """ - -import functools - -import jax -import jax.numpy as jnp - -from .utils import zero, rename_ports, get_ports, validate_params, copy_params -from .typing import Optional, Callable, Tuple, Dict, ParamsDict, ModelDict, ModelFunc, ComplexFloat - - -def modelgenerator( - ports: Tuple[str, ...], default_params: Optional[ParamsDict] = None, reciprocal: bool = True -) -> Callable: - """function decorator to easily generate a model dictionary - - Args: - ports: the port names of the model (port combination tuples will be the - keys of the model dictionary) - default_params: the dictionary containing the default model parameters. - reciprocal: whether the model is reciprocal or not, i.e. whether - model(i, j) == model(j, i). If a model is reciprocal, the decorated - model function only needs to be defined for i <= j. - - Returns: - a decorator acting on a function-generating function. - """ - ports = tuple(p for p in ports) - num_ports = len(ports) - - def modeldecorator(modelgenerator: Callable) -> ModelDict: - """generator a model dictionary from a function-generating function. - - Args: - modelgenerator: the function-generating function taking two integer - indices as arguments: (i, j). modelgenerator(i, j) needs to return - a function taking a single dictionary argument: the parameters of the - function. modelgenerator(i, j) only needs to be defined for the nonzero - elements. - - Returns: - the model dictionary, for which each of the nonzero - port-combinations is mapped to its corresponding model function. - """ - m: ModelDict = {} - m["default_params"] = {} if default_params is None else copy_params(default_params) - for j in range(num_ports): - for i in range(j + 1): - func = modelgenerator(i, j) - if func is not None: - m[ports[i], ports[j]] = func - if not reciprocal: - func = modelgenerator(j, i) - if func is not None: - m[ports[j], ports[i]] = func - return m - - return modeldecorator - - -def circuit( - models: Dict[str, ModelDict], connections: Dict[str, str], ports: Dict[str, str] -) -> ModelDict: - """create a (sub)circuit model from a collection of models and connections - - Args: - models: a dictionary with as keys the model names and values - the model dictionaries. - connections: a dictionary where both keys and values are strings of the - form "modelname:portname" - ports: a dictionary mapping portnames of the form - "modelname:portname" to new unique portnames - - Returns: - the circuit model dictionary with the given port names. - - Example: - A simple mzi can be created as follows:: - - mzi = circuit( - models = { - "left": model_directional_coupler, - "top": model_waveguide, - "bottom": model_waveguide, - "right": model_directional_coupler, - }, - connections={ - "left:p2": "top:in", - "left:p1": "bottom:in", - "top:out": "right:p3", - "bottom:out": "right:p0", - }, - ports={ - "left:p3": "in2", - "left:p0": "in1", - "right:p2": "out2", - "right:p1": "out1", - }, - ) - """ - - models, connections, ports = _validate_circuit_parameters( - models, connections, ports - ) - - for name, model in models.items(): - models[name] = rename_ports(model, {p: f"{name}:{p}" for p in get_ports(model)}) - validate_params(models[name].get("default_params", {})) - modelnames = [[name] for name in models] - - while len(modelnames) > 1: - for names1, names2 in zip(modelnames[::2], modelnames[1::2]): - model1 = models.pop(names1[0]) - model2 = models.pop(names2[0]) - model = _combine_models( - model1, - model2, - None if len(names1) > 1 else names1[0], - None if len(names2) > 1 else names2[0], - ) - names1.extend(names2) - for port1, port2 in [(k, v) for k, v in connections.items()]: - n1, p1 = port1.split(":") - n2, p2 = port2.split(":") - if n1 in names1 and n2 in names1: - del connections[port1] - model = _interconnect_model(model, port1, port2) - models[names1[0]] = model - modelnames = list(reversed(modelnames[::2])) - - model = rename_ports(model, ports) - - return model - - -def _validate_circuit_parameters( - models: Dict[str, ModelDict], connections: Dict[str, str], ports: Dict[str, str] -) -> Tuple[Dict[str, ModelDict], Dict[str, str], Dict[str, str]]: - """validate the netlist parameters of a circuit - - Args: - models: a dictionary with as keys the model names and values - the model dictionaries. - connections: a dictionary where both keys and values are strings of the - form "modelname:portname" - ports: a dictionary mapping portnames of the form - "modelname:portname" to new unique portnames - - Returns: - the validated and possibly slightly modified models, connections and - ports dictionaries. - """ - - all_ports = set() - for name, model in models.items(): - _validate_model_dict(name, model) - for port in get_ports(model): - all_ports.add(f"{name}:{port}") - - if not isinstance(connections, dict): - msg = f"Connections should be a str:str dict or a list of length-2 tuples." - assert all(len(conn) == 2 for conn in connections), msg - connections, _connections = {}, connections - connection_ports = set() - for conn in _connections: - connections[conn[0]] = conn[1] - for port in conn: - msg = f"Duplicate port found in connections: '{port}'" - assert port not in connection_ports, msg - connection_ports.add(port) - - connection_ports = set() - for connection in connections.items(): - for port in connection: - if port in all_ports: - all_ports.remove(port) - msg = f"Connection ports should all be strings. Got: '{port}'" - assert isinstance(port, str), msg - msg = f"Connection ports should have format 'modelname:port'. Got: '{port}'" - assert len(port.split(":")) == 2, msg - name, _port = port.split(":") - msg = f"Model '{name}' used in connection " - msg += f"'{connection[0]}':'{connection[1]}', " - msg += f"but '{name}' not found in models dictionary." - assert name in models, msg - msg = f"Port name '{_port}' not found in model '{name}'. " - msg += f"Allowed ports for '{name}': {get_ports(models[name])}" - assert _port in get_ports(models[name]), msg - msg = f"Duplicate port found in connections: '{port}'" - assert port not in connection_ports, msg - connection_ports.add(port) - - output_ports = set() - for port, output_port in ports.items(): - if port in all_ports: - all_ports.remove(port) - msg = f"Ports keys in 'ports' should all be strings. Got: '{port}'" - assert isinstance(port, str), msg - msg = f"Port values in 'ports' should all be strings. Got: '{output_port}'" - assert isinstance(output_port, str), msg - msg = f"Port keys in 'ports' should have format 'model:port'. Got: '{port}'" - assert len(port.split(":")) == 2, msg - msg = f"Port values in 'ports' shouldn't contain a ':'. Got: '{output_port}'" - assert ":" not in output_port, msg - msg = f"Duplicate port found in ports or connections: '{port}'" - assert port not in connection_ports, msg - name, _port = port.split(":") - msg = f"Model '{name}' used in output port " - msg += f"'{port}':'{output_port}', " - msg += f"but '{name}' not found in models dictionary." - assert name in models, msg - msg = f"Port name '{_port}' not found in model '{name}'. " - msg += f"Allowed ports for '{name}': {get_ports(models[name])}" - assert _port in get_ports(models[name]), msg - connection_ports.add(port) - msg = f"Duplicate port found in output ports: '{output_port}'" - assert output_port not in output_ports, msg - output_ports.add(output_port) - - assert not all_ports, f"Unused ports found: {all_ports}" - - return models, connections, ports - - -def _validate_model_dict(name: str, model: ModelDict): - assert isinstance(model, dict), f"Model '{model}' should be a dictionary" - ports = get_ports(model) - assert ports, f"No ports in model {name}" - for p1 in ports: - for p2 in ports: - msg = ( - f"model {name} port combination {p1}->{p2} is no function or callable." - ) - assert callable(model.get((p1, p2), zero)), msg - - -def _namedparamsfunc(func: Callable, name: str, params: ParamsDict) -> ComplexFloat: - """make a model function look for its model name before acting on the parameters - - Args: - func: the original model function acting on a dictionary of parameters - name: the name of the model - params: a dictionary for which the keys are the model names and the - values are the model parameter dictionaries. - """ - return func(params[name]) - - -def _combine_models( - model1: ModelDict, - model2: ModelDict, - name1: Optional[str] = None, - name2: Optional[str] = None, -) -> ModelDict: - """Combine two models into a combined model (without connecting any ports) - - Args: - model1: the first model dictionary to combine - model2: the second model dictionary to combine - name1: the name of the first model (can be None for unnamed models) - name2: the name of the second model (can be None for unnamed models) - """ - model: ModelDict = {} - model["default_params"] = {} - for _model, _name in [(model1, name1), (model2, name2)]: - for key, value in _model.items(): - if isinstance(key, str): - if key != "default_params": - model[key] = value - else: - p1, p2 = key - if value is zero or _name is None: - model[p1, p2] = value - else: - model[p1, p2] = _partialmodelfunc(_namedparamsfunc, value, _name) - if _name is None: - model["default_params"].update(_model["default_params"]) - else: - model["default_params"][_name] = copy_params(_model["default_params"]) - return model - - -def _interconnect_model(model: ModelDict, k: str, l: str) -> ModelDict: - """interconnect two ports in a given model - - Args: - model: the component for which to interconnect the given ports - k: the first port name to connect - l: the second port name to connect - - Returns: - the resulting interconnected component, i.e. a component with two ports - less than the original component. - - Note: - The interconnect algorithm is based on equation 6 in the paper below:: - - Filipsson, Gunnar. "A new general computer algorithm for S-matrix calculation - of interconnected multiports." 11th European Microwave Conference. IEEE, 1981. - """ - new_model: ModelDict = {} - new_model["default_params"] = copy_params(model["default_params"]) - ports = get_ports(model) - for i in ports: - for j in ports: - mij = model.get((i, j), zero) - mik = model.get((i, k), zero) - mil = model.get((i, l), zero) - mkj = model.get((k, j), zero) - mkk = model.get((k, k), zero) - mkl = model.get((k, l), zero) - mlj = model.get((l, j), zero) - mlk = model.get((l, k), zero) - mll = model.get((l, l), zero) - if ( - (mij is zero) - and ((mkj is zero) or (mil is zero)) - and ((mlj is zero) or (mik is zero)) - and ((mkj is zero) or (mll is zero) or (mik is zero)) - and ((mlj is zero) or (mkk is zero) or (mil is zero)) - ): - continue - new_model[i, j] = _partialmodelfunc( - _model_ijkl, mij, mik, mil, mkj, mkk, mkl, mlj, mlk, mll - ) - for key in list(new_model.keys()): - if isinstance(key, str): - continue - i, j = key - if i == k or i == l or j == k or j == l: - del new_model[i, j] - return new_model - - -def _model_ijkl( - mij: ModelFunc, - mik: ModelFunc, - mil: ModelFunc, - mkj: ModelFunc, - mkk: ModelFunc, - mkl: ModelFunc, - mlj: ModelFunc, - mlk: ModelFunc, - mll: ModelFunc, - params: ParamsDict, -) -> ComplexFloat: - """combine the given model functions. - - Note: - The interconnect algorithm is based on equation 6 in the paper below:: - - Filipsson, Gunnar. "A new general computer algorithm for S-matrix calculation - of interconnected multiports." 11th European Microwave Conference. IEEE, 1981. - """ - vij = mij(params) - vik = mik(params) - vil = mil(params) - vkj = mkj(params) - vkk = mkk(params) - vkl = mkl(params) - vlj = mlj(params) - vlk = mlk(params) - vll = mll(params) - return vij + ( - vkj * vil * (1 - vlk) - + vlj * vik * (1 - vkl) - + vkj * vll * vik - + vlj * vkk * vil - ) / ((1 - vkl) * (1 - vlk) - vkk * vll) - - -class _partialmodelfunc(functools.partial): - """fun(params) - - Args: - params: parameter dictionary for the model. - - Returns: - Model transmission. - """ - - def __repr__(self): - func = self.func - while hasattr(func, "func"): - func = func.func - return repr(func) diff --git a/sax/models.py b/sax/models.py new file mode 100644 index 0000000..903ff18 --- /dev/null +++ b/sax/models.py @@ -0,0 +1,258 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05_models.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['straight', 'coupler', 'unitary', 'passthru', 'copier', 'passthru', 'get_models', 'models'] + +# Cell +#nbdev_comment from __future__ import annotations + +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +import sax +from .typing_ import Model, SCoo, SDict +from .utils import get_inputs_outputs, reciprocal + +# Cell + +def straight( + *, + wl: float = 1.55, + wl0: float = 1.55, + neff: float = 2.34, + ng: float = 3.4, + length: float = 10.0, + loss: float = 0.0 +) -> SDict: + """a simple straight waveguide model""" + dwl = wl - wl0 + dneff_dwl = (ng - neff) / wl0 + neff = neff - dwl * dneff_dwl + phase = 2 * jnp.pi * neff * length / wl + amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex) + transmission = amplitude * jnp.exp(1j * phase) + sdict = reciprocal( + { + ("in0", "out0"): transmission, + } + ) + return sdict + +# Cell + +def coupler(*, coupling: float = 0.5) -> SDict: + """a simple coupler model""" + kappa = coupling ** 0.5 + tau = (1 - coupling) ** 0.5 + sdict = reciprocal( + { + ("in0", "out0"): tau, + ("in0", "out1"): 1j * kappa, + ("in1", "out0"): 1j * kappa, + ("in1", "out1"): tau, + } + ) + return sdict + +# Internal Cell + +def _validate_ports(ports, num_inputs, num_outputs, diagonal) -> Tuple[Tuple[str,...], Tuple[str,...], int, int]: + if ports is None: + if num_inputs is None or num_outputs is None: + raise ValueError( + "if not ports given, you must specify how many input ports " + "and how many output ports a model has." + ) + input_ports = [f"in{i}" for i in range(num_inputs)] + output_ports = [f"out{i}" for i in range(num_outputs)] + else: + if num_inputs is not None: + if num_outputs is None: + raise ValueError( + "if num_inputs is given, num_outputs should be given as well." + ) + if num_outputs is not None: + if num_inputs is None: + raise ValueError( + "if num_outputs is given, num_inputs should be given as well." + ) + if num_inputs is not None and num_outputs is not None: + if num_inputs + num_outputs != len(ports): + raise ValueError("num_inputs + num_outputs != len(ports)") + input_ports = ports[:num_inputs] + output_ports = ports[num_inputs:] + else: + input_ports, output_ports = get_inputs_outputs(ports) + num_inputs = len(input_ports) + num_outputs = len(output_ports) + + if diagonal: + if num_inputs != num_outputs: + raise ValueError( + "Can only have a diagonal passthru if number of input ports equals the number of output ports!" + ) + return input_ports, output_ports, num_inputs, num_outputs + +# Cell + +@sax.cache +def unitary( + num_inputs: Optional[int] = None, + num_outputs: Optional[int] = None, + ports: Optional[Tuple[str, ...]] = None, + *, + jit=True, + reciprocal=True, + diagonal=False, +) -> Model: + input_ports, output_ports, num_inputs, num_outputs = _validate_ports(ports, num_inputs, num_outputs, diagonal) + assert num_inputs is not None and num_outputs is not None + + # let's create the squared S-matrix: + N = max(num_inputs, num_outputs) + S = jnp.zeros((2*N, 2*N), dtype=float) + + if not diagonal: + S = S.at[:N, N:].set(1) + else: + r = jnp.arange(N, dtype=int) # reciprocal only works if num_inputs == num_outputs! + S = S.at[r, N+r].set(1) + + if reciprocal: + if not diagonal: + S = S.at[N:, :N].set(1) + else: + r = jnp.arange(N, dtype=int) # reciprocal only works if num_inputs == num_outputs! + S = S.at[N+r, r].set(1) + + # Now we need to normalize the squared S-matrix + U, s, V = jnp.linalg.svd(S, full_matrices=False) + S = jnp.sqrt(U@jnp.diag(jnp.where(s > 1e-12, 1, 0))@V) + + # Now create subset of this matrix we're interested in: + r = jnp.concatenate([jnp.arange(num_inputs, dtype=int), N+jnp.arange(num_outputs, dtype=int)], 0) + S = S[r, :][:, r] + + # let's convert it in SCOO format: + Si, Sj = jnp.where(S > 1e-6) + Sx = S[Si, Sj] + + # the last missing piece is a port map: + pm = { + **{p: i for i, p in enumerate(input_ports)}, + **{p: i + num_inputs for i, p in enumerate(output_ports)}, + } + + def func(wl: float = 1.5) -> SCoo: + wl_ = jnp.asarray(wl) + Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape)) + return Si, Sj, Sx_, pm + + func.__name__ = f"unitary_{num_inputs}_{num_outputs}" + func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}" + if jit: + return jax.jit(func) + return func + +# Cell +@sax.cache +def passthru( + num_links: Optional[int] = None, + ports: Optional[Tuple[str, ...]] = None, + *, + jit=True, + reciprocal=True, +) -> Model: + passthru = unitary(num_links, num_links, ports, jit=jit, reciprocal=reciprocal, diagonal=True) + passthru.__name__ = f"passthru_{num_links}_{num_links}" + passthru.__qualname__ = f"passthru_{num_links}_{num_links}" + if jit: + return jax.jit(passthru) + return passthru + +# Cell + +@sax.cache +def copier( + num_inputs: Optional[int] = None, + num_outputs: Optional[int] = None, + ports: Optional[Tuple[str, ...]] = None, + *, + jit=True, + reciprocal=True, + diagonal=False, +) -> Model: + input_ports, output_ports, num_inputs, num_outputs = _validate_ports(ports, num_inputs, num_outputs, diagonal) + assert num_inputs is not None and num_outputs is not None + + # let's create the squared S-matrix: + S = jnp.zeros((num_inputs+num_outputs, num_inputs+num_outputs), dtype=float) + + if not diagonal: + S = S.at[:num_inputs, num_inputs:].set(1) + else: + r = jnp.arange(num_inputs, dtype=int) # == range(num_outputs) # reciprocal only works if num_inputs == num_outputs! + S = S.at[r, num_inputs+r].set(1) + + if reciprocal: + if not diagonal: + S = S.at[num_inputs:, :num_inputs].set(1) + else: + r = jnp.arange(num_inputs, dtype=int) # == range(num_outputs) # reciprocal only works if num_inputs == num_outputs! + S = S.at[num_inputs+r, r].set(1) + + # let's convert it in SCOO format: + Si, Sj = jnp.where(S > 1e-6) + Sx = S[Si, Sj] + + # the last missing piece is a port map: + pm = { + **{p: i for i, p in enumerate(input_ports)}, + **{p: i + num_inputs for i, p in enumerate(output_ports)}, + } + + def func(wl: float = 1.5) -> SCoo: + wl_ = jnp.asarray(wl) + Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape)) + return Si, Sj, Sx_, pm + + func.__name__ = f"unitary_{num_inputs}_{num_outputs}" + func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}" + if jit: + return jax.jit(func) + return func + +# Cell +@sax.cache +def passthru( + num_links: Optional[int] = None, + ports: Optional[Tuple[str, ...]] = None, + *, + jit=True, + reciprocal=True, +) -> Model: + passthru = unitary(num_links, num_links, ports, jit=jit, reciprocal=reciprocal, diagonal=True) + passthru.__name__ = f"passthru_{num_links}_{num_links}" + passthru.__qualname__ = f"passthru_{num_links}_{num_links}" + if jit: + return jax.jit(passthru) + return passthru + +# Cell + +models = { + "copier": copier, + "coupler": coupler, + "passthru": passthru, + "straight": straight, + "unitary": unitary, +} + +def get_models(copy: bool=True): + if copy: + return {**models} + return models \ No newline at end of file diff --git a/sax/models/__init__.py b/sax/models/__init__.py deleted file mode 100644 index 8479c78..0000000 --- a/sax/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import pic -from . import thinfilm diff --git a/sax/models/pic.py b/sax/models/pic.py deleted file mode 100644 index d5e7b13..0000000 --- a/sax/models/pic.py +++ /dev/null @@ -1,56 +0,0 @@ -""" SAX Photonic Integrated Circuit models """ - -import jax.numpy as jnp -from ..utils import zero -from ..core import modelgenerator -from ..typing import Dict, ModelDict, ComplexFloat - - -######################### -## Waveguides ## -######################### - -def model_waveguide_transmission(params: Dict[str, float]) -> ComplexFloat: - neff = params["neff"] - dwl = params["wl"] - params["wl0"] - dneff_dwl = (params["ng"] - params["neff"]) / params["wl0"] - neff = neff - dwl * dneff_dwl - phase = jnp.exp( - jnp.log(2 * jnp.pi * neff * params["length"]) - jnp.log(params["wl"]) - ) - return 10 ** (-params["loss"] * params["length"] / 20) * jnp.exp(1j * phase) - -waveguide: ModelDict = { - ("in", "out"): model_waveguide_transmission, - ("out", "in"): model_waveguide_transmission, - "default_params": { - "length": 25e-6, - "wl": 1.55e-6, - "wl0": 1.55e-6, - "neff": 2.34, - "ng": 3.4, - "loss": 0.0, - }, -} - -######################### -## Directional coupler ## -######################### - -def model_directional_coupler_coupling(params: Dict[str, float]) -> ComplexFloat: - return 1j * params["coupling"] ** 0.5 - -def model_directional_coupler_transmission(params: Dict[str, float]) -> ComplexFloat: - return (1 - params["coupling"]) ** 0.5 - -directional_coupler: ModelDict = { - ("p0", "p1"): model_directional_coupler_transmission, - ("p1", "p0"): model_directional_coupler_transmission, - ("p2", "p3"): model_directional_coupler_transmission, - ("p3", "p2"): model_directional_coupler_transmission, - ("p0", "p2"): model_directional_coupler_coupling, - ("p2", "p0"): model_directional_coupler_coupling, - ("p1", "p3"): model_directional_coupler_coupling, - ("p3", "p1"): model_directional_coupler_coupling, - "default_params": {"coupling": 0.5}, -} \ No newline at end of file diff --git a/sax/models/thinfilm.py b/sax/models/thinfilm.py deleted file mode 100644 index 966ff68..0000000 --- a/sax/models/thinfilm.py +++ /dev/null @@ -1,92 +0,0 @@ -""" SAX thin-film models """ - -import jax.numpy as jnp -from ..utils import zero -from ..core import modelgenerator -from ..typing import Dict, ModelDict, ComplexFloat - - -####################### -## Fresnel interface ## -####################### - -def r_fresnel_ij(params: Dict[str, float]) -> ComplexFloat: - """ - Normal incidence amplitude reflection from Fresnel's equations - ni : refractive index of the initial medium - nj : refractive index of the final medium - """ - return (params["ni"] - params["nj"]) / (params["ni"] + params["nj"]) - -def t_fresnel_ij(params: Dict[str, float]) -> ComplexFloat: - """ - Normal incidence amplitude transmission from Fresnel's equations - ni : refractive index of the initial medium - nj : refractive index of the final medium - """ - return 2 * params["ni"] / (params["ni"] + params["nj"]) - -fresnel_mirror_ij = { - ("in", "in"): r_fresnel_ij, - ("in", "out"): t_fresnel_ij, - ("out", "in"): lambda params: (1 - r_fresnel_ij(params)**2)/t_fresnel_ij(params), # t_ji, - ("out", "out"): lambda params: -1*r_fresnel_ij(params), # r_ji, - "default_params": { - "ni": 1., - "nj": 1., - } -} - -################# -## Propagation ## -################# - -def prop_i(params: Dict[str, float]) -> ComplexFloat: - """ - Phase shift acquired as a wave propagates through medium i - wl : wavelength (arb. units) - ni : refractive index of medium (at wavelength wl) - di : thickness of layer (same arb. unit as wl) - """ - return jnp.exp(1j * 2*jnp.pi * params["ni"] / params["wl"] * params["di"]) - -propagation_i = { - ("in", "out"): prop_i, - ("out", "in"): prop_i, - "default_params": { - "ni": 1., - "di": 500., - "wl": 532., - } -} - -################################# -## Lossless reciprocal element ## -################################# - -def t_complex(params: Dict[str, float]) -> ComplexFloat: - """ - Transmission coefficient (design parameter) - """ - return params['t_amp']*jnp.exp(-1j*params['t_ang']) - -def r_complex(params: Dict[str, float]) -> ComplexFloat: - """ - Reflection coefficient, derived from transmission coefficient - Magnitude from |t|^2 + |r|^2 = 1 - Phase from phase(t) - phase(r) = pi/2 - """ - r_amp = jnp.sqrt( ( 1. - params['t_amp']**2 ) ) - r_ang = params['t_ang'] - jnp.pi/2 - return r_amp*jnp.exp(-1j*r_ang) - -mirror = { - ("in", "in"): r_complex, - ("in", "out"): t_complex, - ("out", "in"): t_complex, - ("out", "out"): r_complex, - "default_params": { - "t_amp": jnp.sqrt(0.5), - "t_ang": 0.0, - } -} \ No newline at end of file diff --git a/sax/multimode.py b/sax/multimode.py new file mode 100644 index 0000000..2987f8a --- /dev/null +++ b/sax/multimode.py @@ -0,0 +1,226 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/04_multimode.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['multimode', 'singlemode'] + +# Cell +#nbdev_comment from __future__ import annotations + +from functools import wraps +from typing import Dict, Tuple, Union, cast, overload + +import jax.numpy as jnp +from .typing_ import ( + Model, + SCoo, + SDense, + SDict, + SType, + is_model, + is_multimode, + is_scoo, + is_sdense, + is_sdict, + is_singlemode, +) +from .utils import ( + block_diag, + mode_combinations, + validate_multimode, + validate_not_mixedmode, +) + +# Internal Cell + +@overload +def multimode(S: Model, modes: Tuple[str, ...] = ("te", "tm")) -> Model: + ... + + +@overload +def multimode(S: SDict, modes: Tuple[str, ...] = ("te", "tm")) -> SDict: + ... + + +@overload +def multimode(S: SCoo, modes: Tuple[str, ...] = ("te", "tm")) -> SCoo: + ... + + +@overload +def multimode(S: SDense, modes: Tuple[str, ...] = ("te", "tm")) -> SDense: + ... + +# Cell + +def multimode( + S: Union[SType, Model], modes: Tuple[str, ...] = ("te", "tm") +) -> Union[SType, Model]: + """Convert a single mode model to a multimode model""" + if is_model(S): + model = cast(Model, S) + + @wraps(model) + def new_model(**params): + return multimode(model(**params), modes=modes) + + return cast(Model, new_model) + + S = cast(SType, S) + + validate_not_mixedmode(S) + if is_multimode(S): + validate_multimode(S, modes=modes) + return S + + if is_sdict(S): + return _multimode_sdict(cast(SDict, S), modes=modes) + elif is_scoo(S): + return _multimode_scoo(cast(SCoo, S), modes=modes) + elif is_sdense(S): + return _multimode_sdense(cast(SDense, S), modes=modes) + else: + raise ValueError("cannot convert to multimode. Unknown stype.") + + +def _multimode_sdict(sdict: SDict, modes: Tuple[str, ...] = ("te", "tm")) -> SDict: + multimode_sdict = {} + _mode_combinations = mode_combinations(modes) + for (p1, p2), value in sdict.items(): + for (m1, m2) in _mode_combinations: + multimode_sdict[f"{p1}@{m1}", f"{p2}@{m2}"] = value + return multimode_sdict + + +def _multimode_scoo(scoo: SCoo, modes: Tuple[str, ...] = ("te", "tm")) -> SCoo: + + Si, Sj, Sx, port_map = scoo + num_ports = len(port_map) + mode_map = ( + {mode: i for i, mode in enumerate(modes)} + if not isinstance(modes, dict) + else cast(Dict, modes) + ) + + _mode_combinations = mode_combinations(modes) + + Si_m = jnp.concatenate( + [Si + mode_map[m] * num_ports for m, _ in _mode_combinations], -1 + ) + Sj_m = jnp.concatenate( + [Sj + mode_map[m] * num_ports for _, m in _mode_combinations], -1 + ) + Sx_m = jnp.concatenate([Sx for _ in _mode_combinations], -1) + port_map_m = { + f"{port}@{mode}": idx + mode_map[mode] * num_ports + for mode in modes + for port, idx in port_map.items() + } + + return Si_m, Sj_m, Sx_m, port_map_m + + +def _multimode_sdense(sdense, modes=("te", "tm")): + + Sx, port_map = sdense + num_ports = len(port_map) + mode_map = ( + {mode: i for i, mode in enumerate(modes)} + if not isinstance(modes, dict) + else modes + ) + + Sx_m = block_diag(*(Sx for _ in modes)) + + port_map_m = { + f"{port}@{mode}": idx + mode_map[mode] * num_ports + for mode in modes + for port, idx in port_map.items() + } + + return Sx_m, port_map_m + +# Internal Cell + +@overload +def singlemode(S: Model, mode: str = "te") -> Model: + ... + + +@overload +def singlemode(S: SDict, mode: str = "te") -> SDict: + ... + + +@overload +def singlemode(S: SCoo, mode: str = "te") -> SCoo: + ... + + +@overload +def singlemode(S: SDense, mode: str = "te") -> SDense: + ... + +# Cell + +def singlemode(S: Union[SType, Model], mode: str = "te") -> Union[SType, Model]: + """Convert multimode model to a singlemode model""" + if is_model(S): + model = cast(Model, S) + + @wraps(model) + def new_model(**params): + return singlemode(model(**params), mode=mode) + + return cast(Model, new_model) + + S = cast(SType, S) + + validate_not_mixedmode(S) + if is_singlemode(S): + return S + if is_sdict(S): + return _singlemode_sdict(cast(SDict, S), mode=mode) + elif is_scoo(S): + return _singlemode_scoo(cast(SCoo, S), mode=mode) + elif is_sdense(S): + return _singlemode_sdense(cast(SDense, S), mode=mode) + else: + raise ValueError("cannot convert to multimode. Unknown stype.") + + +def _singlemode_sdict(sdict: SDict, mode: str = "te") -> SDict: + singlemode_sdict = {} + for (p1, p2), value in sdict.items(): + if p1.endswith(f"@{mode}") and p2.endswith(f"@{mode}"): + p1, _ = p1.split("@") + p2, _ = p2.split("@") + singlemode_sdict[p1, p2] = value + return singlemode_sdict + + +def _singlemode_scoo(scoo: SCoo, mode: str = "te") -> SCoo: + Si, Sj, Sx, port_map = scoo + # no need to touch the data... + # just removing some ports from the port map should be enough + port_map = { + port.split("@")[0]: idx + for port, idx in port_map.items() + if port.endswith(f"@{mode}") + } + return Si, Sj, Sx, port_map + + +def _singlemode_sdense(sdense: SDense, mode: str = "te") -> SDense: + Sx, port_map = sdense + # no need to touch the data... + # just removing some ports from the port map should be enough + port_map = { + port.split("@")[0]: idx + for port, idx in port_map.items() + if port.endswith(f"@{mode}") + } + return Sx, port_map \ No newline at end of file diff --git a/sax/netlist.py b/sax/netlist.py new file mode 100644 index 0000000..eb00d8b --- /dev/null +++ b/sax/netlist.py @@ -0,0 +1,580 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/06_netlist.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['netlist', 'netlist_from_yaml', 'logical_netlist'] + +# Cell +#nbdev_comment from __future__ import annotations + +import os +import re +from functools import partial +from typing import Callable, Dict, Iterable, Optional, Tuple, Union, cast + +from flax.core import FrozenDict +from natsort import natsorted +from .models import models as default_sax_models +from .typing_ import ( + ComplexFloat, + Instance, + Instances, + LogicalNetlist, + Model, + ModelFactory, + Models, + Netlist, + Settings, + is_instance, + is_model_factory, + is_netlist, +) +from .utils import ( + clean_string, + copy_settings, + get_settings, + merge_dicts, + rename_params, + rename_ports, + try_float, +) +from yaml import Loader +from yaml import load as load_yaml + +# Internal Cell + +def _clean_component_names(models: Optional[Models]) -> Models: + if models is None: + models = {} + return {clean_string(comp): model for comp, model in models.items()} + + +def _clean_instance_names(instances: Optional[Instances]) -> Instances: + _instances = {} + if instances is None: + instances = {} + for name, inst in instances.items(): + if "," in name: + raise ValueError( + f"Instance name '{name}' is invalid. It contains the port separator ','." + ) + if ":" in name: + raise ValueError( + f"Instance name '{name}' is invalid. It contains the port slice symbol ':'." + ) + name = clean_string(name) + _instances[name] = inst + return _instances + + +def _component_from_callable(f: Callable, models: Models) -> Tuple[str, Models]: + _reverse_models = {id(model): name for name, model in models.items()} + if id(f) in _reverse_models: + component = _reverse_models[id(f)] + else: + component = _funcname(f) + models[component] = f + return component, models + + +def _funcname(p: Callable) -> str: + name = "" + f: Callable = p + while isinstance(f, partial): + name = "{name}p_" + if f.args: + try: + name = f"{name[:-1]}{hash(f.args)}_" + except TypeError: + raise TypeError( + "when using partials as SAX models, positional arguments of the partial should be hashable." + ) + f = f.func + return f"{name}{f.__name__}_{id(f)}" + + +def _instance_from_callable(f: Callable, models: Models) -> Tuple[Instance, Models]: + f, settings = _maybe_parse_partial(f) + component, models = _component_from_callable(f, models) + instance = Instance(component=component, settings=settings) + return instance, models + + +def _instance_from_instance( + name: str, + instance: Instance, + models: Models, + override_settings: Settings, + global_settings: Dict[str, ComplexFloat], + default_models=None, +) -> Tuple[Instance, Models]: + default_models = default_sax_models if default_models is None else default_models + component = clean_string(instance["component"]) + if component not in models: + if component not in default_models: + raise ValueError( + f"Error constructing netlist. Component '{component}' not found." + ) + model = default_models[component] + else: + model = models[component] + if isinstance(model, str): + if model not in default_models: + raise ValueError( + f"Error constructing netlist. Component '{model}' not found." + ) + model = default_models[model] + if not callable(model): # Model or ModelFactory + raise ValueError( + f"Error constructing netlist. Model for component '{component}' is not callable." + ) + # fmt: off + _default_settings = get_settings(model) + _instance_settings = {k: v for k, v in instance.get('settings', {}).items() if k in _default_settings} + _override_settings = cast(Dict[str, ComplexFloat], override_settings.get(name, {})) + _override_settings = {k: v for k, v in _override_settings.items() if k in _default_settings} + _global_settings = {k: v for k, v in global_settings.items() if k in _default_settings} + settings = merge_dicts(_default_settings, _instance_settings, _override_settings, _global_settings) + # fmt: on + + instance = Instance(component=component, settings=settings) + models[instance["component"]] = model + return instance, models + + +def _instance_from_string( + s: str, models: Models, default_models=None +) -> Tuple[Instance, Models]: + default_models = default_sax_models if default_models is None else default_models + if s not in models: + if s not in default_models: + raise ValueError(f"Error constructing netlist. Component '{s}' not found.") + models[s] = default_models[s] + instance = Instance(component=s, settings={}) + return instance, models + + +def _maybe_parse_partial(p: Callable) -> Tuple[Callable, Dict[str, ComplexFloat]]: + _settings = {} + while isinstance(p, partial): + _settings = merge_dicts(_settings, p.keywords) + if not p.args: + p = p.func + else: + p = partial(p.func, *p.args) + break + return p, _settings + + +def _model_operations( + models: Optional[Models], + ops: Optional[Dict[str, Callable]] = None, + default_models=None, +) -> Models: + default_models = default_sax_models if default_models is None else default_models + if ops is None: + ops = {} + if models is None: + models = {} + + _models = {} + for component, model in models.items(): + if isinstance(model, str): + if not model in default_models: + raise ValueError(f"Could not find model {model}.") + model = default_models[model] + + _models[component] = model + + if not isinstance(model, dict): + continue + + if is_netlist(model): + continue # TODO: This case should actually be handled... + + model = {**model} + + if "model" not in model: + raise ValueError( + "Invalid model dict for '{component}'. Key 'model' not found." + ) + + if isinstance(model["model"], str): + if not model["model"] in default_models: + raise ValueError(f"Could not find model {model['model']}.") + model["model"] = default_models[cast(str, model["model"])] + + for op_name, op_func in ops.items(): + assert isinstance(model, dict) + if op_name not in model: + continue + op_args = model[op_name] + model["model"] = op_func(model["model"], op_args) + + _models[component] = model["model"] + return _models + + +def _split_global_settings( + settings: Optional[dict], instance_names: Iterable[str] +) -> Tuple[Settings, Dict[str, ComplexFloat]]: + if settings: + override_settings = cast(Dict[str, Settings], copy_settings(settings)) + global_settings: Dict[str, ComplexFloat] = {} + for k in list(override_settings.keys()): + if k in instance_names: + continue + global_settings[k] = cast(ComplexFloat, try_float(override_settings.pop(k))) + else: + override_settings: Dict[str, Settings] = {} + global_settings: Dict[str, ComplexFloat] = {} + return override_settings, global_settings + + +def _enumerate_portrange(s): + if not ":" in s: + return [s] + idx1, idx2 = s.split(":") + s1 = re.sub("[0-9]*", "", idx1) + idx1 = int(re.sub("[^0-9]*", "", idx1)) + s2 = re.sub("[0-9]*", "", idx2) + idx2 = int(re.sub("[^0-9]*", "", idx2)) + if s1 != s2 and s2 != "": + raise ValueError( + "Cannot enumerate portrange {s}, string portion of port differs." + ) + return [f"{s1}{i}" for i in range(idx1, idx2)] + + +def _validate_connections(connections): + # todo: check if instance names are available in instances + # todo: check if instance ports are used in output ports + _ports = set() + old_connections, connections = connections, {} + for conn1, conn2 in old_connections.items(): + if conn1.count(",") != 1 or conn2.count(",") != 1: + raise ValueError( + "Connections ports should have format '{instance_name},{port}'. " + f"Got '{conn1}'." + ) + name1, port1 = conn1.split(",") + name2, port2 = conn2.split(",") + ports1 = _enumerate_portrange(port1) + ports2 = _enumerate_portrange(port2) + + if len(ports1) != len(ports2): + if len(ports1) == 1: + ports1 = [ports1[0] for _ in ports2] + elif len(ports2) == 1: + ports2 = [ports2[0] for _ in ports1] + else: + raise ValueError( + f"Cannot enumerate connection {conn1} -> {conn2}, slice lengths on both sides differ." + ) + + for port1, port2 in zip(ports1, ports2): + name1 = clean_string(name1) + name2 = clean_string(name2) + port1 = clean_string(port1) + port2 = clean_string(port2) + (name1, port1), (name2, port2) = natsorted([(name1, port1), (name2, port2)]) + conn1 = f"{name1},{port1}" + conn2 = f"{name2},{port2}" + if conn1 in _ports: + raise ValueError(f"duplicate connection port: '{conn1}'") + if conn2 in _ports: + raise ValueError(f"duplicate connection port: '{conn2}'") + connections[conn1] = conn2 + _ports.add(conn1) + _ports.add(conn2) + + return dict(natsorted([natsorted([k, v]) for k, v in connections.items()])) + + +def _validate_ports(ports): + # todo: check if instance names are available in instances + # todo: check if instance ports are used in connections + _ports = set() + old_ports, ports = ports, {} + for port1, conn2 in old_ports.items(): + + if port1.count(",") == 1: + if conn2.count(",") != 0: + raise ValueError( + "Netlist output port '{conn2}' should not contain any ','." + ) + port1, conn2 = conn2, port1 + elif conn2.count(",") == 1: + if port1.count(",") != 0: + raise ValueError( + "Netlist output port '{port1}' should not contain any ','." + ) + else: + raise ValueError( + "Netlist output ports should be mapped onto an instance port " + "using the format: '{output_port}': '{instance_name},{port}'. " + f"Got: '{port1}': '{conn2}.'" + ) + + name2, port2 = conn2.split(",") + ports1 = _enumerate_portrange(port1) + ports2 = _enumerate_portrange(port2) + + if len(ports1) != len(ports2): + if len(ports1) == 1: + ports1 = [ports1[0] for _ in ports2] + elif len(ports2) == 1: + ports2 = [ports2[0] for _ in ports1] + else: + raise ValueError( + f"Cannot enumerate output ports {port1} -> {conn2}, slice lengths on both sides differ." + ) + + for port1, port2 in zip(ports1, ports2): + port1 = clean_string(port1) # output_port + name2 = clean_string(name2) + port2 = clean_string(port2) + conn2 = f"{name2},{port2}" + if port1 in _ports: + raise ValueError(f"duplicate output port: '{port1}'") + if conn2 in _ports: + raise ValueError(f"duplicate instance output port: '{conn2}'") + ports[port1] = conn2 + _ports.add(port1) + _ports.add(conn2) + return dict(natsorted(ports.items())) + +# Cell + +def netlist( + *, + instances: Instances, + connections: Dict[str, str], + ports: Dict[str, str], + models: Optional[Models] = None, + settings: Optional[Settings] = None, + default_models=None, +) -> Tuple[Netlist, Models]: + """Create a `Netlist` and `Models` dictionary""" + default_models = default_sax_models if default_models is None else default_models + models = _clean_component_names(models) + models = _model_operations( + models, + ops={ + "rename_params": rename_params, + "rename_ports": rename_ports, + }, + default_models=default_models, + ) + + instances = _clean_instance_names(instances) + override_settings, global_settings = _split_global_settings(settings, instances) + + _instances: Dict[str, Union[Instance, Netlist]] = {} + for name, instance in instances.items(): + if callable(instance): + instance, models = _instance_from_callable(instance, models) + elif isinstance(instance, str): + instance, models = _instance_from_string(instance, models, default_models) + if not isinstance(instance, dict): + raise ValueError( + f"invalid instance '{name}': expected str, dict or callable." + ) + if is_instance(instance): + instance, models = _instance_from_instance( + name=name, + instance=cast(Instance, instance), + models=models, + override_settings=override_settings, + global_settings=global_settings, + default_models=default_models, + ) + elif is_netlist(instance): + instance = cast(Netlist, instance) + instance, models = netlist( + instances=instance["instances"], + connections=instance["connections"], + ports=instance["ports"], + models=models, + settings=merge_dicts(global_settings, override_settings), + default_models=default_models, + ) + else: + raise ValueError( + f"Instance {name} cannot be interpreted as an Instance or a Netlist." + ) + _instances[name] = instance + + _instances = {k: _instances[k] for k in natsorted(_instances.keys())} + _connections = _validate_connections(connections) + _ports = _validate_ports(ports) + + _netlist = Netlist( + instances=cast(Instances, _instances), + connections=_connections, + ports=_ports, + ) + return _netlist, models + +# Cell + +def netlist_from_yaml( + yaml: str, + *, + models: Optional[Models] = None, + settings: Optional[Settings] = None, + default_models=None, +) -> Tuple[Netlist, Models]: + """Load a sax `Netlist` from yaml definition and a `Models` dictionary""" + + default_models = default_sax_models if default_models is None else default_models + ext = None + directory = None + yaml_path = os.path.abspath(os.path.expanduser(yaml)) + if os.path.isdir(yaml_path): + raise IsADirectoryError( + "Cannot read from yaml path '{yaml_path}'. Path is a directory." + ) + elif os.path.exists(yaml_path): + if ext is None: + _, *ext_list = os.path.basename(yaml_path).split(".") + ext = f".{'.'.join(ext_list)}" + if directory is None: + directory = os.path.dirname(yaml_path) + yaml = open(yaml_path, "r").read() + else: + yaml_path = None + + subnetlists = {} + if directory is not None and ext is not None: + subnetlists = { + re.sub(f"{ext}$", "", os.path.basename(file)): os.path.join(root, file) + for root, _, files in os.walk(os.path.abspath(directory)) + for file in files + if file.endswith(ext) + } + + raw_netlist = load_yaml(yaml, Loader) + + for section in ["instances", "connections", "ports"]: + if section not in raw_netlist: + raise ValueError(f"Can not load from yaml: '{section}' not found.") + + raw_instances = raw_netlist["instances"] + connections = raw_netlist["connections"] + ports = raw_netlist["ports"] + override_settings, global_settings = _split_global_settings(settings, raw_instances) + + instances = {} + for name, instance in raw_instances.items(): + if isinstance(instance, str): + instance = {"component": instance, "settings": {}} + elif isinstance(instance, dict): + if "component" not in instance: + raise ValueError( + f"Can not load from yaml: 'component' not found in instance '{name}'." + ) + component = instance["component"] + component = re.sub(r"\.ba$", "", component) # for compat with pb + if component in subnetlists: + _override_settings = cast(Settings, override_settings.get(name, {})) + _settings = merge_dicts(global_settings, _override_settings) + instance, models = netlist_from_yaml( + yaml=subnetlists[component], + models=models, + settings=_settings, + default_models=default_models, + ) + instances[name] = instance + + _netlist, _models = netlist( + instances=instances, + connections=connections, + ports=ports, + models=models, + settings=settings, + default_models=default_models, + ) + return _netlist, _models + +# Cell + +def logical_netlist( + *, + instances: Instances, + connections: Dict[str, str], + ports: Dict[str, str], + models: Optional[Models] = None, + settings: Optional[Settings] = None, + default_models=None, +) -> Tuple[LogicalNetlist, Settings, Models]: + """Create a `LogicalNetlist` with separated `Settings` and `Models` dictionary""" + + default_models = default_sax_models if default_models is None else default_models + _netlist, models = netlist( + instances=instances, + connections=connections, + ports=ports, + models=models, + settings=settings, + default_models=default_models, + ) + _instances: Dict[str, str] = {} + _settings: Settings = {} + _models = models + + for name, instance in _netlist["instances"].items(): + if is_netlist(instance): + instance = cast(Netlist, instance) + model, _settings[name], _models = logical_netlist( + instances=instance["instances"], + connections=instance["connections"], + ports=instance["ports"], + models=_models, + default_models=default_models, + ) + model_hash = hex(abs(hash(FrozenDict(model))))[2:] + component = f"logical_netlist_{model_hash}" + _instances[name] = component + _models[component] = model + elif is_instance(instance): + instance = cast(Instance, instance) + component = instance["component"] + _instance_settings = instance.get("settings", {}) + _instance_model = cast(Model, _models[component]) + + if is_model_factory(_instance_model): + model_factory = cast(ModelFactory, _instance_model) + _instance_model = model_factory(**_instance_settings) + instance, _models = _instance_from_callable(_instance_model, _models) + _instance_settings = instance.get("settings", {}) + component = instance["component"] + + _instances[name] = component + _model_settings = get_settings(_instance_model) + _instance_settings = { + k: try_float( + v + if (k not in _instance_settings or _instance_settings[k] is None) + else _instance_settings[k] + ) + for k, v in _model_settings.items() + } + _settings[name] = cast(Settings, _instance_settings) + else: + raise ValueError(f"instance '{name}' is not an Instance or a Netlist.") + + _instances = dict(natsorted(_instances.items())) + + _logical_netlist = LogicalNetlist( + instances=_instances, + connections=_netlist["connections"], + ports=_netlist["ports"], + ) + return ( + _logical_netlist, + _settings, + _models, + ) \ No newline at end of file diff --git a/sax/nn/__init__.py b/sax/nn/__init__.py new file mode 100644 index 0000000..835cc7a --- /dev/null +++ b/sax/nn/__init__.py @@ -0,0 +1,43 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/09_nn.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = [] + +# Cell +#nbdev_comment from __future__ import annotations + +from .loss import huber_loss as huber_loss +from .loss import l2_reg as l2_reg +from .loss import mse as mse + +# Cell + +from .utils import ( + cartesian_product as cartesian_product, + denormalize as denormalize, + get_normalization as get_normalization, + get_df_columns as get_df_columns, + normalize as normalize, +) + +# Cell + +from .core import ( + preprocess as preprocess, + dense as dense, + generate_dense_weights as generate_dense_weights, +) + +# Cell + +from .io import ( + load_nn_weights_json as load_nn_weights_json, + save_nn_weights_json as save_nn_weights_json, + get_available_sizes as get_available_sizes, + get_dense_weights_path as get_dense_weights_path, + get_norm_path as get_norm_path, + load_nn_dense as load_nn_dense, +) \ No newline at end of file diff --git a/sax/nn/core.py b/sax/nn/core.py new file mode 100644 index 0000000..af12e02 --- /dev/null +++ b/sax/nn/core.py @@ -0,0 +1,86 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/09c_nn_core.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['preprocess', 'dense', 'generate_dense_weights'] + +# Cell +#nbdev_comment from __future__ import annotations + +from typing import Callable, Dict, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from .utils import denormalize, normalize +from ..typing_ import Array, ComplexFloat + +# Cell +def preprocess(*params: ComplexFloat) -> ComplexFloat: + """preprocess parameters + + > Note: (1) all arguments are first casted into the same shape. (2) then pairs + of arguments are divided into each other to create relative arguments. (3) all + arguments are then stacked into one big tensor + """ + x = jnp.stack(jnp.broadcast_arrays(*params), -1) + assert isinstance(x, jnp.ndarray) + to_concatenate = [x] + for i in range(1, x.shape[-1]): + _x = jnp.roll(x, shift=i, axis=-1) + to_concatenate.append(x / _x) + to_concatenate.append(_x / x) + x = jnp.concatenate(to_concatenate, -1) + assert isinstance(x, jnp.ndarray) + return x + +# Cell +def dense( + weights: Dict[str, Array], + *params: ComplexFloat, + x_norm: Tuple[float, float] = (0.0, 1.0), + y_norm: Tuple[float, float] = (0.0, 1.0), + preprocess: Callable = preprocess, + activation: Callable = jax.nn.leaky_relu, +) -> ComplexFloat: + """simple dense neural network""" + x_mean, x_std = x_norm + y_mean, y_std = y_norm + x = preprocess(*params) + x = normalize(x, mean=x_mean, std=x_std) + for i in range(len([w for w in weights if w.startswith("w")])): + x = activation(x @ weights[f"w{i}"] + weights.get(f"b{i}", 0.0)) + y = denormalize(x, mean=y_mean, std=y_std) + return y + +# Cell +def generate_dense_weights( + key: Union[int, Array], + sizes: Tuple[int, ...], + input_names: Optional[Tuple[str, ...]] = None, + output_names: Optional[Tuple[str, ...]] = None, + preprocess=preprocess, +) -> Dict[str, ComplexFloat]: + """Generate the weights for a dense neural network""" + + if isinstance(key, int): + key = jax.random.PRNGKey(key) + assert isinstance(key, jnp.ndarray) + + sizes = tuple(s for s in sizes) + if input_names: + arr = preprocess(*jnp.ones(len(input_names))) + assert isinstance(arr, jnp.ndarray) + sizes = (arr.shape[-1],) + sizes + if output_names: + sizes = sizes + (len(output_names),) + + keys = jax.random.split(key, 2 * len(sizes)) + rand = jax.nn.initializers.lecun_normal() + weights = {} + for i, (m, n) in enumerate(zip(sizes[:-1], sizes[1:])): + weights[f"w{i}"] = rand(keys[2 * i], (m, n)) + weights[f"b{i}"] = rand(keys[2 * i + 1], (1, n)).ravel() + + return weights \ No newline at end of file diff --git a/sax/nn/io.py b/sax/nn/io.py new file mode 100644 index 0000000..96b3396 --- /dev/null +++ b/sax/nn/io.py @@ -0,0 +1,202 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/09d_nn_io.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['load_nn_weights_json', 'save_nn_weights_json', 'get_available_sizes', 'get_dense_weights_path', + 'get_norm_path', 'load_nn_dense'] + +# Cell +#nbdev_comment from __future__ import annotations + +import json +import os +import re +from typing import Callable, Dict, List, Optional, Tuple + +import jax.numpy as jnp +from .core import dense, preprocess +from .utils import norm +from ..typing_ import ComplexFloat + +# Cell + +def load_nn_weights_json(path: str) -> Dict[str, ComplexFloat]: + """Load json weights from given path""" + path = os.path.abspath(os.path.expanduser(path)) + weights = {} + if os.path.exists(path): + with open(path, "r") as file: + for k, v in json.load(file).items(): + _v = jnp.array(v, dtype=float) + assert isinstance(_v, jnp.ndarray) + weights[k] = _v + return weights + +# Cell +def save_nn_weights_json(weights: Dict[str, ComplexFloat], path: str): + """Save json weights to given path""" + path = os.path.abspath(os.path.expanduser(path)) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w") as file: + _weights = {} + for k, v in weights.items(): + v = jnp.atleast_1d(jnp.array(v)) + assert isinstance(v, jnp.ndarray) + _weights[k] = v.tolist() + json.dump(_weights, file) + +# Cell +def get_available_sizes( + dirpath: str, + prefix: str, + input_names: Tuple[str, ...], + output_names: Tuple[str, ...], +) -> List[Tuple[int, ...]]: + """Get all available json weight hidden sizes given filename parameters + + > Note: this function does NOT return the input size and the output size + of the neural network. ONLY the hidden sizes are reported. The input + and output sizes can easily be derived from `input_names` (after + preprocessing) and `output_names`. + """ + all_weightfiles = os.listdir(dirpath) + possible_weightfiles = ( + s for s in all_weightfiles if s.endswith(f"-{'-'.join(output_names)}.json") + ) + possible_weightfiles = ( + s + for s in possible_weightfiles + if s.startswith(f"{prefix}-{'-'.join(input_names)}") + ) + possible_weightfiles = (re.sub("[^0-9x]", "", s) for s in possible_weightfiles) + possible_weightfiles = (re.sub("^x*", "", s) for s in possible_weightfiles) + possible_weightfiles = (re.sub("x[^0-9]*$", "", s) for s in possible_weightfiles) + possible_hidden_sizes = (s.strip() for s in possible_weightfiles if s.strip()) + possible_hidden_sizes = ( + tuple(hs.strip() for hs in s.split("x") if hs.strip()) + for s in possible_hidden_sizes + ) + possible_hidden_sizes = ( + tuple(int(hs) for hs in s[1:-1]) for s in possible_hidden_sizes if len(s) > 2 + ) + possible_hidden_sizes = sorted( + possible_hidden_sizes, key=lambda hs: (len(hs), max(hs)) + ) + return possible_hidden_sizes + +# Cell + +def get_dense_weights_path( + *sizes: int, + input_names: Optional[Tuple[str, ...]] = None, + output_names: Optional[Tuple[str, ...]] = None, + dirpath: str = "weights", + prefix: str = "dense", + preprocess=preprocess, +): + """Create the SAX conventional path for a given weight dictionary""" + if input_names: + num_inputs = preprocess(*jnp.ones(len(input_names))).shape[0] + sizes = (num_inputs,) + sizes + if output_names: + sizes = sizes + (len(output_names),) + path = os.path.abspath(os.path.join(dirpath, prefix)) + if input_names: + path = f"{path}-{'-'.join(input_names)}" + if sizes: + path = f"{path}-{'x'.join(str(s) for s in sizes)}" + if output_names: + path = f"{path}-{'-'.join(output_names)}" + return f"{path}.json" + +# Cell + +def get_norm_path( + *shape: int, + input_names: Optional[Tuple[str, ...]] = None, + output_names: Optional[Tuple[str, ...]] = None, + dirpath: str = "norms", + prefix: str = "norm", + preprocess=preprocess, +): + """Create the SAX conventional path for the normalization constants""" + if input_names and output_names: + raise ValueError( + "To get the norm name, one can only specify `input_names` OR `output_names`." + ) + if input_names: + num_inputs = preprocess(*jnp.ones(len(input_names))).shape[0] + shape = (num_inputs,) + shape + if output_names: + shape = shape + (len(output_names),) + path = os.path.abspath(os.path.join(dirpath, prefix)) + if input_names: + path = f"{path}-{'-'.join(input_names)}" + if shape: + path = f"{path}-{'x'.join(str(s) for s in shape)}" + if output_names: + path = f"{path}-{'-'.join(output_names)}" + return f"{path}.json" + +# Internal Cell +class _PartialDense: + def __init__(self, weights, x_norm, y_norm, input_names, output_names): + self.weights = weights + self.x_norm = x_norm + self.y_norm = y_norm + self.input_names = input_names + self.output_names = output_names + + def __call__(self, *params: ComplexFloat) -> ComplexFloat: + return dense(self.weights, *params, x_norm=self.x_norm, y_norm=self.y_norm) + + def __repr__(self): + return f"{self.__class__.__name__}{repr(self.input_names)}->{repr(self.output_names)}" + +# Cell +def load_nn_dense( + *sizes: int, + input_names: Optional[Tuple[str, ...]] = None, + output_names: Optional[Tuple[str, ...]] = None, + weightprefix="dense", + weightdirpath="weights", + normdirpath="norms", + normprefix="norm", + preprocess=preprocess, +) -> Callable: + """Load a pre-trained dense model""" + weights_path = get_dense_weights_path( + *sizes, + input_names=input_names, + output_names=output_names, + prefix=weightprefix, + dirpath=weightdirpath, + preprocess=preprocess, + ) + if not os.path.exists(weights_path): + raise ValueError("Cannot find weights path for given parameters") + x_norm_path = get_norm_path( + input_names=input_names, + prefix=normprefix, + dirpath=normdirpath, + preprocess=preprocess, + ) + if not os.path.exists(x_norm_path): + raise ValueError("Cannot find normalization for input parameters") + y_norm_path = get_norm_path( + output_names=output_names, + prefix=normprefix, + dirpath=normdirpath, + preprocess=preprocess, + ) + if not os.path.exists(x_norm_path): + raise ValueError("Cannot find normalization for output parameters") + weights = load_nn_weights_json(weights_path) + x_norm_dict = load_nn_weights_json(x_norm_path) + y_norm_dict = load_nn_weights_json(y_norm_path) + x_norm = norm(x_norm_dict["mean"], x_norm_dict["std"]) + y_norm = norm(y_norm_dict["mean"], y_norm_dict["std"]) + partial_dense = _PartialDense(weights, x_norm, y_norm, input_names, output_names) + return partial_dense \ No newline at end of file diff --git a/sax/nn/loss.py b/sax/nn/loss.py new file mode 100644 index 0000000..b57c2fb --- /dev/null +++ b/sax/nn/loss.py @@ -0,0 +1,38 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/09a_nn_loss.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['mse', 'huber_loss', 'l2_reg'] + +# Cell +#nbdev_comment from __future__ import annotations + +from typing import Dict + +import jax.numpy as jnp +from ..typing_ import ComplexFloat + +# Cell + +def mse(x: ComplexFloat, y: ComplexFloat) -> float: + """mean squared error""" + return ((x - y) ** 2).mean() + +# Cell + +def huber_loss(x: ComplexFloat, y: ComplexFloat, delta: float=0.5) -> float: + """huber loss""" + return ((delta ** 2) * ((1.0 + ((x - y) / delta) ** 2) ** 0.5 - 1.0)).mean() + +# Cell + +def l2_reg(weights: Dict[str, ComplexFloat]) -> float: + """L2 regularization loss""" + numel = 0 + loss = 0.0 + for w in (v for k, v in weights.items() if k[0] in ("w", "b")): + numel = numel + w.size + loss = loss + (jnp.abs(w) ** 2).sum() + return loss / numel \ No newline at end of file diff --git a/sax/nn/utils.py b/sax/nn/utils.py new file mode 100644 index 0000000..debb6a3 --- /dev/null +++ b/sax/nn/utils.py @@ -0,0 +1,62 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/09b_nn_utils.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['cartesian_product', 'denormalize', 'get_normalization', 'get_df_columns', 'normalize'] + +# Cell +#nbdev_comment from __future__ import annotations + +from collections import namedtuple +from typing import Tuple + +import jax.numpy as jnp +import pandas as pd +from ..typing_ import ComplexFloat + +# Cell +def cartesian_product(*arrays: ComplexFloat) -> ComplexFloat: + """calculate the n-dimensional cartesian product of an arbitrary number of arrays""" + ixarrays = jnp.ix_(*arrays) + barrays = jnp.broadcast_arrays(*ixarrays) + sarrays = jnp.stack(barrays, -1) + assert isinstance(sarrays, jnp.ndarray) + product = sarrays.reshape(-1, sarrays.shape[-1]) + assert isinstance(product, jnp.ndarray) + return product + +# Cell +def denormalize(x: ComplexFloat, mean: ComplexFloat = 0.0, std: ComplexFloat = 1.0) -> ComplexFloat: + """denormalize an array with a given mean and standard deviation""" + return x * std + mean + +# Internal Cell +norm = namedtuple("norm", ("mean", "std")) + +# Cell +def get_normalization(x: ComplexFloat): + """Get mean and standard deviation for a given array""" + if isinstance(x, (complex, float)): + return x, 0.0 + return norm(x.mean(0), x.std(0)) + +# Cell +def get_df_columns(df: pd.DataFrame, *names: str) -> Tuple[ComplexFloat, ...]: + """Get certain columns from a pandas DataFrame as jax.numpy arrays""" + tup = namedtuple("params", names) + params_list = [] + for name in names: + column_np = df[name].values + column_jnp = jnp.array(column_np) + assert isinstance(column_jnp, jnp.ndarray) + params_list.append(column_jnp.ravel()) + return tup(*params_list) + +# Cell +def normalize( + x: ComplexFloat, mean: ComplexFloat = 0.0, std: ComplexFloat = 1.0 +) -> ComplexFloat: + """normalize an array with a given mean and standard deviation""" + return (x - mean) / std \ No newline at end of file diff --git a/sax/patched.py b/sax/patched.py new file mode 100644 index 0000000..2ac0e49 --- /dev/null +++ b/sax/patched.py @@ -0,0 +1,47 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_patched.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = [] + +# Cell +#nbdev_comment from __future__ import annotations + +import re +from fastcore.basics import patch_to +from flax.core import FrozenDict +from jaxlib.xla_extension import DeviceArray + +from .typing_ import is_complex_float, is_float +from textwrap import dedent + +# Internal Cell +@patch_to(FrozenDict) +def __repr__(self): # type: ignore + _dict = lambda d: dict( + {k: (v if not isinstance(v, self.__class__) else dict(v)) for k, v in d.items()} + ) + return f"{self.__class__.__name__}({dict.__repr__(_dict(self))})" + +# Internal Cell +@patch_to(DeviceArray) +def __repr__(self): # type: ignore + if self.ndim == 0 and is_float(self): + v = float(self) + return repr(round(v, 5)) if abs(v) > 1e-4 else repr(v) + elif self.ndim == 0 and is_complex_float(self): + r, i = float(self.real), float(self.imag) + r = round(r, 5) if abs(r) > 1e-4 else r + i = round(i, 5) if abs(i) > 1e-4 else i + s = repr(r + 1j * i) + if s[0] == "(" and s[-1] == ")": + s = s[1:-1] + return s + else: + s = super(self.__class__, self).__repr__() + s = s.replace("DeviceArray(", " array(") + s = re.sub(r", dtype=.*[,)]", "", s) + s = re.sub(r" weak_type=.*[,)]", "", s) + return dedent(s)+")" \ No newline at end of file diff --git a/sax/typing.py b/sax/typing.py deleted file mode 100644 index f97d6bb..0000000 --- a/sax/typing.py +++ /dev/null @@ -1,13 +0,0 @@ -""" Common datastructure types used in SAX """ - -from typing import Optional, Dict, Union, Tuple, Callable, Any - - - -ParamsDict = Dict[str, Union[Dict, float]] - -ModelDict = Dict[Union[Tuple[str, str], str], Union[Callable, ParamsDict]] - -ComplexFloat = Union[complex, float] - -ModelFunc = Callable[[ParamsDict], ComplexFloat] diff --git a/sax/typing_.py b/sax/typing_.py new file mode 100644 index 0000000..6aab97c --- /dev/null +++ b/sax/typing_.py @@ -0,0 +1,418 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_typing.ipynb (unless otherwise specified). + + +from __future__ import annotations + + +__all__ = ['Array', 'Int', 'Float', 'ComplexFloat', 'Settings', 'SDict', 'SCoo', 'SDense', 'SType', 'Model', + 'ModelFactory', 'GeneralModel', 'Models', 'Instance', 'GeneralInstance', 'Instances', 'Netlist', + 'LogicalNetlist', 'is_float', 'is_complex', 'is_complex_float', 'is_sdict', 'is_scoo', 'is_sdense', + 'is_model', 'is_model_factory', 'validate_model', 'is_instance', 'is_netlist', 'is_stype', 'is_singlemode', + 'is_multimode', 'is_mixedmode', 'sdict', 'scoo', 'sdense', 'modelfactory'] + +# Cell +#nbdev_comment from __future__ import annotations + +import functools +import inspect +from collections.abc import Callable as CallableABC +from typing import Any, Callable, Dict, Tuple, TypedDict, Union, cast, overload + +import jax.numpy as jnp +import numpy as np +from natsort import natsorted + +# Cell +Array = Union[jnp.ndarray, np.ndarray] + +# Cell +Int = Union[int, Array] + +# Cell +Float = Union[float, Array] + +# Cell +ComplexFloat = Union[complex, Float] + +# Cell +Settings = Union[Dict[str, ComplexFloat], Dict[str, "Settings"]] + +# Cell +SDict = Dict[Tuple[str, str], ComplexFloat] + +# Cell +SCoo = Tuple[Array, Array, ComplexFloat, Dict[str, int]] + +# Cell +SDense = Tuple[Array, Dict[str, int]] + +# Cell +SType = Union[SDict, SCoo, SDense] + +# Cell +Model = Callable[..., SType] + +# Cell +ModelFactory = Callable[..., Model] + +# Cell +GeneralModel = Union[Model, "LogicalNetlist"] + +# Cell +Models = Dict[str, GeneralModel] + +# Cell +Instance = TypedDict( + "Instance", + { + "component": str, + "settings": Settings, + }, +) + +# Cell +GeneralInstance = Union[str, Instance, "LogicalNetlist", "Netlist"] + +# Cell +Instances = Union[Dict[str, str], Dict[str, GeneralInstance]] + +# Cell + +Netlist = TypedDict( + "Netlist", + { + "instances": Instances, + "connections": Dict[str, str], + "ports": Dict[str, str], + }, +) + +# Cell + +LogicalNetlist = TypedDict( + "LogicalNetlist", + { + "instances": Dict[str, str], + "connections": Dict[str, str], + "ports": Dict[str, str], + }, +) + +# Cell +def is_float(x: Any) -> bool: + """Check if an object is a `Float`""" + if isinstance(x, float): + return True + if isinstance(x, np.ndarray): + return x.dtype in (np.float16, np.float32, np.float64, np.float128) + if isinstance(x, jnp.ndarray): + return x.dtype in (jnp.float16, jnp.float32, jnp.float64) + return False + +# Cell +def is_complex(x: Any) -> bool: + """check if an object is a `ComplexFloat`""" + if isinstance(x, complex): + return True + if isinstance(x, np.ndarray): + return x.dtype in (np.complex64, np.complex128) + if isinstance(x, jnp.ndarray): + return x.dtype in (jnp.complex64, jnp.complex128) + return False + +# Cell +def is_complex_float(x: Any) -> bool: + """check if an object is either a `ComplexFloat` or a `Float`""" + return is_float(x) or is_complex(x) + +# Cell +def is_sdict(x: Any) -> bool: + """check if an object is an `SDict` (a SAX S-dictionary)""" + return isinstance(x, dict) + +# Cell +def is_scoo(x: Any) -> bool: + """check if an object is an `SCoo` (a SAX sparse S-matrix representation in COO-format)""" + return isinstance(x, (tuple, list)) and len(x) == 4 + +# Cell +def is_sdense(x: Any) -> bool: + """check if an object is an `SDense` (a SAX dense S-matrix representation)""" + return isinstance(x, (tuple, list)) and len(x) == 2 + +# Cell +def is_model(model: Any) -> bool: + """check if a callable is a `Model` (a callable returning an `SType`)""" + if not callable(model): + return False + try: + sig = inspect.signature(model) + except ValueError: + return False + for param in sig.parameters.values(): + if param.default == inspect.Parameter.empty: + return False # a proper SAX model does not have any positional arguments. + if _is_callable_annotation(sig.return_annotation): # model factory + return False + return True + +def _is_callable_annotation(annotation: Any) -> bool: + """check if an annotation is `Callable`-like""" + if isinstance(annotation, str): + # happens when + # from __future__ import annotations + # was imported at the top of the file... + return annotation.startswith("Callable") or annotation.endswith("Model") + # TODO: this is not a very robust check... + try: + return annotation.__origin__ == CallableABC + except AttributeError: + return False + +# Cell +def is_model_factory(model: Any) -> bool: + """check if a callable is a model function.""" + if not callable(model): + return False + sig = inspect.signature(model) + if _is_callable_annotation(sig.return_annotation): # model factory + return True + return False + +# Cell +def validate_model(model: Callable): + """Validate the parameters of a model""" + positional_arguments = [] + for param in inspect.signature(model).parameters.values(): + if param.default is inspect.Parameter.empty: + positional_arguments.append(param.name) + if positional_arguments: + raise ValueError( + f"model '{model}' takes positional arguments {', '.join(positional_arguments)} " + "and hence is not a valid SAX Model! A SAX model should ONLY take keyword arguments (or no arguments at all)." + ) + +# Cell +def is_instance(instance: Any) -> bool: + """check if a dictionary is an instance""" + if not isinstance(instance, dict): + return False + return "component" in instance + +# Cell +def is_netlist(netlist: Any) -> bool: + """check if a dictionary is a netlist""" + if not isinstance(netlist, dict): + return False + if not "instances" in netlist: + return False + if not "connections" in netlist: + return False + if not "ports" in netlist: + return False + return True + +# Cell +def is_stype(stype: Any) -> bool: + """check if an object is an SDict, SCoo or SDense""" + return is_sdict(stype) or is_scoo(stype) or is_sdense(stype) + +# Cell +def is_singlemode(S: Any) -> bool: + """check if an stype is single mode""" + if not is_stype(S): + return False + ports = _get_ports(S) + return not any(("@" in p) for p in ports) + +def _get_ports(S: SType): + if is_sdict(S): + S = cast(SDict, S) + ports_set = {p1 for p1, _ in S} | {p2 for _, p2 in S} + return tuple(natsorted(ports_set)) + else: + *_, ports_map = S + assert isinstance(ports_map, dict) + return tuple(natsorted(ports_map.keys())) + +# Cell +def is_multimode(S: Any) -> bool: + """check if an stype is single mode""" + if not is_stype(S): + return False + + ports = _get_ports(S) + return all(("@" in p) for p in ports) + +# Cell +def is_mixedmode(S: Any) -> bool: + """check if an stype is neither single mode nor multimode (hence invalid)""" + return not is_singlemode(S) and not is_multimode(S) + +# Internal Cell + +@overload +def sdict(S: Model) -> Model: + ... + + +@overload +def sdict(S: SType) -> SDict: + ... + +# Cell +def sdict(S: Union[Model, SType]) -> Union[Model, SType]: + """Convert an `SCoo` or `SDense` to `SDict`""" + + if is_model(S): + model = cast(Model, S) + + @functools.wraps(model) + def wrapper(**kwargs): + return sdict(model(**kwargs)) + + return wrapper + + elif is_scoo(S): + x_dict = _scoo_to_sdict(*cast(SCoo, S)) + elif is_sdense(S): + x_dict = _sdense_to_sdict(*cast(SDense, S)) + elif is_sdict(S): + x_dict = cast(SDict, S) + else: + raise ValueError("Could not convert arguments to sdict.") + + return x_dict + + +def _scoo_to_sdict(Si: Array, Sj: Array, Sx: Array, ports_map: Dict[str, int]) -> SDict: + sdict = {} + inverse_ports_map = {int(i): p for p, i in ports_map.items()} + for i, (si, sj) in enumerate(zip(Si, Sj)): + sdict[ + inverse_ports_map.get(int(si), ""), inverse_ports_map.get(int(sj), "") + ] = Sx[..., i] + sdict = {(p1, p2): v for (p1, p2), v in sdict.items() if p1 and p2} + return sdict + + +def _sdense_to_sdict(S: Array, ports_map: Dict[str, int]) -> SDict: + sdict = {} + for p1, i in ports_map.items(): + for p2, j in ports_map.items(): + sdict[p1, p2] = S[..., i, j] + return sdict + +# Internal Cell + +@overload +def scoo(S: Callable) -> Callable: + ... + + +@overload +def scoo(S: SType) -> SCoo: + ... + +# Cell + +def scoo(S: Union[Callable, SType]) -> Union[Callable, SCoo]: + """Convert an `SDict` or `SDense` to `SCoo`""" + + if is_model(S): + model = cast(Model, S) + + @functools.wraps(model) + def wrapper(**kwargs): + return scoo(model(**kwargs)) + + return wrapper + + elif is_scoo(S): + S = cast(SCoo, S) + elif is_sdense(S): + S = _sdense_to_scoo(*cast(SDense, S)) + elif is_sdict(S): + S = _sdict_to_scoo(cast(SDict, S)) + else: + raise ValueError("Could not convert arguments to scoo.") + + return S + + +def _sdense_to_scoo(S: Array, ports_map: Dict[str, int]) -> SCoo: + Sj, Si = jnp.meshgrid(jnp.arange(S.shape[-1]), jnp.arange(S.shape[-2])) + return Si.ravel(), Sj.ravel(), S.reshape(*S.shape[:-2], -1), ports_map + + +def _sdict_to_scoo(sdict: SDict) -> SCoo: + all_ports = {} + for p1, p2 in sdict: + all_ports[p1] = None + all_ports[p2] = None + ports_map = {p: i for i, p in enumerate(all_ports)} + Sx = jnp.stack(jnp.broadcast_arrays(*sdict.values()), -1) + Si = jnp.array([ports_map[p] for p, _ in sdict]) + Sj = jnp.array([ports_map[p] for _, p in sdict]) + return Si, Sj, Sx, ports_map + +# Internal Cell + + +@overload +def sdense(S: Callable) -> Callable: + ... + + +@overload +def sdense(S: SType) -> SDense: + ... + +# Cell + +def sdense(S: Union[Callable, SType]) -> Union[Callable, SDense]: + """Convert an `SDict` or `SCoo` to `SDense`""" + + if is_model(S): + model = cast(Model, S) + + @functools.wraps(model) + def wrapper(**kwargs): + return sdense(model(**kwargs)) + + return wrapper + + if is_sdict(S): + S = _sdict_to_sdense(cast(SDict, S)) + elif is_scoo(S): + S = _scoo_to_sdense(*cast(SCoo, S)) + elif is_sdense(S): + S = cast(SDense, S) + else: + raise ValueError("Could not convert arguments to sdense.") + + return S + + +def _scoo_to_sdense( + Si: Array, Sj: Array, Sx: Array, ports_map: Dict[str, int] +) -> SDense: + n_col = len(ports_map) + S = jnp.zeros((*Sx.shape[:-1], n_col, n_col), dtype=complex) + S = S.at[..., Si, Sj].add(Sx) + return S, ports_map + + +def _sdict_to_sdense(sdict: SDict) -> SDense: + Si, Sj, Sx, ports_map = _sdict_to_scoo(sdict) + return _scoo_to_sdense(Si, Sj, Sx, ports_map) + +# Cell + +def modelfactory(func): + """Decorator that marks a function as `ModelFactory`""" + sig = inspect.signature(func) + if _is_callable_annotation(sig.return_annotation): # already model factory + return func + func.__signature__ = sig.replace(return_annotation=Model) + return func \ No newline at end of file diff --git a/sax/utils.py b/sax/utils.py index acba582..e0b406a 100644 --- a/sax/utils.py +++ b/sax/utils.py @@ -1,179 +1,521 @@ -""" Useful functions for working with SAX. """ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02_utils.ipynb (unless otherwise specified). -import pickle -import numpy as np + +from __future__ import annotations + + +__all__ = ['block_diag', 'clean_string', 'copy_settings', 'validate_settings', 'try_float', 'flatten_dict', + 'unflatten_dict', 'get_ports', 'get_port_combinations', 'get_settings', 'grouped_interp', 'merge_dicts', + 'mode_combinations', 'reciprocal', 'rename_params', 'rename_ports', 'update_settings', + 'validate_not_mixedmode', 'validate_multimode', 'validate_sdict', 'get_inputs_outputs'] + +# Cell +#nbdev_comment from __future__ import annotations + +import inspect +import re +from functools import lru_cache, partial, wraps +from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union, cast, overload + +import jax import jax.numpy as jnp +import jax.scipy as jsp +from natsort import natsorted +from .typing_ import ( + Array, + ComplexFloat, + Float, + Model, + ModelFactory, + SCoo, + SDense, + SDict, + Settings, + SType, + is_mixedmode, + is_model, + is_model_factory, + is_scoo, + is_sdense, + is_sdict, +) + +# Cell +def block_diag(*arrs: Array) -> Array: + """create block diagonal matrix with arbitrary batch dimensions """ + batch_shape = arrs[0].shape[:-2] + N = 0 + for arr in arrs: + if batch_shape != arr.shape[:-2]: + raise ValueError("batch dimensions for given arrays don't match.") + m, n = arr.shape[-2:] + if m != n: + raise ValueError("given arrays are not square.") + N += n + + block_diag = jax.vmap(jsp.linalg.block_diag, in_axes=0, out_axes=0)( + *(arr.reshape(-1, arr.shape[-2], arr.shape[-1]) for arr in arrs) + ).reshape(*batch_shape, N, N) + + return block_diag + +# Cell + +def clean_string(s: str) -> str: + """clean a string such that it is a valid python identifier""" + s = s.strip() + s = s.replace(".", "p") # point + s = s.replace("-", "m") # minus + s = re.sub("[^0-9a-zA-Z]", "_", s) + if s[0] in "0123456789": + s = "_" + s + return s + +# Cell + +def copy_settings(settings: Settings) -> Settings: + """copy a parameter dictionary""" + return validate_settings(settings) # validation also copies + +def validate_settings(settings: Settings) -> Settings: + """Validate a parameter dictionary""" + _settings = {} + for k, v in settings.items(): + if isinstance(v, dict): + _settings[k] = validate_settings(v) + else: + _settings[k] = try_float(v) + return _settings + +def try_float(f: Any) -> Any: + """try converting an object to float, return unchanged object on fail""" + try: + return jnp.asarray(f, dtype=float) + except (ValueError, TypeError): + return f + +# Cell +def flatten_dict(dic: Dict[str, Any], sep: str = ",") -> Dict[str, Any]: + """flatten a nested dictionary""" + return _flatten_dict(dic, sep=sep) + + +def _flatten_dict( + dic: Dict[str, Any], sep: str = ",", frozen: bool = False, parent_key: str = "" +) -> Dict[str, Any]: + items = [] + for k, v in dic.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, dict): + items.extend( + _flatten_dict(v, sep=sep, frozen=frozen, parent_key=new_key).items() + ) + else: + items.append((new_key, v)) -from .typing import Any, Union, Tuple, Callable, Dict, ParamsDict, ModelDict + return dict(items) +# Cell -def load(name: str) -> object: - """load an object using pickle +def unflatten_dict(dic, sep=","): + """unflatten a flattened dictionary """ - Args: - name: the name to load + # from: https://gist.github.com/fmder/494aaa2dd6f8c428cede + items = dict() - Returns: - the unpickled object. - """ - with open(name, "rb") as file: - obj = pickle.load(file) - return obj + for k, v in dic.items(): + keys = k.split(sep) + sub_items = items + for ki in keys[:-1]: + if ki in sub_items: + sub_items = sub_items[ki] + else: + sub_items[ki] = dict() + sub_items = sub_items[ki] + sub_items[keys[-1]] = v -def save(obj: object, name: str): - """save an object using pickle + return items - Args: - obj: the object to save - name: the name to save the object under - """ - with open(name, "wb") as file: - pickle.dump(obj, file) +# Cell +def get_ports(S: Union[Model, SType]) -> Tuple[str, ...]: + """get port names of a model or an stype""" + if is_model(S): + return _get_ports_from_model(cast(Model, S)) + elif is_sdict(S): + ports_set = {p1 for p1, _ in S} | {p2 for _, p2 in S} + return tuple(natsorted(ports_set)) + elif is_scoo(S) or is_sdense(S): + *_, ports_map = S + return tuple(natsorted(ports_map.keys())) + else: + raise ValueError("Could not extract ports for given S") + +@lru_cache(maxsize=4096) # cache to prevent future tracing +def _get_ports_from_model(model: Model) -> Tuple[str, ...]: + S: SType = jax.eval_shape(model) + return get_ports(S) + +# Cell + +def get_port_combinations(S: Union[Model, SType]) -> Tuple[Tuple[str, str], ...]: + """get port combinations of a model or an stype""" + + if is_model(S): + S = cast(Model, S) + return _get_port_combinations_from_model(S) + elif is_sdict(S): + S = cast(SDict, S) + return tuple(S.keys()) + elif is_scoo(S): + Si, Sj, _, pm = cast(SCoo, S) + rpm = {int(i): str(p) for p, i in pm.items()} + return tuple(natsorted((rpm[int(i)], rpm[int(j)]) for i, j in zip(Si, Sj))) + elif is_sdense(S): + _, pm = cast(SDense, S) + return tuple(natsorted((p1, p2) for p1 in pm for p2 in pm)) + else: + raise ValueError("Could not extract ports for given S") + +@lru_cache(maxsize=4096) # cache to prevent future tracing +def _get_port_combinations_from_model(model: Model) -> Tuple[Tuple[str, str], ...]: + S: SType = jax.eval_shape(model) + return get_port_combinations(S) + +# Cell +def get_settings(model: Union[Model, ModelFactory]) -> Settings: + """Get the parameters of a SAX model function""" + + signature = inspect.signature(model) + + settings: Settings = { + k: (v.default if not isinstance(v, dict) else v) + for k, v in signature.parameters.items() + if v.default is not inspect.Parameter.empty + } + + # make sure an inplace operation of resulting dict does not change the + # circuit parameters themselves + return copy_settings(settings) + +# Cell +def grouped_interp(wl: Float, wls: Float, phis: Float) -> Float: + """Grouped phase interpolation""" + wl = cast(Array, jnp.asarray(wl)) + wls = cast(Array, jnp.asarray(wls)) + # make sure values between -pi and pi + phis = cast(Array, jnp.asarray(phis)) % (2 * jnp.pi) + phis = jnp.where(phis > jnp.pi, phis - 2 * jnp.pi, phis) + if not wls.ndim == 1: + raise ValueError("grouped_interp: wls should be a 1D array") + if not phis.ndim == 1: + raise ValueError("grouped_interp: wls should be a 1D array") + if not wls.shape == phis.shape: + raise ValueError("grouped_interp: wls and phis shape does not match") + return _grouped_interp(wl.reshape(-1), wls, phis).reshape(*wl.shape) + + +@partial(jax.vmap, in_axes=(0, None, None), out_axes=0) +@jax.jit +def _grouped_interp( + wl: Array, # 0D array (not-vmapped) ; 1D array (vmapped) + wls: Array, # 1D array + phis: Array, # 1D array +) -> Array: + dphi_dwl = (phis[1::2] - phis[::2]) / (wls[1::2] - wls[::2]) + phis = phis[::2] + wls = wls[::2] + dwl = (wls[1:] - wls[:-1]).mean(0, keepdims=True) + + t = (wl - wls + 1e-5 * dwl) / dwl # small offset to ensure no values are zero + t = jnp.where(jnp.abs(t) < 1, t, 0) + m0 = jnp.where(t > 0, size=1)[0] + m1 = jnp.where(t < 0, size=1)[0] + t = t[m0] + wl0 = wls[m0] + wl1 = wls[m1] + phi0 = phis[m0] + phi1 = phis[m1] + dphi_dwl0 = dphi_dwl[m0] + dphi_dwl1 = dphi_dwl[m1] + _phi0 = phi0 - 0.5 * (wl1 - wl0) * ( + dphi_dwl0 * (t ** 2 - 2 * t) - dphi_dwl1 * t ** 2 + ) + _phi1 = phi1 - 0.5 * (wl1 - wl0) * ( + dphi_dwl0 * (t - 1) ** 2 - dphi_dwl1 * (t ** 2 - 1) + ) + phis = jnp.arctan2( + (1 - t) * jnp.sin(_phi0) + t * jnp.sin(_phi1), + (1 - t) * jnp.cos(_phi0) + t * jnp.cos(_phi1), + ) + return phis + +# Cell +def merge_dicts(*dicts: Dict) -> Dict: + """merge (possibly deeply nested) dictionaries""" + if len(dicts) == 1: + return dict(_generate_merged_dict(dicts[0], {})) + elif len(dicts) == 2: + return dict(_generate_merged_dict(dicts[0], dicts[1])) + else: + return merge_dicts(dicts[0], merge_dicts(*dicts[1:])) + + +def _generate_merged_dict(dict1: Dict, dict2: Dict) -> Iterator[Tuple[Any, Any]]: + # inspired by https://stackoverflow.com/questions/7204805/how-to-merge-dictionaries-of-dictionaries + keys = {**{k: None for k in dict1}, **{k: None for k in dict2}} # keep key order, values irrelevant + for k in keys: + if k in dict1 and k in dict2: + v1, v2 = dict1[k], dict2[k] + if isinstance(v1, dict) and isinstance(v2, dict): + v = dict(_generate_merged_dict(v1, v2)) + else: + # If one of the values is not a dict, you can't continue merging it. + # Value from second dict overrides one in first and we move on. + v = v2 + elif k in dict1: + v = dict1[k] + else: # k in dict2: + v = dict2[k] + + if isinstance(v, dict): + yield (k, {**v}) # shallow copy of dict + else: + yield (k, v) -def validate_params(params: ParamsDict): - """validate a parameter dictionary +# Cell - params: the parameter dictionary. This dictionary should be a possibly - nested dictionary of floats. - """ - if not params: - return - - is_dict_dict = all(isinstance(v, dict) for v in params.values()) - if not is_dict_dict: - for k, v in params.items(): - msg = f"Wrong parameter dictionary format. Should be a (possibly nested) " - msg += f"dictionary of floats or float arrays. Got: {k}: {v}{type(v)}" - assert ( - isinstance(v, float) - or (isinstance(v, jnp.ndarray) and v.dtype == jnp.float32) - or (isinstance(v, np.ndarray) and v.dtype == np.float32) - ), msg +def mode_combinations( + modes: Iterable[str], cross: bool = False +) -> Tuple[Tuple[str, str], ...]: + """create mode combinations for a collection of given modes""" + if cross: + mode_combinations = natsorted((m1, m2) for m1 in modes for m2 in modes) + else: + mode_combinations = natsorted((m, m) for m in modes) + return tuple(mode_combinations) + +# Cell +def reciprocal(sdict: SDict) -> SDict: + """Make an SDict reciprocal""" + if is_sdict(sdict): + return { + **{(p1, p2): v for (p1, p2), v in sdict.items()}, + **{(p2, p1): v for (p1, p2), v in sdict.items()}, + } else: - for v in params.values(): - validate_params(v) + raise ValueError("sax.reciprocal is only valid for SDict types") +# Internal Cell -def copy_params(params: ParamsDict) -> ParamsDict: - """copy a parameter dictionary +@overload +def rename_params(model: ModelFactory, renamings: Dict[str, str]) -> ModelFactory: + ... - Args: - params: the parameter dictionary to copy - Returns: - the copied parameter dictionary +@overload +def rename_params(model: Model, renamings: Dict[str, str]) -> Model: + ... - Note: - this copy function works recursively on all subdictionaries of the params - dictionary but does NOT copy any non-dictionary values. - """ - validate_params(params) - params = {**params} - if all(isinstance(v, dict) for v in params.values()): - return {k: copy_params(params[k]) for k in params} - return params +# Cell +def rename_params( + model: Union[Model, ModelFactory], renamings: Dict[str, str] +) -> Union[Model, ModelFactory]: + """rename the parameters of a `Model` or `ModelFactory` given a renamings mapping old parameter names to new.""" + reversed_renamings = {v: k for k, v in renamings.items()} + if len(reversed_renamings) < len(renamings): + raise ValueError("Multiple old names point to the same new name!") -def set_global_params(params: ParamsDict, **kwargs) -> ParamsDict: - """add or update the given keyword arguments to each (sub)dictionary of the - given params dictionary + if is_model_factory(model): + old_model_factory = cast(ModelFactory, model) + old_settings = get_settings(model) - Args: - params: the parameter dictionary to update with the given global parameters - **kwargs: the global parameters to update the parameter dictionary with. - These global parameters are often wavelength ('wl') or temperature ('T'). + @wraps(old_model_factory) + def new_model_factory(**settings): + old_settings = { + reversed_renamings.get(k, k): v for k, v in settings.items() + } + model = old_model_factory(**old_settings) + return rename_params(model, renamings) - Returns: - The modified dictionary. + new_settings = {renamings.get(k, k): v for k, v in old_settings.items()} + _replace_kwargs(new_model_factory, **new_settings) - Note: - This operation NEVER updates the given params dictionary inplace. + return new_model_factory - Example: - This is how to change the wavelength to 1600nm for each component in - the nested parameter dictionary:: + elif is_model(model): + old_model = cast(Model, model) + old_settings = get_settings(model) - params = set_global_params(params, wl=1.6e-6) - """ - validate_params(params) - params = copy_params(params) - if all(isinstance(v, dict) for v in params.values()): - return {k: set_global_params(params[k], **kwargs) for k in params} - params.update(kwargs) - validate_params(params) - return params + @wraps(old_model) + def new_model(**settings): + old_settings = { + reversed_renamings.get(k, k): v for k, v in settings.items() + } + return old_model(**old_settings) + new_settings = {renamings.get(k, k): v for k, v in old_settings.items()} + _replace_kwargs(new_model, **new_settings) -def get_ports(model: ModelDict) -> Tuple[str, ...]: - """get port names of the model + return new_model - Args: - model: the model dictionary to get the port names from - """ - ports: Dict[str, Any] = {} - for key in model: - if isinstance(key, str): - continue - p1, p2 = key - ports[p1] = None - ports[p2] = None - return tuple(p for p in ports) + else: + raise ValueError( + "rename_params should be used to decorate a Model or ModelFactory." + ) +def _replace_kwargs(func: Callable, **kwargs: ComplexFloat): + """Change the kwargs signature of a function""" + sig = inspect.signature(func) + settings = [ + inspect.Parameter(k, inspect.Parameter.KEYWORD_ONLY, default=v) + for k, v in kwargs.items() + ] + func.__signature__ = sig.replace(parameters=settings) -def rename_ports( - model: ModelDict, ports: Union[Dict[str, str], Tuple[str]] -) -> ModelDict: - """rename the ports of a model +# Internal Cell - Args: - model: the model dictionary to rename the ports for - ports: a port mapping (dictionary) with keys the old names and values - the new names. - """ - original_ports = get_ports(model) - assert len(ports) == len(original_ports) - if not isinstance(ports, dict): - assert len(ports) == len(set(ports)) - ports = {original_ports[i]: port for i, port in enumerate(ports)} - new_model: ModelDict = {} - for key in model: - if isinstance(key, str): - value = model[key] - if isinstance(value, dict): - value = {**value} - new_model[key] = value - else: - p1, p2 = key - new_model[ports[p1], ports[p2]] = model[p1, p2] - return new_model +@overload +def rename_ports(S: SDict, renamings: Dict[str, str]) -> SDict: + ... -def zero(params: ParamsDict) -> float: - """the zero model function. +@overload +def rename_ports(S: SCoo, renamings: Dict[str, str]) -> SCoo: + ... - Args: - params: the model parameters dictionary - Returns: - This function always returns zero. - """ - return 0.0 +@overload +def rename_ports(S: SDense, renamings: Dict[str, str]) -> SDense: + ... -def cartesian_product(*arrays) -> jnp.ndarray: - """calculate the n-dimensional cartesian product, i.e. create all - possible combinations of all elements in a given collection of arrays. +@overload +def rename_ports(S: Model, renamings: Dict[str, str]) -> Model: + ... - Args: - *arrays: the arrays to calculate the cartesian product for - Returns: - the cartesian product. +@overload +def rename_ports(S: ModelFactory, renamings: Dict[str, str]) -> ModelFactory: + ... + +# Cell + +def rename_ports( + S: Union[SType, Model, ModelFactory], renamings: Dict[str, str] +) -> Union[SType, Model, ModelFactory]: + """rename the ports of an `SDict`, `Model` or `ModelFactory` given a renamings mapping old port names to new.""" + if is_scoo(S): + Si, Sj, Sx, ports_map = cast(SCoo, S) + ports_map = {renamings[p]: i for p, i in ports_map.items()} + return Si, Sj, Sx, ports_map + elif is_sdense(S): + Sx, ports_map = cast(SDense, S) + ports_map = {renamings[p]: i for p, i in ports_map.items()} + return Sx, ports_map + elif is_sdict(S): + sdict = cast(SDict, S) + original_ports = get_ports(sdict) + assert len(renamings) == len(original_ports) + return {(renamings[p1], renamings[p2]): v for (p1, p2), v in sdict.items()} + elif is_model(S): + old_model = cast(Model, S) + + @wraps(old_model) + def new_model(**settings) -> SType: + return rename_ports(old_model(**settings), renamings) + + return new_model + elif is_model_factory(S): + old_model_factory = cast(ModelFactory, S) + + @wraps(old_model_factory) + def new_model_factory(**settings) -> Callable[..., SType]: + return rename_ports(old_model_factory(**settings), renamings) + + return new_model_factory + else: + raise ValueError("Cannot rename ports for type {type(S)}") + +# Cell + +def update_settings( + settings: Settings, *compnames: str, **kwargs: ComplexFloat +) -> Settings: + """update a nested settings dictionary""" + _settings = {} + if not compnames: + for k, v in settings.items(): + if isinstance(v, dict): + _settings[k] = update_settings(v, **kwargs) + else: + if k in kwargs: + _settings[k] = try_float(kwargs[k]) + else: + _settings[k] = try_float(v) + else: + for k, v in settings.items(): + if isinstance(v, dict): + if k == compnames[0]: + _settings[k] = update_settings(v, *compnames[1:], **kwargs) + else: + _settings[k] = v + else: + _settings[k] = try_float(v) + return _settings + +# Cell +def validate_not_mixedmode(S: SType): + """validate that an stype is not 'mixed mode' (i.e. invalid) + + Args: + S: the stype to validate """ - ixarrays = jnp.ix_(*arrays) - barrays = jnp.broadcast_arrays(*ixarrays) - sarrays = jnp.stack(barrays, -1) - product = sarrays.reshape(-1, sarrays.shape[-1]) - return product + + if is_mixedmode(S): # mixed mode + raise ValueError( + "Given SType is neither multimode or singlemode. Please check the port " + "names: they should either ALL contain the '@' separator (multimode) " + "or NONE should contain the '@' separator (singlemode)." + ) + +# Cell +def validate_multimode(S: SType, modes=("te", "tm")) -> None: + """validate that an stype is multimode and that the given modes are present.""" + try: + current_modes = set(p.split("@")[1] for p in get_ports(S)) + except IndexError: + raise ValueError("The given stype is not multimode.") + for mode in modes: + if mode not in current_modes: + raise ValueError( + f"Could not find mode '{mode}' in one of the multimode models." + ) + +# Cell +def validate_sdict(sdict: Any) -> None: + """Validate an `SDict`""" + + if not isinstance(sdict, dict): + raise ValueError("An SDict should be a dictionary.") + for ports in sdict: + if not isinstance(ports, tuple) and not len(ports) == 2: + raise ValueError(f"SDict keys should be length-2 tuples. Got {ports}") + p1, p2 = ports + if not isinstance(p1, str) or not isinstance(p2, str): + raise ValueError( + f"SDict ports should be strings. Got {ports} " + f"({type(ports[0])}, {type(ports[1])})" + ) + +# Cell + +def get_inputs_outputs(ports: Tuple[str, ...]): + inputs = tuple(p for p in ports if p.lower().startswith("in")) + outputs = tuple(p for p in ports if not p.lower().startswith("in")) + if not inputs: + inputs = tuple(p for p in ports if not p.lower().startswith("out")) + outputs = tuple(p for p in ports if p.lower().startswith("out")) + return inputs, outputs \ No newline at end of file diff --git a/settings.ini b/settings.ini new file mode 100644 index 0000000..51bed36 --- /dev/null +++ b/settings.ini @@ -0,0 +1,30 @@ +[DEFAULT] +host = github +lib_name = sax +repo_name = sax +user = flaport +description = S + Autograd + XLA +keywords = s-parameters jax xla tensorflow photonics optics +author = Floris Laporte +author_email = floris.laporte@gmail.com +copyright = Floris Laporte (Apache 2.0) +branch = master +version = 0.6.0 +min_python = 3.8 +audience = Developers +language = English +custom_sidebar = True +license = apache2 +status = 2 +nbs_path = . +doc_path = docs +recursive = True +doc_host = https://flaport.github.io +doc_baseurl = /sax/ +git_url = https://github.com/flaport/sax/ +lib_path = sax +title = sax +monospace_docstrings = False +jekyll_styles = note,warning,tip,important +#tst_flags = +#cell_spacing = diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..d77a191 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,35 @@ +[metadata] +name = sax +version = 0.6.0 +description = Autograd and XLA for S-parameters +author = Floris Laporte +author_email = floris.laporte@gmail.com +long_description = file: README.md +long_description_content_type = text/markdown +classifiers = + Development Status :: 3 - Alpha + Operating System :: POSIX :: Linux + +[options] +packages = + sax + sax.nn + sax.backends +install_requires = + fastcore + flax + h5py + jax + jaxlib + natsort + networkx + numpy + pyyaml + tqdm + +[options.package_data] +* = + settings.ini + LICENSE + CONTRIBUTING.md + README.md diff --git a/setup.py b/setup.py deleted file mode 100644 index 35afc54..0000000 --- a/setup.py +++ /dev/null @@ -1,30 +0,0 @@ -# imports -import sys -import sax -import warnings -import setuptools -import subprocess - -# install sax: -setuptools.setup( - name=sax.__name__, - version=sax.__version__, - description=sax.__doc__, - long_description=open("README.md").read(), - author=sax.__author__, - author_email="floris.laporte@gmail.com", - long_description_content_type="text/markdown", - packages=setuptools.find_packages(), - classifiers=[ - "Development Status :: 3 - Alpha", - "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Physics", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "License :: OSI Approved :: Apache Software License" - ], -)