Skip to content

Strange behavior of jax.lax.scan #24926

Closed Answered by onnoeberhard
onnoeberhard asked this question in Q&A
Discussion options

You must be logged in to vote

Okay, it seems that the error is indeed due to the summation of many small floating point numbers. Replacing the naive addition by a Kahanen addition fixes the problem. I can only assume that the last example I show above (jax.lax.scan(lambda c, y: (c + z, None), 0., zs)) works because Jax does some clever compilation that turns it into a non-rolling sum.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by onnoeberhard
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant