-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathspiro.py
140 lines (110 loc) · 4.53 KB
/
spiro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# pylint: disable=redefined-builtin,unspecified-encoding
import io
import pickle
import pickletools as pt
from dataclasses import dataclass
from struct import pack, unpack
from typing import Any, Optional, TypeVar
import torch
from fickling import pickle as p
from fickling.pickle import Pickled
@dataclass
class GetPlaceholder:
name: str | bytes | None
old: Optional[p.Opcode] = None
def __post_init__(self) -> None:
self.data = self.old.data if self.old else b"hh"
@dataclass
class MemoPlaceholder:
name: str
data = b"\x94" # just for length counts
Opcodes = list[p.Opcode]
T = TypeVar("T")
class Variables:
"human names for memory indexes"
def __init__(self, counter_start: int = 0):
self.memory_counter = counter_start
self.memo_indexes: dict[str | bytes | int, int] = {}
# self.varname_counter = 0
# self.ids: dict[str, str | bytes | int] = {}
def assign(self, name: str | bytes | int, id: Optional[str] = None) -> p.Memoize:
self.memo_indexes[name] = self.memory_counter
self.memory_counter += 1
# if id:
# self.ids[id] = name
# else:
# self.ids[f"_var{self.varname_counter}"] = name
# self.varname_counter += 1
return p.Memoize()
def __getitem__(self, name: str | int | bytes) -> p.BinGet | p.LongBinGet:
memo_index = self.memo_indexes[name]
return make_get(memo_index)
# # should be used for show in debugger but whatever
# def gloss(self, varname: str) -> str | bytes | int:
# return self.ids[varname]
class PlaceholderVariables:
# programs should use this guy
def assign(self, name: str) -> MemoPlaceholder:
return MemoPlaceholder(name)
def __getitem__(self, name: str) -> GetPlaceholder:
return GetPlaceholder(name)
def find_main_pickle(ckpt: str | Any, magic=False) -> tuple[bytes, bytes, bytes]:
"get first bytes, result pickle, last bytes from a torch object or ckpt path"
if isinstance(ckpt, str):
model = torch.load(ckpt) # type: ignore
else:
model = ckpt
buf = io.BytesIO()
# we want protocol 4 and none of the zipfile stuff
# that said fickling has an example of how to use zipfiles
torch.save(model, buf, _use_new_zipfile_serialization=False, pickle_protocol=4)
buf.seek(0)
# https://github.com/pytorch/pytorch/blob/master/torch/serialization.py#L1012-L1024
# a torch.save is:
# 1. magic number
# 2. protocol version
# 3. sys info
# 4. real obj (persistent id) / "result" <- fuck here
# 5. (de)serialized storage keys
# 6. not pickle data, read by THPStorage_setFromFile to set storages from these
# let's find the right parts of the ckpt
# pt.dis will read the buffer up until STOP
# we don't care about the dis, just the indexes
# discard the first three pickles
# for reference, a careless scanner might not scan these,
# they could be good targets
# otoh a good scanner would flag any reduce/global here
devnull = open("/dev/null", "w")
if not magic:
pt.dis(buf, devnull) # magic number
pt.dis(buf, devnull) # protocol version
pt.dis(buf, devnull) # sys info
# figure out where the pickle we want starts/stops
result_start = buf.tell()
pt.dis(buf, devnull)
result_end = buf.tell()
buf.seek(result_start)
main_bytes = buf.read(result_end - result_start) # might be an off by one?
last_bytes = buf.read()
buf.seek(0)
first_bytes = buf.read(result_start)
buf.seek(0)
assert first_bytes + main_bytes + last_bytes == buf.read()
return (first_bytes, main_bytes, last_bytes)
def get_index(op: p.BinGet | p.LongBinGet) -> int:
# https://github.com/python/cpython/blob/3.9/Lib/pickle.py#L528-L531
if isinstance(op, p.BinGet):
return unpack("<B", op.data[1:])[0]
return unpack("<I", op.data[1:])[0]
def make_get(memo_index: int) -> p.BinGet | p.LongBinGet:
if memo_index < 256:
return p.BinGet(data=pickle.BINGET + pack("<B", memo_index))
return p.LongBinGet(data=pickle.LONG_BINGET + pack("<I", memo_index))
def change_frame_len(frame: p.Frame, length_change: int) -> p.Frame:
# pickle targets frames being under 64 * 1024
# https://github.com/python/cpython/blob/3.9/Lib/pickle.py#L228
frame_len = unpack("<Q", frame.data[1:])[0]
frame.data = pickle.FRAME + pack("<Q", frame_len + length_change)
return frame
def count_ops(ops: Opcodes | Pickled, op_type: type) -> int:
return len([op for op in ops if isinstance(op, op_type)])