Skip to content

Commit

Permalink
fix calling, add to README, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joshday committed Aug 2, 2018
1 parent d693d56 commit 79fc414
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ value(p, x[1], s[1]) # evaluate on element, scaled by scalar
value(p, x, s[1]) # evaluate on array, scaled by scalar
value(p, x, s) # evaluate on array, element-wise scaling

# value via calling the Penalty object
p = L1Penalty()
p([1,2,3])

# derivatives and gradients
deriv(p, x[1]) # derivative
deriv(p, x[1], s[1]) # scaled derivative
Expand Down
2 changes: 1 addition & 1 deletion src/PenaltyFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ include("elementpenalty.jl")
include("arraypenalty.jl")

# Make Penalties Callable
for T in filter(isconcretetype, union(subtypes(ElementPenalty),
for T in filter(!isabstracttype, union(subtypes(ElementPenalty),
subtypes(ProxableElementPenalty),
subtypes(ArrayPenalty)))
@eval (pen::$T)(args...) = value(pen, args...)
Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ end
@test r(@inferred(deriv(p, θ))) r(v2)
@test r.(value.(Ref(p), fill(θ, 5))) r.(fill(v1, 5))
@test r.(deriv.(Ref(p), fill(θ, 5))) r.(fill(v2, 5))
@test value(p, θ) == p(θ)
@test value(p, θ, s) == p(θ, s)
if isa(p, P.ProxableElementPenalty)
@test r(@inferred(prox(p, θ, s))) r(v3)
@test r.(prox.(Ref(p), fill(θ, 5), Ref(s))) r.(fill(v3, 5))
Expand Down Expand Up @@ -195,6 +197,7 @@ end
s = .05
# FIXME: @inference broken. seems like a type instability
@test value(p, Θ) sum(svd(Θ).S)
@test value(p, Θ) == p(Θ)
@test value(p, Θ, s) s * sum(svd(Θ).S)
prox!(p, Θ, s)
end
Expand All @@ -203,6 +206,7 @@ end
Θ = randn(10)
s = .05
@test @inferred(value(p, Θ)) norm(Θ)
@test value(p, Θ) == p(Θ)
prox!(p, Θ, s)

Θ = .01 * ones(10)
Expand All @@ -212,6 +216,7 @@ end
C = randn(5, 10)
p = MahalanobisPenalty(C)
θ = rand(10)
@test value(p, θ) == p(θ)
s = .05
@test @inferred(value(p, θ)) 0.5 * dot(C * θ, C * θ)
prox!(p, θ, s)
Expand All @@ -220,6 +225,7 @@ end
p = GroupLassoPenalty()
s = scaled(p, .1)
Θ = randn(10)
@test value(p, Θ) == p(Θ)
@test @inferred(value(p, Θ, .1)) value(s, Θ)

Θ2 = copy(Θ)
Expand Down

0 comments on commit 79fc414

Please sign in to comment.