From b8a523f977a9b4e63b568a7af83621fdb1c755be Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 12 May 2022 22:22:52 -0700 Subject: [PATCH] Enable colors when we are using a terminal or IPython --- jax/_src/config.py | 2 +- jax/_src/pretty_printer.py | 23 +++++++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 3f5977415858..78cad0cb4f13 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -476,7 +476,7 @@ def update_thread_local_jit_state(**kw): flags.DEFINE_bool( 'jax_pprint_use_color', - bool_env('JAX_PPRINT_USE_COLOR', False), + bool_env('JAX_PPRINT_USE_COLOR', True), help='Enable jaxpr pretty-printing with colorful syntax highlighting.' ) diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index 3ad804eba4d0..2f8a5ffdfb46 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -27,6 +27,7 @@ import abc import enum +import sys from functools import partial from typing import List, NamedTuple, Optional, Sequence, Tuple, Union from jax.config import config @@ -36,13 +37,31 @@ except ImportError: colorama = None +def _can_use_color() -> bool: + try: + # Check if we're in IPython or Colab + ipython = get_ipython() # type: ignore[name-defined] + shell = ipython.__class__.__name__ + if shell == "ZMQInteractiveShell": + # Jupyter Notebook + return True + elif "colab" in str(ipython.__class__): + # Google Colab (external or internal) + return True + except NameError: + pass + # Otherwise check if we're in a terminal + return sys.stdout.isatty() + +CAN_USE_COLOR = _can_use_color() class Doc(abc.ABC): __slots__ = () - def format(self, width: int = 80, use_color: bool = False, + def format(self, width: int = 80, use_color: Optional[bool] = None, annotation_prefix=" # ") -> str: - use_color = use_color or config.FLAGS.jax_pprint_use_color + if use_color is None: + use_color = CAN_USE_COLOR and config.FLAGS.jax_pprint_use_color return _format(self, width, use_color=use_color, annotation_prefix=annotation_prefix)