Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keras.ops.searchsorted #19922

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Add keras.ops.searchsorted #19922

wants to merge 7 commits into from

Conversation

LarsKue
Copy link
Contributor

@LarsKue LarsKue commented Jun 26, 2024

This is commonly used for spline transformations.

@codecov-commenter
Copy link

codecov-commenter commented Jun 26, 2024

Codecov Report

Attention: Patch coverage is 82.75862% with 5 lines in your changes missing coverage. Please review.

Project coverage is 78.96%. Comparing base (558d38c) to head (3f7fb5e).
Report is 9 commits behind head on master.

Files Patch % Lines
keras/src/ops/numpy.py 73.33% 3 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19922      +/-   ##
==========================================
- Coverage   79.02%   78.96%   -0.06%     
==========================================
  Files         499      499              
  Lines       46436    46523      +87     
  Branches     8548     8561      +13     
==========================================
+ Hits        36695    36738      +43     
- Misses       8015     8052      +37     
- Partials     1726     1733       +7     
Flag Coverage Δ
keras 78.82% <82.75%> (-0.06%) ⬇️
keras-jax 62.41% <58.62%> (-0.01%) ⬇️
keras-numpy 57.22% <65.51%> (+<0.01%) ⬆️
keras-tensorflow 63.60% <62.06%> (-0.04%) ⬇️
keras-torch 62.37% <62.06%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@LarsKue LarsKue marked this pull request as draft June 26, 2024 13:16
@LarsKue
Copy link
Contributor Author

LarsKue commented Jun 26, 2024

Converted to draft because I will add a test

@LarsKue LarsKue marked this pull request as ready for review June 26, 2024 13:24
@LarsKue LarsKue marked this pull request as draft June 26, 2024 13:41
@LarsKue
Copy link
Contributor Author

LarsKue commented Jun 26, 2024

I opted to try to maximize support for N-D searchsorted (because this is my use-case). However, numpy does not support it. JAX supports it by vmapping, which I implemented.

If you have better suggestions on how we can support N-D searchsorted, I would be happy to implement them.

The tests also need to be updated, still, because self.assertAllEqual does not support multi-dimensional tensors.

Nevertheless, I am marking this as ready for review now, so that you can give feedback. Thank you.

@LarsKue LarsKue marked this pull request as ready for review June 26, 2024 14:42
@fchollet
Copy link
Member

Thanks for the PR!

I opted to try to maximize support for N-D searchsorted (because this is my use-case). However, numpy does not support it. JAX supports it by vmapping, which I implemented.

Since we support vmapping APIs, we could simply not implement N-D support for this op. We should try to stay as close to NumPy as possible in order to minimize user surprise.

A TF test seems to be failing:

>           assert knp.all(knp.searchsorted(a, v) == expected)

keras/src/ops/numpy_test.py:3990: 
...
E     tensorflow.python.framework.errors_impl.InvalidArgumentError:
cannot compute Equal as input #1(zero-based) was expected to be a int32 tensor but is a int64 tensor

@LarsKue
Copy link
Contributor Author

LarsKue commented Jun 27, 2024

@fchollet Thank you for the review!

we could simply not implement N-D support for this op.

You make a good point. In that case, should we raise an error if the user passes an N-D sorted_sequence, or let the backend handle it if it is incompatible? Raising would make the function truly agnostic, but prevent using built-in N-D functionality for backends that do support it.

If we raise an error: Should we do this in keras.ops or in the respective keras.backend functions?

A TF test seems to be failing

We can drop the part that is failing if we only support 1-D.

@@ -3966,6 +3966,13 @@ def test_round(self):
self.assertAllClose(knp.round(x, decimals=-1), np.round(x, decimals=-1))
self.assertAllClose(knp.Round(decimals=-1)(x), np.round(x, decimals=-1))

def test_searchsorted(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add a test for static shape inference (on KerasTensors).

@fchollet
Copy link
Member

In that case, should we raise an error if the user passes an N-D sorted_sequence, or let the backend handle it if it is incompatible?

Better to do it in each backend function I think!

@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

None yet

4 participants