Skip to content

Commit

Permalink
test: do not use staging for cache tests
Browse files Browse the repository at this point in the history
The HF_ENDPOINT variable is not always taken into account when using the
huggingface_hub client depending on the order of imports.
This modifies the tests to create temporary dorectories under the testing
user account instead.
  • Loading branch information
dacorvo committed Feb 6, 2025
1 parent c510221 commit 836271c
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions tests/cache/test_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import socket
import subprocess
from tempfile import TemporaryDirectory
from time import time

import PIL
import pytest
import torch
from huggingface_hub import HfApi
from transformers import AutoTokenizer
from transformers.testing_utils import ENDPOINT_STAGING

from optimum.neuron import (
NeuronModelForCausalLM,
Expand All @@ -37,27 +37,21 @@


@pytest.fixture
def cache_repos(staging):
def cache_repos():
# Setup: create temporary Hub repository and local cache directory
token = staging["token"]
user = staging["user"]
api = HfApi(endpoint=ENDPOINT_STAGING, token=token)
api = HfApi()
hostname = socket.gethostname()
cache_repo_id = f"{user}/{hostname}-optimum-neuron-cache"
if api.repo_exists(cache_repo_id):
api.delete_repo(cache_repo_id)
cache_repo_id = f"{hostname}-{time()}-optimum-neuron-cache"
cache_repo_id = api.create_repo(cache_repo_id, private=True).repo_id
cache_dir = TemporaryDirectory()
cache_path = cache_dir.name
# Modify environment to force neuronx cache to use temporary caches
previous_env = {}
env_vars = ["NEURON_COMPILE_CACHE_URL", "CUSTOM_CACHE_REPO", "HF_ENDPOINT", "HF_TOKEN"]
env_vars = ["NEURON_COMPILE_CACHE_URL", "CUSTOM_CACHE_REPO"]
for var in env_vars:
previous_env[var] = os.environ.get(var)
os.environ["NEURON_COMPILE_CACHE_URL"] = cache_path
os.environ["CUSTOM_CACHE_REPO"] = cache_repo_id
os.environ["HF_ENDPOINT"] = ENDPOINT_STAGING
os.environ["HF_TOKEN"] = token
yield (cache_path, cache_repo_id)
# Teardown
api.delete_repo(cache_repo_id)
Expand Down Expand Up @@ -173,8 +167,7 @@ def check_traced_cache_entry(cache_path):


def assert_local_and_hub_cache_sync(cache_path, cache_repo_id):
# Since created models are public on the staging endpoint we don't need a token
api = HfApi(endpoint=ENDPOINT_STAGING)
api = HfApi()
remote_files = api.list_repo_files(cache_repo_id)
local_files = get_local_cached_files(cache_path)
for file in local_files:
Expand Down

0 comments on commit 836271c

Please sign in to comment.