Skip to content

Commit

Permalink
Merge pull request #16 from ahuang11/marvin
Browse files Browse the repository at this point in the history
Add AI generated colormaps with Marvin
  • Loading branch information
ahuang11 authored Dec 15, 2023
2 parents a9f5e0f + 9448f4f commit 093669e
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 7 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,18 @@ pair_tbar(

![example](https://github.com/ahuang11/tastymap/assets/15331990/04ab9ea7-d836-44b8-843d-2cb65eddfe63)

Try to craft your visual delight *interactively* with the TastyKitchen UI, hosted [here](https://huggingface.co/spaces/ahuang11/tastykitchen)
Or if you need suggestions, get help from AI by providing a description of what you're imagining:

```python
from tastymap import ai

tmap = ai.suggest_tmap("Pikachu")
tmap
```

![image](https://github.com/ahuang11/tastymap/assets/15331990/5a6f2bd4-4c4f-449c-9f2a-3352c956400a)

Try to craft your visual delight *interactively* with the TastyKitchen UI, hosted [here](https://huggingface.co/spaces/ahuang11/tastykitchen).

```bash
tastymap ui
Expand Down
7 changes: 7 additions & 0 deletions docs/reference/ai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# AI

**Installing `marvin` and `pydantic` is required to use this module.**

::: tastymap.ai
options:
show_source: true
12 changes: 12 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,18 @@ tmap = cook_tmap(["red", "green", "blue"])
tmap >> "rgb"
```

## Suggesting based on a description

You can have AI suggest a `TastyMap` based on a description:

```python
from tastymap import ai

tmap = ai.suggest_tmap("Pikachu")
```

![image](https://github.com/ahuang11/tastymap/assets/15331990/5a6f2bd4-4c4f-449c-9f2a-3352c956400a)

## Using the TastyKitchen UI

You can use the TastyKitchen UI to craft your `TastyMap` interactively:
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ nav:
- Core: reference/core.md
- Models: reference/models.md
- Utils: reference/utils.md
- AI: reference/ai.md
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ ui = [
"pooch",
]

ai = [
"marvin",
"pydantic",
]

[project.urls]
Documentation = "https://ahuang11.github.io/tastymap/"
Source = "https://github.com/ahuang11/tastymap"
Expand Down
2 changes: 1 addition & 1 deletion tastymap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from .models import TastyBar, TastyMap
from .ui import TastyKitchen

__version__ = "0.2.0"
__version__ = "0.3.0"

__all__ = ["cook_tmap", "pair_tbar", "TastyMap", "TastyBar", "TastyKitchen"]
72 changes: 72 additions & 0 deletions tastymap/ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
try:
from marvin import ai_fn, ai_model # type: ignore
from pydantic import BaseModel, Field # type: ignore
except ImportError:
raise ImportError(
"Please install marvin and pydantic to use this module, "
"e.g. pip install marvin pydantic"
)

from .core import cook_tmap
from .models import TastyMap


@ai_model(max_tokens=256)
class AIPalette(BaseModel):
colors: list[str] = Field(
default=...,
description="""
A list of colors as existing matplotlib named colors or hex codes,
like `["firebrick", "#FFFFFF", "#000000"]`. If the color is invalid,
find the closest color to the provided color.
""",
)
name: str = Field(..., description="A creative name to describe the colors.")


@ai_fn(max_tokens=256)
def _refine_description(description: str, num_colors: int) -> str: # pragma: no cover
"""
You are a master painter, and well versed in matplotlib colors.
Describe in detail what you imagine when you think of
the provided `description` in descriptive named colors.
Then, share a variety of colors, either as valid matplotlib named colors
or hex codes that best represent the image, so that you can use
it to paint the image, up to `num_colors` colors.
"""


def suggest_tmap(
description: str, num_colors: int = 5, retries: int = 3, verbose: bool = True
) -> TastyMap:
"""
Suggest a TastyMap based on a description of the image.
Args:
description: A description of the image.
num_colors: Number of colors in the colormap. Defaults to 5.
retries: Number of retries to suggest a TastyMap. Defaults to 3.
verbose: Whether to print the AI description. Defaults to True.
Returns:
TastyMap: A new TastyMap instance with the new colormap.
"""
exceptions = []
for _ in range(retries):
try:
ai_description = _refine_description(description, num_colors)
if verbose:
print(ai_description)
ai_palette = AIPalette(ai_description)
return cook_tmap(
["".join(color.split()) for color in ai_palette.colors],
name=ai_palette.name,
)
except Exception as exception:
exceptions.append(exception)
else:
raise ValueError(
f"Attempted to suggest a TastyMap {retries} times, "
f"but failed due to {exceptions}"
)
41 changes: 36 additions & 5 deletions tastymap/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
"run `pip install 'tastymap[ui]'` to install."
)

try:
from .ai import suggest_tmap
except ImportError:
suggest_tmap = None

from .core import cook_tmap, pair_tbar
from .models import ColorModel, TastyMap
from .utils import get_cmap, get_registered_cmaps
Expand Down Expand Up @@ -198,6 +203,21 @@ def __init__(self, **params):
sizing_mode="stretch_width",
margin=(10, 30, 5, 20),
)
if suggest_tmap is None:
colors_suggest = pn.widgets.TextAreaInput(
placeholder="This feature requires the `tastymap[ai]` extra.",
margin=(5, 5, 5, 20),
auto_grow=True,
max_rows=3,
disabled=True,
)
else:
colors_suggest = pn.widgets.TextAreaInput(
placeholder="Enter a description to let AI suggest a colormap",
margin=(5, 5, 5, 20),
auto_grow=True,
max_rows=3,
)
colors_clear = pn.widgets.Button(
name="Clear",
sizing_mode="stretch_width",
Expand All @@ -209,6 +229,7 @@ def __init__(self, **params):
("Text", colors_input),
("Pick", colors_picker),
("Upload", colors_upload),
("Suggest", colors_suggest),
("Clear", colors_clear),
),
self.colors_select,
Expand All @@ -217,6 +238,7 @@ def __init__(self, **params):
colors_input.param.watch(self._add_color, "value")
colors_picker.param.watch(self._add_color, "value")
colors_upload.param.watch(self._add_color, "value")
colors_suggest.param.watch(self._add_color, "value")
colors_clear.on_click(lambda event: setattr(self.colors_select, "value", []))

# tmap widgets
Expand Down Expand Up @@ -400,6 +422,14 @@ def _add_color(self, event):

if isinstance(new_event, bytes):
new_event = new_event.decode("utf-8")
elif "let AI" in event.obj.placeholder:
try:
event.obj.disabled = True
tmap = suggest_tmap(new_event, self.num_colors)
self.custom_name = tmap.cmap.name
new_event = tmap.to_model("hex").tolist()
finally:
event.obj.disabled = False

value = self.colors_select.value
if isinstance(value, dict):
Expand All @@ -416,13 +446,14 @@ def _add_color(self, event):

processed_colors = []
for color in new_event:
if not color.strip() or color.startswith("#"):
color = color.strip().strip(",")
if not color:
continue
try:
if " " in color or "," in color:
color = np.array(
ast.literal_eval(",".join(color.strip().split()))
).astype(float)
if " " in color or color.count(",") == 2:
color = np.array(ast.literal_eval(",".join(color.split()))).astype(
float
)
if any(c > 1 for c in color):
color /= 255
color = tuple(color.round(2))
Expand Down

0 comments on commit 093669e

Please sign in to comment.