Skip to content

Commit

Permalink
Merge pull request #13 from anh-tong/pure_jax
Browse files Browse the repository at this point in the history
update workflow
  • Loading branch information
anh-tong authored Sep 8, 2022
2 parents a422340 + e6c6921 commit 631db7e
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pypi-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Build package
run: python -m build
- name: Publish package
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
if: github.event_name == 'push' && startsWith(github.ref, 'refs/v')
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: anh-tong
Expand Down
64 changes: 62 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,61 @@ This implementation is inspired by [patrick-kidger/signatory](https://github.com

## Examples

<!-- TODO: example with equinox -->
Basic usage

```python
import jax
import jax.random as jrandom

from signax.signature import signature

key = jrandom.PRNGKey(0)
depth = 3

# compute signature for a single path
length = 100
dim = 20
path = jrandom.normal(shape=(length, dim), key=key)
output = signature(path, depth)
# output is a list of array presenting tensor algebra

# compute signature for batches (multiple) of paths
# this is done via `jax.vmap`
batch_size = 20
path = jrandom.normal(shape=(batch_size, length, dim), key=key)
output = jax.vmap(lambda x: signature(x, depth))(path)
```

Integrate with [equinox](https://github.com/patrick-kidger/equinox) library

```python
import equinox as eqx
import jax.random as jrandom

from signax.module import SignatureTransform

# random generator key
key = jrandom.PRNGKey(0)
mlp_key, data_key = jrandom.split(key)

depth=3
length, dim = 100, 3

# we signature transfrom
signature_layer = SignatureTransform(depth=depth)
# finally, getting output via a neural network
last_layer = eqx.nn.MLP(depth=1,
in_size=3 + 3**2 + 3**3,
width_size=4,
out_size=1,
key=mlp_key)

model = eqx.nn.Sequential(layers=[signature_layer, last_layer])
x = jrandom.normal(shape=(length, dim), key=data_key)
output = model(x)
```

Also, check notebooks in `examples` folder for some experiments of [deep signature transforms paper](https://arxiv.org/abs/1905.08494).
## Installation

```
Expand All @@ -33,4 +86,11 @@ Signatory allows dividing a path into chunks and performing asynchronous multith

Because JAX make use of just-in-time (JIT) compilations with XLA, this implementation can be reasonably fast.

We observe that the performance of this implementation is similar to Signatory in CPU and slightly better in GPU. It could be because of the optimized operators of XLA in JAX. As mentioned in the paper, signatory is not fully optimized for CUDA but relies on LibTorch.
We observe that the performance of this implementation is similar to Signatory in CPU and slightly better in GPU. It could be because of the optimized operators of XLA in JAX. As mentioned in the paper, signatory is not fully optimized for CUDA but relies on LibTorch.

## Acknowledgement

This repo is based on
- [Signatory](https://github.com/patrick-kidger/signatory)
- [Deep-Signature-Transforms](https://github.com/patrick-kidger/Deep-Signature-Transforms)
- [Equinox](https://github.com/patrick-kidger/equinox)
30 changes: 19 additions & 11 deletions examples/compare.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark between Signatory and Signax\n",
"\n",
"This is just a rough comparison. "
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand All @@ -22,6 +31,13 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gradient computation"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand Down Expand Up @@ -50,15 +66,7 @@
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"outputs": [],
"source": [
"def func(x):\n",
" def _fn(x):\n",
Expand Down Expand Up @@ -89,7 +97,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"4.76 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"13.7 ms ± 709 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
Expand Down Expand Up @@ -143,7 +151,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"3.83 ms ± 225 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"10 ms ± 651 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
Expand Down
36 changes: 33 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
import pathlib

import setuptools


metadata = {"name": "signax", "version": "1.0", "author": "signax authors"}
HERE = pathlib.Path(__file__).resolve().parent

metadata = {"name": "signax", "version": "0.1.0", "author": "signax authors"}

python_requires = "~=3.7"
install_requires = ["jax>=0.3.10", "equinox"]

classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Financial and Insurance Industry",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Mathematics",
]

description = "Signax: Signature computation in JAX"
url = "https://github.com/anh-tong/signax"

with open(HERE / "README.md", "r") as f:
readme = f.read()

setuptools.setup(
name=metadata["name"],
version=metadata["version"],
author=metadata["author"],
packages=[metadata["name"]],
ext_package=metadata["name"],
maintainer=metadata["author"],
description=description,
long_description=readme,
long_description_content_type="text/markdown",
url=url,
classifiers=classifiers,
zip_safe=False,
python_requires=python_requires,
install_requires=install_requires,
packages=setuptools.find_packages(exclude=["examples", "test"]),
)

0 comments on commit 631db7e

Please sign in to comment.