Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandasai/core/code_execution/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def get_environment() -> dict:
"pd": import_dependency("pandas"),
"plt": import_dependency("matplotlib.pyplot"),
"np": import_dependency("numpy"),
"px": import_dependency("plotly.express"),
"go": import_dependency("plotly.graph_objects"),
}

return env
Expand Down
37 changes: 33 additions & 4 deletions pandasai/core/code_generation/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import os.path
import re
import uuid
from pathlib import Path

import astor

from pandasai.agent.state import AgentState
from pandasai.constants import DEFAULT_CHART_DIRECTORY
from pandasai.core.code_execution.code_executor import CodeExecutor
from pandasai.query_builders.sql_parser import SQLParser

from ...exceptions import MaliciousQueryError
Expand Down Expand Up @@ -146,9 +144,12 @@ def clean_code(self, code: str) -> str:
tuple: Cleaned code as a string and a list of additional dependencies.
"""
code = self._replace_output_filenames_with_temp_chart(code)
code = self._replace_output_filenames_with_temp_json_chart(code)

# If plt.show is in the code, remove that line
code = re.sub(r"plt.show\(\)", "", code)
code = self._remove_make_dirs(code)

# If plt.show or fig.show is in the code, remove that line
code = re.sub(r"\b(?:plt|fig)\.show\(\)", "", code)

tree = ast.parse(code)
new_body = []
Expand Down Expand Up @@ -180,3 +181,31 @@ def _replace_output_filenames_with_temp_chart(self, code: str) -> str:
lambda m: f"{m.group(1)}{chart_path}{m.group(1)}",
code,
)

def _replace_output_filenames_with_temp_json_chart(self, code: str) -> str:
"""
Replace output file names with "temp_chart.json" (in case of usage of plotly).
"""
_id = uuid.uuid4()
chart_path = os.path.join(DEFAULT_CHART_DIRECTORY, f"temp_chart_{_id}.json")
chart_path = chart_path.replace("\\", "\\\\")
return re.sub(
r"""(['"])([^'"]*\.json)\1""",
lambda m: f"{m.group(1)}{chart_path}{m.group(1)}",
code,
)

def _remove_make_dirs(self, code: str) -> str:
"""
Remove any directory creation commands from the code.
"""
# Remove lines that create directories, except for the default chart directory DEFAULT_CHART_DIRECTORY
code_lines = code.splitlines()
cleaned_lines = []
for line in code_lines:
if DEFAULT_CHART_DIRECTORY not in line and (
"os.makedirs(" in line or "os.mkdir(" in line
):
continue
cleaned_lines.append(line)
return "\n".join(cleaned_lines)
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ At the end, declare "result" variable as a dictionary of type and value in the f

Generate python code and return full updated code:

### Note: Use only relevant table for query and do aggregation, sorting, joins and grouby through sql query
### Note: Use only relevant table for query and do aggregation, sorting, joins and group by through sql query
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% if not output_type %}
type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }
type (possible values "string", "number", "dataframe", "plot", "iplot"). No other type available. "plot" is when "matplotlib" is used; "iplot" when "plotly" is used. Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } or { "type": "iplot", "value": "temp_chart.json" }
{% elif output_type == "number" %}
type (must be "number"), value must int. Example: { "type": "number", "value": 125 }
{% elif output_type == "string" %}
Expand All @@ -8,4 +8,6 @@ type (must be "string"), value must be string. Example: { "type": "string", "val
type (must be "dataframe"), value must be pd.DataFrame or pd.Series. Example: { "type": "dataframe", "value": pd.DataFrame({...}) }
{% elif output_type == "plot" %}
type (must be "plot"), value must be string. Example: { "type": "plot", "value": "temp_chart.png" }
{% elif output_type == "iplot" %}
type (must be "iplot"), value must be string. Example: { "type": "iplot", "value": "temp_chart.json" }
{% endif %}
2 changes: 2 additions & 0 deletions pandasai/core/response/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .chart import ChartResponse
from .dataframe import DataFrameResponse
from .error import ErrorResponse
from .interactive_chart import InteractiveChartResponse
from .number import NumberResponse
from .parser import ResponseParser
from .string import StringResponse
Expand All @@ -10,6 +11,7 @@
"ResponseParser",
"BaseResponse",
"ChartResponse",
"InteractiveChartResponse",
"DataFrameResponse",
"NumberResponse",
"StringResponse",
Expand Down
55 changes: 55 additions & 0 deletions pandasai/core/response/interactive_chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
import os
from typing import Any

from .base import BaseResponse


class InteractiveChartResponse(BaseResponse):
def __init__(self, value: Any, last_code_executed: str):
super().__init__(value, "ichart", last_code_executed)

self._validate()

def _get_chart(self) -> dict:
if isinstance(self.value, dict):
return self.value

if isinstance(self.value, str):
if os.path.exists(self.value):
with open(self.value, "rb") as f:
return json.load(f)

return json.loads(self.value)

raise ValueError(
"Invalid value type for InteractiveChartResponse. Expected dict or str."
)

def save(self, path: str):
img = self._get_chart()
with open(path, "w") as f:
json.dump(img, f)

def __str__(self) -> str:
return self.value if isinstance(self.value, str) else json.dumps(self.value)

def get_dict_image(self) -> dict:
return self._get_chart()

def _validate(self):
if not isinstance(self.value, (dict, str)):
raise ValueError(
"InteractiveChartResponse value must be a dict or a str representing a file path."
)

# if a string, it can be a path to a file or a JSON string
if isinstance(self.value, str):
try:
json.loads(self.value) # Check if it's a valid JSON string
except json.JSONDecodeError:
# If it fails, check if it's a valid file path
if not os.path.exists(self.value):
raise ValueError(
"InteractiveChartResponse value must be a valid file path or a JSON string."
)
18 changes: 18 additions & 0 deletions pandasai/core/response/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .base import BaseResponse
from .chart import ChartResponse
from .dataframe import DataFrameResponse
from .interactive_chart import InteractiveChartResponse
from .number import NumberResponse
from .string import StringResponse

Expand All @@ -26,6 +27,8 @@ def _generate_response(self, result: dict, last_code_executed: str = None):
return DataFrameResponse(result["value"], last_code_executed)
elif result["type"] == "plot":
return ChartResponse(result["value"], last_code_executed)
elif result["type"] == "iplot":
return InteractiveChartResponse(result["value"], last_code_executed)
else:
raise InvalidOutputValueMismatch(f"Invalid output type: {result['type']}")

Expand Down Expand Up @@ -72,4 +75,19 @@ def _validate_response(self, result: dict):
"Invalid output: Expected a plot save path str but received an incompatible type."
)

elif result["type"] == "iplot":
if not isinstance(result["value"], (str, dict)):
raise InvalidOutputValueMismatch(
"Invalid output: Expected a plot save path str but received an incompatible type."
)

if isinstance(result["value"], dict):
return True

path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$"
if not bool(re.match(path_to_plot_pattern, result["value"])):
raise InvalidOutputValueMismatch(
"Invalid output: Expected a plot save path str but received an incompatible type."
)

return True
Loading
Loading