Skip to content

Commit 97490a5

Browse files
committed
feat(python): adding test scripts
1 parent 6c38d8a commit 97490a5

File tree

9 files changed

+3571
-0
lines changed

9 files changed

+3571
-0
lines changed

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SCM syntax highlighting & preventing 3-way merges
2+
pixi.lock merge=binary linguist-language=YAML linguist-generated=true

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,7 @@ trace-*.json
7070
# Generated by Tauri
7171
# will have schema files for capabilities auto-completion
7272
/gen/schemas
73+
74+
# pixi environments
75+
.pixi
76+
*.egg-info

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "python/dependencies/librosa"]
2+
path = python/dependencies/librosa
3+
url = https://github.com/librosa/librosa

cspell.config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ words:
66
- AIRI
77
- byteorder
88
- clippy
9+
- coreml
910
- cuda
11+
- directml
1012
- distil
1113
- dtolnay
1214
- DTYPE
@@ -16,6 +18,7 @@ words:
1618
- logprob
1719
- melfilters
1820
- mmaped
21+
- ndarray
1922
- onnx
2023
- probs
2124
- Resampler

hack/hftf

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/usr/bin/env python
2+
import sys
3+
import os
4+
from huggingface_hub import hf_hub_download
5+
from huggingface_hub.utils import RepositoryNotFoundError, EntryNotFoundError, LocalEntryNotFoundError
6+
7+
def get_hf_file_path(model_id, file_path):
8+
"""
9+
Retrieves the local cached path for a file from a Hugging Face Hub repository,
10+
then resolves its symbolic link to the actual blob file.
11+
12+
Args:
13+
model_id (str): The ID of the model repository (e.g., "onnx-community/whisper-large-v3-turbo").
14+
file_path (str): The relative path to the file within the repository (e.g., "onnx/encoder_model.onnx").
15+
16+
Returns:
17+
A string containing the absolute local path to the resolved blob file, or an error message.
18+
"""
19+
try:
20+
# Step 1: Use hf_hub_download to find the file in the cache.
21+
# local_files_only=True ensures we don't trigger a download.
22+
symlink_path = hf_hub_download(
23+
repo_id=model_id,
24+
filename=file_path,
25+
local_files_only=True
26+
)
27+
28+
# Step 2: Use os.path.realpath to resolve the symbolic link to the actual file path.
29+
resolved_path = os.path.realpath(symlink_path)
30+
return resolved_path
31+
32+
except LocalEntryNotFoundError:
33+
return f"Error: File '{file_path}' not found in local cache for repo '{model_id}'. Try downloading it first."
34+
except RepositoryNotFoundError:
35+
return f"Error: Model repository '{model_id}' not found on the Hugging Face Hub."
36+
except EntryNotFoundError:
37+
# This error is less likely with local_files_only=True, but good to keep.
38+
return f"Error: File '{file_path}' not found in the repository '{model_id}'."
39+
except Exception as e:
40+
return f"An unexpected error occurred: {e}"
41+
42+
if __name__ == "__main__":
43+
# Check for the correct number of command-line arguments.
44+
if len(sys.argv) != 3:
45+
print("Usage: hftf <model_id> <file_path>")
46+
print("\nExample:")
47+
print(" hftf onnx-community/whisper-large-v3-turbo onnx/encoder_model.onnx")
48+
sys.exit(1)
49+
50+
repo_id = sys.argv[1]
51+
filename = sys.argv[2]
52+
53+
# Get the path and print it to standard output.
54+
final_path = get_hf_file_path(repo_id, filename)
55+
print(final_path)
56+
57+
# Exit with an error code if the path starts with "Error:"
58+
if final_path.startswith("Error:"):
59+
sys.exit(1)

pixi.lock

Lines changed: 3164 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[workspace]
2+
authors = ["Neko Ayaka <[email protected]>"]
3+
channels = ["conda-forge"]
4+
name = "candle-examples"
5+
platforms = ["osx-arm64", "linux-64", "win-64"]
6+
version = "0.1.0"
7+
8+
[dependencies]
9+
python = "3.12.*"
10+
pip = ">=25.1.1,<26"
11+
12+
[pypi-dependencies]
13+
setuptools = ">=80.9.0, <81"
14+
numpy = "==2.2"
15+
huggingface-hub = ">=0.33.2, <0.34"
16+
transformers = ">=4.53.1, <5"
17+
onnxruntime = ">=1.22.0, <2"
18+
torch = ">=2.7.1, <3"
19+
torchaudio = ">=2.7.1, <3"
20+
torchvision = ">=0.22.1, <0.23"
21+
matplotlib = ">=3.10.3, <4"
22+
librosa = { git = "https://github.com/librosa/librosa" }
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from pathlib import Path
4+
from transformers import WhisperProcessor
5+
from os import path
6+
7+
# --- Configuration ---
8+
# This must match the model used to generate the data
9+
MODEL_ID = "onnx-community/lite-whisper-large-v3-ONNX"
10+
# MODEL_ID = "onnx-community/whisper-large-v3-turbo"
11+
12+
# Directory where the .npy files are stored
13+
INPUT_DIR = Path("verification_data")
14+
15+
# Number of top tokens to show in the logits plot
16+
TOP_K_LOGITS = 20
17+
18+
def plot_mel_spectrogram(features, output_path):
19+
"""Generates and saves a plot of the mel spectrogram."""
20+
if features.ndim == 3 and features.shape[0] == 1:
21+
features = features.squeeze(0) # Remove batch dimension
22+
23+
fig, ax = plt.subplots(figsize=(12, 6))
24+
im = ax.imshow(features, aspect='auto', origin='lower', cmap='viridis', interpolation='none')
25+
fig.colorbar(im, ax=ax, format='%+2.0f dB')
26+
ax.set_title("Input Mel Spectrogram")
27+
ax.set_xlabel("Time Steps")
28+
ax.set_ylabel("Mel Bins")
29+
plt.tight_layout()
30+
plt.savefig(output_path)
31+
print(f"Saved spectrogram plot to {output_path}")
32+
return fig
33+
34+
def plot_encoder_output(hidden_states, output_path):
35+
"""Generates and saves a plot of the encoder hidden states."""
36+
if hidden_states.ndim == 3 and hidden_states.shape[0] == 1:
37+
hidden_states = hidden_states.squeeze(0) # Remove batch dimension
38+
39+
fig, ax = plt.subplots(figsize=(12, 6))
40+
im = ax.imshow(hidden_states, aspect='auto', origin='lower', cmap='viridis', interpolation='none')
41+
fig.colorbar(im, ax=ax)
42+
ax.set_title("Encoder Hidden States")
43+
ax.set_xlabel("Sequence Length")
44+
ax.set_ylabel("Hidden Dimension")
45+
plt.tight_layout()
46+
plt.savefig(output_path)
47+
print(f"Saved encoder output plot to {output_path}")
48+
return fig
49+
50+
def plot_logits(logits, tokenizer, output_path):
51+
"""Generates and saves a bar chart of the top K logits."""
52+
# Logits shape is (batch, sequence, vocab_size). We want the logits for the *next* token.
53+
# In the first step, the input sequence has 3 tokens, so we take the logits from the last position.
54+
last_token_logits = logits[0, -1, :]
55+
56+
# Find the top K tokens and their corresponding logit values
57+
top_k_indices = np.argsort(last_token_logits)[-TOP_K_LOGITS:]
58+
top_k_values = last_token_logits[top_k_indices]
59+
60+
# Decode the token IDs to human-readable strings
61+
top_k_tokens = [tokenizer.decode([idx]) for idx in top_k_indices]
62+
63+
# Find the token that was actually chosen (the one with the highest logit)
64+
chosen_token_index = np.argmax(top_k_values)
65+
66+
fig, ax = plt.subplots(figsize=(10, 8))
67+
bars = ax.barh(np.arange(TOP_K_LOGITS), top_k_values, color='skyblue')
68+
69+
# Highlight the chosen token in a different color
70+
bars[chosen_token_index].set_color('salmon')
71+
72+
ax.set_yticks(np.arange(TOP_K_LOGITS))
73+
ax.set_yticklabels(top_k_tokens)
74+
ax.invert_yaxis() # Display the highest value at the top
75+
ax.set_xlabel("Logit Value")
76+
ax.set_title(f"Top {TOP_K_LOGITS} Predicted Tokens (First Decoder Step)")
77+
78+
# Add the logit values as text on the bars
79+
for bar in bars:
80+
width = bar.get_width()
81+
label_x_pos = width if width > 0 else 1 # Position label correctly for negative logits
82+
ax.text(label_x_pos, bar.get_y() + bar.get_height()/2, f' {width:.2f}',
83+
va='center', ha='left')
84+
85+
plt.tight_layout()
86+
plt.savefig(output_path)
87+
print(f"Saved logits plot to {output_path}")
88+
return fig
89+
90+
91+
def main():
92+
"""Loads data and generates all visualizations."""
93+
# Ensure the input directory exists
94+
if not INPUT_DIR.is_dir():
95+
print(f"Error: Directory '{INPUT_DIR}' not found. Please run the data generation script first.")
96+
return
97+
98+
# --- Load Data ---
99+
try:
100+
input_features = np.load(INPUT_DIR / f"{path.basename(MODEL_ID)}_input_features.npy")
101+
encoder_output = np.load(INPUT_DIR / f"{path.basename(MODEL_ID)}_encoder_output.npy")
102+
step_0_logits = np.load(INPUT_DIR / f"{path.basename(MODEL_ID)}_step_0_logits.npy")
103+
except FileNotFoundError as e:
104+
print(f"Error: Missing data file - {e}. Please ensure all .npy files exist in '{INPUT_DIR}'.")
105+
return
106+
107+
print("Successfully loaded all .npy files.")
108+
109+
# --- Load Tokenizer ---
110+
# The tokenizer is needed to decode the logit indices into text
111+
print(f"Loading tokenizer for {MODEL_ID}...")
112+
processor = WhisperProcessor.from_pretrained(MODEL_ID)
113+
tokenizer = processor.tokenizer
114+
print("Tokenizer loaded.")
115+
116+
# --- Generate Plots ---
117+
# Create a directory to save the plots
118+
plots_dir = Path("plots")
119+
plots_dir.mkdir(exist_ok=True)
120+
121+
plot_mel_spectrogram(input_features, plots_dir / "mel_spectrogram.png")
122+
plot_encoder_output(encoder_output, plots_dir / "encoder_output.png")
123+
plot_logits(step_0_logits, tokenizer, plots_dir / "step_0_logits.png")
124+
125+
# --- Show Plots ---
126+
# This will open interactive windows for each plot.
127+
print("\nDisplaying plots. Close the plot windows to exit the script.")
128+
plt.show()
129+
130+
131+
if __name__ == "__main__":
132+
# You will need matplotlib: pip install matplotlib
133+
main()

0 commit comments

Comments
 (0)