Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference helpers & tests #57

Merged
merged 50 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c8ae3cb
Add inference helpers & tests
palp Jul 23, 2023
dbad98c
Support testing with hatch
palp Jul 23, 2023
7e8fbd7
fixes to hatch script
palp Jul 23, 2023
5333836
add inference test action
palp Jul 23, 2023
74d261d
change workflow trigger
palp Jul 23, 2023
e02de5a
widen trigger to test
palp Jul 23, 2023
5cc8f7c
revert changes to workflow triggers
palp Jul 23, 2023
927ffff
Install local python in action
palp Jul 23, 2023
f1eb786
Trigger on push again
palp Jul 23, 2023
4e9ffe5
fix python version
palp Jul 23, 2023
cc7e983
add CODEOWNERS and change triggers
palp Jul 23, 2023
29a39d0
Report tests results
palp Jul 23, 2023
06b9c76
update action versions
palp Jul 23, 2023
0703ef8
format
palp Jul 23, 2023
0cedb25
Fix typo and add refiner helper
palp Jul 24, 2023
9f27cc6
use a shared path loaded from a secret for checkpoints source
palp Jul 25, 2023
2ebd30a
typo fix
palp Jul 25, 2023
129422b
Merge remote-tracking branch 'base/main' into palp/model-tests
palp Jul 25, 2023
8086691
Use device from input and remove duplicated code
palp Jul 25, 2023
2ac4c50
PR feedback
palp Jul 25, 2023
81e1047
fix call to load_model_from_config
palp Jul 25, 2023
5648dce
Move model to gpu
palp Jul 25, 2023
ed78819
Refactor helpers
palp Jul 26, 2023
9283e34
cleanup
palp Jul 26, 2023
da77108
test refiner, prep for 1.0, align with metadata
palp Jul 26, 2023
7c26bde
fix paths on second load
palp Jul 26, 2023
b094614
deduplicate streamlit code
palp Jul 26, 2023
8ae8888
filenames
palp Jul 26, 2023
83d6c66
fixes
palp Jul 26, 2023
89c74b2
add pydantic to requirements
palp Jul 26, 2023
e5a6443
Merge pull request #1 from palp/inference
palp Jul 26, 2023
def2cd4
fix usage of `msg` in demo script
palp Jul 26, 2023
dd9a1a7
remove double text
palp Jul 26, 2023
00b8f10
run black
palp Jul 26, 2023
f77d9e5
fix streamlit sampling when returning latents
palp Jul 26, 2023
733dfb3
extract function for streamlit output
palp Jul 26, 2023
959a7ee
another fix for streamlit outputs
palp Jul 26, 2023
cca46d3
fix img2img in streamlit
palp Jul 26, 2023
31056b9
Make fp16 optional and fix device param
palp Jul 26, 2023
87bdf2a
Merge remote-tracking branch 'base/main' into palp/model-tests
palp Jul 26, 2023
42d11ff
PR feedback
palp Jul 26, 2023
733d38b
fix dict cast for dataclass
palp Jul 26, 2023
2cc5425
Merge remote-tracking branch 'base/main' into palp/model-tests
palp Jul 26, 2023
ba60896
run black, update ci script
palp Jul 26, 2023
f811542
cache pip dependencies on hosted runners, remove extra runs
palp Jul 26, 2023
7e86a69
install package in ci env
palp Jul 26, 2023
64b5c9d
fix cache path
palp Jul 26, 2023
f471736
PR cleanup
palp Jul 26, 2023
197a074
one more cleanup
palp Jul 26, 2023
30a7f79
don't cache, it filled up
palp Jul 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.github @Stability-AI/infrastructure
34 changes: 34 additions & 0 deletions .github/workflows/test-inference.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Test inference

on:
pull_request:
push:
branches:
- main

jobs:
test:
name: "Test inference"
# This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
if: github.repository == 'stability-ai/generative-models'
runs-on: [self-hosted, slurm, g40]
steps:
- uses: actions/checkout@v3
- name: "Symlink checkpoints"
run: ln -s ${{secrets.SGM_CHECKPOINTS_PATH}} checkpoints
- name: "Setup python"
uses: actions/setup-python@v4
with:
python-version: "3.10"
benjaminaubin marked this conversation as resolved.
Show resolved Hide resolved
- name: "Install Hatch"
run: pip install hatch
- name: "Run inference tests"
run: hatch run ci:test-inference --junit-xml test-results.xml
- name: Surface failing tests
if: always()
uses: pmeier/pytest-results-action@main
with:
path: test-results.xml
summary: true
display-options: fEX
fail-on-empty: true
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,19 @@ include = [

[tool.hatch.build.targets.wheel.force-include]
"./configs" = "sgm/configs"

[tool.hatch.envs.ci]
# Skip for now, since requirements.txt is used by scripts and includes the project
# This should be changed when dependencies are handled by Hatch
skip-install = true

dependencies = [
"pytest"
]

[tool.hatch.envs.ci.scripts]
test-inference = [
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
benjaminaubin marked this conversation as resolved.
Show resolved Hide resolved
"pip install -r requirements_pt2.txt",
benjaminaubin marked this conversation as resolved.
Show resolved Hide resolved
"pytest -v tests/inference/test_inference.py {args}",
]
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')
1 change: 1 addition & 0 deletions requirements_pt13.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ torchvision==0.14.1+cu117
torchmetrics
opencv-python==4.6.0.66
fairscale
pydantic
pytorch-lightning==1.8.5
fsspec
kornia==0.6.9
Expand Down
1 change: 1 addition & 0 deletions requirements_pt2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ torchmetrics
torchvision>=0.15.2
opencv-python==4.6.0.66
fairscale
pydantic
pytorch-lightning==2.0.1
fire
fsspec
Expand Down
24 changes: 19 additions & 5 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import numpy as np
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)

SAVE_PATH = "outputs/demo/txt2img/"

Expand Down Expand Up @@ -130,6 +137,8 @@ def run_txt2img(

if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
Expand All @@ -143,6 +152,8 @@ def run_txt2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)

return out


Expand Down Expand Up @@ -174,6 +185,8 @@ def run_img2img(
num_samples = num_rows * num_cols

if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
Expand All @@ -183,7 +196,9 @@ def run_img2img(
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
filter=filter,
)
logger=st,
)
show_samples(out, outputs)
return out


Expand Down Expand Up @@ -248,8 +263,6 @@ def apply_refiner(
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))

state = init_st(version_dict)
if state["msg"]:
st.info(state["msg"])
model = state["model"]

is_legacy = version_dict["is_legacy"]
Expand All @@ -274,7 +287,6 @@ def apply_refiner(

version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2)
st.info(state2["msg"])

stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
Expand Down Expand Up @@ -314,6 +326,7 @@ def apply_refiner(
samples_z = None

if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
Expand All @@ -323,7 +336,8 @@ def apply_refiner(
prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "",
filter=filter,
)
)
show_samples(samples, outputs)

if save_locally and samples is not None:
perform_save_locally(save_path, samples)
Loading