Skip to content

Conversation

hsjts0u
Copy link

@hsjts0u hsjts0u commented Sep 4, 2025

Add default args for _aten_conv2d, which would otherwise fail in the following code snippet

import torch
from torch.export import export_for_training
import torchax
from torchax import interop
from torch.utils import _pytree as pytree
import jax
from torchax.ops import mappings

class Simple(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=4, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        return x
    

model = Simple()

exported = export_for_training(model, (torch.randn(1, 3, 224, 224),))

def make_shape_struct(x):
    return jax.ShapeDtypeStruct(x.shape, mappings.t2j_dtype(x.dtype))


def map_nth(v, i):
    def f(t):
        if isinstance(t, torch.Tensor):
            return t[i : i + 1]
        return t

    return pytree.tree_map(f, v)


env = torchax.default_env()
with env:
    model = exported.module().to("jax")

    def func_to_export(x):
        # hard code weights in model
        return model(x)

    example_inputs_jax = pytree.tree_map_only(
        torch.Tensor, lambda x: x.to("jax"), map_nth(exported.example_inputs, 0)
    )

    res = jax.jit(interop.jax_view(func_to_export)).lower(*example_inputs_jax[0])

# TypeError: _aten_conv2d() missing 5 required positional arguments: 'bias', 'stride', 'padding', 'dilation', and 'groups'

cc @qihqi

@hsjts0u hsjts0u changed the title Add default args for _aten_con2d Add default args for _aten_conv2d Sep 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant