forked from google-research/google-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
8 changed files
with
705 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Differentiable Ranks and Sorting operators for Tensorflow | ||
|
||
## Overview | ||
|
||
This work was originally introduced in the paper in 2019: | ||
Cuturi M., Teboul O., Vert JP: [Differentiable Sorting using Optimal Transport: | ||
The Sinkhorn CDF and Quantile Operator](https://arxiv.org/pdf/1905.11885.pdf) | ||
|
||
## License | ||
|
||
Licensed under the | ||
[Apache 2.0](https://github.com/google-research/google-research/blob/master/LICENSE) | ||
License. | ||
|
||
## Disclaimer | ||
|
||
This is not an official Google product. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""This module defines the softranks and softsort operators.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
|
||
from __future__ import print_function | ||
|
||
import tensorflow.compat.v2 as tf | ||
from soft_sort import soft_quantilizer | ||
|
||
|
||
DIRECTIONS = ('ASCENDING', 'DESCENDING') | ||
|
||
|
||
def _preprocess(x, axis): | ||
"""Reshapes the input data to make it rank 2 as required by SoftQuantilizer. | ||
The SoftQuantilizer expects an input tensor of rank 2, where the first | ||
dimension is the batch dimension and the soft sorting is applied on the second | ||
one. | ||
Args: | ||
x: Tensor<float> of any dimension. | ||
axis: (int) the axis to be turned into the second dimension. | ||
Returns: | ||
a Tensor<float>[batch, n] where n is the dimensions over the axis and batch | ||
the product of all other dimensions | ||
""" | ||
dims = list(range(x.shape.rank)) | ||
dims[-1], dims[axis] = dims[axis], dims[-1] | ||
z = tf.transpose(x, dims) if dims[axis] != dims[-1] else x | ||
return tf.reshape(z, (-1, tf.shape(x)[axis])) | ||
|
||
|
||
def _postprocess(x, shape, axis): | ||
"""Applies the inverse transformation of _preprocess. | ||
Args: | ||
x: Tensor<float>[batch, n] | ||
shape: TensorShape of the desired output. | ||
axis: (int) the axis along which the original tensor was processed. | ||
Returns: | ||
A Tensor<float> with the shape given in argument. | ||
""" | ||
s = list(shape) | ||
s[axis], s[-1] = s[-1], s[axis] | ||
z = tf.reshape(x, s) | ||
|
||
# Transpose to get back to the original shape | ||
dims = list(range(shape.rank)) | ||
dims[-1], dims[axis] = dims[axis], dims[-1] | ||
return tf.transpose(z, dims) if dims[axis] != dims[-1] else z | ||
|
||
|
||
@tf.function | ||
def softsort(x, direction='ASCENDING', axis=-1, **kwargs): | ||
"""Applies the softsort operator on input tensor x. | ||
This operator acts as differentiable alternative to tf.sort. | ||
Args: | ||
x: the input tensor. It can be either of shape [batch, n] or [n]. | ||
direction: the direction 'ASCENDING' or 'DESCENDING' | ||
axis: the axis on which to operate the sort. | ||
**kwargs: see SoftQuantilizer for possible parameters. | ||
Returns: | ||
A tensor of the same shape as the input. | ||
""" | ||
if direction not in DIRECTIONS: | ||
raise ValueError('`direction` should be one of {}'.format(DIRECTIONS)) | ||
|
||
z = _preprocess(x, axis) | ||
descending = (direction == 'DESCENDING') | ||
sorter = soft_quantilizer.SoftQuantilizer(z, descending=descending, **kwargs) | ||
return _postprocess(sorter.softsort, x.shape, axis) | ||
|
||
|
||
@tf.function | ||
def softranks(x, direction='ASCENDING', axis=-1, zero_based=False, **kwargs): | ||
"""A differentiable argsort-like operator that returns directly the ranks. | ||
Note that it behaves as the 'inverse' of the argsort operator since it returns | ||
soft ranks, i.e. real numbers that play the role of indices and quantify the | ||
relative standing (among all n entries) of each entry of x. | ||
Args: | ||
x: Tensor<float> of any shape. | ||
direction: (str) either 'ASCENDING' or 'DESCENDING', as in tf.sort. | ||
axis: (int) the axis along which to sort, as in tf.sort. | ||
zero_based: (bool) to return values in [0, n-1] or in [1, n]. | ||
**kwargs: see SoftQuantilizer for possible parameters. | ||
Returns: | ||
A Tensor<float> of the same shape as the input containing the soft ranks. | ||
""" | ||
if direction not in DIRECTIONS: | ||
raise ValueError('`direction` should be one of {}'.format(DIRECTIONS)) | ||
|
||
descending = (direction == 'DESCENDING') | ||
z = _preprocess(x, axis) | ||
sorter = soft_quantilizer.SoftQuantilizer(z, descending=descending, **kwargs) | ||
ranks = sorter.softcdf * tf.cast(tf.shape(z)[1], dtype=x.dtype) | ||
if zero_based: | ||
ranks -= tf.cast(1.0, dtype=x.dtype) | ||
|
||
return _postprocess(ranks, x.shape, axis) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for the softsort and softranks operators.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from absl.testing import parameterized | ||
import numpy as np | ||
import tensorflow.compat.v2 as tf | ||
|
||
from soft_sort import ops | ||
|
||
|
||
class OpsTest(parameterized.TestCase, tf.test.TestCase): | ||
|
||
def setUp(self): | ||
super(OpsTest, self).setUp() | ||
tf.random.set_seed(0) | ||
np.random.seed(seed=0) | ||
|
||
def test_preprocess(self): | ||
"""Tests that _preprocess prepares the tensor as expected.""" | ||
# Test preprocessing with input of dimension 1. | ||
n = 10 | ||
x = tf.random.uniform((n,), dtype=tf.float64) | ||
z = ops._preprocess(x, axis=-1) | ||
self.assertEqual(z.shape.rank, 2) | ||
self.assertEqual(z.shape, (1, n)) | ||
self.assertAllEqual(z[0], x) | ||
|
||
# Test preprocessing with input of dimension 2. | ||
x = tf.random.uniform((3, n), dtype=tf.float64) | ||
z = ops._preprocess(x, axis=-1) | ||
self.assertEqual(z.shape.rank, 2) | ||
self.assertEqual(z.shape, x.shape) | ||
self.assertAllEqual(z, x) | ||
|
||
# Test preprocessing with input of dimension 2, preparing for axis 0 | ||
x = tf.random.uniform((3, n), dtype=tf.float64) | ||
z = ops._preprocess(x, axis=0) | ||
self.assertEqual(z.shape.rank, 2) | ||
self.assertEqual(z.shape, (x.shape[1], x.shape[0])) | ||
batch = 1 | ||
self.assertAllEqual(z[batch], x[:, batch]) | ||
|
||
# Test preprocessing with input of dimension > 2 | ||
shape = [4, 21, 7, 10] | ||
x = tf.random.uniform(shape, dtype=tf.float64) | ||
axis = 2 | ||
n = shape.pop(axis) | ||
z = ops._preprocess(x, axis=axis) | ||
self.assertEqual(z.shape.rank, 2) | ||
self.assertEqual(z.shape, (np.prod(shape), n)) | ||
|
||
def test_postprocess(self): | ||
"""Tests that _postprocess is the inverse of _preprocess.""" | ||
shape = (4, 21, 7, 10) | ||
for i in range(1, len(shape)): | ||
x = tf.random.uniform(shape[:i]) | ||
for axis in range(x.shape.rank): | ||
z = ops._postprocess(ops._preprocess(x, axis), x.shape, axis) | ||
self.assertAllEqual(x, z) | ||
|
||
def test_softsort(self): | ||
# Tests that the values are sorted (epsilon being small enough) | ||
x = tf.constant([3, 4, 1, 5, 2], dtype=tf.float32) | ||
eps = 1e-3 | ||
sinkhorn_threshold = 1e-3 | ||
values = ops.softsort(x, direction='ASCENDING', | ||
epsilon=eps, sinkhorn_threshold=sinkhorn_threshold) | ||
self.assertEqual(values.shape, x.shape) | ||
self.assertAllGreater(np.diff(values), 0.0) | ||
|
||
# Since epsilon is not very small, we cannot expect to retrieve the sorted | ||
# values with high precision. | ||
tolerance = 1e-1 | ||
self.assertAllClose(tf.sort(x), values, tolerance, tolerance) | ||
|
||
# Test descending sort. | ||
direction = 'DESCENDING' | ||
values = ops.softsort(x, direction=direction, | ||
epsilon=eps, sinkhorn_threshold=sinkhorn_threshold) | ||
self.assertEqual(values.shape, x.shape) | ||
self.assertAllLess(np.diff(values), 0.0) | ||
self.assertAllClose( | ||
tf.sort(x, direction=direction), values, tolerance, tolerance) | ||
|
||
@parameterized.named_parameters( | ||
('axis', 0, 'direction', 'ASCENDING'), | ||
('axis', 1, 'direction', 'ASCENDING'), | ||
('axis', 2, 'direction', 'ASCENDING'), | ||
('axis', 0, 'direction', 'DESCENDING'), | ||
('axis', 1, 'direction', 'DESCENDING'), | ||
('axis', 2, 'direction', 'DESCENDING')) | ||
def softranks(self, axis, direction): | ||
"""Test ops.softranks for a given shape, axis and direction.""" | ||
shape = tf.TensorShape((3, 8, 6)) | ||
n = shape[axis] | ||
p = int(np.prod(shape) / shape[axis]) | ||
|
||
# Build a target tensor of ranks, of rank 2. | ||
# Those targets are zero based. | ||
target = tf.constant( | ||
[np.random.permutation(n) for _ in range(p)], dtype=tf.float32) | ||
|
||
# Turn it into a tensor of desired shape. | ||
target = ops._postprocess(target, shape, axis) | ||
|
||
# Apply a monotonic transformation to turn ranks into values | ||
sign = 2 * float(direction == 'ASCENDING') - 1 | ||
x = sign * (1.2 * target - 0.4) | ||
|
||
# The softranks of x along the axis should be close to the target. | ||
eps = 1e-3 | ||
sinkhorn_threshold = 1e-3 | ||
tolerance = 0.5 | ||
for zero_based in [False, True]: | ||
ranks = ops.softranks( | ||
x, direction=direction, axis=axis, zero_based=zero_based, | ||
epsilon=eps, sinkhorn_threshold=sinkhorn_threshold) | ||
targets = target + 1 if not zero_based else target | ||
self.assertAllClose(ranks, targets, tolerance, tolerance) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.enable_v2_behavior() | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
gin-config>=0.1.4 | ||
numpy>=1.16.2 | ||
tensorflow>=2.0.0-beta1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright 2019 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
#!/bin/bash | ||
set -e | ||
set -x | ||
|
||
virtualenv -p python3 . | ||
source ./bin/activate | ||
|
||
pip install -r soft_sort/requirements.txt | ||
python -m soft_sort.ops_test | ||
|
Oops, something went wrong.