From a362f9e97c96815bd61732b5265d288400e0e0a5 Mon Sep 17 00:00:00 2001 From: Stephen Becker Date: Thu, 24 Oct 2013 16:53:20 -0400 Subject: [PATCH] Modifying smooth_logsumexp.m to allow scaling and to be numerically stable against over-flow errors. --- smooth_logsumexp.m | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/smooth_logsumexp.m b/smooth_logsumexp.m index ed2d716..693fcbc 100644 --- a/smooth_logsumexp.m +++ b/smooth_logsumexp.m @@ -1,19 +1,37 @@ -function op = smooth_logsumexp() +function op = smooth_logsumexp(sigma) % SMOOTH_LOGSUMEXP The function log(sum(exp(x))) % returns a smooth function to calculate % log( sum( exp(x) ) ) % +% SMOOTH_LOGSUMEXP( SIGMA ) is a scaled version +% that calclates sigma*log(sum(exp(x/sigma)), for sigma > 0. +% As sigma --> 0, this becomes a good approximation +% of max(x). +% The Lipschitz constant of the gradient is 1/sigma. +% By default, sigma = 1. +% % For a fancier version (with offsets), -% see also smooth_LogLLogistic.m +% see also smooth_logLLogistic.m + +if nargin < 1 || isempty(sigma), sigma = 1; end +op = @(x)smooth_logsumexp_impl(x,sigma); -op = @smooth_logsumexp_impl; +function [ v, g ] = smooth_logsumexp_impl( x, sigma ) -function [ v, g ] = smooth_logsumexp_impl( x ) -expx = exp(x); +% Even for moderate values of x/sigma, exp(x/sigma) +% will overflow before we have a chance to take +% its logarithm. So we subtract off the max value +% and treat it separately: + +c = max(x); +expx = exp((x-c)/sigma); sum_expx = sum(expx(:)); -v = log(sum_expx); +v = sigma*log(sum_expx) + c; + if nargout > 1, g = expx ./ sum_expx; + % (the factor of e^{-c} cancels from both the numerator + % and denominator) end % TFOCS v1.3 by Stephen Becker, Emmanuel Candes, and Michael Grant.