diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 25cde4c..e2bddac 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -27,7 +27,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.12' - name: Install dependencies run: | diff --git a/requirements.txt b/requirements.txt index 205423b..a934978 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -ai-edge-torch-nightly -tf_nightly -torchvision -segmentation-models-pytorch -albumentations -opencv-python -numpy +litert-torch==0.8.0 +torchvision==0.24.1 +segmentation-models-pytorch==0.5.0 +albumentations==2.0.8 +opencv-python==4.13.0.92 +numpy==2.4.2 diff --git a/train.py b/train.py index 170dac6..c5854ce 100644 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ from urllib.request import urlretrieve import shutil -DATASET_VERSION = 'v2.0' +DATASET_VERSION = 'v2.1' DATASET_ZIP_URL = f'https://github.com/pynicolas/fairscan-dataset/releases/download/{DATASET_VERSION}/fairscan-dataset-{DATASET_VERSION}.zip' BUILD_DIR = "build" @@ -38,7 +38,7 @@ DATASET_ZIP_PATH = BUILD_DIR + "/dataset.zip" DATASET_PARENT_DIR = BUILD_DIR + "/dataset" DATASET_DIR = DATASET_PARENT_DIR + "/fairscan-dataset" -NB_EPOCHS = 25 +NB_EPOCHS = 35 if os.path.isdir(BUILD_DIR): shutil.rmtree(BUILD_DIR) @@ -181,7 +181,8 @@ def evaluate_encoder(encoder_name, model_save_path, device=torch.device('cpu')): end = time.time() print(f"- Epoch {epoch + 1}/{NB_EPOCHS}: train_loss={avg_train_loss:.4f} | Val Loss: {val_loss:.4f}" + - f" | Dice (cont): {dice_cont_mean:.4f} | Dice (disc): {dice_disc_mean:.4f} | {end - start:.1f} seconds") + f" | Dice (cont): {dice_cont_mean:.4f} | Dice (disc): {dice_disc_mean:.4f} | {end - start:.1f} seconds", + flush=True) if dice_disc_mean > best_dice: best_dice = dice_disc_mean @@ -232,8 +233,8 @@ def evaluate_encoder(encoder_name, model_save_path, device=torch.device('cpu')): # Convert to TFLite -import ai_edge_torch -from ai_edge_torch.generative.quantize import quant_recipes +import litert_torch +from litert_torch.generative.quantize import quant_recipes model = smp.DeepLabV3Plus( encoder_name=encoder, @@ -280,10 +281,10 @@ def representative_dataset(): yield (img,) # 3. quant_config -quant_config = quant_recipes.full_int8_dynamic_recipe() +quant_config = quant_recipes.full_dynamic_recipe() # 4. Conversion -edge_model_quantized = ai_edge_torch.convert( +edge_model_quantized = litert_torch.convert( wrapped_model, sample_args=sample_args, sample_kwargs=None,