diff --git a/Dockerfile b/Dockerfile
index 859e82b40..7dcafb340 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -12,28 +12,52 @@ FROM python:${PYVERSION}-${BASE_IMAGE}
COPY --from=jl /usr/local/julia /usr/local/julia
ENV PATH="/usr/local/julia/bin:${PATH}"
-# Install IPython and other useful libraries:
-RUN pip install --no-cache-dir ipython matplotlib
+# # Install font used for GUI
+# RUN mkdir -p /usr/local/share/fonts/IBM_Plex_Mono && \
+# curl -L https://github.com/IBM/plex/releases/download/v6.4.0/IBM-Plex-Mono.zip -o /tmp/IBM_Plex_Mono.zip && \
+# unzip /tmp/IBM_Plex_Mono.zip -d /usr/local/share/fonts/IBM_Plex_Mono && \
+# rm /tmp/IBM_Plex_Mono.zip
+# RUN fc-cache -f -v
-WORKDIR /pysr
+# Set up a new user named "user" with user ID 1000
+RUN useradd -m -u 1000 user
+USER user
+WORKDIR /home/user/
+ENV HOME=/home/user
+ENV PATH=/home/user/.local/bin:$PATH
-# Caches install (https://stackoverflow.com/questions/25305788/how-to-avoid-reinstalling-packages-when-building-docker-image-for-python-project)
-ADD ./requirements.txt /pysr/requirements.txt
-RUN pip3 install --no-cache-dir -r /pysr/requirements.txt
+RUN python -m venv $HOME/.venv
-# Install PySR:
-# We do a minimal copy so it doesn't need to rerun at every file change:
-ADD ./pyproject.toml /pysr/pyproject.toml
-ADD ./setup.py /pysr/setup.py
-ADD ./pysr /pysr/pysr
-RUN pip3 install --no-cache-dir .
+ENV PYTHON="${HOME}/.venv/bin/python"
+ENV PIP="${PYTHON} -m pip"
+ENV PATH="${HOME}/.venv/bin:${PATH}"
+
+WORKDIR $HOME/pysr
+
+# Install all requirements, and then PySR itself
+COPY --chown=user ./requirements.txt $HOME/pysr/requirements.txt
+RUN $PIP install --no-cache-dir -r $HOME/pysr/requirements.txt
+
+COPY --chown=user ./pyproject.toml $HOME/pysr/pyproject.toml
+COPY --chown=user ./setup.py $HOME/pysr/setup.py
+COPY --chown=user ./pysr $HOME/pysr/pysr
+RUN $PIP install --no-cache-dir ".[gui]"
# Install Julia pre-requisites:
-RUN python3 -c 'import pysr'
+RUN $PYTHON -c 'import pysr'
+
+COPY --chown=user ./gui/*.py $HOME/pysr/gui/
+
+EXPOSE 7860
+ENV GRADIO_ALLOW_FLAGGING=never \
+ GRADIO_NUM_PORTS=1 \
+ GRADIO_SERVER_NAME=0.0.0.0 \
+ GRADIO_THEME=huggingface \
+ SYSTEM=spaces
# metainformation
LABEL org.opencontainers.image.authors = "Miles Cranmer"
LABEL org.opencontainers.image.source = "https://github.com/MilesCranmer/PySR"
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
-CMD ["ipython"]
+CMD ["/home/user/.venv/bin/python", "/home/user/pysr/gui/app.py"]
diff --git a/README.md b/README.md
index 67493c28d..3cb4504df 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,13 @@
+---
+title: PySR
+emoji: 🌍
+colorFrom: green
+colorTo: indigo
+sdk: docker
+pinned: false
+license: apache-2.0
+---
+
[//]: # (Logo:)
diff --git a/gui/app.py b/gui/app.py
new file mode 100644
index 000000000..dee4680f4
--- /dev/null
+++ b/gui/app.py
@@ -0,0 +1,300 @@
+from collections import OrderedDict
+
+import gradio as gr
+import numpy as np
+from data import TEST_EQUATIONS
+from gradio.components.base import Component
+from plots import plot_example_data, plot_pareto_curve
+from processing import processing, stop
+
+
+class ExampleData:
+ def __init__(self, demo: gr.Blocks) -> None:
+ with gr.Column(scale=1):
+ self.example_plot = gr.Plot()
+ with gr.Column(scale=1):
+ self.test_equation = gr.Radio(
+ TEST_EQUATIONS, value=TEST_EQUATIONS[0], label="Test Equation"
+ )
+ self.num_points = gr.Slider(
+ minimum=10,
+ maximum=1000,
+ value=200,
+ label="Number of Data Points",
+ step=1,
+ )
+ self.noise_level = gr.Slider(
+ minimum=0, maximum=1, value=0.05, label="Noise Level"
+ )
+ self.data_seed = gr.Number(value=0, label="Random Seed")
+
+ # Set up plotting:
+
+ eqn_components = [
+ self.test_equation,
+ self.num_points,
+ self.noise_level,
+ self.data_seed,
+ ]
+ for eqn_component in eqn_components:
+ eqn_component.change(
+ plot_example_data,
+ eqn_components,
+ self.example_plot,
+ show_progress=False,
+ )
+
+ demo.load(plot_example_data, eqn_components, self.example_plot)
+
+
+class UploadData:
+ def __init__(self) -> None:
+ self.file_input = gr.File(label="Upload a CSV File")
+ self.label = gr.Markdown(
+ "The rightmost column of your CSV file will be used as the target variable."
+ )
+
+
+class Data:
+ def __init__(self, demo: gr.Blocks) -> None:
+ with gr.Tab("Example Data"):
+ self.example_data = ExampleData(demo)
+ with gr.Tab("Upload Data"):
+ self.upload_data = UploadData()
+
+
+class BasicSettings:
+ def __init__(self) -> None:
+ self.binary_operators = gr.CheckboxGroup(
+ choices=["+", "-", "*", "/", "^", "max", "min", "mod", "cond"],
+ label="Binary Operators",
+ value=["+", "-", "*", "/"],
+ )
+ self.unary_operators = gr.CheckboxGroup(
+ choices=[
+ "sin",
+ "cos",
+ "tan",
+ "exp",
+ "log",
+ "square",
+ "cube",
+ "sqrt",
+ "abs",
+ "erf",
+ "relu",
+ "round",
+ "sign",
+ ],
+ label="Unary Operators",
+ value=["sin"],
+ )
+ self.niterations = gr.Slider(
+ minimum=1,
+ maximum=1000,
+ value=40,
+ label="Number of Iterations",
+ step=1,
+ )
+ self.maxsize = gr.Slider(
+ minimum=7,
+ maximum=100,
+ value=20,
+ label="Maximum Complexity",
+ step=1,
+ )
+ self.parsimony = gr.Number(
+ value=0.0032,
+ label="Parsimony Coefficient",
+ )
+
+
+class AdvancedSettings:
+ def __init__(self) -> None:
+ self.populations = gr.Slider(
+ minimum=2,
+ maximum=100,
+ value=15,
+ label="Number of Populations",
+ step=1,
+ )
+ self.population_size = gr.Slider(
+ minimum=2,
+ maximum=1000,
+ value=33,
+ label="Population Size",
+ step=1,
+ )
+ self.ncycles_per_iteration = gr.Number(
+ value=550,
+ label="Cycles per Iteration",
+ )
+ self.elementwise_loss = gr.Radio(
+ ["L2DistLoss()", "L1DistLoss()", "LogitDistLoss()", "HuberLoss()"],
+ value="L2DistLoss()",
+ label="Loss Function",
+ )
+ self.adaptive_parsimony_scaling = gr.Number(
+ value=20.0,
+ label="Adaptive Parsimony Scaling",
+ )
+ self.optimizer_algorithm = gr.Radio(
+ ["BFGS", "NelderMead"],
+ value="BFGS",
+ label="Optimizer Algorithm",
+ )
+ self.optimizer_iterations = gr.Slider(
+ minimum=1,
+ maximum=100,
+ value=8,
+ label="Optimizer Iterations",
+ step=1,
+ )
+ self.batching = gr.Checkbox(
+ value=False,
+ label="Batching",
+ )
+ self.batch_size = gr.Slider(
+ minimum=2,
+ maximum=1000,
+ value=50,
+ label="Batch Size",
+ step=1,
+ )
+
+
+class GradioSettings:
+ def __init__(self) -> None:
+ self.plot_update_delay = gr.Slider(
+ minimum=1,
+ maximum=100,
+ value=3,
+ label="Plot Update Delay",
+ )
+ self.force_run = gr.Checkbox(
+ value=False,
+ label="Ignore Warnings",
+ )
+
+
+class Settings:
+ def __init__(self):
+ with gr.Tab("Basic Settings"):
+ self.basic_settings = BasicSettings()
+ with gr.Tab("Advanced Settings"):
+ self.advanced_settings = AdvancedSettings()
+ with gr.Tab("Gradio Settings"):
+ self.gradio_settings = GradioSettings()
+
+
+class Results:
+ def __init__(self):
+ with gr.Tab("Pareto Front"):
+ self.pareto = gr.Plot()
+ with gr.Tab("Predictions"):
+ self.predictions_plot = gr.Plot()
+
+ self.df = gr.Dataframe(
+ headers=["complexity", "loss", "equation"],
+ datatype=["number", "number", "str"],
+ wrap=True,
+ column_widths=[75, 75, 200],
+ interactive=False,
+ )
+
+ self.messages = gr.Textbox(label="Messages", value="", interactive=False)
+
+
+def flatten_attributes(
+ component_group, absolute_name: str, d: OrderedDict
+) -> OrderedDict:
+ if not hasattr(component_group, "__dict__"):
+ return d
+
+ for name, elem in component_group.__dict__.items():
+ new_absolute_name = absolute_name + "." + name
+ if name.startswith("_"):
+ # Private attribute
+ continue
+ elif elem in d.values():
+ # Don't duplicate any tiems
+ continue
+ elif isinstance(elem, Component):
+ # Only add components to dict
+ d[new_absolute_name] = elem
+ else:
+ flatten_attributes(elem, new_absolute_name, d)
+
+ return d
+
+
+class AppInterface:
+ def __init__(self, demo: gr.Blocks) -> None:
+ with gr.Row():
+ with gr.Column(scale=2):
+ with gr.Row():
+ self.data = Data(demo)
+ with gr.Row():
+ self.settings = Settings()
+ with gr.Column(scale=2):
+ self.results = Results()
+ with gr.Row():
+ with gr.Column(scale=1):
+ self.stop = gr.Button(value="Stop")
+ with gr.Column(scale=1, min_width=200):
+ self.run = gr.Button()
+
+ # Update plot when dataframe is updated:
+ self.results.df.change(
+ plot_pareto_curve,
+ inputs=[self.results.df, self.settings.basic_settings.maxsize],
+ outputs=[self.results.pareto],
+ show_progress=False,
+ )
+
+ ignore = ["df", "predictions_plot", "pareto", "messages"]
+ self.run.click(
+ create_processing_function(self, ignore=ignore),
+ inputs=[
+ v
+ for k, v in flatten_attributes(self, "interface", OrderedDict()).items()
+ if last_part(k) not in ignore
+ ],
+ outputs=[
+ self.results.df,
+ self.results.predictions_plot,
+ self.results.messages,
+ ],
+ show_progress=True,
+ )
+ self.stop.click(stop)
+
+
+def last_part(k: str) -> str:
+ return k.split(".")[-1]
+
+
+def create_processing_function(interface: AppInterface, ignore=[]):
+ d = flatten_attributes(interface, "interface", OrderedDict())
+ keys = [k for k in map(last_part, d.keys()) if k not in ignore]
+ _, idx, counts = np.unique(keys, return_index=True, return_counts=True)
+ if np.any(counts > 1):
+ raise AssertionError("Bad keys: " + ",".join(np.array(keys)[idx[counts > 1]]))
+
+ def f(*components):
+ n = len(components)
+ assert n == len(keys)
+ for output in processing(**{keys[i]: components[i] for i in range(n)}):
+ yield output
+
+ return f
+
+
+def main():
+ with gr.Blocks(theme="default") as demo:
+ _ = AppInterface(demo)
+ demo.launch(debug=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gui/data.py b/gui/data.py
new file mode 100644
index 000000000..02a6d1424
--- /dev/null
+++ b/gui/data.py
@@ -0,0 +1,44 @@
+import numpy as np
+import pandas as pd
+
+TEST_EQUATIONS = ["sin(2*x)/x + 0.1*x"]
+
+
+def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
+ rstate = np.random.RandomState(data_seed)
+ x = rstate.uniform(-10, 10, num_points)
+ for k, v in {
+ "sin": "np.sin",
+ "cos": "np.cos",
+ "exp": "np.exp",
+ "log": "np.log",
+ "tan": "np.tan",
+ "^": "**",
+ }.items():
+ s = s.replace(k, v)
+ y = eval(s)
+ noise = rstate.normal(0, noise_level, y.shape)
+ y_noisy = y + noise
+ return pd.DataFrame({"x": x}), y_noisy
+
+
+def read_csv(file_input: str, force_run: bool):
+ # Look at some statistics of the file:
+ df = pd.read_csv(file_input)
+ if len(df) == 0:
+ raise ValueError("The file is empty!")
+ if len(df.columns) == 1:
+ raise ValueError("The file has only one column!")
+ if len(df) > 10_000 and not force_run:
+ raise ValueError(
+ "You have uploaded a file with more than 10,000 rows. "
+ "This will take very long to run. "
+ "Please upload a subsample of the data, "
+ "or check the box 'Ignore Warnings'.",
+ )
+
+ col_to_fit = df.columns[-1]
+ y = np.array(df[col_to_fit])
+ X = df.drop([col_to_fit], axis=1)
+
+ return X, y
diff --git a/gui/plots.py b/gui/plots.py
new file mode 100644
index 000000000..21f97ead1
--- /dev/null
+++ b/gui/plots.py
@@ -0,0 +1,94 @@
+import matplotlib
+
+matplotlib.use("agg")
+
+import numpy as np
+import pandas as pd
+from matplotlib import pyplot as plt
+
+plt.ioff()
+plt.rcParams["font.family"] = "monospace"
+# plt.rcParams["font.family"] = [
+# "IBM Plex Mono",
+# # Fallback fonts:
+# "DejaVu Sans Mono",
+# "Courier New",
+# "monospace",
+# ]
+
+from data import generate_data
+
+
+def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
+ fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
+
+ if len(df) == 0 or "Equation" not in df.columns:
+ return fig
+
+ ax.loglog(
+ df["Complexity"],
+ df["Loss"],
+ marker="o",
+ linestyle="-",
+ color="#333f48",
+ linewidth=1.5,
+ markersize=6,
+ )
+
+ ax.set_xlim(0.5, maxsize + 1)
+ ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
+ ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
+ ax.set_ylim(ybottom, ytop)
+
+ stylize_axis(ax)
+
+ ax.set_xlabel("Complexity")
+ ax.set_ylabel("Loss")
+ fig.tight_layout(pad=2)
+
+ return fig
+
+
+def plot_example_data(test_equation, num_points, noise_level, data_seed):
+ fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
+
+ X, y = generate_data(test_equation, num_points, noise_level, data_seed)
+ x = X["x"]
+
+ ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
+
+ stylize_axis(ax)
+
+ ax.set_xlabel("x")
+ ax.set_ylabel("y")
+ fig.tight_layout(pad=2)
+
+ return fig
+
+
+def plot_predictions(y, ypred):
+ fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
+
+ ax.scatter(y, ypred, alpha=0.7, edgecolors="w", s=50)
+
+ stylize_axis(ax)
+
+ ax.set_xlabel("true")
+ ax.set_ylabel("prediction")
+ fig.tight_layout(pad=2)
+
+ return fig
+
+
+def stylize_axis(ax):
+ ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
+ ax.spines["top"].set_visible(False)
+ ax.spines["right"].set_visible(False)
+
+ # Range-frame the plot
+ for direction in ["bottom", "left"]:
+ ax.spines[direction].set_position(("outward", 10))
+
+ # Delete far ticks
+ ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
+ ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
diff --git a/gui/processing.py b/gui/processing.py
new file mode 100644
index 000000000..74947bfc6
--- /dev/null
+++ b/gui/processing.py
@@ -0,0 +1,239 @@
+import multiprocessing as mp
+import os
+import tempfile
+import time
+from pathlib import Path
+from typing import Callable
+
+import numpy as np
+import pandas as pd
+from data import generate_data, read_csv
+from plots import plot_predictions
+
+
+def empty_df():
+ return pd.DataFrame(
+ {
+ "Equation": [],
+ "Loss": [],
+ "Complexity": [],
+ }
+ )
+
+
+def pysr_fit(queue: mp.Queue, out_queue: mp.Queue):
+ import pysr
+
+ while True:
+ # Get the arguments from the queue, if available
+ args = queue.get()
+ if args is None:
+ break
+ X = args["X"]
+ y = args["y"]
+ kwargs = args["kwargs"]
+ model = pysr.PySRRegressor(
+ progress=False,
+ timeout_in_seconds=1000,
+ **kwargs,
+ )
+ model.fit(X, y)
+ out_queue.put(None)
+
+
+def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
+ while True:
+ args = queue.get()
+
+ if args is None:
+ break
+
+ X = args["X"]
+ equation_file = str(args["equation_file"])
+ index = args["index"]
+
+ equation_file_pkl = equation_file.replace(".csv", ".pkl")
+ equation_file_bkup = equation_file + ".bkup"
+
+ equation_file_copy = equation_file.replace(".csv", "_copy.csv")
+ equation_file_pkl_copy = equation_file.replace(".csv", "_copy.pkl")
+
+ # TODO: See if there is way to get lock on file
+ os.system(f"cp {equation_file_bkup} {equation_file_copy}")
+ os.system(f"cp {equation_file_pkl} {equation_file_pkl_copy}")
+
+ # Note that we import pysr late in this process to avoid
+ # pre-compiling the code in two places at once
+ import pysr
+
+ try:
+ model = pysr.PySRRegressor.from_file(equation_file_pkl_copy, verbosity=0)
+ except pd.errors.EmptyDataError:
+ continue
+
+ ypred = model.predict(X, index)
+
+ # Rename the columns to uppercase
+ equations = model.equations_[["complexity", "loss", "equation"]].copy()
+
+ # Remove any row that has worse loss than previous row:
+ equations = equations[equations["loss"].cummin() == equations["loss"]]
+ # TODO: Why is this needed? Are rows not being removed?
+
+ equations.columns = ["Complexity", "Loss", "Equation"]
+ out_queue.put(dict(ypred=ypred, equations=equations))
+
+
+class ProcessWrapper:
+ def __init__(self, target: Callable[[mp.Queue, mp.Queue], None]):
+ self.queue = mp.Queue(maxsize=1)
+ self.out_queue = mp.Queue(maxsize=1)
+ self.process = mp.Process(target=target, args=(self.queue, self.out_queue))
+ self.process.start()
+
+
+ACTIVE_PROCESS = None
+
+
+def _random_string():
+ return "".join(list(np.random.choice("abcdefghijklmnopqrstuvwxyz".split(), 16)))
+
+
+def processing(
+ *,
+ file_input,
+ force_run,
+ test_equation,
+ num_points,
+ noise_level,
+ data_seed,
+ niterations,
+ maxsize,
+ binary_operators,
+ unary_operators,
+ plot_update_delay,
+ parsimony,
+ populations,
+ population_size,
+ ncycles_per_iteration,
+ elementwise_loss,
+ adaptive_parsimony_scaling,
+ optimizer_algorithm,
+ optimizer_iterations,
+ batching,
+ batch_size,
+ **kwargs,
+):
+ # random string:
+ global ACTIVE_PROCESS
+ cur_process = _random_string()
+ ACTIVE_PROCESS = cur_process
+
+ """Load data, then spawn a process to run the greet function."""
+ print("Starting PySR fit process")
+ writer = ProcessWrapper(pysr_fit)
+
+ print("Starting PySR predict process")
+ reader = ProcessWrapper(pysr_predict)
+
+ if file_input is not None:
+ try:
+ X, y = read_csv(file_input, force_run)
+ except ValueError as e:
+ return (empty_df(), plot_predictions([], []), str(e))
+ else:
+ X, y = generate_data(test_equation, num_points, noise_level, data_seed)
+
+ tmpdirname = tempfile.mkdtemp()
+ base = Path(tmpdirname)
+ equation_file = base / "hall_of_fame.csv"
+ # Check if queue is empty, if not, kill the process
+ # and start a new one
+ if not writer.queue.empty():
+ print("Restarting PySR fit process")
+ if writer.process.is_alive():
+ writer.process.terminate()
+ writer.process.join()
+
+ writer = ProcessWrapper(pysr_fit)
+
+ if not reader.queue.empty():
+ print("Restarting PySR predict process")
+ if reader.process.is_alive():
+ reader.process.terminate()
+ reader.process.join()
+
+ reader = ProcessWrapper(pysr_predict)
+
+ writer.queue.put(
+ dict(
+ X=X,
+ y=y,
+ kwargs=dict(
+ niterations=niterations,
+ maxsize=maxsize,
+ binary_operators=binary_operators,
+ unary_operators=unary_operators,
+ equation_file=equation_file,
+ parsimony=parsimony,
+ populations=populations,
+ population_size=population_size,
+ ncycles_per_iteration=ncycles_per_iteration,
+ elementwise_loss=elementwise_loss,
+ adaptive_parsimony_scaling=adaptive_parsimony_scaling,
+ optimizer_algorithm=optimizer_algorithm,
+ optimizer_iterations=optimizer_iterations,
+ batching=batching,
+ batch_size=batch_size,
+ ),
+ )
+ )
+
+ last_yield = (
+ pd.DataFrame({"Complexity": [], "Loss": [], "Equation": []}),
+ plot_predictions([], []),
+ "Started!",
+ )
+
+ yield last_yield
+
+ while writer.out_queue.empty():
+ if (
+ equation_file.exists()
+ and Path(str(equation_file).replace(".csv", ".pkl")).exists()
+ ):
+ # First, copy the file to a the copy file
+ reader.queue.put(
+ dict(
+ X=X,
+ equation_file=equation_file,
+ index=-1,
+ )
+ )
+ out = reader.out_queue.get()
+ predictions = out["ypred"]
+ equations = out["equations"]
+ last_yield = (
+ equations[["Complexity", "Loss", "Equation"]],
+ plot_predictions(y, predictions),
+ "Running...",
+ )
+ yield last_yield
+
+ if cur_process != ACTIVE_PROCESS:
+ # Kill both reader and writer
+ writer.process.kill()
+ reader.process.kill()
+ yield (*last_yield[:-1], "Stopped.")
+ return
+
+ time.sleep(0.1)
+
+ yield (*last_yield[:-1], "Done.")
+ return
+
+
+def stop():
+ global ACTIVE_PROCESS
+ ACTIVE_PROCESS = None
+ return
diff --git a/pyproject.toml b/pyproject.toml
index 4b0f45ea2..b13e54d93 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,6 +19,12 @@ classifiers = [
]
dynamic = ["dependencies"]
+[project.optional-dependencies]
+gui = [
+ "gradio>=4.0.0,<5.0.0",
+ "matplotlib>=3.0.0,<4.0.0",
+]
+
[tool.setuptools]
packages = ["pysr", "pysr._cli", "pysr.test"]
include-package-data = false
@@ -32,14 +38,14 @@ profile = "black"
[tool.rye]
dev-dependencies = [
- "pre-commit>=3.7.0",
- "ipython>=8.23.0",
+ "coverage>=7.5.3",
"ipykernel>=6.29.4",
- "mypy>=1.10.0",
+ "ipython>=8.23.0",
"jax[cpu]>=0.4.26",
- "torch>=2.3.0",
+ "mypy>=1.10.0",
"pandas-stubs>=2.2.1.240316",
- "types-pytz>=2024.1.0.20240417",
+ "pre-commit>=3.7.0",
+ "torch>=2.3.0",
"types-openpyxl>=3.1.0.20240428",
- "coverage>=7.5.3",
+ "types-pytz>=2024.1.0.20240417",
]
diff --git a/pysr/julia_import.py b/pysr/julia_import.py
index 0e032bee1..88ca9a124 100644
--- a/pysr/julia_import.py
+++ b/pysr/julia_import.py
@@ -4,6 +4,8 @@
from types import ModuleType
from typing import cast
+import juliapkg
+
# Check if JuliaCall is already loaded, and if so, warn the user
# about the relevant environment variables. If not loaded,
# set up sensible defaults.
@@ -36,6 +38,14 @@
):
os.environ[k] = os.environ.get(k, default)
+juliapkg.require_julia("~1.6.7, ~1.7, ~1.8, ~1.9, =1.10.0, ^1.10.3")
+juliapkg.add(
+ "SymbolicRegression",
+ "8254be44-1295-4e6a-a16d-46603ac705cb",
+ version="=0.24.5",
+)
+juliapkg.add("Serialization", "9e88b42a-f829-5b0c-bbe9-9e923198166b", version="1")
+
autoload_extensions = os.environ.get("PYSR_AUTOLOAD_EXTENSIONS")
if autoload_extensions is not None:
diff --git a/pysr/juliapkg.json b/pysr/juliapkg.json
deleted file mode 100644
index 045d79e30..000000000
--- a/pysr/juliapkg.json
+++ /dev/null
@@ -1,13 +0,0 @@
-{
- "julia": "~1.6.7, ~1.7, ~1.8, ~1.9, =1.10.0, ^1.10.3",
- "packages": {
- "SymbolicRegression": {
- "uuid": "8254be44-1295-4e6a-a16d-46603ac705cb",
- "version": "=0.24.5"
- },
- "Serialization": {
- "uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b",
- "version": "1"
- }
- }
-}
diff --git a/pysr/sr.py b/pysr/sr.py
index 0054ce502..ff1983bdc 100644
--- a/pysr/sr.py
+++ b/pysr/sr.py
@@ -957,6 +957,7 @@ def from_file(
feature_names_in: Optional[ArrayLike[str]] = None,
selection_mask: Optional[NDArray[np.bool_]] = None,
nout: int = 1,
+ verbosity=1,
**pysr_kwargs,
):
"""
@@ -986,6 +987,8 @@ def from_file(
Number of outputs of the model.
Not needed if loading from a pickle file.
Default is `1`.
+ verbosity : int
+ What verbosity level to use. 0 means minimal print statements.
**pysr_kwargs : dict
Any other keyword arguments to initialize the PySRRegressor object.
These will overwrite those stored in the pickle file.
@@ -1000,9 +1003,11 @@ def from_file(
pkl_filename = _csv_filename_to_pkl_filename(equation_file)
# Try to load model from .pkl
- print(f"Checking if {pkl_filename} exists...")
+ if verbosity > 0:
+ print(f"Checking if {pkl_filename} exists...")
if os.path.exists(pkl_filename):
- print(f"Loading model from {pkl_filename}")
+ if verbosity > 0:
+ print(f"Loading model from {pkl_filename}")
assert binary_operators is None
assert unary_operators is None
assert n_features_in is None
@@ -1022,10 +1027,11 @@ def from_file(
return model
# Else, we re-create it.
- print(
- f"{pkl_filename} does not exist, "
- "so we must create the model from scratch."
- )
+ if verbosity > 0:
+ print(
+ f"{pkl_filename} does not exist, "
+ "so we must create the model from scratch."
+ )
assert binary_operators is not None or unary_operators is not None
assert n_features_in is not None
diff --git a/requirements.txt b/requirements.txt
index 230f67dce..1f7c104b3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,6 +2,7 @@ sympy>=1.0.0,<2.0.0
pandas>=0.21.0,<3.0.0
numpy>=1.13.0,<3.0.0
scikit_learn>=1.0.0,<2.0.0
+juliapkg==0.1.13
juliacall==0.9.20
click>=7.0.0,<9.0.0
setuptools>=50.0.0