Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/opengradient/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DEFAULT_RPC_URL,
)
from .types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM
from .utils import generate_nonce

OG_CONFIG_FILE = Path.home() / ".opengradient_config.json"

Expand Down Expand Up @@ -280,15 +281,16 @@ def upload_file(obj, file_path: Path, repo_name: str, version: str):
@click.option(
"--mode", "inference_mode", type=click.Choice(InferenceModes.keys()), default="VANILLA", help="Inference mode (default: VANILLA)"
)
@click.option("--input", "-d", "input_data", type=Dict, help="Input data for inference as a JSON string")
@click.option("--input", "-d", "input_data", type=Dict, help="Input data for inference as a JSON string, used for TEE inferences")
@click.option(
"--input-file",
"-f",
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, path_type=Path),
help="JSON file containing input data for inference",
)
@click.option("--nonce", type=str, required=False, help="A 20 character long hexadecimal nonce")
@click.pass_context
def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path):
def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path, nonce: Optional[str]):
"""
Run inference on a model.

Expand Down Expand Up @@ -322,7 +324,14 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path
model_input = json.load(file)

click.echo(f'Running {inference_mode} inference for model "{model_cid}"')
inference_result = client.infer(model_cid=model_cid, inference_mode=InferenceModes[inference_mode], model_input=model_input)
if InferenceModes[inference_mode] is InferenceMode.TEE:
if nonce is None:
nonce = generate_nonce()
click.echo(f'No nonce provided for TEE inference, generating random nonce: {nonce}')
else:
click.echo(f'Using provided nonce for TEE inference: {nonce}')

inference_result = client.infer(model_cid=model_cid, inference_mode=InferenceModes[inference_mode], model_input=model_input, tee_nonce=nonce)

click.echo() # Add a newline for better spacing
click.secho("✅ Transaction successful", fg="green", bold=True)
Expand All @@ -340,6 +349,8 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path
inference_result.model_output, indent=2, default=lambda x: x.tolist() if hasattr(x, "tolist") else str(x)
)
click.echo(formatted_output)
click.secho("Attestation document: ", fg="yellow")
click.echo(inference_result.attestation)
except json.JSONDecodeError as e:
click.echo(f"Error decoding JSON: {e}", err=True)
click.echo(f"Error occurred on line {e.lineno}, column {e.colno}", err=True)
Expand Down
9 changes: 7 additions & 2 deletions src/opengradient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
FileUploadResult,
)
from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS
from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output
from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output, get_attestation, generate_nonce

_FIREBASE_CONFIG = {
"apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
Expand Down Expand Up @@ -271,6 +271,7 @@ def infer(
inference_mode: InferenceMode,
model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
max_retries: Optional[int] = None,
tee_nonce: Optional[str] = generate_nonce(),
) -> InferenceResult:
"""
Perform inference on a model.
Expand Down Expand Up @@ -325,7 +326,11 @@ def execute_transaction():
# TODO: This should return a ModelOutput class object
model_output = convert_to_model_output(parsed_logs[0]["args"])

return InferenceResult(tx_hash.hex(), model_output)
# TODO: Temporary measure for getting remote attestation
if InferenceMode(inference_mode_uint8) is InferenceMode.TEE:
attestation = get_attestation(tee_nonce)

return InferenceResult(tx_hash.hex(), model_output, attestation)

return run_with_retry(execute_transaction, max_retries)

Expand Down
2 changes: 2 additions & 0 deletions src/opengradient/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ class InferenceResult:
This class has two fields
transaction_hash (str): Blockchain hash for the transaction
model_output (Dict[str, np.ndarray]): Output of the ONNX model
attestation (Optinonal[Dict[str, str]]): Decoded attestation document for TEE inferences
"""

transaction_hash: str
model_output: Dict[str, np.ndarray]
attestation: Optional[Dict[str, str]]


@dataclass
Expand Down
31 changes: 31 additions & 0 deletions src/opengradient/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

from .types import ModelOutput

import requests
import os
import binascii


def convert_to_fixed_point(number: float) -> Tuple[int, int]:
"""
Expand Down Expand Up @@ -220,3 +224,30 @@ def convert_array_to_model_output(array_data: List) -> ModelOutput:
jsons=json_data,
is_simulation_result=array_data[3],
)

def get_attestation(nonce: str) -> str:
"""Gets """
attestation_url = f"https://18.224.61.175/enclave/attestation"
params = {
"nonce" : nonce
}
print("nonce is: ", nonce)

try:
attestation_response=requests.get(attestation_url, params=params, verify=False)

if attestation_response.status_code == 200:
attestation_document = attestation_response.text
else:
raise RuntimeError("Attestation request failed, error status %s: %s", attestation_response.status_code, attestation_response.text)
except requests.exceptions.RequestException as e:
raise RuntimeError("Error occured while requesting attestation document: %s", e)
except Exception as e:
raise RuntimeError("Verification for attestation document failed: %s" % e)

return attestation_document

def generate_nonce() -> str:
"""Generate nonce for TEE inferences."""
random_bytes = os.urandom(20)
return binascii.hexlify(random_bytes).decode('ascii')