gemma_penzai is JAX research toolkit for visualizing, manipulating and
understanding Gemma models with multi-modal support based on Penzai. The
original Penzai mainly supports
text-only LLMs, including Gemma 1 and Gemma 2. Now we extend Penzai with vision
and multimodal support. Therefore, Gemma 3 is supported in this package for more
interpretability research. As a preliminary, the detailed documentation on
Penzai can be found at https://penzai.readthedocs.io.
Gemma is a family of open-weights Large Language Model (LLM) by Google DeepMind, based on Gemini research and technology.
It has been implemented on different platforms:
However, using the above implementations makes it difficult to visualize the internal mechanism of Gemma. Therefore, we extend the implementation of Gemma on Penzai, a JAX research toolkit for building, editing, and visualizing neural networks.
Gemma 1 and Gemma 2 have been supported in original Penzai package, here we mainly extend the support of Gemma 3 with following new features:
- Vision Transformers (ViTs) and basic components.
- Multi-modal Large Language Models (MLLMs) with both vision encoders and LLM backbone. The new attention mask is implemented.
- Decoding algorithms for MLLMs.
If you haven't already installed JAX with TPU support, you should do that first, since the installation process depends on your platform. You can find instructions in the JAX documentation. Afterwards, you can install our package as
git clone https://github.com/google-deepmind/gemma_penzai.git
cd gemma_penzai
pip install --upgrade pip
pip install -e .and import it and its dependence penzai using
import penzai
from penzai import pz
from gemma_penzai import mllm, vision(penzai.pz is an alias namespace, which makes it easier to reference
common Penzai objects.)
When working in a Colab or IPython notebook, we recommend also configuring Treescope (Penzai's companion pretty-printer) as the default pretty printer, and enabling some utilities for interactive use:
import treescope
treescope.basic_interactive_setup(autovisualize_arrays=True)We provide notebooks inside ./notebooks about the basic usage of Gemma 3,
including the multimodal case.
Our codes are developed based on Penzai and Gemma on JAX.