Skip to content

Commit a6cb68d

Browse files
committed
Finish Copilot code
1 parent 53adf9a commit a6cb68d

File tree

10 files changed

+605
-291
lines changed

10 files changed

+605
-291
lines changed

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,18 @@
3131
"""
3232

3333

34+
@jax_funcify.register(AdvancedSubtensor1)
35+
def jax_funcify_AdvancedSubtensor1(op, node, **kwargs):
36+
def advanced_subtensor1(x, ilist):
37+
return x[ilist]
38+
39+
return advanced_subtensor1
40+
41+
3442
@jax_funcify.register(Subtensor)
3543
@jax_funcify.register(AdvancedSubtensor)
36-
@jax_funcify.register(AdvancedSubtensor1)
3744
def jax_funcify_Subtensor(op, node, **kwargs):
38-
idx_list = getattr(op, "idx_list", None)
45+
idx_list = op.idx_list
3946

4047
def subtensor(x, *ilists):
4148
indices = indices_from_subtensor(ilists, idx_list)
@@ -47,10 +54,24 @@ def subtensor(x, *ilists):
4754
return subtensor
4855

4956

50-
@jax_funcify.register(IncSubtensor)
5157
@jax_funcify.register(AdvancedIncSubtensor1)
58+
def jax_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
59+
if getattr(op, "set_instead_of_inc", False):
60+
61+
def jax_fn(x, y, ilist):
62+
return x.at[ilist].set(y)
63+
64+
else:
65+
66+
def jax_fn(x, y, ilist):
67+
return x.at[ilist].add(y)
68+
69+
return jax_fn
70+
71+
72+
@jax_funcify.register(IncSubtensor)
5273
def jax_funcify_IncSubtensor(op, node, **kwargs):
53-
idx_list = getattr(op, "idx_list", None)
74+
idx_list = op.idx_list
5475

5576
if getattr(op, "set_instead_of_inc", False):
5677

@@ -77,8 +98,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
7798

7899
@jax_funcify.register(AdvancedIncSubtensor)
79100
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
80-
idx_list = getattr(op, "idx_list", None)
81-
101+
idx_list = op.idx_list
102+
82103
if getattr(op, "set_instead_of_inc", False):
83104

84105
def jax_fn(x, indices, y):

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
2222
from pytensor.tensor import TensorType
23-
from pytensor.tensor.rewriting.subtensor import is_full_slice
2423
from pytensor.tensor.subtensor import (
2524
AdvancedIncSubtensor,
2625
AdvancedIncSubtensor1,
@@ -29,7 +28,7 @@
2928
IncSubtensor,
3029
Subtensor,
3130
)
32-
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType
31+
from pytensor.tensor.type_other import MakeSlice
3332

3433

3534
def slice_new(self, start, stop, step):
@@ -239,15 +238,15 @@ def {function_name}({", ".join(input_names)}):
239238
@register_funcify_and_cache_key(AdvancedIncSubtensor)
240239
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
241240
if isinstance(op, AdvancedSubtensor):
242-
x, y, tensor_inputs = node.inputs[0], None, node.inputs[1:]
241+
_, _, tensor_inputs = node.inputs[0], None, node.inputs[1:]
243242
else:
244-
x, y, *tensor_inputs = node.inputs
243+
_, _, *tensor_inputs = node.inputs
245244

246245
# Reconstruct indexing information from idx_list and tensor inputs
247246
basic_idxs = []
248247
adv_idxs = []
249248
input_idx = 0
250-
249+
251250
for i, entry in enumerate(op.idx_list):
252251
if isinstance(entry, slice):
253252
# Basic slice index
@@ -256,12 +255,14 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
256255
# Advanced tensor index
257256
if input_idx < len(tensor_inputs):
258257
idx_input = tensor_inputs[input_idx]
259-
adv_idxs.append({
260-
"axis": i,
261-
"dtype": idx_input.type.dtype,
262-
"bcast": idx_input.type.broadcastable,
263-
"ndim": idx_input.type.ndim,
264-
})
258+
adv_idxs.append(
259+
{
260+
"axis": i,
261+
"dtype": idx_input.type.dtype,
262+
"bcast": idx_input.type.broadcastable,
263+
"ndim": idx_input.type.ndim,
264+
}
265+
)
265266
input_idx += 1
266267

267268
# Special implementation for consecutive integer vector indices

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Subtensor,
1010
indices_from_subtensor,
1111
)
12-
from pytensor.tensor.type_other import MakeSlice, SliceType
12+
from pytensor.tensor.type_other import MakeSlice
1313

1414

1515
def check_negative_steps(indices):
@@ -63,8 +63,8 @@ def makeslice(start, stop, step):
6363
@pytorch_funcify.register(AdvancedSubtensor1)
6464
@pytorch_funcify.register(AdvancedSubtensor)
6565
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
66-
idx_list = getattr(op, "idx_list", None)
67-
66+
idx_list = op.idx_list
67+
6868
def advsubtensor(x, *flattened_indices):
6969
indices = indices_from_subtensor(flattened_indices, idx_list)
7070
check_negative_steps(indices)
@@ -105,7 +105,7 @@ def inc_subtensor(x, y, *flattened_indices):
105105
@pytorch_funcify.register(AdvancedIncSubtensor)
106106
@pytorch_funcify.register(AdvancedIncSubtensor1)
107107
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
108-
idx_list = getattr(op, "idx_list", None)
108+
idx_list = op.idx_list
109109
inplace = op.inplace
110110
ignore_duplicates = getattr(op, "ignore_duplicates", False)
111111

@@ -139,7 +139,9 @@ def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
139139

140140
else:
141141
# Check if we have slice indexing in idx_list
142-
has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
142+
has_slice_indexing = (
143+
any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
144+
)
143145
if has_slice_indexing:
144146
raise NotImplementedError(
145147
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"

pytensor/sparse/basic.py

Lines changed: 123 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,8 +1317,6 @@ def perform(self, node, inputs, outputs):
13171317
z[0] = y
13181318

13191319
def grad(self, inputs, gout):
1320-
from pytensor.sparse.math import sp_sum
1321-
13221320
(x, s) = inputs
13231321
(gz,) = gout
13241322
return [col_scale(gz, s), sp_sum(x * gz, axis=0)]
@@ -1368,8 +1366,6 @@ def perform(self, node, inputs, outputs):
13681366
z[0] = scipy.sparse.csc_matrix((y_data, indices, indptr), (M, N))
13691367

13701368
def grad(self, inputs, gout):
1371-
from pytensor.sparse.math import sp_sum
1372-
13731369
(x, s) = inputs
13741370
(gz,) = gout
13751371
return [row_scale(gz, s), sp_sum(x * gz, axis=1)]
@@ -1435,6 +1431,126 @@ def row_scale(x, s):
14351431
return col_scale(x.T, s).T
14361432

14371433

1434+
class SpSum(Op):
1435+
"""
1436+
1437+
WARNING: judgement call...
1438+
We are not using the structured in the comparison or hashing
1439+
because it doesn't change the perform method therefore, we
1440+
*do* want Sums with different structured values to be merged
1441+
by the merge optimization and this requires them to compare equal.
1442+
"""
1443+
1444+
__props__ = ("axis",)
1445+
1446+
def __init__(self, axis=None, sparse_grad=True):
1447+
super().__init__()
1448+
self.axis = axis
1449+
self.structured = sparse_grad
1450+
if self.axis not in (None, 0, 1):
1451+
raise ValueError("Illegal value for self.axis.")
1452+
1453+
def make_node(self, x):
1454+
x = as_sparse_variable(x)
1455+
assert x.format in ("csr", "csc")
1456+
1457+
if self.axis is not None:
1458+
out_shape = (None,)
1459+
else:
1460+
out_shape = ()
1461+
1462+
z = TensorType(dtype=x.dtype, shape=out_shape)()
1463+
return Apply(self, [x], [z])
1464+
1465+
def perform(self, node, inputs, outputs):
1466+
(x,) = inputs
1467+
(z,) = outputs
1468+
if self.axis is None:
1469+
z[0] = np.asarray(x.sum())
1470+
else:
1471+
z[0] = np.asarray(x.sum(self.axis)).ravel()
1472+
1473+
def grad(self, inputs, gout):
1474+
(x,) = inputs
1475+
(gz,) = gout
1476+
if x.dtype not in continuous_dtypes:
1477+
return [x.zeros_like(dtype=config.floatX)]
1478+
if self.structured:
1479+
if self.axis is None:
1480+
r = gz * sp_ones_like(x)
1481+
elif self.axis == 0:
1482+
r = col_scale(sp_ones_like(x), gz)
1483+
elif self.axis == 1:
1484+
r = row_scale(sp_ones_like(x), gz)
1485+
else:
1486+
raise ValueError("Illegal value for self.axis.")
1487+
else:
1488+
o_format = x.format
1489+
x = dense_from_sparse(x)
1490+
if _is_sparse_variable(gz):
1491+
gz = dense_from_sparse(gz)
1492+
if self.axis is None:
1493+
r = ptb.second(x, gz)
1494+
else:
1495+
ones = ptb.ones_like(x)
1496+
if self.axis == 0:
1497+
r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones
1498+
elif self.axis == 1:
1499+
r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones
1500+
else:
1501+
raise ValueError("Illegal value for self.axis.")
1502+
r = SparseFromDense(o_format)(r)
1503+
return [r]
1504+
1505+
def infer_shape(self, fgraph, node, shapes):
1506+
r = None
1507+
if self.axis is None:
1508+
r = [()]
1509+
elif self.axis == 0:
1510+
r = [(shapes[0][1],)]
1511+
else:
1512+
r = [(shapes[0][0],)]
1513+
return r
1514+
1515+
def __str__(self):
1516+
return f"{self.__class__.__name__}{{axis={self.axis}}}"
1517+
1518+
1519+
def sp_sum(x, axis=None, sparse_grad=False):
1520+
"""
1521+
Calculate the sum of a sparse matrix along the specified axis.
1522+
1523+
It operates a reduction along the specified axis. When `axis` is `None`,
1524+
it is applied along all axes.
1525+
1526+
Parameters
1527+
----------
1528+
x
1529+
Sparse matrix.
1530+
axis
1531+
Axis along which the sum is applied. Integer or `None`.
1532+
sparse_grad : bool
1533+
`True` to have a structured grad.
1534+
1535+
Returns
1536+
-------
1537+
object
1538+
The sum of `x` in a dense format.
1539+
1540+
Notes
1541+
-----
1542+
The grad implementation is controlled with the `sparse_grad` parameter.
1543+
`True` will provide a structured grad and `False` will provide a regular
1544+
grad. For both choices, the grad returns a sparse matrix having the same
1545+
format as `x`.
1546+
1547+
This op does not return a sparse matrix, but a dense tensor matrix.
1548+
1549+
"""
1550+
1551+
return SpSum(axis, sparse_grad)(x)
1552+
1553+
14381554
class Diag(Op):
14391555
"""Extract the diagonal of a square sparse matrix as a dense vector.
14401556
@@ -1944,3 +2060,6 @@ def grad(self, inputs, grads):
19442060

19452061

19462062
construct_sparse_from_list = ConstructSparseFromList()
2063+
2064+
# Import sp_sum from math to maintain backward compatibility
2065+
# This must be at the end to avoid circular imports

0 commit comments

Comments
 (0)