Skip to content

Commit

Permalink
Replace Mypy with Pyright (#206)
Browse files Browse the repository at this point in the history
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
mattt authored Nov 27, 2023
1 parent 7926e2b commit 5f7ae72
Showing 12 changed files with 45 additions and 39 deletions.
1 change: 0 additions & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -3,6 +3,5 @@
"charliermarsh.ruff",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.mypy-type-checker"
]
}
5 changes: 1 addition & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -15,16 +15,13 @@
}
},
"python.languageServer": "Pylance",
"python.analysis.typeCheckingMode": "basic",
"python.testing.pytestArgs": [
"-vvv",
"python"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"mypy-type-checker.args": [
"--show-column-numbers",
"--no-pretty"
],
"ruff.lint.args": [
"--config=pyproject.toml"
],
7 changes: 1 addition & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -17,8 +17,8 @@ dependencies = [
"typing_extensions>=4.5.0",
]
optional-dependencies = { dev = [
"mypy",
"pylint",
"pyright",
"pytest",
"pytest-asyncio",
"pytest-recording",
@@ -39,11 +39,6 @@ packages = ["replicate"]
[tool.setuptools.package-data]
"replicate" = ["py.typed"]

[tool.mypy]
plugins = "pydantic.mypy"
exclude = ["tests/"]
enable_incomplete_feature = ["Unpack"]

[tool.pylint.main]
disable = [
"C0301", # Line too long
21 changes: 16 additions & 5 deletions replicate/collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Union, overload

from typing_extensions import deprecated

@@ -32,13 +32,24 @@ def id(self) -> str:
"""
return self.slug

def __iter__(self): # noqa: ANN204
return iter(self.models)
def __iter__(self) -> Iterator[Model]:
if self.models is not None:
return iter(self.models)
return iter([])

@overload
def __getitem__(self, index: int) -> Optional[Model]:
...

def __getitem__(self, index) -> Optional[Model]:
@overload
def __getitem__(self, index: slice) -> Optional[List[Model]]:
...

def __getitem__(
self, index: Union[int, slice]
) -> Union[Optional[Model], Optional[List[Model]]]:
if self.models is not None:
return self.models[index]

return None

def __len__(self) -> int:
6 changes: 3 additions & 3 deletions replicate/json.py
Original file line number Diff line number Diff line change
@@ -31,10 +31,10 @@ def encode_json(
if isinstance(obj, io.IOBase):
return upload_file(obj)
if HAS_NUMPY:
if isinstance(obj, np.integer):
if isinstance(obj, np.integer): # type: ignore
return int(obj)
if isinstance(obj, np.floating):
if isinstance(obj, np.floating): # type: ignore
return float(obj)
if isinstance(obj, np.ndarray):
if isinstance(obj, np.ndarray): # type: ignore
return obj.tolist()
return obj
2 changes: 1 addition & 1 deletion replicate/pagination.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
pass


class Page(pydantic.BaseModel, Generic[T]):
class Page(pydantic.BaseModel, Generic[T]): # type: ignore
"""
A page of results from the API.
"""
2 changes: 1 addition & 1 deletion replicate/resource.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from replicate.client import Client


class Resource(pydantic.BaseModel):
class Resource(pydantic.BaseModel): # type: ignore
"""
A base class for representing a single object on the server.
"""
14 changes: 7 additions & 7 deletions replicate/stream.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
from replicate.prediction import Predictions


class ServerSentEvent(pydantic.BaseModel):
class ServerSentEvent(pydantic.BaseModel): # type: ignore
"""
A server-sent event.
"""
@@ -136,10 +136,10 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
if sse is not None:
if sse.event == "done":
return
elif sse.event == "error":
if sse.event == "error":
raise RuntimeError(sse.data)
else:
yield sse

yield sse

async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
decoder = EventSource.Decoder()
@@ -149,10 +149,10 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
if sse is not None:
if sse.event == "done":
return
elif sse.event == "error":
if sse.event == "error":
raise RuntimeError(sse.data)
else:
yield sse

yield sse


def stream(
8 changes: 6 additions & 2 deletions replicate/training.py
Original file line number Diff line number Diff line change
@@ -231,6 +231,8 @@ def create( # type: ignore
Create a new training using the specified model version as a base.
"""

url = None

# Support positional arguments for backwards compatibility
if args:
if shorthand := args[0] if len(args) > 0 else None:
@@ -245,12 +247,12 @@ def create( # type: ignore
params["webhook_completed"] = args[4]
if len(args) > 5:
params["webhook_events_filter"] = args[5]

elif model and version:
url = _create_training_url_from_model_and_version(model, version)
elif model is None and isinstance(version, str):
url = _create_training_url_from_shorthand(version)
else:

if not url:
raise ValueError("model and version or shorthand version must be specified")

body = _create_training_body(input, **params)
@@ -376,6 +378,8 @@ def _create_training_url_from_model_and_version(
owner, name = model.owner, model.name
elif isinstance(model, tuple):
owner, name = model[0], model[1]
else:
raise ValueError("model must be a Model or a tuple of (owner, name)")

if isinstance(version, Version):
version_id = version.id
12 changes: 7 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -37,10 +37,8 @@ mccabe==0.7.0
# via pylint
multidict==6.0.4
# via yarl
mypy==1.4.1
# via replicate (pyproject.toml)
mypy-extensions==1.0.0
# via mypy
nodeenv==1.8.0
# via pyright
packaging==23.1
# via
# pytest
@@ -55,6 +53,8 @@ pydantic-core==2.3.0
# via pydantic
pylint==3.0.2
# via replicate (pyproject.toml)
pyright==1.1.337
# via replicate (pyproject.toml)
pytest==7.4.0
# via
# pytest-asyncio
@@ -79,7 +79,6 @@ tomlkit==0.12.1
# via pylint
typing-extensions==4.7.1
# via
# mypy
# pydantic
# pydantic-core
# replicate (pyproject.toml)
@@ -89,3 +88,6 @@ wrapt==1.15.0
# via vcrpy
yarl==1.9.2
# via vcrpy

# The following packages are considered to be unsafe in a requirements file:
# setuptools
4 changes: 2 additions & 2 deletions script/lint
Original file line number Diff line number Diff line change
@@ -4,8 +4,8 @@ set -e

STATUS=0

echo "Running mypy"
python -m mypy replicate || STATUS=$?
echo "Running pyright"
python -m pyright replicate || STATUS=$?
echo ""

echo "Running pylint"
2 changes: 0 additions & 2 deletions script/setup
Original file line number Diff line number Diff line change
@@ -3,5 +3,3 @@
set -e

python -m pip install -r requirements.txt -r requirements-dev.txt .

yes | python -m mypy --install-types replicate || true

0 comments on commit 5f7ae72

Please sign in to comment.