Skip to content

Commit cced020

Browse files
committed
Add 'stream' type
1 parent dd6bfde commit cced020

File tree

1 file changed

+138
-7
lines changed

1 file changed

+138
-7
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 138 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ class Own(ValType):
166166
class Borrow(ValType):
167167
rt: ResourceType
168168

169+
@dataclass
170+
class Stream(ValType):
171+
t: ValType
172+
169173
### Despecialization
170174

171175
def despecialize(t):
@@ -193,6 +197,7 @@ def alignment(t):
193197
case Variant(cases) : return alignment_variant(cases)
194198
case Flags(labels) : return alignment_flags(labels)
195199
case Own(_) | Borrow(_) : return 4
200+
case Stream(_) : return 4
196201

197202
def alignment_record(fields):
198203
a = 1
@@ -243,6 +248,7 @@ def elem_size(t):
243248
case Variant(cases) : return elem_size_variant(cases)
244249
case Flags(labels) : return elem_size_flags(labels)
245250
case Own(_) | Borrow(_) : return 4
251+
case Stream(_) : return 4
246252

247253
def elem_size_record(fields):
248254
s = 0
@@ -303,6 +309,8 @@ class ComponentInstance:
303309
active_sync_task: bool
304310
pending_sync_tasks: list[asyncio.Future]
305311
async_subtasks: Table[AsyncSubtask]
312+
readable_streams: Table[ReadableStreamElem]
313+
writable_streams: Table[WritableStreamElem]
306314

307315
def __init__(self):
308316
self.may_leave = True
@@ -314,6 +322,8 @@ def __init__(self):
314322
self.active_sync_task = False
315323
self.pending_sync_tasks = []
316324
self.async_subtasks = Table[AsyncSubtask]()
325+
self.readable_streams = Table[ReadableStreamElem]()
326+
self.writable_streams = Table[WritableStreamElem]()
317327

318328
class HandleTables:
319329
rt_to_table: MutableMapping[ResourceType, Table[HandleElem]]
@@ -388,6 +398,28 @@ def __init__(self, rep, own, scope = None):
388398
self.scope = scope
389399
self.lend_count = 0
390400

401+
class Buffer:
402+
cx: CallContext
403+
ptr: int
404+
length: int
405+
406+
def __init__(self, cx, ptr, length):
407+
self.cx = cx
408+
self.ptr = ptr
409+
self.length = length
410+
411+
class StreamElem:
412+
t: ValType
413+
active: bool
414+
writer: Optional[Buffer]
415+
reader: Optional[Buffer]
416+
417+
def __init__(self, t):
418+
self.t = t
419+
self.active = True
420+
self.writer = None
421+
self.reader = None
422+
391423
class AsyncCallState(IntEnum):
392424
STARTING = 0
393425
STARTED = 1
@@ -400,6 +432,10 @@ class EventCode(IntEnum):
400432
CALL_RETURNED = AsyncCallState.RETURNED
401433
CALL_DONE = AsyncCallState.DONE
402434
YIELDED = 4
435+
STREAM_READ = 5
436+
STREAM_CLOSED = 6
437+
STREAM_WROTE = 7
438+
STREAM_CANCELLED = 8
403439

404440
current_task = asyncio.Lock()
405441

@@ -489,6 +525,12 @@ def async_subtask_made_progress(self, subtask):
489525
subtask.enqueued = True
490526
self.events.put_nowait(subtask)
491527

528+
def stream_new(self, t):
529+
s = StreamElem(t)
530+
rsi = self.inst.readable_streams.add(s)
531+
wsi = self.inst.writable_streams.add(s)
532+
return (rsi, wsi)
533+
492534
def create_borrow(self):
493535
self.borrow_count += 1
494536

@@ -624,6 +666,7 @@ def load(cx, ptr, t):
624666
case Flags(labels) : return load_flags(cx, ptr, labels)
625667
case Own() : return lift_own(cx, load_int(cx, ptr, 4), t)
626668
case Borrow() : return lift_borrow(cx, load_int(cx, ptr, 4), t)
669+
case Stream() : return lift_stream(cx, load_int(cx, ptr, 4), t)
627670

628671
def load_int(cx, ptr, nbytes, signed = False):
629672
return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed)
@@ -704,11 +747,14 @@ def load_string_from_range(cx, ptr, tagged_code_units):
704747
def load_list(cx, ptr, elem_type):
705748
begin = load_int(cx, ptr, 4)
706749
length = load_int(cx, ptr + 4, 4)
707-
return load_list_from_range(cx, begin, length, elem_type)
750+
return lift_list(cx, begin, length, elem_type)
708751

709-
def load_list_from_range(cx, ptr, length, elem_type):
752+
def lift_list(cx, ptr, length, elem_type):
710753
trap_if(ptr != align_to(ptr, alignment(elem_type)))
711754
trap_if(ptr + length * elem_size(elem_type) > len(cx.opts.memory))
755+
return lift_list_after_checks(cx, ptr, length, elem_type)
756+
757+
def lift_list_after_checks(cx, ptr, length, elem_type):
712758
a = []
713759
for i in range(length):
714760
a.append(load(cx, ptr + i * elem_size(elem_type), elem_type))
@@ -772,6 +818,10 @@ def lift_borrow(cx, i, t):
772818
cx.track_owning_lend(h)
773819
return h.rep
774820

821+
def lift_stream(cx, i, t):
822+
# TODO
823+
pass
824+
775825
### Storing
776826

777827
def store(cx, v, t, ptr):
@@ -797,6 +847,7 @@ def store(cx, v, t, ptr):
797847
case Flags(labels) : store_flags(cx, v, ptr, labels)
798848
case Own() : store_int(cx, lower_own(cx.opts, v, t), ptr, 4)
799849
case Borrow() : store_int(cx, lower_borrow(cx.opts, v, t), ptr, 4)
850+
case Stream() : store_int(cx, lower_stream(cx.opts, v, t), ptr, 4)
800851

801852
def store_int(cx, v, ptr, nbytes, signed = False):
802853
cx.opts.memory[ptr : ptr+nbytes] = int.to_bytes(v, nbytes, 'little', signed=signed)
@@ -989,19 +1040,22 @@ def store_probably_utf16_to_latin1_or_utf16(cx, src, src_code_units):
9891040
return (ptr, latin1_size)
9901041

9911042
def store_list(cx, v, ptr, elem_type):
992-
begin, length = store_list_into_range(cx, v, elem_type)
1043+
begin, length = lower_list(cx, v, elem_type)
9931044
store_int(cx, begin, ptr, 4)
9941045
store_int(cx, length, ptr + 4, 4)
9951046

996-
def store_list_into_range(cx, v, elem_type):
1047+
def lower_list(cx, v, elem_type):
9971048
byte_length = len(v) * elem_size(elem_type)
9981049
trap_if(byte_length >= (1 << 32))
9991050
ptr = cx.opts.realloc(0, 0, alignment(elem_type), byte_length)
10001051
trap_if(ptr != align_to(ptr, alignment(elem_type)))
10011052
trap_if(ptr + byte_length > len(cx.opts.memory))
1053+
lower_list_after_checks(cx, v, ptr, elem_type)
1054+
return (ptr, len(v))
1055+
1056+
def lower_list_after_checks(cx, v, ptr, elem_type):
10021057
for i,e in enumerate(v):
10031058
store(cx, e, elem_type, ptr + i * elem_size(elem_type))
1004-
return (ptr, len(v))
10051059

10061060
def store_record(cx, v, ptr, fields):
10071061
for f in fields:
@@ -1052,6 +1106,10 @@ def lower_borrow(cx, rep, t):
10521106
cx.create_borrow()
10531107
return cx.inst.handles.add(t.rt, h)
10541108

1109+
def lower_stream(cx, rep, t):
1110+
# TODO
1111+
pass
1112+
10551113
### Flattening
10561114

10571115
MAX_FLAT_PARAMS = 16
@@ -1101,6 +1159,7 @@ def flatten_type(t):
11011159
case Variant(cases) : return flatten_variant(cases)
11021160
case Flags(labels) : return ['i32']
11031161
case Own(_) | Borrow(_) : return ['i32']
1162+
case Stream(_) : return ['i32']
11041163

11051164
def flatten_record(fields):
11061165
flat = []
@@ -1162,6 +1221,7 @@ def lift_flat(cx, vi, t):
11621221
case Flags(labels) : return lift_flat_flags(vi, labels)
11631222
case Own() : return lift_own(cx, vi.next('i32'), t)
11641223
case Borrow() : return lift_borrow(cx, vi.next('i32'), t)
1224+
case Stream() : return lift_stream(cx, vi.next('i32'), t)
11651225

11661226
def lift_flat_unsigned(vi, core_width, t_width):
11671227
i = vi.next('i' + str(core_width))
@@ -1184,7 +1244,7 @@ def lift_flat_string(cx, vi):
11841244
def lift_flat_list(cx, vi, elem_type):
11851245
ptr = vi.next('i32')
11861246
length = vi.next('i32')
1187-
return load_list_from_range(cx, ptr, length, elem_type)
1247+
return lift_list(cx, ptr, length, elem_type)
11881248

11891249
def lift_flat_record(cx, vi, fields):
11901250
record = {}
@@ -1248,6 +1308,7 @@ def lower_flat(cx, v, t):
12481308
case Flags(labels) : return lower_flat_flags(v, labels)
12491309
case Own() : return [lower_own(cx, v, t)]
12501310
case Borrow() : return [lower_borrow(cx, v, t)]
1311+
case Stream() : return [lower_stream(cx, v, t)]
12511312

12521313
def lower_flat_signed(i, core_bits):
12531314
if i < 0:
@@ -1259,7 +1320,7 @@ def lower_flat_string(cx, v):
12591320
return [ptr, packed_length]
12601321

12611322
def lower_flat_list(cx, v, elem_type):
1262-
(ptr, length) = store_list_into_range(cx, v, elem_type)
1323+
(ptr, length) = lower_list(cx, v, elem_type)
12631324
return [ptr, length]
12641325

12651326
def lower_flat_record(cx, v, fields):
@@ -1523,3 +1584,73 @@ async def canon_task_yield(task):
15231584
trap_if(task.opts.callback is not None)
15241585
await task.yield_()
15251586
return []
1587+
1588+
### 🔀 `canon stream.new`
1589+
1590+
async def canon_stream_new(t, task, ptr):
1591+
rsi, wsi = task.stream_new(t)
1592+
store(task, wsi, U32(), ptr)
1593+
return [rsi]
1594+
1595+
### 🔀 `canon stream.write`
1596+
1597+
1598+
# TODO: do we want read/write to both support full/partial?
1599+
async def canon_stream_write(t, sync, task, wsi, ptr, length, partial):
1600+
trap_if(length == 0 or length >= 2**31)
1601+
trap_if(ptr != align_to(ptr, alignment(t)))
1602+
trap_if(ptr + length > len(task.opts.memory))
1603+
s = self.inst.writable_streams.get(wsi)
1604+
trap_if(s.t != t)
1605+
trap_if(s.writer is not None)
1606+
s.writer = Buffer(task, ptr, length)
1607+
while s.writer:
1608+
if not s.active:
1609+
return pack_final_result(length - s.writer.length)
1610+
if not s.reader:
1611+
# TODO: this is all wrong
1612+
if not sync:
1613+
return [length - s.writer.length]
1614+
task.inst.calling_sync_import = True
1615+
await ... TODO
1616+
task.inst.calling_sync_import = False
1617+
copy_writer_to_reader(s)
1618+
return [length]
1619+
1620+
def pack_final_result(nwritten):
1621+
assert(nwritten < 2**31)
1622+
return [ (1 << 31) | nwritten ]
1623+
1624+
def copy_writer_to_reader(writer, reader):
1625+
copy_length = min(writer.length, reader.length)
1626+
v = lift_list_after_checks(writer.cx, writer.ptr, copy_length, t)
1627+
lower_list_after_checks(reader.cx, v, reader.ptr, t)
1628+
writer.ptr += copy_length * elem_size(t)
1629+
reader.ptr += copy_length * elem_size(t)
1630+
writer.length -= copy_length
1631+
reader.length -= copy_length
1632+
if writer.length == 0:
1633+
TODO
1634+
if reader.length == 0:
1635+
TODO
1636+
1637+
### 🔀 `canon stream.close`
1638+
1639+
async def canon_stream_close(task, wsi):
1640+
# TODO
1641+
pass
1642+
1643+
### 🔀 `canon stream.read`
1644+
1645+
async def canon_stream_read(sync, task, rsi, ptr, length, partial):
1646+
trap_if(read_length == 0)
1647+
trap_if(ptr != align_to(ptr, alignment(t)))
1648+
trap_if(ptr + read_length > len(task.opts.memory))
1649+
# TODO
1650+
pass
1651+
1652+
### 🔀 `canon stream.drop`
1653+
1654+
async def canon_stream_drop(task, rsi):
1655+
# TODO
1656+
pass

0 commit comments

Comments
 (0)