diff --git a/fiber/chain/commitments.py b/fiber/chain/commitments.py index fe06cf2..1dc523e 100644 --- a/fiber/chain/commitments.py +++ b/fiber/chain/commitments.py @@ -1,3 +1,5 @@ +from typing import Any + from scalecodec import ScaleType from substrateinterface import Keypair, SubstrateInterface from tenacity import retry, stop_after_attempt, wait_exponential @@ -24,23 +26,48 @@ def _serialize_commitment_field(field: CommitmentDataField) -> dict[str, bytes]: return {serialized_data_type: data} -def _deserialize_commitment_field(field: dict[str, bytes | str]) -> CommitmentDataField: - data_type, data = field.items().__iter__().__next__() +def _deserialize_commitment_field(field: dict[str, Any]) -> CommitmentDataField: + # Extract the single key/value pair + data_type, data = next(iter(field.items())) + # Handle explicit empty marker if data_type == EMPTY_COMMITMENT_FIELD_TYPE: return None + # Normalize the payload into bytes + def _normalize_to_bytes(value: Any) -> bytes: + if isinstance(value, (bytes, bytearray)): + return bytes(value) + if isinstance(value, str): + # Expect hex string like 0x... ; fall back to utf-8 if not hex + if value.startswith("0x"): + return bytes.fromhex(value[2:]) + return value.encode("utf-8") + if isinstance(value, (list, tuple)): + # Unwrap single-element nesting like ((1,2,3),) + current: Any = value + while isinstance(current, (list, tuple)) and len(current) == 1 and isinstance(current[0], (list, tuple)): + current = current[0] + # Sequence of integers + return bytes(current) + # Last resort: string conversion + return str(value).encode("utf-8") + + # Support Raw fields that may include a length suffix (e.g., "Raw83") if data_type.startswith(CommitmentDataFieldType.RAW.value): - expected_field_data_length = int(data_type[len(CommitmentDataFieldType.RAW.value):]) + suffix = data_type[len(CommitmentDataFieldType.RAW.value):] + expected_length = int(suffix) if suffix.isdigit() else None data_type = CommitmentDataFieldType.RAW.value - data = bytes.fromhex(data[2:]) - - if len(data) != expected_field_data_length: - raise ValueError(f"Got commitment raw field expecting {expected_field_data_length} data but got {len(data)} data") - - field: CommitmentDataField = (CommitmentDataFieldType(data_type), data) + data_bytes = _normalize_to_bytes(data) + if expected_length is not None and len(data_bytes) != expected_length: + raise ValueError( + f"Got commitment raw field expecting {expected_length} data but got {len(data_bytes)} data" + ) + return (CommitmentDataFieldType(data_type), data_bytes) - return field + # Non-raw fields + data_bytes = _normalize_to_bytes(data) + return (CommitmentDataFieldType(data_type), data_bytes) @retry( @@ -146,13 +173,37 @@ def query_commitment( netuid, hotkey, block, - ).value + ) if not value: return None - fields: list[dict[str, bytes]] = value["info"]["fields"] - mapped_fields = [_deserialize_commitment_field(field) for field in fields] + # The chain may return nested tuples/lists, e.g., (({...},),) + raw_fields: Any = value["info"]["fields"] + + # Unwrap nested containers until we get a sequence of dicts + def _unwrap_fields(container: Any) -> list[dict[str, Any]]: + current = container + while True: + if current is None: + return [] + if isinstance(current, dict): + # Single dict; wrap in list + return [current] + if isinstance(current, (list, tuple)): + if len(current) == 0: + return [] + first = current[0] + if isinstance(first, dict): + return list(current) # type: ignore[return-value] + # Keep unwrapping one level + current = first + continue + # Unknown shape; return empty to be safe + return [] + + fields_list = _unwrap_fields(raw_fields) + mapped_fields = [_deserialize_commitment_field(field) for field in fields_list] return CommitmentQuery( fields=mapped_fields, @@ -211,6 +262,9 @@ def get_raw_commitment( f"Commitment for {hotkey} in netuid {netuid} is of type {data_type.value} and not {CommitmentDataFieldType.RAW.value}" ) + if commitment is None: + return None + return RawCommitmentQuery( data=data, block=commitment.block,