Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export escn_l6_m2_lay12_2M_s2ef_export_cuda_9182024 using aot_inductor #1000

Open
ipcamit opened this issue Feb 6, 2025 · 1 comment
Open
Assignees

Comments

@ipcamit
Copy link

ipcamit commented Feb 6, 2025

What would you like to report?

I want to use the exported escn_l6_m2_lay12_2M_s2ef_export_cuda_9182024.pt2 as a shared library using torch._export.aot_compile. I tried the naive attempt as following:

model = torch.export.load("./escn_l6_m2_lay12_2M_s2ef_export_cuda_9182024.pt2")

inputs = sim_export_input(batch_size=1, device="cuda")

atoms_dim = Dim("atoms_dim", min=2, max=1000)
edges_dim = Dim("edges_dim", min=2, max=10000)

dynamic_shapes = {
            "pos": {0: atoms_dim, 1: None},
            "batch_idx": {0: atoms_dim},
            "natoms": {0: None},
            "atomic_numbers": {0: atoms_dim},
            "edge_index": {0: None, 1: edges_dim},
            "edge_distance": {0: edges_dim},
            "edge_distance_vec": {0: edges_dim, 1: None},
        }


with torch.inference_mode():
    exported_prog = export(module_fx, args=inputs_cpu, dynamic_shapes=dynamic_shapes)
    export_path = os.path.join("./", "escn_l6_m2_lay12_2M_s2ef_export_cuda_9182024.pt2")
    torch.export.save(exported_prog, export_path)

so_path = os.path.join("./", "escn_l6_m2_lay12_2M_s2ef_export_cpu_9182024.so")
aot_compile_options = {"aot_inductor.output_path": so_path}
if device == "cuda":
    aot_compile_options.update({"max_autotune": True})

so_path = torch._export.aot_compile(
    model,
    args=inputs,
    dynamic_shapes=dynamic_shapes,
    options=aot_compile_options,
)

But i get the following error:

ValueError: Node keys mismatch; missing key(s): {'arg_5', 'arg_1', 'arg_4', 'arg_3', 'arg_6', 'arg_2', 'arg_0'}; extra key(s): {'edge_distance', 'batch_idx', 'pos', 'natoms', 'atomic_numbers', 'edge_index', 'edge_distance_vec'}.

I tried different variations of it such as:

dynamic_shapes = {
            "args_0": {0: atoms_dim, 1: None},
            "args_1": {0: atoms_dim},
            "args_2": {0: None},
            "args_3": {0: atoms_dim},
            "args_4": {0: None, 1: edges_dim},
            "args_5": {0: edges_dim},
            "args_6": {0: edges_dim, 1: None},
        }

But i keep getting variations of above error with permutations of arg_*.

ValueError: Node keys mismatch; missing key(s): {'arg_0', 'arg_4', 'arg_2', 'arg_3', 'arg_1', 'arg_5', 'arg_6'}; extra key(s): {'args_5', 'args_4', 'args_6', 'args_2', 'args_3', 'args_1', 'args_0'}.

What is easiest way to convert this model to a .so file?

Also there seem to be a mismatch between escn_l6_m2_lay12_2M_s2ef.pt and escn_l6_m2_lay12_2M_s2ef_export_cuda_9182024.pt2. pt2 file lists number of parameters as

model = torch.export.load("escn_l6_m2_lay12_2M_s2ef_export_cuda_9182024.pt2")
torch.sum(torch.tensor([p.numel() for p in model.module().parameters() if p.requires_grad]))
>>> tensor(51869952)

but the escn_l6_m2_lay12_2M_s2ef.pt checkpoint only has 51852408 parameters.

ckpt = torch.load("escn_l6_m2_lay12_2M_s2ef.pt", map_location=torch.device("cpu"))
n_total = 0
for p in ckpt["state_dict"]:
    n_total += ckpt["state_dict"][p].numel()
n_total
>>> 51852408
@ipcamit
Copy link
Author

ipcamit commented Feb 6, 2025

I also tried recreating the model from the config file as:

from fairchem.core.models.escn.escn_exportable import eSCN

model_new = eSCN(
    num_layers = 12,
    max_neighbors = 20,
    max_num_elements = 100,
    cutoff = 12.0,
    sphere_channels = 128,
    hidden_channels = 256,
    lmax = 6,
    mmax = 2,
    num_sphere_samples = 128,
    distance_function = "gaussian",
    basis_width_scalar = 2.0,
    export = True,
    distance_resolution = 0.02,
    edge_channels = 128)

Now while the number of traininable parameters match original model:

torch.sum(torch.tensor([p.numel() for p in model_new.parameters() if p.requires_grad]))
>> tensor(51869952)

The state dict are different


len(list(model_new.state_dict().keys()))
>>> 4547

len(list(model.module().state_dict().keys()))
>>> 521

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants