Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

❓ [Question] torchscript int8 quantization degradation in recent versions #3173

Open
seymurkafkas opened this issue Sep 22, 2024 · 1 comment
Assignees
Labels
question Further information is requested

Comments

@seymurkafkas
Copy link
Contributor

TS INT8 degradation later versions

Hi all, I get a degradation in results after an INT8 quantization with torchscript, after updating my torch_tensorrt, torch and tensorrt versions. I have listed the dependencies for both cases below, is this expected?

Earlier Version (Works Well):
Torch: 2.0.1
CUDA: 11.8
torch_tensorrt: 1.4.0
Tensorrt: 8.5.3.1
GPU: A100
Python: 3.9

Later Version (Degradation in Results): Torch 2.4.0
CUDA 12.1
torch_tensorrt: 2.4.0
Tensorrt: 10.1.0
GPU: A100
Python: 3.11

Script (Approximately, as I can't submit the model):

import torch
import time
from pathlib import Path
import PIL
import PIL.Image
import torch_tensorrt

import torch_tensorrt.ptq
from torchvision.transforms.functional import to_tensor, center_crop
from torch.utils.data import Dataset, DataLoader

class CalibrationDataset(Dataset):
    def __init__(self, tile_size: int, model: torch.nn.Module, dtype: torch.dtype) -> None:
        self._tile_size = tile_size
        self._images = [f for f in Path("images").glob("**/*")]
        self._length = len(self._images)
        print("Dataset size:", self._length)
        self._model = model
        self._dtype = dtype

    def __len__(self) -> int:
        return self._length

    def _to_tensor(self, img_path: Path) -> torch.Tensor:
        pil_img = PIL.Image.open(img_path).convert("RGB")
        return to_tensor(pil_img).to(device="cuda", dtype=self._dtype).unsqueeze(0)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        print(f"CalibrationDataset called with {idx=}")
        input_file = self._images[idx]
        input_tensor = center_crop(self._to_tensor(input_file), output_size=self._tile_size)
        return input_tensor, self._model(input_tensor)



def compile_to_tensort_and_quantize() -> None:
    HALF = True
    dtype = torch.float16
    batch_size, tile_size = 1, 538

    model = ImageToImageModel.create(checkpoint = "base", half=HALF, device=torch.device("cuda"))# Proprietary upscaling model, cannot submit code
    with torch.no_grad():
        calibration_dataset = CalibrationDataset(tile_size=tile_size, model=model, dtype=dtype)
        testing_dataloader = DataLoader(
            calibration_dataset, batch_size=4, shuffle=True, num_workers=0,)

        calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(
            testing_dataloader,
            cache_file="./calibration.cache",
            use_cache=False,
            algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
            device=torch.device("cuda"),
        )
        dummy_input = torch.randn(1, 3, tile_size, tile_size, device=torch.device("cuda"), dtype=dtype)
        inputs = torch.randn(1, 3, tile_size, tile_size, device=torch.device("cuda"), dtype=dtype)
        torch_script_module = torch.jit.trace(model, example_inputs=inputs)

        with torch_tensorrt.logging.debug():
            trt_ts_module = torch_tensorrt.compile(
                torch_script_module,
                truncate_long_and_double=True,
                inputs=[dummy_input],
                enabled_precisions={torch.int8},
                calibrator=calibrator,
                device={
                    "device_type": torch_tensorrt.DeviceType.GPU,
                    "gpu_id": 0,
                    "dla_core": 0,
                    "allow_gpu_fallback": False,
                    "disable_tf32": False
                },
            )


        torch.jit.save(trt_ts_module, "trt_OLD.ts")

    print("Benchmark")
    times = []
    for _ in range(5):
        t1 = time.monotonic()
        out = trt_ts_module(inputs)
        print(out)
        torch.cuda.synchronize()
        times.append(time.monotonic() - t1)

    print(times)


if __name__ == "__main__":
    compile_to_tensort_and_quantize()

Note: In the later version, need to switch import torch_tensorrt.ptq to import torch_tensorrt.ts.ptq, the rest of the script is identical

While the previous versions work well (I get a quantized model that produces close-enough results to the original model), for the later version, I get garbage outputs (I can see there is something wrong with the calibration as the output tensor values is always within a small range 0.18-0.21, whereas it should take any value between -1,1). I'm posting the quantization script approximately, however, I cannot post the model details unfortunately, as it's proprietary.

Would appreciate all forms of help :), also would love to submit a fix for the underlying issue (if one is present).

@seymurkafkas seymurkafkas added the question Further information is requested label Sep 22, 2024
@seymurkafkas
Copy link
Contributor Author

I have also tried the script with following dependencies to bisect the issue.
Torch: 2.2.1
TensorRT: 8.6.1
torch_tensorrt: 2.2.0
Python: 3.11
CUDA: 12.1

With these dependencies, it also works as expected (good results)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants