Skip to content

Add JAX Frontend Support to Ivy Transpiler  #28846

@YushaArif99

Description

@YushaArif99

Description:
The current implementation of ivy.transpile supports "torch" as the sole source argument. This allows transpiling PyTorch functions or classes to target frameworks like TensorFlow, JAX, or NumPy. This task aims to extend the functionality by adding JAX as a valid source, enabling transpilation of JAX code to other frameworks via Ivy's intermediate representation.

For example, after completing this task, we should be able to transpile JAX code using:

ivy.transpile(func, source="jax", target="tensorflow")

Goals:

The main objective is to implement the first two stages of the transpilation pipeline for JAX:

  1. Lower JAX code to Ivy’s JAX Frontend IR.
  2. Transform the JAX Frontend IR to Ivy’s core representation.

Once these stages are complete, the rest of the pipeline can be reused to target other frameworks like TensorFlow, PyTorch, or NumPy. The steps would look as follows:

source='jax' → target='jax_frontend'  
source='jax_frontend' → target='ivy'  
source='ivy' → target='tensorflow'/'torch'/etc.  

This mirrors the existing pipeline for PyTorch:

source='torch' → target='torch_frontend'  
source='torch_frontend' → target='ivy'  
source='ivy' → target='jax'/'numpy'/etc.  

Key Tasks:

  1. Add Native Framework-Specific Implementations for Core Transformation Passes:

    • For example, implement the native_jax_recursive_transformer.py for traversing and transforming JAX native source code.
    • Use native_torch_recursive_transformer.py as a reference (example here)
  2. Define the Transformation Pipeline for JAX to JAX Frontend IR:

    • Create a new pipeline in source_to_frontend_translator_config.py to handle the stage source='jax', target='jax_frontend' (example here).
  3. Define the Transformation Pipeline for JAX Frontend IR to Ivy:

    • Add another pipeline in frontend_to_ivy_translator_config.py to handle the stage source='jax_frontend', target='ivy' (example here).
  4. Add Stateful Classes for Flax APIs:

    • Implement stateful class for flax.nnx.Module API that inherit from ivy.Module.
    • Reference the existing implementation for PyTorch's nn.Module (example here)
    • This allows for sequential lowering:
      nnx.Module→ (frontend nnx.Module) → ivy.Module → (target keras.Model/keras.Layer)
      
  5. Understand and Leverage Reusability:

    • Explore reusable components in the existing PyTorch pipeline, especially for AST transformers and configuration management.

Testing:

  • Familiarize yourself with the transpilation flow by exploring transpiler tests
  • Add appropriate tests to validate JAX source transpilation at each stage of the pipeline.

Additional Notes:

  • Keep in mind the modular and extensible design of the transpiler, ensuring that the new implementation integrates smoothly into the existing architecture.
  • Be prepared for nuances or intricacies in JAX/Flax API, especially with nnx.Module/nnx,Variable.

Metadata

Metadata

Assignees

No one assigned

    Labels

    JAX FrontendDeveloping the JAX Frontend, checklist triggered by commenting add_frontend_checklistToDoA ToDo list of tasksTranspilerAnything related to transpiling

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions