Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,51 @@
# Python
*.pyc
__pycache__/
*.pyo
*.egg-info/
dist/
build/
*.egg

# Virtual environments
.venv/
venv/
env/

# MetaClaw runtime data
memory_data/skills/
records/
system_prompt_cache.json
evolution_history.jsonl
scheduler_state.json
*.pid

# RL training artifacts
wandb/
checkpoints/
*.ckpt

# MLX model cache (large downloads)
mlx_models/

# Smoke test temp files
tests/.smoke_records/

# OS junk
.DS_Store
Thumbs.db

# IDE
.vscode/
.idea/
*.swp
*.swo
*~

# Secrets
.env
config.yaml

# MLX training output
mlx_metaclaw_output/
*.egg-info/
80 changes: 80 additions & 0 deletions INTEGRATION_NOTES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# MLX Backend Integration Notes

## Files to add
- `metaclaw/mlx_backend/__init__.py`
- `metaclaw/mlx_backend/data_types.py`
- `metaclaw/mlx_backend/params.py`
- `metaclaw/mlx_backend/lora.py`
- `metaclaw/mlx_backend/service_client.py`
- `tests/test_mlx_backend.py`

## Files to replace
- `metaclaw/sdk_backend.py` (full replacement with MLX support)

## Files to patch (small edits)

### metaclaw/setup_wizard.py
# In metaclaw/setup_wizard.py, update line ~156:
#
# BEFORE:
# ["auto", "tinker", "mint"],
#
# AFTER:
# ["auto", "tinker", "mint", "mlx"],
#
# This adds "mlx" to the interactive backend selection menu.


### metaclaw/config.py
# In metaclaw/config.py, add these fields to MetaClawConfig:
#
# # MLX backend settings
# mlx_model_path: str = "" # local path or HF repo (e.g. mlx-community/Qwen2.5-7B-4bit)
# mlx_output_dir: str = "./mlx_metaclaw_output"
#
# Update training_backend_label() around line 168:
#
# BEFORE:
# def training_backend_label(self) -> str:
# return "MinT" if self.resolved_backend_key() == "mint" else "Tinker"
#
# AFTER:
# def training_backend_label(self) -> str:
# key = self.resolved_backend_key()
# if key == "mlx":
# return "MLX"
# return "MinT" if key == "mint" else "Tinker"
#
# Update training_backend_banner() around line 171:
#
# BEFORE:
# def training_backend_banner(self) -> str:
# return f"{self.training_backend_label()} cloud RL"
#
# AFTER:
# def training_backend_banner(self) -> str:
# label = self.training_backend_label()
# suffix = "local RL" if self.resolved_backend_key() == "mlx" else "cloud RL"
# return f"{label} {suffix}"


## Optional: pyproject.toml extras

```toml
[project.optional-dependencies]
mlx = ["mlx>=0.22.0", "mlx-lm>=0.21.0", "safetensors"]
```

## Usage

```bash
# Install with MLX extras
pip install -e ".[mlx]"

# Configure
metaclaw setup # select backend → mlx

# Or via env
export METACLAW_RL_BACKEND=mlx
metaclaw start
```
29 changes: 22 additions & 7 deletions metaclaw/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,21 @@ def _prompt_len(msgs):
raw_system = _flatten_message_content(m.get("content"))
break
if raw_system:
cached_system = await asyncio.to_thread(
run_llm,
[{"role": "user", "content": raw_system}],
)
# System prompt compression requires an external LLM API.
# When running with a local-only backend (e.g. MLX) and no
# llm_api_key configured, skip compression and use raw prompt.
if self.config.llm_api_key:
try:
cached_system = await asyncio.to_thread(
run_llm,
[{"role": "user", "content": raw_system}],
)
except Exception as exc:
logger.warning(
"[OpenClaw] system prompt compression failed, "
"using raw prompt: %s", exc,
)
cached_system = None
cached_system = (cached_system or raw_system).strip()
self._write_cached_system_prompt(cached_system)

Expand Down Expand Up @@ -953,13 +964,17 @@ async def _forward_to_backend(self, body: dict[str, Any]) -> dict[str, Any]:
sampling_params = self._sdk.SamplingParams(**sp_kwargs)

# Call active backend
response = await self._sampling_client.sample_async(
# include_prompt_logprobs / topk_prompt_logprobs are Tinker-specific;
# MLX (and potentially other local backends) don't support them.
sample_kwargs: dict[str, Any] = dict(
prompt=model_input,
num_samples=1,
sampling_params=sampling_params,
include_prompt_logprobs=False,
topk_prompt_logprobs=0,
)
if backend_key != "mlx":
sample_kwargs["include_prompt_logprobs"] = False
sample_kwargs["topk_prompt_logprobs"] = 0
response = await self._sampling_client.sample_async(**sample_kwargs)

# Decode response tokens → text
seq = response.sequences[0]
Expand Down
32 changes: 32 additions & 0 deletions metaclaw/mlx_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
MLX-native LoRA training backend for MetaClaw.

Provides a local, zero-cloud alternative to the Tinker and MinT backends
using Apple MLX on Apple Silicon. No API key or network required.
"""

from .data_types import (
Datum,
EncodedTextChunk,
ModelInput,
SampleResponse,
SampleSequence,
TensorData,
)
from .params import AdamParams, SamplingParams
from .service_client import SamplingClient, SaveStateResult, ServiceClient, TrainingClient

__all__ = [
"AdamParams",
"Datum",
"EncodedTextChunk",
"ModelInput",
"SampleResponse",
"SampleSequence",
"SamplingClient",
"SamplingParams",
"SaveStateResult",
"ServiceClient",
"TensorData",
"TrainingClient",
]
123 changes: 123 additions & 0 deletions metaclaw/mlx_backend/data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Data types that mirror the Tinker SDK surface used by data_formatter.py
and api_server.py.

Training path: TensorData, ModelInput.from_ints(), Datum
Inference path: EncodedTextChunk, ModelInput(chunks=...), SampleSequence, SampleResponse
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Optional

import mlx.core as mx


# ------------------------------------------------------------------ #
# Training types (used by data_formatter.py) #
# ------------------------------------------------------------------ #

@dataclass
class TensorData:
"""Thin wrapper around an MLX array, convertible from PyTorch tensors."""

array: mx.array

@classmethod
def from_torch(cls, tensor) -> "TensorData":
import numpy as np
arr = mx.array(tensor.detach().cpu().numpy())
return cls(array=arr)

def to_mlx(self) -> mx.array:
return self.array

def __len__(self) -> int:
return self.array.shape[0]


@dataclass
class Datum:
"""One training example in the tinker-cookbook RL convention."""

model_input: "ModelInput"
loss_fn_inputs: Dict[str, TensorData] = field(default_factory=dict)


# ------------------------------------------------------------------ #
# Inference types (used by api_server.py forward_to_backend) #
# ------------------------------------------------------------------ #

@dataclass
class EncodedTextChunk:
"""Mirrors tinker.EncodedTextChunk.

api_server.py calls:
chunk = sdk.EncodedTextChunk(tokens=list(prompt_ids), type="encoded_text")
"""
tokens: List[int]
type: str = "encoded_text"


@dataclass
class SampleSequence:
"""One generated sequence returned by SamplingClient.sample_async().

api_server.py reads:
seq = response.sequences[0]
seq.tokens -> list[int]
seq.logprobs -> list[float]
seq.stop_reason -> str
"""
tokens: List[int]
logprobs: List[float]
stop_reason: str = "stop"


@dataclass
class SampleResponse:
"""Container returned by SamplingClient.sample_async().

api_server.py reads: response.sequences[0]
"""
sequences: List[SampleSequence]


# ------------------------------------------------------------------ #
# ModelInput (dual-purpose: training + inference) #
# ------------------------------------------------------------------ #

@dataclass
class ModelInput:
"""Token sequence for model consumption.

Training path (data_formatter.py):
sdk.ModelInput.from_ints(all_tokens[:-1])
-> uses .tokens

Inference path (api_server.py):
sdk.ModelInput(chunks=[chunk])
-> uses .chunks[0].tokens
"""
tokens: Optional[mx.array] = None
chunks: Optional[List[EncodedTextChunk]] = None

@classmethod
def from_ints(cls, token_ids: List[int]) -> "ModelInput":
return cls(tokens=mx.array(token_ids, dtype=mx.int32))

def get_token_ids(self) -> List[int]:
"""Return plain list of ints regardless of how this was constructed."""
if self.tokens is not None:
return self.tokens.tolist()
if self.chunks:
return self.chunks[0].tokens
return []

def __len__(self) -> int:
if self.tokens is not None:
return self.tokens.shape[0]
if self.chunks:
return len(self.chunks[0].tokens)
return 0
59 changes: 59 additions & 0 deletions metaclaw/mlx_backend/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""LoRA layer injection and weight I/O using mlx_lm's built-in tuner."""

from __future__ import annotations

import logging
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten

logger = logging.getLogger(__name__)


def inject_lora(
model: nn.Module,
rank: int = 16,
alpha: float = 16.0,
num_layers: int = -1,
) -> nn.Module:
from mlx_lm.tuner.utils import linear_to_lora_layers

lora_cfg = {"rank": rank, "alpha": alpha, "dropout": 0.0, "scale": alpha / rank}

linear_to_lora_layers(
model,
num_layers=num_layers,
config=lora_cfg,
)

n_train = sum(p.size for _, p in tree_flatten(model.trainable_parameters()))
n_total = sum(p.size for _, p in tree_flatten(model.parameters()))
pct = 100 * n_train / n_total if n_total > 0 else 0
logger.info(
"[MLX-LoRA] injected adapters (rank=%d alpha=%.1f): "
"trainable=%d / %d params (%.2f%%)",
rank, alpha, n_train, n_total, pct,
)
return model


def save_lora_weights(model: nn.Module, path: str | Path) -> Path:
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
out_file = path / "adapters.safetensors"

trainable = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(out_file), trainable)
logger.info("[MLX-LoRA] saved %d tensors -> %s", len(trainable), out_file)
return out_file


def load_lora_weights(model: nn.Module, path: str | Path) -> nn.Module:
path = Path(path)
adapter_file = path / "adapters.safetensors" if path.is_dir() else path

model.load_weights(str(adapter_file), strict=False)
logger.info("[MLX-LoRA] loaded adapters <- %s", adapter_file)
return model
Loading