Skip to content

Commit

Permalink
Add support for matplotlib issue StructuredLabs#71
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosnid1 committed Feb 14, 2025
1 parent 7c66df5 commit 921d915
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/triangular_function/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
secrets.toml
.preswald_deploy
.env.structured
6 changes: 6 additions & 0 deletions examples/triangular_function/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Preswald Project

## Setup
1. Configure your data connections in `preswald.toml`
2. Add sensitive information (passwords, API keys) to `secrets.toml`
3. Run your app with `preswald run hello.py`
28 changes: 28 additions & 0 deletions examples/triangular_function/hello.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from preswald import matplotlib
import numpy as np

x = np.linspace(0, 50, 100)
sin = np.sin(x)
cos = np.cos(x)
tan = np.tan(x)

plt = matplotlib.plt
plt.figure(figsize=(10, 6))

plt.plot(x, sin, 'b--', label='sin(x)', linewidth=2)
plt.plot(x, cos, 'r--', label='cos(x)', linewidth=2)
plt.plot(x, tan, 'g--', label='tan(x)', linewidth=2)

plt.title('Sinus Waves', fontsize=16, pad=20)

plt.xlabel('X axis', fontsize=12)
plt.ylabel('Y axis', fontsize=12)

plt.legend(loc='upper right', fontsize=10)

plt.grid(True, linestyle='--', alpha=0.5)

plt.xlim(0, 10)
plt.ylim(-1.5, 1.5)

matplotlib(plt.gcf(), format='svg')
Binary file added examples/triangular_function/images/favicon.ico
Binary file not shown.
Binary file added examples/triangular_function/images/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions examples/triangular_function/preswald.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[project]
title = "Preswald Project"
version = "0.1.0"
port = 8501

[branding]
name = "Preswald Project"
logo = "images/logo.png"
favicon = "images/favicon.ico"
primaryColor = "#F89613"

[data.csv]
type = "csv"
path = "data/sample.csv"

[logging]
level = "INFO" # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL
format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
23 changes: 23 additions & 0 deletions examples/triangular_function/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "preswald-app"
version = "0.1.0"
description = "A Preswald application"
requires-python = ">=3.8"
dependencies = [
"preswald"
]

[tool.hatch.build.targets.wheel]
packages = ["."]

[tool.black]
line-length = 88
target-version = ['py38']

[tool.isort]
profile = "black"
multi_line_output = 3
4 changes: 4 additions & 0 deletions frontend/src/components/DynamicComponents.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import SpinnerWidget from './widgets/SpinnerWidget';
import TableViewerWidget from './widgets/TableViewerWidget';
import TextInputWidget from './widgets/TextInputWidget';
import UnknownWidget from './widgets/UnknownWidget';
import MatplotlibComponent from './widgets/MatplotlibComponent';
import { cn } from '@/lib/utils';

// Error boundary component
Expand Down Expand Up @@ -234,6 +235,9 @@ const MemoizedComponent = memo(
case 'dag':
return <DAGVisualizationWidget {...commonProps} data={component.data || {}} />;

case 'matplotlib':
return <MatplotlibComponent data={component.data} />;

default:
console.warn(`[DynamicComponents] Unknown component type: ${component.type}`);
return (
Expand Down
24 changes: 24 additions & 0 deletions frontend/src/components/widgets/MatplotlibComponent.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import React from 'react';

const MatplotlibComponent = ({ data }) => {
const { image, format } = data;
if (format === 'svg') {
// For SVG, we need to decode and render it as HTML
const decodedSvg = atob(image);
return (
<div
className="w-full overflow-hidden rounded-lg shadow-sm"
dangerouslySetInnerHTML={{ __html: decodedSvg }}
/>
);
}

// For PNG and other formats, use img tag
return (
<div className="w-full overflow-hidden rounded-lg shadow-sm">
<img src={`data:image/${format};base64,${image}`} alt="Matplotlib Plot" />
</div>
);
};

export default MatplotlibComponent;
1 change: 1 addition & 0 deletions preswald/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
text,
text_input,
workflow_dag,
matplotlib
)
from .data import connect, get_df, query, view
from .workflow import RetryPolicy, Workflow, WorkflowAnalyzer
Expand Down
79 changes: 79 additions & 0 deletions preswald/interfaces/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import json
import logging
import uuid
import base64
import io

import numpy as np
import pandas as pd
from matplotlib.figure import Figure

from preswald.engine.service import PreswaldService

Expand Down Expand Up @@ -524,3 +527,79 @@ def separator():
component = {"type": "separator", "id": str(uuid.uuid4())}
service.append_component(component)
return component


class MatplotlibWrapper:
"""
A wrapper class for matplotlib functionality in Preswald.
"""
def __init__(self):
import matplotlib.pyplot as plt
self.plt = plt
self.format = 'png' # default format

def __call__(self, fig, label=None, format='png'):
"""
Original matplotlib function to display a figure
Args:
fig: matplotlib.figure.Figure object
format: 'png' or 'svg' (default: 'png')
"""
if not isinstance(fig, Figure):
raise TypeError("Expected matplotlib.figure.Figure, got {}".format(type(fig)))

if format not in ['png', 'svg']:
raise ValueError("Format must be 'png' or 'svg'")

service = PreswaldService.get_instance()
buf = io.BytesIO()

# Set figure DPI for better quality
if format == 'svg':
fig.savefig(buf, format=format, bbox_inches='tight')
else:
fig.savefig(buf, format=format, bbox_inches='tight', dpi=100)

buf.seek(0)

# For SVG, decode as UTF-8 string first
if format == 'svg':
img_data = base64.b64encode(buf.getvalue()).decode('utf-8')
else:
img_data = base64.b64encode(buf.getvalue()).decode('utf-8')

component = {
"type": "matplotlib",
"data": {
"image": img_data,
"format": format
}
}
service.append_component(component)
return component

def plot(self, *args, format='png', label=None, **kwargs):
"""
Create a line plot, similar to matplotlib.pyplot.plot()
Args:
*args: Arguments passed to plt.plot()
format: 'png' or 'svg' (default: 'png')
**kwargs: Keyword arguments passed to plt.plot()
"""
if format not in ['png', 'svg']:
raise ValueError("Format must be 'png' or 'svg'")

# Create a new figure if one doesn't exist
if not self.plt.get_fignums():
self.plt.figure(figsize=(10, 6))

# Create the plot
self.plt.plot(*args, label=label, **kwargs)

# Return the component
return self(self.plt.gcf(), label=label, format=format)

# Replace the function with an instance of the wrapper class
matplotlib = MatplotlibWrapper()

0 comments on commit 921d915

Please sign in to comment.