diff --git a/pandasai/core/code_execution/environment.py b/pandasai/core/code_execution/environment.py index f81fc3428..bf42c69ae 100644 --- a/pandasai/core/code_execution/environment.py +++ b/pandasai/core/code_execution/environment.py @@ -29,6 +29,8 @@ def get_environment() -> dict: "pd": import_dependency("pandas"), "plt": import_dependency("matplotlib.pyplot"), "np": import_dependency("numpy"), + "px": import_dependency("plotly.express"), + "go": import_dependency("plotly.graph_objects"), } return env diff --git a/pandasai/core/code_generation/code_cleaning.py b/pandasai/core/code_generation/code_cleaning.py index 244e653b1..3c6de674b 100644 --- a/pandasai/core/code_generation/code_cleaning.py +++ b/pandasai/core/code_generation/code_cleaning.py @@ -2,13 +2,11 @@ import os.path import re import uuid -from pathlib import Path import astor from pandasai.agent.state import AgentState from pandasai.constants import DEFAULT_CHART_DIRECTORY -from pandasai.core.code_execution.code_executor import CodeExecutor from pandasai.query_builders.sql_parser import SQLParser from ...exceptions import MaliciousQueryError @@ -146,9 +144,12 @@ def clean_code(self, code: str) -> str: tuple: Cleaned code as a string and a list of additional dependencies. """ code = self._replace_output_filenames_with_temp_chart(code) + code = self._replace_output_filenames_with_temp_json_chart(code) - # If plt.show is in the code, remove that line - code = re.sub(r"plt.show\(\)", "", code) + code = self._remove_make_dirs(code) + + # If plt.show or fig.show is in the code, remove that line + code = re.sub(r"\b(?:plt|fig)\.show\(\)", "", code) tree = ast.parse(code) new_body = [] @@ -180,3 +181,31 @@ def _replace_output_filenames_with_temp_chart(self, code: str) -> str: lambda m: f"{m.group(1)}{chart_path}{m.group(1)}", code, ) + + def _replace_output_filenames_with_temp_json_chart(self, code: str) -> str: + """ + Replace output file names with "temp_chart.json" (in case of usage of plotly). + """ + _id = uuid.uuid4() + chart_path = os.path.join(DEFAULT_CHART_DIRECTORY, f"temp_chart_{_id}.json") + chart_path = chart_path.replace("\\", "\\\\") + return re.sub( + r"""(['"])([^'"]*\.json)\1""", + lambda m: f"{m.group(1)}{chart_path}{m.group(1)}", + code, + ) + + def _remove_make_dirs(self, code: str) -> str: + """ + Remove any directory creation commands from the code. + """ + # Remove lines that create directories, except for the default chart directory DEFAULT_CHART_DIRECTORY + code_lines = code.splitlines() + cleaned_lines = [] + for line in code_lines: + if DEFAULT_CHART_DIRECTORY not in line and ( + "os.makedirs(" in line or "os.mkdir(" in line + ): + continue + cleaned_lines.append(line) + return "\n".join(cleaned_lines) diff --git a/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl b/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl index 53eb0e28e..c8c6f280d 100644 --- a/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl +++ b/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl @@ -29,4 +29,4 @@ At the end, declare "result" variable as a dictionary of type and value in the f Generate python code and return full updated code: -### Note: Use only relevant table for query and do aggregation, sorting, joins and grouby through sql query \ No newline at end of file +### Note: Use only relevant table for query and do aggregation, sorting, joins and group by through sql query \ No newline at end of file diff --git a/pandasai/core/prompts/templates/shared/output_type_template.tmpl b/pandasai/core/prompts/templates/shared/output_type_template.tmpl index c792693c4..be78693a0 100644 --- a/pandasai/core/prompts/templates/shared/output_type_template.tmpl +++ b/pandasai/core/prompts/templates/shared/output_type_template.tmpl @@ -1,5 +1,5 @@ {% if not output_type %} -type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } +type (possible values "string", "number", "dataframe", "plot", "iplot"). No other type available. "plot" is when "matplotlib" is used; "iplot" when "plotly" is used. Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } or { "type": "iplot", "value": "temp_chart.json" } {% elif output_type == "number" %} type (must be "number"), value must int. Example: { "type": "number", "value": 125 } {% elif output_type == "string" %} @@ -8,4 +8,6 @@ type (must be "string"), value must be string. Example: { "type": "string", "val type (must be "dataframe"), value must be pd.DataFrame or pd.Series. Example: { "type": "dataframe", "value": pd.DataFrame({...}) } {% elif output_type == "plot" %} type (must be "plot"), value must be string. Example: { "type": "plot", "value": "temp_chart.png" } +{% elif output_type == "iplot" %} +type (must be "iplot"), value must be string. Example: { "type": "iplot", "value": "temp_chart.json" } {% endif %} \ No newline at end of file diff --git a/pandasai/core/response/__init__.py b/pandasai/core/response/__init__.py index 4ac15d2b9..64b53286e 100644 --- a/pandasai/core/response/__init__.py +++ b/pandasai/core/response/__init__.py @@ -2,6 +2,7 @@ from .chart import ChartResponse from .dataframe import DataFrameResponse from .error import ErrorResponse +from .interactive_chart import InteractiveChartResponse from .number import NumberResponse from .parser import ResponseParser from .string import StringResponse @@ -10,6 +11,7 @@ "ResponseParser", "BaseResponse", "ChartResponse", + "InteractiveChartResponse", "DataFrameResponse", "NumberResponse", "StringResponse", diff --git a/pandasai/core/response/interactive_chart.py b/pandasai/core/response/interactive_chart.py new file mode 100644 index 000000000..9e0e46788 --- /dev/null +++ b/pandasai/core/response/interactive_chart.py @@ -0,0 +1,55 @@ +import json +import os +from typing import Any + +from .base import BaseResponse + + +class InteractiveChartResponse(BaseResponse): + def __init__(self, value: Any, last_code_executed: str): + super().__init__(value, "ichart", last_code_executed) + + self._validate() + + def _get_chart(self) -> dict: + if isinstance(self.value, dict): + return self.value + + if isinstance(self.value, str): + if os.path.exists(self.value): + with open(self.value, "rb") as f: + return json.load(f) + + return json.loads(self.value) + + raise ValueError( + "Invalid value type for InteractiveChartResponse. Expected dict or str." + ) + + def save(self, path: str): + img = self._get_chart() + with open(path, "w") as f: + json.dump(img, f) + + def __str__(self) -> str: + return self.value if isinstance(self.value, str) else json.dumps(self.value) + + def get_dict_image(self) -> dict: + return self._get_chart() + + def _validate(self): + if not isinstance(self.value, (dict, str)): + raise ValueError( + "InteractiveChartResponse value must be a dict or a str representing a file path." + ) + + # if a string, it can be a path to a file or a JSON string + if isinstance(self.value, str): + try: + json.loads(self.value) # Check if it's a valid JSON string + except json.JSONDecodeError: + # If it fails, check if it's a valid file path + if not os.path.exists(self.value): + raise ValueError( + "InteractiveChartResponse value must be a valid file path or a JSON string." + ) diff --git a/pandasai/core/response/parser.py b/pandasai/core/response/parser.py index f83fea313..70a9f2404 100644 --- a/pandasai/core/response/parser.py +++ b/pandasai/core/response/parser.py @@ -8,6 +8,7 @@ from .base import BaseResponse from .chart import ChartResponse from .dataframe import DataFrameResponse +from .interactive_chart import InteractiveChartResponse from .number import NumberResponse from .string import StringResponse @@ -26,6 +27,8 @@ def _generate_response(self, result: dict, last_code_executed: str = None): return DataFrameResponse(result["value"], last_code_executed) elif result["type"] == "plot": return ChartResponse(result["value"], last_code_executed) + elif result["type"] == "iplot": + return InteractiveChartResponse(result["value"], last_code_executed) else: raise InvalidOutputValueMismatch(f"Invalid output type: {result['type']}") @@ -72,4 +75,19 @@ def _validate_response(self, result: dict): "Invalid output: Expected a plot save path str but received an incompatible type." ) + elif result["type"] == "iplot": + if not isinstance(result["value"], (str, dict)): + raise InvalidOutputValueMismatch( + "Invalid output: Expected a plot save path str but received an incompatible type." + ) + + if isinstance(result["value"], dict): + return True + + path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$" + if not bool(re.match(path_to_plot_pattern, result["value"])): + raise InvalidOutputValueMismatch( + "Invalid output: Expected a plot save path str but received an incompatible type." + ) + return True diff --git a/poetry.lock b/poetry.lock index 9f6a1e918..0fff5a153 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -35,7 +35,7 @@ typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21.0b1) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\""] trio = ["trio (>=0.26.1)"] [[package]] @@ -206,7 +206,7 @@ files = [ [package.extras] dev = ["Pygments", "build", "chardet", "pre-commit", "pytest", "pytest-cov", "pytest-dependency", "ruff", "tomli", "twine"] hard-encoding-detection = ["chardet"] -toml = ["tomli"] +toml = ["tomli ; python_version < \"3.11\""] types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency"] [[package]] @@ -377,7 +377,7 @@ files = [ ] [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cycler" @@ -500,7 +500,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version <= \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -524,7 +524,7 @@ files = [ [package.extras] docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] -typing = ["typing-extensions (>=4.12.2)"] +typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""] [[package]] name = "fonttools" @@ -587,18 +587,18 @@ files = [ ] [package.extras] -all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +all = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\"", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0) ; python_version <= \"3.12\"", "xattr ; sys_platform == \"darwin\"", "zopfli (>=0.1.4)"] graphite = ["lz4 (>=1.7.4.2)"] -interpolatable = ["munkres", "pycairo", "scipy"] +interpolatable = ["munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\""] lxml = ["lxml (>=4.0)"] pathops = ["skia-pathops (>=0.5.0)"] plot = ["matplotlib"] repacker = ["uharfbuzz (>=0.23.0)"] symfont = ["sympy"] -type1 = ["xattr"] +type1 = ["xattr ; sys_platform == \"darwin\""] ufo = ["fs (>=2.2.0,<3)"] -unicode = ["unicodedata2 (>=15.1.0)"] -woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] +unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""] +woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"] [[package]] name = "h11" @@ -653,7 +653,7 @@ httpcore = "==1.*" idna = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -706,7 +706,7 @@ files = [ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] @@ -1092,6 +1092,58 @@ pillow = ">=6.2.0" pyparsing = ">=2.3.1" python-dateutil = ">=2.7" +[[package]] +name = "narwhals" +version = "1.42.1" +description = "Extremely lightweight compatibility layer between dataframe libraries" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "narwhals-1.42.1-py3-none-any.whl", hash = "sha256:7a270d44b94ccdb277a799ae890c42e8504c537c1849f195eb14717c6184977a"}, + {file = "narwhals-1.42.1.tar.gz", hash = "sha256:50a5635b11aeda98cf9c37e839fd34b0a24159f59a4dfae930290ad698320494"}, +] + +[package.extras] +cudf = ["cudf (>=24.10.0)"] +dask = ["dask[dataframe] (>=2024.8)"] +duckdb = ["duckdb (>=1.0)"] +ibis = ["ibis-framework (>=6.0.0)", "packaging", "pyarrow-hotfix", "rich"] +modin = ["modin"] +pandas = ["pandas (>=0.25.3)"] +polars = ["polars (>=0.20.3)"] +pyarrow = ["pyarrow (>=11.0.0)"] +pyspark = ["pyspark (>=3.5.0)"] +pyspark-connect = ["pyspark[connect] (>=3.5.0)"] +sqlframe = ["sqlframe (>=3.22.0)"] + +[[package]] +name = "narwhals" +version = "2.8.0" +description = "Extremely lightweight compatibility layer between dataframe libraries" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "narwhals-2.8.0-py3-none-any.whl", hash = "sha256:6304856676ba4a79fd34148bda63aed8060dd6edb1227edf3659ce5e091de73c"}, + {file = "narwhals-2.8.0.tar.gz", hash = "sha256:52e0b22d54718264ae703bd9293af53b04abc995a1414908c3b807ba8c913858"}, +] + +[package.extras] +cudf = ["cudf (>=24.10.0)"] +dask = ["dask[dataframe] (>=2024.8)"] +duckdb = ["duckdb (>=1.1)"] +ibis = ["ibis-framework (>=6.0.0)", "packaging", "pyarrow-hotfix", "rich"] +modin = ["modin"] +pandas = ["pandas (>=1.1.3)"] +polars = ["polars (>=0.20.4)"] +pyarrow = ["pyarrow (>=13.0.0)"] +pyspark = ["pyspark (>=3.5.0)"] +pyspark-connect = ["pyspark[connect] (>=3.5.0)"] +sqlframe = ["sqlframe (>=3.22.0,!=3.39.3)"] + [[package]] name = "nodeenv" version = "1.9.1" @@ -1234,7 +1286,7 @@ files = [ numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.21.0", markers = "python_version == \"3.10\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1358,7 +1410,7 @@ docs = ["furo", "olefile", "sphinx (>=7.3)", "sphinx-copybutton", "sphinx-inline fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] -typing = ["typing-extensions"] +typing = ["typing-extensions ; python_version < \"3.10\""] xmp = ["defusedxml"] [[package]] @@ -1378,6 +1430,30 @@ docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-a test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] type = ["mypy (>=1.11.2)"] +[[package]] +name = "plotly" +version = "6.3.1" +description = "An open-source interactive data visualization library for Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "plotly-6.3.1-py3-none-any.whl", hash = "sha256:8b4420d1dcf2b040f5983eed433f95732ed24930e496d36eb70d211923532e64"}, + {file = "plotly-6.3.1.tar.gz", hash = "sha256:dd896e3d940e653a7ce0470087e82c2bd903969a55e30d1b01bb389319461bb0"}, +] + +[package.dependencies] +narwhals = ">=1.15.1" +packaging = "*" + +[package.extras] +dev = ["plotly[dev-optional]"] +dev-build = ["build", "jupyter", "plotly[dev-core]"] +dev-core = ["pytest", "requests", "ruff (==0.11.12)"] +dev-optional = ["anywidget", "colorcet", "fiona (<=1.9.6) ; python_version <= \"3.8\"", "geopandas", "inflect", "numpy", "orjson", "pandas", "pdfrw", "pillow", "plotly-geo", "plotly[dev-build]", "plotly[kaleido]", "polars[timezone]", "pyarrow", "pyshp", "pytz", "scikit-image", "scipy", "shapely", "statsmodels", "vaex ; python_version <= \"3.9\"", "xarray"] +express = ["numpy"] +kaleido = ["kaleido (>=1.0.0)"] + [[package]] name = "pluggy" version = "1.5.0" @@ -1481,7 +1557,7 @@ typing-extensions = ">=4.12.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -1945,7 +2021,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version <= \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -2040,7 +2116,7 @@ files = [ ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -2064,7 +2140,7 @@ platformdirs = ">=3.9.1,<5" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] [[package]] name = "zipp" @@ -2080,14 +2156,14 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] -test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.8,<3.12" -content-hash = "1ee6133c27f712829a1606972a660f210112fcbf5f44193ccd77b179ab69e85d" +content-hash = "6fd75dfa8c4bd5bc74a6fc8900b7fc6bb9b05881458d1be83ba3baef750b48ad" diff --git a/pyproject.toml b/pyproject.toml index d14915923..f5a061445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ seaborn = "^0.12.2" sqlglot = "^25.0.3" pyarrow = ">=14.0.1,<19.0.0" pyyaml = "^6.0.2" +plotly = "^6.1.1" [tool.poetry.group.dev] optional = true diff --git a/tests/unit_tests/core/code_execution/test_environment.py b/tests/unit_tests/core/code_execution/test_environment.py index 2e77198ab..06f098647 100644 --- a/tests/unit_tests/core/code_execution/test_environment.py +++ b/tests/unit_tests/core/code_execution/test_environment.py @@ -18,6 +18,7 @@ def test_get_environment_with_secure_mode(self, mock_import_dependency): self.assertIn("pd", env) self.assertIn("plt", env) self.assertIn("np", env) + self.assertIn("px", env) @patch("pandasai.core.code_execution.environment.import_dependency") def test_get_environment_without_secure_mode(self, mock_import_dependency): @@ -28,6 +29,7 @@ def test_get_environment_without_secure_mode(self, mock_import_dependency): self.assertIn("pd", env) self.assertIn("plt", env) self.assertIn("np", env) + self.assertIn("px", env) self.assertIsInstance(env["pd"], MagicMock) @patch("pandasai.core.code_execution.environment.importlib.import_module") diff --git a/tests/unit_tests/core/code_generation/test_code_cleaning.py b/tests/unit_tests/core/code_generation/test_code_cleaning.py index f11e72f56..71b85fa82 100644 --- a/tests/unit_tests/core/code_generation/test_code_cleaning.py +++ b/tests/unit_tests/core/code_generation/test_code_cleaning.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock from pandasai.agent.state import AgentState +from pandasai.constants import DEFAULT_CHART_DIRECTORY from pandasai.core.code_generation.code_cleaning import CodeCleaner from pandasai.dataframe.base import DataFrame from pandasai.exceptions import MaliciousQueryError @@ -106,6 +107,23 @@ def test_replace_output_filenames_with_temp_chart(self): ) self.assertRegex(code, expected_pattern) + def test_replace_output_filenames_with_temp_json_chart(self): + handler = self.cleaner + handler.context = MagicMock() + handler.context.config.save_charts = True + handler.context.logger = MagicMock() # Mock logger + handler.context.last_prompt_id = 123 + handler.context.config.save_charts_path = "/custom/path" + + code = 'some text "hello.json" more text' + + code = handler._replace_output_filenames_with_temp_json_chart(code) + + expected_pattern = re.compile( + r'some text "exports[/\\]+charts[/\\]+temp_chart_.*\.json" more text' + ) + self.assertRegex(code, expected_pattern) + def test_replace_output_filenames_with_temp_chart_windows_paths(self): handler = self.cleaner handler.context = MagicMock() @@ -158,6 +176,18 @@ def test_replace_output_filenames_with_temp_chart_empty_code(self): result, expected_code, f"Expected '{expected_code}', but got '{result}'" ) + def test_replace_output_filenames_with_temp_json_chart_empty_code(self): + handler = self.cleaner + + code = "" + expected_code = "" # It should remain empty, as no substitution is made + + result = handler._replace_output_filenames_with_temp_json_chart(code) + + self.assertEqual( + result, expected_code, f"Expected '{expected_code}', but got '{result}'" + ) + def test_replace_output_filenames_with_temp_chart_no_png(self): handler = self.cleaner @@ -170,6 +200,43 @@ def test_replace_output_filenames_with_temp_chart_no_png(self): result, expected_code, f"Expected '{expected_code}', but got '{result}'" ) + def test_replace_output_filenames_with_temp_json_chart_no_json(self): + handler = self.cleaner + + code = "some text without json" + + result = handler._replace_output_filenames_with_temp_json_chart(code) + + self.assertEqual(result, code, f"Expected '{code}', but got '{result}'") + + def test_remove_make_dirs(self): + handler = self.cleaner + + code = "os.makedirs('/some/path')\nplt.show()\nfig.show()" + expected_code = "plt.show()\nfig.show()" # Should remove the os.makedirs line + result = handler._remove_make_dirs(code) + self.assertEqual( + result, expected_code, f"Expected '{expected_code}', but got '{result}'" + ) + + code = "os.mkdir('/some/path')\nplt.show()\nfig.show()" + expected_code = "plt.show()\nfig.show()" # Should remove the os.mkdir line + result = handler._remove_make_dirs(code) + self.assertEqual( + result, expected_code, f"Expected '{expected_code}', but got '{result}'" + ) + + def test_do_not_remove_make_default_chart_dir(self): + handler = self.cleaner + + code = f"os.makedirs('{DEFAULT_CHART_DIRECTORY}')\nplt.show()\nfig.show()" + result = handler._remove_make_dirs(code) + self.assertEqual(result, code, f"Expected '{code}', but got '{result}'") + + code = f"os.mkdir('{DEFAULT_CHART_DIRECTORY}')\nplt.show()\nfig.show()" + result = handler._remove_make_dirs(code) + self.assertEqual(result, code, f"Expected '{code}', but got '{result}'") + if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/helpers/test_optional_dependency.py b/tests/unit_tests/helpers/test_optional_dependency.py index 660acabd9..b14a5dcea 100644 --- a/tests/unit_tests/helpers/test_optional_dependency.py +++ b/tests/unit_tests/helpers/test_optional_dependency.py @@ -32,3 +32,4 @@ def test_env_for_necessary_deps(): assert "pd" in env assert "plt" in env assert "np" in env + assert "px" in env diff --git a/tests/unit_tests/helpers/test_responses.py b/tests/unit_tests/helpers/test_responses.py index 100898a0c..6b936e31f 100644 --- a/tests/unit_tests/helpers/test_responses.py +++ b/tests/unit_tests/helpers/test_responses.py @@ -4,11 +4,13 @@ from unittest.mock import MagicMock, patch import pandas as pd +import pytest from PIL import Image from pandasai.core.response import ( ChartResponse, DataFrameResponse, + InteractiveChartResponse, NumberResponse, StringResponse, ) @@ -55,6 +57,25 @@ def test_parse_valid_plot(self): self.assertEqual(response.last_code_executed, None) self.assertEqual(response.type, "chart") + def test_parse_valid_interactive_plot(self): + path = "path/to/interactive_plot.json" + # mock os.path.exists to return True for the plot path + with patch("os.path.exists", return_value=True) as mock_exists: + result = {"type": "iplot", "value": path} + response = self.response_parser.parse(result) + self.assertIsInstance(response, InteractiveChartResponse) + self.assertEqual(response.value, path) + self.assertEqual(response.last_code_executed, None) + self.assertEqual(response.type, "ichart") + + mock_exists.assert_called_once_with(path) + + def test_parse_invalid_interactive_plot_because_of_not_existing_file(self): + path = "path/to/interactive_plot.json" + with pytest.raises(ValueError): + result = {"type": "iplot", "value": path} + self.response_parser.parse(result) + def test_plot_img_show_triggered(self): result = { "type": "plot", diff --git a/tests/unit_tests/prompts/test_sql_prompt.py b/tests/unit_tests/prompts/test_sql_prompt.py index afcee67ed..03cba9fa2 100644 --- a/tests/unit_tests/prompts/test_sql_prompt.py +++ b/tests/unit_tests/prompts/test_sql_prompt.py @@ -21,7 +21,7 @@ class TestGeneratePythonCodeWithSQLPrompt: [ ( "", - """type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }""", + """type (possible values "string", "number", "dataframe", "plot", "iplot"). No other type available. "plot" is when "matplotlib" is used; "iplot" when "plotly" is used. Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } or { "type": "iplot", "value": "temp_chart.json" }""", ), ( "number", @@ -35,6 +35,10 @@ class TestGeneratePythonCodeWithSQLPrompt: "plot", """type (must be "plot"), value must be string. Example: { "type": "plot", "value": "temp_chart.png" }""", ), + ( + "iplot", + """type (must be "iplot"), value must be string. Example: { "type": "iplot", "value": "temp_chart.json" }""", + ), ( "string", """type (must be "string"), value must be string. Example: { "type": "string", "value": f"The highest salary is {highest_salary}." }""", @@ -103,5 +107,5 @@ def execute_sql_query(sql_query: str) -> pd.DataFrame Generate python code and return full updated code: -### Note: Use only relevant table for query and do aggregation, sorting, joins and grouby through sql query''' # noqa: E501 +### Note: Use only relevant table for query and do aggregation, sorting, joins and group by through sql query''' # noqa: E501 ) diff --git a/tests/unit_tests/response/test_interactive_chart_response.py b/tests/unit_tests/response/test_interactive_chart_response.py new file mode 100644 index 000000000..bb610ad3c --- /dev/null +++ b/tests/unit_tests/response/test_interactive_chart_response.py @@ -0,0 +1,87 @@ +import base64 +import io +import json + +import pytest + +from pandasai.core.response.interactive_chart import InteractiveChartResponse + + +@pytest.fixture +def sample_json(): + # Create a small test plotly dictionary + return { + "data": [ + { + "x": [1, 2, 3], + "y": [4, 5, 6], + "type": "scatter", + "mode": "lines+markers", + "marker": {"color": "red"}, + } + ], + "layout": { + "title": "Test Chart", + "xaxis": {"title": "X Axis"}, + "yaxis": {"title": "Y Axis"}, + }, + "config": { + "responsive": True, + "displayModeBar": True, + "showSendToCloud": False, + }, + "image": { + "width": 100, + "height": 100, + "format": "png", + "data": base64.b64encode(io.BytesIO(b"test_image_data").getvalue()).decode( + "utf-8" + ), + }, + } + + +@pytest.fixture +def interactive_chart_response(sample_json): + return InteractiveChartResponse(sample_json, "test_code") + + +def test_interactive_chart_response_initialization(interactive_chart_response): + assert interactive_chart_response.type == "ichart" + assert interactive_chart_response.last_code_executed == "test_code" + + +def test_get_interactive_chart_from_json(interactive_chart_response): + chart = interactive_chart_response._get_chart() + assert isinstance(chart, dict) + assert chart["image"]["width"] == 100 + + +def test_get_interactive_chart_from_string(sample_json): + response = InteractiveChartResponse(json.dumps(sample_json), "test_code") + chart = response._get_chart() + assert isinstance(chart, dict) + assert chart["image"]["width"] == 100 + + +def test_get_interactive_chart_from_unsupported_format(): + with pytest.raises(ValueError): + response = InteractiveChartResponse(1, "test_code") + response._get_chart() + + +def test_save_interactive_chart(interactive_chart_response, tmp_path): + output_path = tmp_path / "output.json" + interactive_chart_response.save(str(output_path)) + assert output_path.exists() + + +def test_get_dict_interactive_chart(interactive_chart_response): + chart_dict = interactive_chart_response.get_dict_image() + assert isinstance(chart_dict, dict) + assert "data" in chart_dict + assert "layout" in chart_dict + assert "config" in chart_dict + assert "image" in chart_dict + assert isinstance(chart_dict["image"], dict) + assert "data" in chart_dict["image"]