Skip to content

Commit 2114f76

Browse files
authored
Merge branch 'devitocodes:main' into master
2 parents 8913c3e + ff62667 commit 2114f76

File tree

4 files changed

+55
-5
lines changed

4 files changed

+55
-5
lines changed

devito/finite_differences/derivative.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,27 @@ def _eval_at(self, func):
503503
if self.expr.staggered == func.staggered and self.expr.is_Function:
504504
return self
505505

506+
# Check if x0's keys come from a DerivedDimension
506507
x0 = func.indices_ref.getters
508+
psubs = {}
509+
nx0 = x0.copy()
510+
for d, d0 in x0.items():
511+
if d in self.dims:
512+
# d is a valid Derivative dimension
513+
continue
514+
for sd in self.dims:
515+
if sd in d._defines:
516+
# x0 key is a DerivedDimension of the derivative dimension
517+
# e.g f.dx(x0={ix: ix + h_x/2}) for a subdomain
518+
# Set x0 to the derivative dimension and add a substitution
519+
# to the parent
520+
# e.g f.dx(x0={x: x + h_x/2}).subs({x: ix})
521+
psubs[sd] = d
522+
nx0[sd] = nx0.pop(d)._subs(d, sd)
523+
rkw = {'x0': nx0}
524+
if psubs:
525+
rkw['subs'] = (psubs,)
526+
507527
if self.expr.is_Add:
508528
# If `expr` has both staggered and non-staggered terms such as
509529
# `(u(x + h_x/2) + v(x)).dx` then we exploit linearity of FD to split
@@ -512,19 +532,19 @@ def _eval_at(self, func):
512532
mapper = as_mapper(self.expr._args_diff, lambda i: i.staggered)
513533
args = [self.expr.func(*v) for v in mapper.values()]
514534
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
515-
args = [self._rebuild(expr=a, x0=x0) for a in args]
535+
args = [self._rebuild(a, **rkw) for a in args]
516536
return self.expr.func(*args)
517537
elif self.expr.is_Mul:
518538
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear
519539
# in most equation with div(a * u) for example. The expression is re-centered
520540
# at the highest priority index (see _gather_for_diff) to compute the
521541
# derivative at x0.
522-
return self._rebuild(self.expr._gather_for_diff, x0=x0)
542+
return self._rebuild(self.expr._gather_for_diff, **rkw)
523543
else:
524544
# For every other cases, that has more functions or more complexe arithmetic,
525545
# there is not actual way to decide what to do so it’s as safe to use
526546
# the expression as is.
527-
return self._rebuild(x0=x0)
547+
return self._rebuild(self.expr, **rkw)
528548

529549
def _evaluate(self, **kwargs):
530550
# Evaluate finite-difference.

devito/finite_differences/differentiable.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,12 @@ def dimensions(self):
750750
return self._dimensions
751751

752752
def _evaluate(self, **kwargs):
753-
expr = self.expr._evaluate(**kwargs)
753+
try:
754+
expr = self.expr._evaluate(**kwargs)
755+
except AttributeError:
756+
# There are rare circumstances in which `self.expr` is a plain
757+
# SymPy object rather than an Evaluable
758+
expr = Evaluable._evaluate_maybe_nested(self.expr, **kwargs)
754759

755760
if not kwargs.get('expand', True):
756761
return self._rebuild(expr)
@@ -770,7 +775,10 @@ def free_symbols(self):
770775

771776

772777
class WeightsIndexed(Indexed):
773-
pass
778+
779+
@property
780+
def dimension(self):
781+
return self.function.dimension
774782

775783

776784
class Weights(Array):

tests/test_interpolation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ def test_inject_subdomain_sinc(self):
10981098
['p_sr0rsr0xrsr0y'],
10991099
'p_sr0rsr0xrsr0y')
11001100

1101+
@pytest.mark.xfail(reason="OOB issue")
11011102
@pytest.mark.parallel(mode=4)
11021103
def test_interpolate_subdomain_mpi(self, mode):
11031104
"""

tests/test_subdomains.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,3 +1778,24 @@ def define(self, dimensions):
17781778
if grid.distributor.myrank == 0:
17791779
assert np.all(np.isclose(fdata[:], gdata[:, 2:-2, 2:-2]))
17801780
assert np.isclose(np.linalg.norm(fdata[:]), norm)
1781+
1782+
def test_mixed_domain_fd_staggered(self):
1783+
grid = Grid(shape=(20, 20, 20))
1784+
x = grid.dimensions[0]
1785+
1786+
class SD1(SubDomain):
1787+
name = 'sd1'
1788+
1789+
def define(self, dimensions):
1790+
x, y, z = dimensions
1791+
return {x: ('left', 2), y: y, z: z}
1792+
1793+
sd1 = SD1(grid=grid)
1794+
f = TimeFunction(name='f', grid=grid, time_order=1, space_order=2,
1795+
staggered=(x,))
1796+
g = TimeFunction(name='g', grid=sd1, time_order=1, space_order=2,
1797+
staggered=(x,))
1798+
1799+
eq = Eq(g, g + f.dx)
1800+
eqe = eq.evaluate
1801+
assert eqe.rhs == g + f.dx(x0=x).evaluate._subs(x, g.dimensions[1])

0 commit comments

Comments
 (0)