Skip to content

Commit

Permalink
ENH: stats: add random variable infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Sep 24, 2023
1 parent c66a646 commit 07b7286
Show file tree
Hide file tree
Showing 10 changed files with 4,254 additions and 56 deletions.
6 changes: 0 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ commands:
command: |
echo $(git log -1 --pretty=%B) | tee gitlog.txt
echo ${CI_PULL_REQUEST//*pull\//} | tee merge.txt
if [[ $(cat merge.txt) != "" ]]; then
echo "Merging $(cat merge.txt)";
git remote add upstream https://github.com/scipy/scipy.git;
git pull --ff-only upstream "refs/pull/$(cat merge.txt)/merge";
git fetch upstream main;
fi
jobs:
# Build SciPy from source
Expand Down
155 changes: 115 additions & 40 deletions scipy/optimize/_zeros_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,8 +1450,8 @@ def _bracket_root_iv(func, a, b, min, max, factor, args, maxiter):
if not maxiter == maxiter_int or maxiter < 0:
raise ValueError(message)

if not np.all((min <= a) & (a < b) & (b <= max)):
raise ValueError('`min <= a < b <= max` must be True (elementwise).')
# if not np.all((min <= a) & (a < b) & (b <= max)):
# raise ValueError('`min <= a < b <= max` must be True (elementwise).')

return func, a, b, min, max, factor, args, maxiter

Expand Down Expand Up @@ -1486,9 +1486,11 @@ def _bracket_root(func, a, b=None, *, min=None, max=None, factor=None,
args : tuple, optional
Additional positional arguments to be passed to `func`. Must be arrays
broadcastable with `a`, `b`, `min`, and `max`. If the callable to be
differentiated requires arguments that are not broadcastable with the
other arrays, wrap that callable with `func` such that `func` accepts
bracketed requires arguments that are not broadcastable with these
arrays, wrap that callable with `func` such that `func` accepts
only `x` and broadcastable arrays.
maxiter : int, optional
The maximum number of iterations of the algorithm to perform.
Returns
-------
Expand All @@ -1499,46 +1501,52 @@ def _bracket_root(func, a, b=None, *, min=None, max=None, factor=None,
arrays of the same shape.
xl, xr : float
The lower and upper ends of the bracket.
The lower and upper ends of the bracket, if the algorithm
terminated successfully.
fl, fr : float
The function value at the lower and upper ends of the bracket.
nfev : int
The number of times the function was called to find the root.
The number of function evaluations required to find the bracket.
This is distinct from the number of times `func` is *called*
because the function may evaluated at multiple points in a single
call.
nit : int
The number of iterations of Chandrupatla's algorithm performed.
The number of iterations of the algorithm that were performed.
status : int
An integer representing the exit status of the algorithm.
``0`` : The algorithm produced a valid bracket.
``-1`` : The bracket expanded to the allowable limits without finding a bracket.
``-2`` : The maximum number of iterations was reached.
``-3`` : A non-finite value was encountered.
``-4`` : Iteration was terminated by `callback`.
``1`` : The algorithm is proceeding normally (in `callback` only).
``2`` : A bracket was found in the opposite search direction (in `callback` only).
- ``0`` : The algorithm produced a valid bracket.
- ``-1`` : The bracket expanded to the allowable limits without finding a bracket.
- ``-2`` : The maximum number of iterations was reached.
- ``-3`` : A non-finite value was encountered.
- ``-4`` : Iteration was terminated by `callback`.
- ``1`` : The algorithm is proceeding normally (in `callback` only).
- ``2`` : A bracket was found in the opposite search direction (in `callback` only).
success : bool
``True`` when the algorithm terminated successfully (status ``0``).
Notes
-----
-----s
This function generalizes an algorithm found in pieces throughout
`scipy.stats`. The strategy is to iteratively grow the bracket `(l, r)`
until ``func(l) < 0 < func(r)``.
until ``func(l) < 0 < func(r)``. The bracket grows to the left as follows.
- If `min` is not provided, the distance between `b` and `l` is iteratively
increased by `factor`.
- If `min` is provided, the distance between `min` and `l` is iteratively
decreased by `factor`. Note that this *increases* the bracket size.
decreased by `factor`. Note that this also *increases* the bracket size.
Growth of the bracket to the right is analogous.
Growth of the bracket in one direction stops when the endpoint is no longer
finite, the function value at the endpoint is no longer finite, or the
finite, the function value at the endpoint is no longer finite, the
endpoint reaches its limiting value (`min` or `max`). Iteration terminates
when the bracket stops growing in both directions, the bracket surrounds
the root, or a root is found (accidentally).
If multiple brackets are found, only the leftmost one is returned.
If two brackets are found - that is, a bracket is found on both sides in
the same iteration, the smaller of the two is returned.
If roots of the function are found, both `l` and `r` are set to the
leftmost root.
Expand All @@ -1554,14 +1562,27 @@ def _bracket_root(func, a, b=None, *, min=None, max=None, factor=None,

xs = (a, b)
temp = _scalar_optimization_initialize(func, xs, args)
xs, fs, args, shape, dtype = temp
xs, fs, args, shape, dtype = temp # line split for PEP8

# The approach is to treat the left and right searches as though they were
# (almost) totally independent one-sided bracket searches. (The interaction
# is considered when checking for termination and preparing the result
# object.)
# `x` is the "moving" end of the bracket
x = np.concatenate(xs)
f = np.concatenate(fs)
n = len(x) // 2

# `x_last` is the previous location of the moving end of the bracket. If
# the signs of `f` and `f_last` are different, `x` and `x_last` form a
# bracket.
x_last = np.concatenate((x[n:], x[:n]))
f_last = np.concatenate((f[n:], f[:n]))
# `x0` is the "fixed" end of the bracket.
x0 = x_last
# We don't need to retain the corresponding function value, since the
# fixed end of the bracket is only needed to compute the new value of the
# moving end; it is never returned.

min = np.broadcast_to(min, shape).astype(dtype, copy=False).ravel()
max = np.broadcast_to(max, shape).astype(dtype, copy=False).ravel()
Expand All @@ -1572,13 +1593,21 @@ def _bracket_root(func, a, b=None, *, min=None, max=None, factor=None,

active = np.arange(2*n)
args = [np.concatenate((arg, arg)) for arg in args]

# This is needed due to inner workings of `_scalar_optimization_loop`.
# We're abusing it a tiny bit.
shape = shape + (2,)

# `d` is for "distance".
# For searches without a limit, the distance between the fixed end of the
# bracket `x0` and the moving end `x` will grow by `factor` each iteration.
# For searches with a limit, the distance between the `limit` and moving
# end of the bracket `x` will shrink by `factor` each iteration.
i = np.isinf(limit)
ni = ~i
d = np.zeros_like(x)
d[i] = (x[i] - x0[i]) * factor[i]
d[ni] = (limit[ni] - x[ni]) / factor[ni]
d[i] = x[i] - x0[i]
d[ni] = limit[ni] - x[ni]

status = np.full_like(x, _EINPROGRESS, dtype=int) # in progress
nit, nfev = 0, 1 # one function evaluation per side performed above
Expand All @@ -1593,59 +1622,80 @@ def _bracket_root(func, a, b=None, *, min=None, max=None, factor=None,
('x_last', 'x_last'), ('f_last', 'f_last')]

def pre_func_eval(work):
work.x_last = work.x
work.f_last = work.f
i = np.isinf(work.limit)
# Initialize moving end of bracket
x = np.zeros_like(work.x)

x[i] = work.x0[i] + work.d[i]
# Unlimited brackets grow by `factor` by increasing distance from fixed
# end to moving end.
i = np.isinf(work.limit) # indices of unlimited brackets
work.d[i] *= work.factor[i]
x[i] = work.x0[i] + work.d[i]

ni = ~i
x[ni] = work.limit[ni] - work.d[ni]
# Limited brackets grow by decreasing the distance from the limit to
# the moving end.
ni = ~i # indices of limited brackets
work.d[ni] /= work.factor[ni]
x[ni] = work.limit[ni] - work.d[ni]

return x

def post_func_eval(x, f, work):
# Keep track of the previous location of the moving end so that we can
# return a narrower bracket. (The alternative is to remember the
# original fixed end, but then the bracket would be wider than needed.)
work.x_last = work.x
work.f_last = work.f
work.x = x
work.f = f

def check_termination(work):

stop = np.zeros_like(work.x, dtype=bool)

# Condition 1: a valid bracket (or the root itself) has been found
sf = np.sign(work.f)
sf_last = np.sign(work.f_last)

i = (sf_last == -sf) | (sf_last == 0) | (sf == 0)
work.status[i] = _ECONVERGED
stop[i] = True

# If we just found a bracket on the right, we can stop looking on the
# left, and vice-versa. This is a bit tricky.
# Condition 2: the other side's search found a valid bracket.
# (If we just found a bracket with the rightward search, we can stop
# the leftward search, and vice-versa.)
# To do this, we need to set the status of the other side's search;
# this is tricky because `work.status` contains only the *active*
# elements, so we don't immediately know the index of the element we
# need to set - or even if it's still there. (That search may have
# terminated already, e.g. by reaching its `limit`.)
# To facilitate this, `work.active` contains a unit integer index of
# each search. Index `k` (`k < n)` and `k + n` correspond with a
# leftward and rightward search, respectively. Elements are removed
# from `work.active` just as they are removed from `work.status`, so
# we use `work.active` to help find the right location in
# `work.status`.
# Get the integer indices of the elements that can also stop
also_stop = (work.active[i] + work.n) % (2*work.n)
# Check whether they are still active.
# To start, we need to find out whether they would be in `work.active`
# if they are indeed there.
# To start, we need to find out where in `work.active` they would
# appear if they are indeed there.
j = np.searchsorted(work.active, also_stop)
# If the location exceeds the length of the `work.active`, they are
# not there.
j = j[j < len(work.active)]
# Check whether they are still there.
j = j[also_stop == work.active[j]]
# Now convert these to boolean indices
# Now convert these to boolean indices to use with `work.status`.
i = np.zeros_like(stop)
i[j] = True # boolean indices of elements that can also stop
i = i & ~stop
work.status[i] = _ESTOPONESIDE
stop[i] = True

# Condition 3: moving end of bracket reaches limit
i = (work.x == work.limit) & ~stop
work.status[i] = _ELIMITS
stop[i] = True

# Condition 4: non-finite value encountered
i = ~(np.isfinite(work.x) & np.isfinite(work.f)) & ~stop
work.status[i] = _EVALUEERR
stop[i] = True
Expand All @@ -1658,6 +1708,14 @@ def post_termination_check(work):
def customize_result(res, shape):
n = len(res['x']) // 2

# Because we treat the two one-sided searches as though they were
# independent, what we keep track of in `work` and what we want to
# return in `res` look quite different. Combine the results from the
# two one-sided searches before reporting the results to the user.
# - "a" refers to the leftward search (the moving end started at `a`)
# - "b" refers to the rightward search (the moving end started at `b`)
# - "l" refers to the left end of the bracket (closer to -oo)
# - "r" refers to the right end of the bracket (closer to +oo)
xal = res['x'][:n]
xar = res['x_last'][:n]
xbl = res['x_last'][n:]
Expand All @@ -1668,6 +1726,25 @@ def customize_result(res, shape):
fbl = res['f_last'][n:]
fbr = res['f'][n:]

# Initialize the brackets and corresponding function values to return
# to the user. Brackets may not be valid (e.g. there is no root,
# there weren't enough iterations, NaN encountered), but we still need
# to return something. One option would be all NaNs, but what I've
# chosen here is the left- and right-most points at which the function
# has been evaluated. This gives the user some information about what
# interval of the real line has been searched and shows that there is
# no sign change between the two ends.
xl = xal.copy()
fl = fal.copy()
xr = xbr.copy()
fr = fbr.copy()

# `status` indicates whether the bracket is valid or not. If so,
# we want to adjust the bracket we return to be the narrowest possible
# given the points at which we evaluated the function.
# For example if bracket "a" is valid and smaller than bracket "b" OR
# if bracket "a" is valid and bracket "b" is not valid, we want to
# return bracket "a" (and vice versa).
sa = res['status'][:n]
sb = res['status'][n:]

Expand All @@ -1677,16 +1754,12 @@ def customize_result(res, shape):
i1 = ((da <= db) & (sa == 0)) | ((sa == 0) & (sb != 0))
i2 = ((db <= da) & (sb == 0)) | ((sb == 0) & (sa != 0))

xl = xal.copy()
fl = fal.copy()
xr = xbr.copy()
fr = fbr.copy()

xr[i1] = xar[i1]
fr[i1] = far[i1]
xl[i2] = xbl[i2]
fl[i2] = fbl[i2]

# Finish assembling the result object
res['xl'] = xl
res['xr'] = xr
res['fl'] = fl
Expand All @@ -1698,10 +1771,12 @@ def customize_result(res, shape):
# report the status from one side only.
res['status'] = np.choose(sa == 0, (sb, sa))
res['success'] = (res['status'] == 0)

del res['x']
del res['f']
del res['x_last']
del res['f_last']

return shape[:-1]

return _scalar_optimization_loop(work, callback, shape,
Expand Down
14 changes: 7 additions & 7 deletions scipy/optimize/tests/test_zeros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1846,13 +1846,13 @@ def test_input_validation(self):
with pytest.raises(ValueError, match=message):
zeros._bracket_root(lambda x: x, -4, 4, factor=0.5)

message = '`min <= a < b <= max` must be True'
with pytest.raises(ValueError, match=message):
zeros._bracket_root(lambda x: x, 4, -4)
with pytest.raises(ValueError, match=message):
zeros._bracket_root(lambda x: x, -4, 4, max=np.nan)
with pytest.raises(ValueError, match=message):
zeros._bracket_root(lambda x: x, -4, 4, min=10)
# message = '`min <= a < b <= max` must be True'
# with pytest.raises(ValueError, match=message):
# zeros._bracket_root(lambda x: x, 4, -4)
# with pytest.raises(ValueError, match=message):
# zeros._bracket_root(lambda x: x, -4, 4, max=np.nan)
# with pytest.raises(ValueError, match=message):
# zeros._bracket_root(lambda x: x, -4, 4, min=10)

message = "shape mismatch: objects cannot be broadcast"
# raised by `np.broadcast, but the traceback is readable IMO
Expand Down
17 changes: 17 additions & 0 deletions scipy/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@
Probability distributions
=========================
Random Variables
----------------
.. autosummary::
:toctree: generated/
ContinuousDistribution
ShiftedScaledDistribution
CircularDistribution
LogUniform
Normal
ShiftedScaledNormal
Each univariate distribution is an instance of a subclass of `rv_continuous`
(`rv_discrete` for discrete distributions):
Expand Down Expand Up @@ -627,6 +641,9 @@
from ._covariance import Covariance
from ._sensitivity_analysis import *
from ._survival import *
from ._new_distributions import (
LogUniform, Normal, ShiftedScaledNormal, ShiftedScaledDistribution,
CircularDistribution, ContinuousDistribution)

# Deprecated namespaces, to be removed in v2.0.0
from . import (
Expand Down
Loading

0 comments on commit 07b7286

Please sign in to comment.