Skip to content

Commit

Permalink
softranks and softsort operators
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 266103604
  • Loading branch information
olivierteboul authored and copybara-github committed Aug 29, 2019
1 parent e853a0d commit 9f0b797
Show file tree
Hide file tree
Showing 8 changed files with 705 additions and 0 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ env:
- PROJECT="probabilistic_vqvae"
- PROJECT="psycholab"
- PROJECT="robust_loss"
- PROJECT="soft_sort"
- PROJECT="solver1d"
- PROJECT="state_of_sparsity"
- PROJECT="sufficient_input_subsets"
Expand Down
18 changes: 18 additions & 0 deletions soft_sort/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Differentiable Ranks and Sorting operators for Tensorflow

## Overview

This work was originally introduced in the paper in 2019:
Cuturi M., Teboul O., Vert JP: [Differentiable Sorting using Optimal Transport:
The Sinkhorn CDF and Quantile Operator](https://arxiv.org/pdf/1905.11885.pdf)

## License

Licensed under the
[Apache 2.0](https://github.com/google-research/google-research/blob/master/LICENSE)
License.

## Disclaimer

This is not an official Google product.

124 changes: 124 additions & 0 deletions soft_sort/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# coding=utf-8
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module defines the softranks and softsort operators."""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function

import tensorflow.compat.v2 as tf
from soft_sort import soft_quantilizer


DIRECTIONS = ('ASCENDING', 'DESCENDING')


def _preprocess(x, axis):
"""Reshapes the input data to make it rank 2 as required by SoftQuantilizer.
The SoftQuantilizer expects an input tensor of rank 2, where the first
dimension is the batch dimension and the soft sorting is applied on the second
one.
Args:
x: Tensor<float> of any dimension.
axis: (int) the axis to be turned into the second dimension.
Returns:
a Tensor<float>[batch, n] where n is the dimensions over the axis and batch
the product of all other dimensions
"""
dims = list(range(x.shape.rank))
dims[-1], dims[axis] = dims[axis], dims[-1]
z = tf.transpose(x, dims) if dims[axis] != dims[-1] else x
return tf.reshape(z, (-1, tf.shape(x)[axis]))


def _postprocess(x, shape, axis):
"""Applies the inverse transformation of _preprocess.
Args:
x: Tensor<float>[batch, n]
shape: TensorShape of the desired output.
axis: (int) the axis along which the original tensor was processed.
Returns:
A Tensor<float> with the shape given in argument.
"""
s = list(shape)
s[axis], s[-1] = s[-1], s[axis]
z = tf.reshape(x, s)

# Transpose to get back to the original shape
dims = list(range(shape.rank))
dims[-1], dims[axis] = dims[axis], dims[-1]
return tf.transpose(z, dims) if dims[axis] != dims[-1] else z


@tf.function
def softsort(x, direction='ASCENDING', axis=-1, **kwargs):
"""Applies the softsort operator on input tensor x.
This operator acts as differentiable alternative to tf.sort.
Args:
x: the input tensor. It can be either of shape [batch, n] or [n].
direction: the direction 'ASCENDING' or 'DESCENDING'
axis: the axis on which to operate the sort.
**kwargs: see SoftQuantilizer for possible parameters.
Returns:
A tensor of the same shape as the input.
"""
if direction not in DIRECTIONS:
raise ValueError('`direction` should be one of {}'.format(DIRECTIONS))

z = _preprocess(x, axis)
descending = (direction == 'DESCENDING')
sorter = soft_quantilizer.SoftQuantilizer(z, descending=descending, **kwargs)
return _postprocess(sorter.softsort, x.shape, axis)


@tf.function
def softranks(x, direction='ASCENDING', axis=-1, zero_based=False, **kwargs):
"""A differentiable argsort-like operator that returns directly the ranks.
Note that it behaves as the 'inverse' of the argsort operator since it returns
soft ranks, i.e. real numbers that play the role of indices and quantify the
relative standing (among all n entries) of each entry of x.
Args:
x: Tensor<float> of any shape.
direction: (str) either 'ASCENDING' or 'DESCENDING', as in tf.sort.
axis: (int) the axis along which to sort, as in tf.sort.
zero_based: (bool) to return values in [0, n-1] or in [1, n].
**kwargs: see SoftQuantilizer for possible parameters.
Returns:
A Tensor<float> of the same shape as the input containing the soft ranks.
"""
if direction not in DIRECTIONS:
raise ValueError('`direction` should be one of {}'.format(DIRECTIONS))

descending = (direction == 'DESCENDING')
z = _preprocess(x, axis)
sorter = soft_quantilizer.SoftQuantilizer(z, descending=descending, **kwargs)
ranks = sorter.softcdf * tf.cast(tf.shape(z)[1], dtype=x.dtype)
if zero_based:
ranks -= tf.cast(1.0, dtype=x.dtype)

return _postprocess(ranks, x.shape, axis)
142 changes: 142 additions & 0 deletions soft_sort/ops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# coding=utf-8
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the softsort and softranks operators."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v2 as tf

from soft_sort import ops


class OpsTest(parameterized.TestCase, tf.test.TestCase):

def setUp(self):
super(OpsTest, self).setUp()
tf.random.set_seed(0)
np.random.seed(seed=0)

def test_preprocess(self):
"""Tests that _preprocess prepares the tensor as expected."""
# Test preprocessing with input of dimension 1.
n = 10
x = tf.random.uniform((n,), dtype=tf.float64)
z = ops._preprocess(x, axis=-1)
self.assertEqual(z.shape.rank, 2)
self.assertEqual(z.shape, (1, n))
self.assertAllEqual(z[0], x)

# Test preprocessing with input of dimension 2.
x = tf.random.uniform((3, n), dtype=tf.float64)
z = ops._preprocess(x, axis=-1)
self.assertEqual(z.shape.rank, 2)
self.assertEqual(z.shape, x.shape)
self.assertAllEqual(z, x)

# Test preprocessing with input of dimension 2, preparing for axis 0
x = tf.random.uniform((3, n), dtype=tf.float64)
z = ops._preprocess(x, axis=0)
self.assertEqual(z.shape.rank, 2)
self.assertEqual(z.shape, (x.shape[1], x.shape[0]))
batch = 1
self.assertAllEqual(z[batch], x[:, batch])

# Test preprocessing with input of dimension > 2
shape = [4, 21, 7, 10]
x = tf.random.uniform(shape, dtype=tf.float64)
axis = 2
n = shape.pop(axis)
z = ops._preprocess(x, axis=axis)
self.assertEqual(z.shape.rank, 2)
self.assertEqual(z.shape, (np.prod(shape), n))

def test_postprocess(self):
"""Tests that _postprocess is the inverse of _preprocess."""
shape = (4, 21, 7, 10)
for i in range(1, len(shape)):
x = tf.random.uniform(shape[:i])
for axis in range(x.shape.rank):
z = ops._postprocess(ops._preprocess(x, axis), x.shape, axis)
self.assertAllEqual(x, z)

def test_softsort(self):
# Tests that the values are sorted (epsilon being small enough)
x = tf.constant([3, 4, 1, 5, 2], dtype=tf.float32)
eps = 1e-3
sinkhorn_threshold = 1e-3
values = ops.softsort(x, direction='ASCENDING',
epsilon=eps, sinkhorn_threshold=sinkhorn_threshold)
self.assertEqual(values.shape, x.shape)
self.assertAllGreater(np.diff(values), 0.0)

# Since epsilon is not very small, we cannot expect to retrieve the sorted
# values with high precision.
tolerance = 1e-1
self.assertAllClose(tf.sort(x), values, tolerance, tolerance)

# Test descending sort.
direction = 'DESCENDING'
values = ops.softsort(x, direction=direction,
epsilon=eps, sinkhorn_threshold=sinkhorn_threshold)
self.assertEqual(values.shape, x.shape)
self.assertAllLess(np.diff(values), 0.0)
self.assertAllClose(
tf.sort(x, direction=direction), values, tolerance, tolerance)

@parameterized.named_parameters(
('axis', 0, 'direction', 'ASCENDING'),
('axis', 1, 'direction', 'ASCENDING'),
('axis', 2, 'direction', 'ASCENDING'),
('axis', 0, 'direction', 'DESCENDING'),
('axis', 1, 'direction', 'DESCENDING'),
('axis', 2, 'direction', 'DESCENDING'))
def softranks(self, axis, direction):
"""Test ops.softranks for a given shape, axis and direction."""
shape = tf.TensorShape((3, 8, 6))
n = shape[axis]
p = int(np.prod(shape) / shape[axis])

# Build a target tensor of ranks, of rank 2.
# Those targets are zero based.
target = tf.constant(
[np.random.permutation(n) for _ in range(p)], dtype=tf.float32)

# Turn it into a tensor of desired shape.
target = ops._postprocess(target, shape, axis)

# Apply a monotonic transformation to turn ranks into values
sign = 2 * float(direction == 'ASCENDING') - 1
x = sign * (1.2 * target - 0.4)

# The softranks of x along the axis should be close to the target.
eps = 1e-3
sinkhorn_threshold = 1e-3
tolerance = 0.5
for zero_based in [False, True]:
ranks = ops.softranks(
x, direction=direction, axis=axis, zero_based=zero_based,
epsilon=eps, sinkhorn_threshold=sinkhorn_threshold)
targets = target + 1 if not zero_based else target
self.assertAllClose(ranks, targets, tolerance, tolerance)


if __name__ == '__main__':
tf.enable_v2_behavior()
tf.test.main()
3 changes: 3 additions & 0 deletions soft_sort/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gin-config>=0.1.4
numpy>=1.16.2
tensorflow>=2.0.0-beta1
24 changes: 24 additions & 0 deletions soft_sort/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/bin/bash
set -e
set -x

virtualenv -p python3 .
source ./bin/activate

pip install -r soft_sort/requirements.txt
python -m soft_sort.ops_test

Loading

0 comments on commit 9f0b797

Please sign in to comment.