Skip to content

Commit

Permalink
support encoding from 'atjson' format ($bytes, $link)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Feb 21, 2024
1 parent db073f5 commit 72cca91
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 7 deletions.
21 changes: 17 additions & 4 deletions src/cbrrr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,25 @@ def __init__(self, cid_bytes: bytes) -> None:

@classmethod
def cidv1_dag_cbor_sha256_32_from(cls, data: bytes) -> "CID":
return CIDV1_DAG_CBOR_SHA256_32_PFX + hashlib.sha256(data).digest()
return cls(CIDV1_DAG_CBOR_SHA256_32_PFX + hashlib.sha256(data).digest())

@classmethod
def cidv1_raw_sha256_32_from(cls, data: bytes) -> "CID":
return CIDV1_RAW_SHA256_32_PFX + hashlib.sha256(data).digest()
return cls(CIDV1_RAW_SHA256_32_PFX + hashlib.sha256(data).digest())

@classmethod
def decode(cls, data: bytes | str) -> "CID":
if type(data) is bytes:
return cls(data) # TODO: is this correct??? should we check for and strip leading 0?

if data.startswith("b"):
data = data[1:].rstrip("=") # strip b, and existing padding
data += "=" * ((-len(data)) % 8) # add back correct amount of padding (python is fussy)
decoded = base64.b32decode(data, casefold=True) # TODO: do we care about map01?
return cls(decoded)

raise ValueError("I don't know how to decode this CID")

def encode(self, base="base32") -> str:
"""
Encode to base32
Expand Down Expand Up @@ -59,5 +72,5 @@ def parse_dag_cbor(data: bytes) -> DagCborTypes:
raise ValueError("did not parse to end of buffer")
return parsed

def encode_dag_cbor(obj: DagCborTypes) -> bytes:
return _cbrrr.encode_dag_cbor(obj, CID)
def encode_dag_cbor(obj: DagCborTypes, atjson_mode=False) -> bytes:
return _cbrrr.encode_dag_cbor(obj, CID, atjson_mode)
176 changes: 173 additions & 3 deletions src/cbrrr/_cbrrr.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
static PyObject *PY_ZERO;
static PyObject *PY_UINT64_MAX;
static PyObject *PY_UINT64_MAX_INVERTED;
static PyObject *PY_STRING_DECODE;

typedef enum {
DCMT_UNSIGNED_INT = 0,
Expand Down Expand Up @@ -462,6 +463,10 @@ cbrrr_parse_dag_cbor(PyObject *self, PyObject *args)







static int
cbrrr_buf_make_room(CbrrrBuf *buf, size_t len)
{
Expand Down Expand Up @@ -531,6 +536,105 @@ cbrrr_write_cbor_varint(CbrrrBuf *buf, DCMajorType type, uint64_t value)
return cbrrr_buf_write(buf, tmp, 9);
}




/*
Decodes maybe-padded base64 according to https://atproto.com/specs/data-model#bytes (RFC-4648, section 4)
Returns 0 on success, -1 on failure (setting a python exception).
*/

static const uint8_t B64_DECODE_LUT[] = {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,
-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,
-1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
};

static int
cbrrr_write_cbor_bytes_from_b64(CbrrrBuf *buf, const unsigned char *b64_str, size_t str_len)
{
// strip padding
while (str_len && b64_str[str_len - 1] == '=') str_len--;
if ((str_len % 4) == 1) {
PyErr_SetString(PyExc_ValueError, "invalid b64 length");
return -1;
}
size_t decoded_length = (str_len*3)/4;
if (cbrrr_write_cbor_varint(buf, DCMT_BYTE_STRING, decoded_length) < 0) {
return -1;
}
if (cbrrr_buf_make_room(buf, decoded_length) < 0) {
return -1;
}
uint8_t *bufptr = buf->buf + buf->length;
buf->length += decoded_length;

size_t str_i = 0;
uint8_t a, b, c, d;
while (str_i+3 < str_len) {
a = B64_DECODE_LUT[b64_str[str_i++]];
b = B64_DECODE_LUT[b64_str[str_i++]];
c = B64_DECODE_LUT[b64_str[str_i++]];
d = B64_DECODE_LUT[b64_str[str_i++]];
if ((a | b | c | d) & 0x80) {
PyErr_SetString(PyExc_ValueError, "invalid b64 character");
return -1;
}
*bufptr++ = (a << 2) | ((b >> 4) & 0x03);
*bufptr++ = (b << 4) | ((c >> 2) & 0x0f);
*bufptr++ = (c << 6) | ((d >> 0) & 0x3f);
}
switch (str_len - str_i)
{
case 3:
a = B64_DECODE_LUT[b64_str[str_i++]];
b = B64_DECODE_LUT[b64_str[str_i++]];
c = B64_DECODE_LUT[b64_str[str_i++]];
if ((a | b | c) & 0x80) {
PyErr_SetString(PyExc_ValueError, "invalid b64 character");
return -1;
}
*bufptr++ = (a << 2) | ((b >> 4) & 0x03);
*bufptr++ = (b << 4) | ((c >> 2) & 0x0f);
// should we check (c << 6) & 0xff == 0?
break;

case 2:
a = B64_DECODE_LUT[b64_str[str_i++]];
b = B64_DECODE_LUT[b64_str[str_i++]];
if ((a | b) & 0x80) {
PyErr_SetString(PyExc_ValueError, "invalid b64 character");
return -1;
}
*bufptr++ = (a << 2) | ((b >> 4) & 0x03);
// should we check (b << 4) & 0xff == 0?
break;

case 0:
break;

default:
PyErr_SetString(PyExc_AssertionError, "unreachable!?");
return -1;
}

return 0;
}



static int
cbrrr_compare_map_keys(const void *a, const void *b)
{
Expand All @@ -557,7 +661,7 @@ cbrrr_compare_map_keys(const void *a, const void *b)


static int
cbrrr_encode_object(CbrrrBuf *buf, PyObject *obj_in, PyObject* cid_type)
cbrrr_encode_object(CbrrrBuf *buf, PyObject *obj_in, PyObject* cid_type, int atjson_mode)
{
/*
in a slightly unscientific test, frequency counts for each type
Expand Down Expand Up @@ -712,6 +816,62 @@ cbrrr_encode_object(CbrrrBuf *buf, PyObject *obj_in, PyObject* cid_type)
if (keys == NULL) {
break;
}
// TODO: verify all are strings here?
if (atjson_mode && PySequence_Fast_GET_SIZE(keys) == 1) {// logic for $link, $bytes
PyObject *key = PySequence_Fast_GET_ITEM(keys, 0);
if (!PyUnicode_CheckExact(key)) {
PyErr_SetString(PyExc_TypeError, "map keys must be strings");
Py_DECREF(keys);
break;
}
size_t string_len;
const uint8_t *str = PyUnicode_AsUTF8AndSize(key, &string_len); // does this fail gracefully if the item is not a string?
if (str == NULL) {
Py_DECREF(keys);
break;
}
if (string_len == 5 && strcmp(str, "$link") == 0) { // CID
PyObject *cid_str = PyDict_GetItem(obj, key); // borrowed
if (!PyUnicode_CheckExact(cid_str)) { // also handles the case where b64_str is NULL
PyErr_SetString(PyExc_TypeError, "$link field value must be a string");
Py_DECREF(keys);
break;
}
PyObject *cid = PyObject_CallMethodOneArg(cid_type, PY_STRING_DECODE, cid_str);
if (cid == NULL) {
Py_DECREF(keys);
break;
}
if(cbrrr_encode_object(buf, cid, cid_type, 0) < 0) { // call ourselves recursively (max 1 level of recursion though)
Py_DECREF(keys);
Py_DECREF(cid);
break;
}
Py_DECREF(keys);
Py_DECREF(cid);
continue;
}
if (string_len == 6 && strcmp(str, "$bytes") == 0) { // bytes
PyObject *b64_str = PyDict_GetItem(obj, key); // borrowed
if (!PyUnicode_CheckExact(b64_str)) { // also handles the case where b64_str is NULL
PyErr_SetString(PyExc_TypeError, "$bytes field value must be a string");
Py_DECREF(keys);
break;
}
str = PyUnicode_AsUTF8AndSize(b64_str, &string_len); // reusing these variables
if (str == NULL) {
Py_DECREF(keys);
break;
}
if (cbrrr_write_cbor_bytes_from_b64(buf, str, string_len)) {
Py_DECREF(keys);
break;
}
Py_DECREF(keys);
continue;
}
// fallthru
}
qsort( // it's a bit janky but we can sort the key list in-place, I think?
PySequence_Fast_ITEMS(keys),
PySequence_Fast_GET_SIZE(keys),
Expand Down Expand Up @@ -812,17 +972,19 @@ cbrrr_encode_object(CbrrrBuf *buf, PyObject *obj_in, PyObject* cid_type)




static PyObject *
cbrrr_encode_dag_cbor(PyObject *self, PyObject *args)
{
PyObject *obj;
PyObject *cid_type;
PyObject *res;
int atjson_mode;
CbrrrBuf buf;

(void)self; // unused

if (!PyArg_ParseTuple(args, "OO", &obj, &cid_type)) {
if (!PyArg_ParseTuple(args, "OOp", &obj, &cid_type, &atjson_mode)) {
return NULL;
}

Expand All @@ -834,7 +996,7 @@ cbrrr_encode_dag_cbor(PyObject *self, PyObject *args)
return NULL;
}

if (cbrrr_encode_object(&buf, obj, cid_type) < 0) {
if (cbrrr_encode_object(&buf, obj, cid_type, atjson_mode) < 0) {
res = NULL;
} else {
res = PyBytes_FromStringAndSize((const char*)buf.buf, buf.length); // nb: this incurs a copy
Expand Down Expand Up @@ -898,5 +1060,13 @@ PyInit__cbrrr(void)
return NULL;
}

PY_STRING_DECODE = PyUnicode_FromString("decode"); // TODO: do we care about interning?
if (PY_STRING_DECODE == NULL) {
Py_DECREF(PY_ZERO);
Py_DECREF(PY_UINT64_MAX_INVERTED);
Py_DECREF(PY_STRING_DECODE);
return NULL;
}

return m;
}

0 comments on commit 72cca91

Please sign in to comment.