Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 48 additions & 60 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@ parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

Expand All @@ -50,29 +47,36 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
test-results/
pytest-results/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
# Virtual environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# PyBuilder
target/
# IDEs
.vscode/
.idea/
*.swp
*.swo
*~
.project
.pydevproject
.settings/
.DS_Store

# Claude settings
.claude/*

# Poetry
# Note: We intentionally do NOT ignore poetry.lock
dist/
__pypackages__/

# Jupyter Notebook
.ipynb_checkpoints
Expand All @@ -84,46 +88,30 @@ ipython_config.py
# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Model checkpoints and outputs
*.pth
*.pt
*.ckpt
checkpoints/
outputs/
results/
logs/

# Dataset files (if large)
# Uncomment if needed:
# datasets/
# data/

# Temporary files
*.tmp
*.temp
.tmp/
.temp/
4,253 changes: 4,253 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

99 changes: 99 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
[tool.poetry]
name = "mat"
version = "0.1.0"
description = "MAT: Mask-Aware Transformer for Large Hole Image Inpainting"
authors = ["Your Name <you@example.com>"]
readme = "README.md"
packages = [{include = "networks"}, {include = "losses"}, {include = "metrics"}, {include = "torch_utils"}, {include = "dnnlib"}, {include = "training"}, {include = "datasets"}]

[tool.poetry.dependencies]
python = "^3.8"
easydict = "*"
future = "*"
matplotlib = "*"
numpy = "*"
opencv-python = "*"
scikit-image = "*"
scipy = "*"
click = "*"
requests = "*"
tqdm = "*"
pyspng = "*"
ninja = "*"
imageio-ffmpeg = "0.4.3"
timm = "*"
psutil = "*"
scikit-learn = "*"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.0"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[tool.pytest.ini_options]
minversion = "7.0"
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = [
"-ra",
"--strict-markers",
"--strict-config",
"--cov=networks",
"--cov=losses",
"--cov=metrics",
"--cov=torch_utils",
"--cov=dnnlib",
"--cov=training",
"--cov=datasets",
"--cov-branch",
"--cov-report=term-missing:skip-covered",
"--cov-report=html:htmlcov",
"--cov-report=xml:coverage.xml",
"--cov-fail-under=80",
"-vv",
"--tb=short",
]
markers = [
"unit: Unit tests",
"integration: Integration tests",
"slow: Slow-running tests",
]

[tool.coverage.run]
source = ["networks", "losses", "metrics", "torch_utils", "dnnlib", "training", "datasets"]
omit = [
"*/tests/*",
"*/test_*.py",
"*/__pycache__/*",
"*/site-packages/*",
]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if __name__ == .__main__.:",
"raise NotImplementedError",
"pass",
"except ImportError:",
]
show_missing = true
precision = 2
skip_covered = false

[tool.coverage.html]
directory = "htmlcov"

[tool.coverage.xml]
output = "coverage.xml"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Tests package initialization
151 changes: 151 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Shared pytest fixtures and configuration."""
import os
import shutil
import tempfile
from pathlib import Path
from typing import Generator, Dict, Any

import pytest
import numpy as np
import torch
from PIL import Image


@pytest.fixture
def temp_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for test files."""
temp_path = tempfile.mkdtemp()
yield Path(temp_path)
shutil.rmtree(temp_path)


@pytest.fixture
def mock_config() -> Dict[str, Any]:
"""Provide a mock configuration dictionary for testing."""
return {
"model": {
"name": "test_model",
"input_size": 256,
"output_size": 256,
"channels": 3,
},
"training": {
"batch_size": 4,
"learning_rate": 0.001,
"epochs": 10,
},
"dataset": {
"root": "/tmp/test_data",
"train_split": 0.8,
"val_split": 0.2,
},
}


@pytest.fixture
def sample_image() -> np.ndarray:
"""Create a sample RGB image for testing."""
return np.random.randint(0, 256, size=(256, 256, 3), dtype=np.uint8)


@pytest.fixture
def sample_mask() -> np.ndarray:
"""Create a sample binary mask for testing."""
mask = np.zeros((256, 256), dtype=np.uint8)
mask[64:192, 64:192] = 255 # Square hole in the center
return mask


@pytest.fixture
def sample_tensor_image() -> torch.Tensor:
"""Create a sample image tensor for testing."""
return torch.randn(1, 3, 256, 256)


@pytest.fixture
def sample_tensor_mask() -> torch.Tensor:
"""Create a sample mask tensor for testing."""
mask = torch.zeros(1, 1, 256, 256)
mask[:, :, 64:192, 64:192] = 1.0
return mask


@pytest.fixture
def mock_model():
"""Create a mock model for testing."""
class MockModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)

def forward(self, x, mask=None):
return self.conv(x)

return MockModel()


@pytest.fixture
def mock_dataset_path(temp_dir: Path) -> Path:
"""Create a mock dataset directory structure."""
dataset_path = temp_dir / "dataset"
images_path = dataset_path / "images"
masks_path = dataset_path / "masks"

images_path.mkdir(parents=True)
masks_path.mkdir(parents=True)

# Create sample images and masks
for i in range(5):
img = Image.new('RGB', (256, 256), color=(i*50, i*50, i*50))
img.save(images_path / f"image_{i}.png")

mask = Image.new('L', (256, 256), color=0)
mask.save(masks_path / f"mask_{i}.png")

return dataset_path


@pytest.fixture
def device():
"""Get the appropriate device for testing."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture(autouse=True)
def reset_random_seeds():
"""Reset random seeds before each test for reproducibility."""
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)


@pytest.fixture
def capture_stdout(monkeypatch):
"""Capture stdout for testing print statements."""
import io
import sys

captured_output = io.StringIO()
monkeypatch.setattr(sys, 'stdout', captured_output)
return captured_output


# Markers for conditional test execution
def pytest_configure(config):
"""Configure custom markers."""
config.addinivalue_line(
"markers", "gpu: mark test as requiring GPU"
)
config.addinivalue_line(
"markers", "slow: mark test as slow running"
)


def pytest_collection_modifyitems(config, items):
"""Modify test collection to handle GPU tests."""
if not torch.cuda.is_available():
skip_gpu = pytest.mark.skip(reason="GPU not available")
for item in items:
if "gpu" in item.keywords:
item.add_marker(skip_gpu)
Loading