diff --git a/src/opengradient/cli.py b/src/opengradient/cli.py index 6c291d5..fbf076d 100644 --- a/src/opengradient/cli.py +++ b/src/opengradient/cli.py @@ -19,6 +19,7 @@ DEFAULT_RPC_URL, ) from .types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM +from .utils import generate_nonce OG_CONFIG_FILE = Path.home() / ".opengradient_config.json" @@ -280,15 +281,16 @@ def upload_file(obj, file_path: Path, repo_name: str, version: str): @click.option( "--mode", "inference_mode", type=click.Choice(InferenceModes.keys()), default="VANILLA", help="Inference mode (default: VANILLA)" ) -@click.option("--input", "-d", "input_data", type=Dict, help="Input data for inference as a JSON string") +@click.option("--input", "-d", "input_data", type=Dict, help="Input data for inference as a JSON string, used for TEE inferences") @click.option( "--input-file", "-f", type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, path_type=Path), help="JSON file containing input data for inference", ) +@click.option("--nonce", type=str, required=False, help="A 20 character long hexadecimal nonce") @click.pass_context -def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path): +def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path, nonce: Optional[str]): """ Run inference on a model. @@ -322,7 +324,14 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path model_input = json.load(file) click.echo(f'Running {inference_mode} inference for model "{model_cid}"') - inference_result = client.infer(model_cid=model_cid, inference_mode=InferenceModes[inference_mode], model_input=model_input) + if InferenceModes[inference_mode] is InferenceMode.TEE: + if nonce is None: + nonce = generate_nonce() + click.echo(f'No nonce provided for TEE inference, generating random nonce: {nonce}') + else: + click.echo(f'Using provided nonce for TEE inference: {nonce}') + + inference_result = client.infer(model_cid=model_cid, inference_mode=InferenceModes[inference_mode], model_input=model_input, tee_nonce=nonce) click.echo() # Add a newline for better spacing click.secho("✅ Transaction successful", fg="green", bold=True) @@ -340,6 +349,8 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path inference_result.model_output, indent=2, default=lambda x: x.tolist() if hasattr(x, "tolist") else str(x) ) click.echo(formatted_output) + click.secho("Attestation document: ", fg="yellow") + click.echo(inference_result.attestation) except json.JSONDecodeError as e: click.echo(f"Error decoding JSON: {e}", err=True) click.echo(f"Error occurred on line {e.lineno}, column {e.colno}", err=True) diff --git a/src/opengradient/client.py b/src/opengradient/client.py index e84df2d..3688fbc 100644 --- a/src/opengradient/client.py +++ b/src/opengradient/client.py @@ -29,7 +29,7 @@ FileUploadResult, ) from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT, DEFAULT_SCHEDULER_ADDRESS -from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output +from .utils import convert_array_to_model_output, convert_to_model_input, convert_to_model_output, get_attestation, generate_nonce _FIREBASE_CONFIG = { "apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w", @@ -271,6 +271,7 @@ def infer( inference_mode: InferenceMode, model_input: Dict[str, Union[str, int, float, List, np.ndarray]], max_retries: Optional[int] = None, + tee_nonce: Optional[str] = generate_nonce(), ) -> InferenceResult: """ Perform inference on a model. @@ -325,7 +326,11 @@ def execute_transaction(): # TODO: This should return a ModelOutput class object model_output = convert_to_model_output(parsed_logs[0]["args"]) - return InferenceResult(tx_hash.hex(), model_output) + # TODO: Temporary measure for getting remote attestation + if InferenceMode(inference_mode_uint8) is InferenceMode.TEE: + attestation = get_attestation(tee_nonce) + + return InferenceResult(tx_hash.hex(), model_output, attestation) return run_with_retry(execute_transaction, max_retries) diff --git a/src/opengradient/types.py b/src/opengradient/types.py index 7a4421e..4622ac5 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -127,10 +127,12 @@ class InferenceResult: This class has two fields transaction_hash (str): Blockchain hash for the transaction model_output (Dict[str, np.ndarray]): Output of the ONNX model + attestation (Optinonal[Dict[str, str]]): Decoded attestation document for TEE inferences """ transaction_hash: str model_output: Dict[str, np.ndarray] + attestation: Optional[Dict[str, str]] @dataclass diff --git a/src/opengradient/utils.py b/src/opengradient/utils.py index 0b4a2fe..8d2cde1 100644 --- a/src/opengradient/utils.py +++ b/src/opengradient/utils.py @@ -10,6 +10,10 @@ from .types import ModelOutput +import requests +import os +import binascii + def convert_to_fixed_point(number: float) -> Tuple[int, int]: """ @@ -220,3 +224,30 @@ def convert_array_to_model_output(array_data: List) -> ModelOutput: jsons=json_data, is_simulation_result=array_data[3], ) + +def get_attestation(nonce: str) -> str: + """Gets """ + attestation_url = f"https://18.224.61.175/enclave/attestation" + params = { + "nonce" : nonce + } + print("nonce is: ", nonce) + + try: + attestation_response=requests.get(attestation_url, params=params, verify=False) + + if attestation_response.status_code == 200: + attestation_document = attestation_response.text + else: + raise RuntimeError("Attestation request failed, error status %s: %s", attestation_response.status_code, attestation_response.text) + except requests.exceptions.RequestException as e: + raise RuntimeError("Error occured while requesting attestation document: %s", e) + except Exception as e: + raise RuntimeError("Verification for attestation document failed: %s" % e) + + return attestation_document + +def generate_nonce() -> str: + """Generate nonce for TEE inferences.""" + random_bytes = os.urandom(20) + return binascii.hexlify(random_bytes).decode('ascii') \ No newline at end of file