diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index 1b33e64..c0237b5 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -25,7 +25,7 @@ jobs: - name: Build package run: python -m build - name: Publish package - if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + if: github.event_name == 'push' && startsWith(github.ref, 'refs/v') uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 with: user: anh-tong diff --git a/README.md b/README.md index 230124b..c0b3dd0 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,61 @@ This implementation is inspired by [patrick-kidger/signatory](https://github.com ## Examples - +Basic usage +```python +import jax +import jax.random as jrandom + +from signax.signature import signature + +key = jrandom.PRNGKey(0) +depth = 3 + +# compute signature for a single path +length = 100 +dim = 20 +path = jrandom.normal(shape=(length, dim), key=key) +output = signature(path, depth) +# output is a list of array presenting tensor algebra + +# compute signature for batches (multiple) of paths +# this is done via `jax.vmap` +batch_size = 20 +path = jrandom.normal(shape=(batch_size, length, dim), key=key) +output = jax.vmap(lambda x: signature(x, depth))(path) +``` + +Integrate with [equinox](https://github.com/patrick-kidger/equinox) library + +```python +import equinox as eqx +import jax.random as jrandom + +from signax.module import SignatureTransform + +# random generator key +key = jrandom.PRNGKey(0) +mlp_key, data_key = jrandom.split(key) + +depth=3 +length, dim = 100, 3 + +# we signature transfrom +signature_layer = SignatureTransform(depth=depth) +# finally, getting output via a neural network +last_layer = eqx.nn.MLP(depth=1, + in_size=3 + 3**2 + 3**3, + width_size=4, + out_size=1, + key=mlp_key) + +model = eqx.nn.Sequential(layers=[signature_layer, last_layer]) +x = jrandom.normal(shape=(length, dim), key=data_key) +output = model(x) +``` + +Also, check notebooks in `examples` folder for some experiments of [deep signature transforms paper](https://arxiv.org/abs/1905.08494). ## Installation ``` @@ -33,4 +86,11 @@ Signatory allows dividing a path into chunks and performing asynchronous multith Because JAX make use of just-in-time (JIT) compilations with XLA, this implementation can be reasonably fast. -We observe that the performance of this implementation is similar to Signatory in CPU and slightly better in GPU. It could be because of the optimized operators of XLA in JAX. As mentioned in the paper, signatory is not fully optimized for CUDA but relies on LibTorch. \ No newline at end of file +We observe that the performance of this implementation is similar to Signatory in CPU and slightly better in GPU. It could be because of the optimized operators of XLA in JAX. As mentioned in the paper, signatory is not fully optimized for CUDA but relies on LibTorch. + +## Acknowledgement + +This repo is based on +- [Signatory](https://github.com/patrick-kidger/signatory) +- [Deep-Signature-Transforms](https://github.com/patrick-kidger/Deep-Signature-Transforms) +- [Equinox](https://github.com/patrick-kidger/equinox) \ No newline at end of file diff --git a/examples/compare.ipynb b/examples/compare.ipynb index e928a9e..103ee1e 100644 --- a/examples/compare.ipynb +++ b/examples/compare.ipynb @@ -1,5 +1,14 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Benchmark between Signatory and Signax\n", + "\n", + "This is just a rough comparison. " + ] + }, { "cell_type": "code", "execution_count": 1, @@ -22,6 +31,13 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Gradient computation" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -50,15 +66,7 @@ "name": "#%%\n" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], + "outputs": [], "source": [ "def func(x):\n", " def _fn(x):\n", @@ -89,7 +97,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "4.76 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "13.7 ms ± 709 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -143,7 +151,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.83 ms ± 225 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "10 ms ± 651 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], diff --git a/setup.py b/setup.py index 8a63034..c120a52 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,47 @@ +import pathlib + import setuptools -metadata = {"name": "signax", "version": "1.0", "author": "signax authors"} +HERE = pathlib.Path(__file__).resolve().parent + +metadata = {"name": "signax", "version": "0.1.0", "author": "signax authors"} python_requires = "~=3.7" install_requires = ["jax>=0.3.10", "equinox"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Financial and Insurance Industry", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Information Analysis", + "Topic :: Scientific/Engineering :: Mathematics", +] + +description = "Signax: Signature computation in JAX" +url = "https://github.com/anh-tong/signax" + +with open(HERE / "README.md", "r") as f: + readme = f.read() + setuptools.setup( name=metadata["name"], version=metadata["version"], author=metadata["author"], - packages=[metadata["name"]], - ext_package=metadata["name"], + maintainer=metadata["author"], + description=description, + long_description=readme, + long_description_content_type="text/markdown", + url=url, + classifiers=classifiers, + zip_safe=False, python_requires=python_requires, install_requires=install_requires, + packages=setuptools.find_packages(exclude=["examples", "test"]), )