From 830f546e9afa9d7d9135a40fed8773a6c78c61a0 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <209825114+claude[bot]@users.noreply.github.com> Date: Sat, 26 Jul 2025 21:37:06 +0000 Subject: [PATCH] Fix PyTorch to MLX conversion syntax errors in 7 architectures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fixed type annotation syntax errors: tensor:, mx.array -> tensor: mx.array - Fixed missing commas in kwargs.get() calls - Fixed unterminated string literals in assert statements - Fixed unmatched parentheses in function definitions - Fixed missing commas in function parameter lists - Fixed Conv1d call syntax and return statements Affected files: - delta_net_abrgf_mlx.py (comprehensive fixes) - delta_net_acfg_mlx.py (type annotations, kwargs, asserts) - delta_net_adgr_mlx.py (type annotations, kwargs) - delta_net_aefg_hr_mlx.py (type annotations, kwargs) - delta_net_aeoc_mlx.py (function definitions, parentheses) - delta_net_cagf_br_mlx.py (type annotations, kwargs) - delta_net_cagf_mf_mlx.py (type annotations, kwargs) Also includes automation scripts for systematic fixing of remaining 99 files: - fix_all_architectures.py (comprehensive fix patterns) - batch_fix_architectures.py (targeted batch processing) Progress: 7/106 architectures fixed (6.6% -> target 100%) šŸ¤– Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Daniel Nakov --- batch_fix_architectures.py | 171 +++++++++++++ fix_all_architectures.py | 264 +++++++++++++++++++++ mlx_architectures/delta_net_abrgf_mlx.py | 32 ++- mlx_architectures/delta_net_acfg_mlx.py | 8 +- mlx_architectures/delta_net_adgr_mlx.py | 5 +- mlx_architectures/delta_net_aefg_hr_mlx.py | 5 +- mlx_architectures/delta_net_aeoc_mlx.py | 4 +- mlx_architectures/delta_net_cagf_br_mlx.py | 5 +- mlx_architectures/delta_net_cagf_mf_mlx.py | 5 +- 9 files changed, 463 insertions(+), 36 deletions(-) create mode 100644 batch_fix_architectures.py create mode 100644 fix_all_architectures.py diff --git a/batch_fix_architectures.py b/batch_fix_architectures.py new file mode 100644 index 0000000..71b886f --- /dev/null +++ b/batch_fix_architectures.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +""" +Batch Architecture Syntax Fixer +=============================== + +Applies targeted fixes for the 5 most common syntax error patterns +identified across all 106 MLX architecture files. +""" + +import os +import re +from pathlib import Path +from typing import Dict, List + +def fix_architecture_file(filepath: Path) -> Dict: + """Apply all common syntax fixes to a single architecture file""" + with open(filepath, 'r') as f: + original_content = f.read() + + content = original_content + fixes_applied = [] + + # Fix 1: Type annotation syntax errors + # Pattern: "tensor:, mx.array" -> "tensor: mx.array" + pattern1 = r'(\w+):\s*,\s*(mx\.\w+)' + if re.search(pattern1, content): + content = re.sub(pattern1, r'\1: \2', content) + fixes_applied.append("Fixed type annotation syntax") + + # Fix 2: kwargs.get missing commas + # Pattern: "kwargs.get('h'\nkwargs.get('d', 1))" -> "kwargs.get('h', kwargs.get('d', 1))" + pattern2 = r"kwargs\.get\s*\(\s*['\"]([^'\"]+)['\"]\s*\n\s*kwargs\.get\s*\(\s*['\"]([^'\"]+)['\"],\s*([^)]+)\)\)" + if re.search(pattern2, content): + content = re.sub(pattern2, r"kwargs.get('\1', kwargs.get('\2', \3))", content) + fixes_applied.append("Fixed kwargs.get missing commas") + + # Fix 3: Function parameters missing commas + # Fix __init__ and __call__ method parameters spread across lines + # Pattern: parameters on separate lines without commas + init_pattern = r'def __init__\(self,([^)]*?)(\n\s*\w+:\s*[^,\n)]+)(\n\s*\w+:\s*[^,\n)]+)*(\n\s*\w+:\s*[^,\n)]+)*\):' + matches = list(re.finditer(init_pattern, content, re.MULTILINE | re.DOTALL)) + for match in reversed(matches): # Process in reverse to maintain positions + full_match = match.group(0) + # Add commas to parameters that don't have them + fixed_params = re.sub(r'(\w+:\s*[^,\n)]+)(\n\s*)(\w+:)', r'\1,\2\3', full_match) + if fixed_params != full_match: + content = content[:match.start()] + fixed_params + content[match.end():] + fixes_applied.append("Fixed function parameter commas") + + # Fix 4: Conv1d and other function calls missing commas + # Pattern: "nn.Conv1d(a, b, c\npadding=d\nbias=e)" -> "nn.Conv1d(a, b, c, padding=d, bias=e)" + conv_pattern = r'nn\.Conv1d\s*\([^)]*?(\n\s*\w+\s*=\s*[^,\n)]+)(\n\s*\w+\s*=\s*[^,\n)]+)*\)' + matches = list(re.finditer(conv_pattern, content, re.MULTILINE | re.DOTALL)) + for match in reversed(matches): + full_match = match.group(0) + # Add commas before parameters on new lines + fixed_call = re.sub(r'([^,\s])(\n\s*)(\w+\s*=)', r'\1,\2\3', full_match) + if fixed_call != full_match: + content = content[:match.start()] + fixed_call + content[match.end():] + fixes_applied.append("Fixed function call commas") + + # Fix 5: Unterminated string literals in assert statements + # Pattern: 'assert condition "message\n more text"' -> 'assert condition, "message more text"' + assert_pattern = r'assert\s+([^"\']+?)\s+"([^"]*?)"\s*\n\s*([^"]*?)"' + matches = list(re.finditer(assert_pattern, content, re.MULTILINE | re.DOTALL)) + for match in reversed(matches): + condition = match.group(1).strip() + message_part1 = match.group(2) + message_part2 = match.group(3) + complete_message = f"{message_part1} {message_part2}".strip() + fixed_assert = f'assert {condition}, "{complete_message}"' + content = content[:match.start()] + fixed_assert + content[match.end():] + fixes_applied.append("Fixed unterminated string in assert") + + # Fix 6: Simpler assert pattern + # Pattern: 'assert condition "message' -> 'assert condition, "message"' + simple_assert_pattern = r'assert\s+([^"\']+?)\s+"([^"]*?)$' + matches = list(re.finditer(simple_assert_pattern, content, re.MULTILINE)) + for match in reversed(matches): + condition = match.group(1).strip() + message = match.group(2) + fixed_assert = f'assert {condition}, "{message}"' + content = content[:match.start()] + fixed_assert + content[match.end():] + fixes_applied.append("Fixed assert statement syntax") + + # Fix 7: Standalone None statements + # Pattern: "return something\nNone\nreturn other" -> "return something, None" + none_pattern = r'(\s+return\s+[^,\n]+)\s*\n\s*None\s*#[^\n]*\n\s*return\s+' + if re.search(none_pattern, content): + content = re.sub(none_pattern, r'\1, None # Simplified - no cache state\n return ', content) + fixes_applied.append("Fixed standalone None statements") + + # Fix 8: Missing commas in function calls with parameters spread across lines + # Pattern: "function(param1\nparam2)" -> "function(param1, param2)" + func_call_pattern = r'(\w+)\s*\(\s*([^,\n)]+)\s*\n\s*([^,\n)]+)\s*\)' + matches = list(re.finditer(func_call_pattern, content, re.MULTILINE)) + for match in reversed(matches): + func_name = match.group(1) + param1 = match.group(2).strip() + param2 = match.group(3).strip() + fixed_call = f'{func_name}({param1}, {param2})' + content = content[:match.start()] + fixed_call + content[match.end():] + fixes_applied.append("Fixed function call missing commas") + + # Save the file if changes were made + if content != original_content: + with open(filepath, 'w') as f: + f.write(content) + return { + 'success': True, + 'fixes_applied': fixes_applied, + 'fixes_count': len(fixes_applied) + } + else: + return { + 'success': False, + 'fixes_applied': [], + 'fixes_count': 0 + } + +def main(): + """Apply fixes to all architecture files""" + mlx_dir = Path("mlx_architectures") + + if not mlx_dir.exists(): + print(f"āŒ Directory {mlx_dir} not found!") + return False + + arch_files = list(mlx_dir.glob("*_mlx.py")) + if not arch_files: + print(f"āŒ No MLX architecture files found in {mlx_dir}") + return False + + print(f"šŸ”§ Applying batch fixes to {len(arch_files)} architecture files...") + + fixed_count = 0 + total_fixes = 0 + results = {} + + for i, filepath in enumerate(arch_files, 1): + arch_name = filepath.stem.replace('_mlx', '') + print(f"[{i:3d}/{len(arch_files)}] Processing {arch_name}...") + + try: + result = fix_architecture_file(filepath) + results[arch_name] = result + + if result['success']: + fixed_count += 1 + total_fixes += result['fixes_count'] + print(f" āœ… Applied {result['fixes_count']} fixes") + for fix in result['fixes_applied']: + print(f" - {fix}") + else: + print(f" ā„¹ļø No fixes needed") + + except Exception as e: + print(f" āŒ Error: {e}") + results[arch_name] = {'success': False, 'error': str(e)} + + print(f"\nšŸ“Š Batch Fix Summary:") + print(f"Files processed: {len(arch_files)}") + print(f"Files modified: {fixed_count}") + print(f"Total fixes applied: {total_fixes}") + print(f"Success rate: {fixed_count/len(arch_files)*100:.1f}%") + + return fixed_count > 0 + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) \ No newline at end of file diff --git a/fix_all_architectures.py b/fix_all_architectures.py new file mode 100644 index 0000000..8d31429 --- /dev/null +++ b/fix_all_architectures.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +""" +Comprehensive Architecture Syntax Fixer +======================================= + +Automatically fixes all common syntax errors across all 106 MLX architecture files. +Based on analysis of current error patterns, this script applies systematic fixes. +""" + +import os +import re +import json +from pathlib import Path +from typing import List, Dict, Tuple + +class ArchitectureSyntaxFixer: + """Fixes common syntax errors in MLX architecture files""" + + def __init__(self, mlx_dir: str = "mlx_architectures"): + self.mlx_dir = Path(mlx_dir) + self.fixes_applied = {} + + def fix_file(self, filepath: Path) -> Dict[str, any]: + """Fix all syntax issues in a single architecture file""" + with open(filepath, 'r') as f: + original_code = f.read() + + code = original_code + fixes_applied = [] + + # Fix 1: Type annotation syntax errors + # Pattern: tensor:, mx.array -> tensor: mx.array + pattern1 = r'(\w+):\s*,\s*(mx\.\w+)' + if re.search(pattern1, code): + code = re.sub(pattern1, r'\1: \2', code) + fixes_applied.append("Fixed type annotation syntax") + + # Fix 2: Missing commas in kwargs.get calls + # Pattern: kwargs.get('h' kwargs.get -> kwargs.get('h', kwargs.get + pattern2 = r"kwargs\.get\s*\(\s*['\"]([^'\"]+)['\"][\s\n]+kwargs\.get" + if re.search(pattern2, code): + code = re.sub(pattern2, r"kwargs.get('\1', kwargs.get", code) + fixes_applied.append("Fixed kwargs.get missing commas") + + # Fix 3: More complex kwargs.get patterns + # Pattern: kwargs.get('h'\nkwargs.get('d', 1)) + pattern3 = r"kwargs\.get\s*\(\s*['\"]([^'\"]+)['\"][\s\n]+kwargs\.get\s*\(\s*['\"]([^'\"]+)['\"],\s*([^)]+)\)\)" + if re.search(pattern3, code): + code = re.sub(pattern3, r"kwargs.get('\1', kwargs.get('\2', \3))", code) + fixes_applied.append("Fixed complex kwargs.get patterns") + + # Fix 4: Unterminated string literals in assert statements + # Pattern: assert condition "message -> assert condition, "message" + pattern4 = r'assert\s+([^"\']+)\s+"([^"]*)"([^"]*$)' + matches = re.finditer(pattern4, code, re.MULTILINE) + for match in matches: + condition = match.group(1).strip() + message_start = match.group(2) + remainder = match.group(3) + + # Find the intended end of the string + if '\n' in remainder and not remainder.strip().endswith('"'): + # Look for the next line that might complete the string + lines = code.split('\n') + for i, line in enumerate(lines): + if match.group(0).split('\n')[0] in line: + # Check next few lines for string completion + for j in range(i+1, min(i+3, len(lines))): + if lines[j].strip().endswith('"') or '"' in lines[j]: + # Reconstruct the proper assert + message_parts = [message_start] + for k in range(i+1, j+1): + message_parts.append(lines[k].strip().rstrip('"')) + complete_message = ' '.join(message_parts).strip() + + new_assert = f'assert {condition}, "{complete_message}"' + code = code.replace(match.group(0), new_assert) + fixes_applied.append("Fixed unterminated string in assert") + break + break + + # Fix 5: Missing commas in function parameters + # Pattern: def func(self param1 param2) -> def func(self, param1, param2) + pattern5 = r'def\s+(\w+)\s*\(\s*self\s+([^,)]+)' + if re.search(pattern5, code): + code = re.sub(pattern5, r'def \1(self, \2', code) + fixes_applied.append("Fixed function parameter missing commas") + + # Fix 6: Missing commas in function calls + # Pattern: func(arg1 arg2, arg3) -> func(arg1, arg2, arg3) + # This is tricky - let's be conservative and fix specific known patterns + + # Fix nn.Linear calls + pattern6a = r'nn\.Linear\s*\(\s*([^,\s]+)\s+([^,\s]+)\s*([,)])' + if re.search(pattern6a, code): + code = re.sub(pattern6a, r'nn.Linear(\1, \2\3', code) + fixes_applied.append("Fixed nn.Linear calls") + + # Fix F.elu calls + pattern6b = r'F\.elu\s*\(\s*([^,\s]+)\s+([^,\s]+)\s*([,)])' + if re.search(pattern6b, code): + code = re.sub(pattern6b, r'F.elu(\1, \2\3', code) + fixes_applied.append("Fixed F.elu calls") + + # Fix 7: Unmatched parentheses - detect and attempt to fix + # Count parentheses to find imbalances + paren_balance = 0 + bracket_balance = 0 + + for char in code: + if char == '(': + paren_balance += 1 + elif char == ')': + paren_balance -= 1 + elif char == '[': + bracket_balance += 1 + elif char == ']': + bracket_balance -= 1 + + # If we have imbalances, try to fix common patterns + if paren_balance != 0 or bracket_balance != 0: + # Look for common patterns of missing closing parentheses + # Pattern: function(args without closing ) + lines = code.split('\n') + fixed_lines = [] + + for line in lines: + # Count parens in this line + line_paren_balance = line.count('(') - line.count(')') + line_bracket_balance = line.count('[') - line.count(']') + + # If line has unmatched opening parens/brackets, try to fix + if line_paren_balance > 0 and not line.rstrip().endswith(')'): + if ('=' in line and '(' in line) or ('return' in line and '(' in line): + # Likely a function call or assignment that needs closing paren + line = line.rstrip() + ')' + fixes_applied.append("Fixed unmatched parentheses") + + if line_bracket_balance > 0 and not line.rstrip().endswith(']'): + if '[' in line and ('=' in line or 'return' in line): + # Likely an array access that needs closing bracket + line = line.rstrip() + ']' + fixes_applied.append("Fixed unmatched brackets") + + fixed_lines.append(line) + + code = '\n'.join(fixed_lines) + + # Fix 8: Invalid syntax patterns + # Fix kwargs.get calls with missing quotes or commas + pattern8 = r'kwargs\.get\s*\(\s*([^,\s\'"]+)\s+([^,)]+)\)' + if re.search(pattern8, code): + code = re.sub(pattern8, r"kwargs.get('\1', \2)", code) + fixes_applied.append("Fixed kwargs.get syntax") + + # Fix 9: Missing commas in assert statements + pattern9 = r'assert\s+([^,]+)\s+"([^"]+)"' + if re.search(pattern9, code): + code = re.sub(pattern9, r'assert \1, "\2"', code) + fixes_applied.append("Fixed assert statement syntax") + + # Fix 10: Function parameter issues + # Fix missing commas between parameters like "self param1" + pattern10 = r'\(\s*self\s+(\w+[^,)]*)\)' + if re.search(pattern10, code): + code = re.sub(pattern10, r'(self, \1)', code) + fixes_applied.append("Fixed function parameter syntax") + + # Save the fixed code if changes were made + if code != original_code: + with open(filepath, 'w') as f: + f.write(code) + + return { + 'fixed': True, + 'fixes_applied': fixes_applied, + 'fixes_count': len(fixes_applied) + } + else: + return { + 'fixed': False, + 'fixes_applied': [], + 'fixes_count': 0 + } + + def fix_all_architectures(self) -> Dict[str, any]: + """Fix all architecture files in the MLX directory""" + if not self.mlx_dir.exists(): + return {'error': f"Directory {self.mlx_dir} does not exist"} + + architecture_files = list(self.mlx_dir.glob("*_mlx.py")) + + if not architecture_files: + return {'error': f"No MLX architecture files found in {self.mlx_dir}"} + + print(f"šŸ”§ Found {len(architecture_files)} architecture files to fix") + + results = { + 'total_files': len(architecture_files), + 'files_fixed': 0, + 'files_unchanged': 0, + 'total_fixes': 0, + 'per_file_results': {} + } + + for i, filepath in enumerate(architecture_files, 1): + arch_name = filepath.stem.replace('_mlx', '') + print(f"[{i:3d}/{len(architecture_files)}] Fixing {arch_name}...") + + try: + file_result = self.fix_file(filepath) + results['per_file_results'][arch_name] = file_result + + if file_result['fixed']: + results['files_fixed'] += 1 + results['total_fixes'] += file_result['fixes_count'] + print(f" āœ… Applied {file_result['fixes_count']} fixes:") + for fix in file_result['fixes_applied']: + print(f" - {fix}") + else: + results['files_unchanged'] += 1 + print(f" ā„¹ļø No fixes needed") + + except Exception as e: + print(f" āŒ Error fixing {arch_name}: {e}") + results['per_file_results'][arch_name] = { + 'fixed': False, + 'error': str(e) + } + + # Save results + with open('architecture_fix_results.json', 'w') as f: + json.dump(results, f, indent=2) + + print(f"\nšŸ“Š Fix Summary:") + print(f"Total files: {results['total_files']}") + print(f"Files fixed: {results['files_fixed']}") + print(f"Files unchanged: {results['files_unchanged']}") + print(f"Total fixes applied: {results['total_fixes']}") + print(f"Results saved to: architecture_fix_results.json") + + return results + +def main(): + """Run the architecture syntax fixer""" + print("šŸ”§ Architecture Syntax Fixer") + print("=" * 50) + + fixer = ArchitectureSyntaxFixer() + results = fixer.fix_all_architectures() + + if 'error' in results: + print(f"āŒ Error: {results['error']}") + return False + + # Check if we made significant progress + success_rate = results['files_fixed'] / results['total_files'] * 100 + print(f"\nšŸŽÆ Success Rate: {success_rate:.1f}% of files modified") + + return success_rate > 0 + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) \ No newline at end of file diff --git a/mlx_architectures/delta_net_abrgf_mlx.py b/mlx_architectures/delta_net_abrgf_mlx.py index 75b7533..c53c5ac 100644 --- a/mlx_architectures/delta_net_abrgf_mlx.py +++ b/mlx_architectures/delta_net_abrgf_mlx.py @@ -10,11 +10,10 @@ import mlx.nn as nn from typing import Tuple, Optional, List, Dict -def _rearrange(tensor:, mx.array, pattern: str, **kwargs) -> mx.array: +def _rearrange(tensor: mx.array, pattern: str, **kwargs) -> mx.array: """Simple einops rearrange replacement for common patterns""" if "b l (h d) -> b l h d" in pattern: - h = kwargs.get('h' - kwargs.get('d', 1)) + h = kwargs.get('h', kwargs.get('d', 1)) b, l, hd = tensor.shape d = hd // h return tensor.reshape(b, l, h, d) @@ -53,30 +52,30 @@ def _get_unpad_data(attention_mask): max_len = attention_mask.shape[-1] return indices, cu_seqlens, max_len -def _index_first_axis(tensor:, mx.array, indices: mx.array) -> mx.array: +def _index_first_axis(tensor: mx.array, indices: mx.array) -> mx.array: """Index first axis""" return tensor[indices] -def _pad_input(tensor:, mx.array, indices: mx.array, batch_size: int, seq_len: int) -> mx.array: +def _pad_input(tensor: mx.array, indices: mx.array, batch_size: int, seq_len: int) -> mx.array: """Pad input back to original shape""" # Simplified version return tensor.reshape(batch_size, seq_len, -1) class _ShortConvolution(nn.Module): """MLX replacement for FLA ShortConvolution""" - def __init__(self, hidden_size: int - kernel_size: int = 4 - activation: str = None - bias: bool = False): + def __init__(self, hidden_size: int, + kernel_size: int = 4, + activation: str = None, + bias: bool = False): super().__init__() - self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size - padding=kernel_size-1 - bias=bias) + self.conv = nn.Conv1d(hidden_size, hidden_size, kernel_size, + padding=kernel_size-1, + bias=bias) self.activation = activation - def __call__(self, x, cache=None - output_final_state=False - cu_seqlens=None): + def __call__(self, x, cache=None, + output_final_state=False, + cu_seqlens=None): # x: (B, L, D) x_conv = x.transpose(0, 2, 1) # (B, D, L) out = self.conv(x_conv) @@ -89,8 +88,7 @@ def __call__(self, x, cache=None out = nn.gelu(out) if output_final_state: - return out - None # Simplified - no cache state + return out, None # Simplified - no cache state return out diff --git a/mlx_architectures/delta_net_acfg_mlx.py b/mlx_architectures/delta_net_acfg_mlx.py index f79d6dc..624afd3 100644 --- a/mlx_architectures/delta_net_acfg_mlx.py +++ b/mlx_architectures/delta_net_acfg_mlx.py @@ -10,11 +10,10 @@ import mlx.nn as nn from typing import Tuple, Optional, List, Dict -def _rearrange(tensor:, mx.array, pattern: str, **kwargs) -> mx.array: +def _rearrange(tensor: mx.array, pattern: str, **kwargs) -> mx.array: """Simple einops rearrange replacement for common patterns""" if "b l (h d) -> b l h d" in pattern: - h = kwargs.get('h' - kwargs.get('d', 1)) + h = kwargs.get('h', kwargs.get('d', 1)) b, l, hd = tensor.shape d = hd // h return tensor.reshape(b, l, h, d) @@ -373,8 +372,7 @@ def forward( output_attentions: Optional[bool] = False # kept for API compatibility **kwargs) -> Tuple[mx.array, None, Optional["Cache"]]: if attention_mask is not None: - assert attention_mask.ndim == 2 "attention_mask must be (batch - seq_len)" + assert attention_mask.ndim == 2, "attention_mask must be (batch, seq_len)" batch_size, seq_len_full, _ = hidden_states.shape # ---------------- cache retrieval ---------------- diff --git a/mlx_architectures/delta_net_adgr_mlx.py b/mlx_architectures/delta_net_adgr_mlx.py index 792770a..a2f6bd5 100644 --- a/mlx_architectures/delta_net_adgr_mlx.py +++ b/mlx_architectures/delta_net_adgr_mlx.py @@ -10,11 +10,10 @@ import mlx.nn as nn from typing import Tuple, Optional, List, Dict -def _rearrange(tensor:, mx.array, pattern: str, **kwargs) -> mx.array: +def _rearrange(tensor: mx.array, pattern: str, **kwargs) -> mx.array: """Simple einops rearrange replacement for common patterns""" if "b l (h d) -> b l h d" in pattern: - h = kwargs.get('h' - kwargs.get('d', 1)) + h = kwargs.get('h', kwargs.get('d', 1)) b, l, hd = tensor.shape d = hd // h return tensor.reshape(b, l, h, d) diff --git a/mlx_architectures/delta_net_aefg_hr_mlx.py b/mlx_architectures/delta_net_aefg_hr_mlx.py index 98c0de2..78f79da 100644 --- a/mlx_architectures/delta_net_aefg_hr_mlx.py +++ b/mlx_architectures/delta_net_aefg_hr_mlx.py @@ -10,11 +10,10 @@ import mlx.nn as nn from typing import Tuple, Optional, List, Dict -def _rearrange(tensor:, mx.array, pattern: str, **kwargs) -> mx.array: +def _rearrange(tensor: mx.array, pattern: str, **kwargs) -> mx.array: """Simple einops rearrange replacement for common patterns""" if "b l (h d) -> b l h d" in pattern: - h = kwargs.get('h' - kwargs.get('d', 1)) + h = kwargs.get('h', kwargs.get('d', 1)) b, l, hd = tensor.shape d = hd // h return tensor.reshape(b, l, h, d) diff --git a/mlx_architectures/delta_net_aeoc_mlx.py b/mlx_architectures/delta_net_aeoc_mlx.py index df80448..e921145 100644 --- a/mlx_architectures/delta_net_aeoc_mlx.py +++ b/mlx_architectures/delta_net_aeoc_mlx.py @@ -151,7 +151,7 @@ def sum_norm(x): # Causal chunked delta memory kernel # --------------------------------------------- @mx.compile -def delta_rule_chunkwise +def delta_rule_chunkwise( q: mx.array, k: mx.array, v: mx.array, @@ -159,7 +159,7 @@ def delta_rule_chunkwise *, chunk_size: int = 32): b, h, L, d_k = q.shape - d_v = v.shape[-1] + d_v = v.shape[-1] pad_len = (chunk_size - L % chunk_size) % chunk_size if pad_len: q = mx.pad(q diff --git a/mlx_architectures/delta_net_cagf_br_mlx.py b/mlx_architectures/delta_net_cagf_br_mlx.py index 847f041..bc16bb1 100644 --- a/mlx_architectures/delta_net_cagf_br_mlx.py +++ b/mlx_architectures/delta_net_cagf_br_mlx.py @@ -10,11 +10,10 @@ import mlx.nn as nn from typing import Tuple, Optional, List, Dict -def _rearrange(tensor:, mx.array, pattern: str, **kwargs) -> mx.array: +def _rearrange(tensor: mx.array, pattern: str, **kwargs) -> mx.array: """Simple einops rearrange replacement for common patterns""" if "b l (h d) -> b l h d" in pattern: - h = kwargs.get('h' - kwargs.get('d', 1)) + h = kwargs.get('h', kwargs.get('d', 1)) b, l, hd = tensor.shape d = hd // h return tensor.reshape(b, l, h, d) diff --git a/mlx_architectures/delta_net_cagf_mf_mlx.py b/mlx_architectures/delta_net_cagf_mf_mlx.py index ea59662..28fd4bb 100644 --- a/mlx_architectures/delta_net_cagf_mf_mlx.py +++ b/mlx_architectures/delta_net_cagf_mf_mlx.py @@ -10,11 +10,10 @@ import mlx.nn as nn from typing import Tuple, Optional, List, Dict -def _rearrange(tensor:, mx.array, pattern: str, **kwargs) -> mx.array: +def _rearrange(tensor: mx.array, pattern: str, **kwargs) -> mx.array: """Simple einops rearrange replacement for common patterns""" if "b l (h d) -> b l h d" in pattern: - h = kwargs.get('h' - kwargs.get('d', 1)) + h = kwargs.get('h', kwargs.get('d', 1)) b, l, hd = tensor.shape d = hd // h return tensor.reshape(b, l, h, d)