You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am playing around with Espresso and XDiff, because I think it would be cool for some future version of LossFunctions.jl to generate all the losses, their derivatives, and their properties just based on defining the function (like @makeloss L(x) = 1/2 * x^2)
Taking the example of a simple popular version of the least squares loss L(x) = 1/2 * x^2, I found that rdiff changes the operation * to the broadcasted version .* which seems to prevent constant folding
Yes, this is expected. In both - Julia 0.5 and 0.6 operator .* is allowed on scalars and behaves the same way as *. The reason to use .* in the corresponding derivative rule is to support broadcasting for tensors. In short, when rdiff can't find tensor differentiation rule for some operation (e.g. log()), it looks for the corresponding rule for scalars instead an infers rule for tensors. Having .* instead of * in scalar rules is the simplest way to make scalar rules to work in both cases.
Note, that this may totally change once Julia 0.6 is out and XDiff is adapted to support new broadcasting rules.
It doesn't mean your example shouldn't work in the current version, though. simplify (as well as many other transformations here) work via a set of extensible rules, and you can easily define new ones:
julia> simplify(:(0.5 .* (2x)))
:(0.5 .* (2x))
julia> @simple_rule (0.5 .* (2 * _x)) _x # _x is a placeholder that matches anything
julia> simplify(:(0.5 .* (2x)))
:x
At some point I will add this rule to Espresso, but right now I'm a bit suspicious about .* and * since they play an important (and different!) roles in the differentiation process.
I am playing around with Espresso and XDiff, because I think it would be cool for some future version of
LossFunctions.jl
to generate all the losses, their derivatives, and their properties just based on defining the function (like@makeloss L(x) = 1/2 * x^2
)Taking the example of a simple popular version of the least squares loss
L(x) = 1/2 * x^2
, I found thatrdiff
changes the operation*
to the broadcasted version.*
which seems to prevent constant foldingIs this expected behaviour? How would you suggest to deal with such cases?
The text was updated successfully, but these errors were encountered: