forked from facebook/redex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_simple_module.py
183 lines (132 loc) · 4.11 KB
/
gen_simple_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import base64
import io
import logging
import os
import zipfile
from collections import namedtuple
try:
import lzma # noqa(F401)
import tarfile
has_tar_lzma = True
except ImportError:
has_tar_lzma = False
Args = namedtuple("Args", ["inputs", "output", "tarxz"])
def parse_args():
parser = argparse.ArgumentParser(
description="Generate simple module with the given file"
)
parser.add_argument("args", nargs="+", help="name=filename list")
parser.add_argument(
"-o",
"--out",
nargs=1,
type=os.path.realpath,
help="Generated python wrapper",
)
parser.add_argument(
"--force-zip",
action="store_true",
help="Force the use of zip, even when tar and lzma are available",
)
args = parser.parse_args()
global has_tar_lzma
return Args(args.args, args.out[0], has_tar_lzma and not args.force_zip)
def compress_zip(inputs):
logging.info("Compressing as zip")
buf = io.BytesIO(b"")
with zipfile.ZipFile(buf, "w") as zf:
for input in inputs:
logging.info("Adding %s", input)
with open(input, "rb") as f:
zf.writestr(os.path.basename(input), f)
buf.seek(0)
return buf
def compress_tar_xz(inputs):
logging.info("Compressing as tar.xz")
buf = io.BytesIO(b"")
tar = tarfile.open(fileobj=buf, mode="w:xz")
for input in inputs:
logging.info("Adding %s", input)
info = tar.gettarinfo(input)
info.name = os.path.basename(input)
with open(input, "rb") as f:
tar.addfile(info, fileobj=f)
tar.close()
buf.seek(0)
return buf
def compress_and_base_64(inputs, tar_xz):
with compress_tar_xz(inputs) if tar_xz else compress_zip(inputs) as buf:
return base64.b64encode(buf.getbuffer())
_FILE_TEMPLATE = """
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import base64
import io
import re
{extra_imports}
_BASE64BLOB = "{base64_blob}"
{compression_specific_api}
"""
# def get_api_level_file(level):
# name = f"framework_classes_api_{{level}}.txt"
# return _load(name)
# def get_api_levels():
# name_re = re.compile(r"^framework_classes_api_(\\d+)\\.txt$")
# return {{
# int(match.group(1)) for name in _all() for match in [name_re.match(name)] if match
# }}
_TAR_XZ_IMPORTS = """
import lzma # noqa(F401)
import tarfile
"""
_TAR_XZ_API = """
_TAR = tarfile.open(mode="r:xz", fileobj=io.BytesIO(base64.b64decode(_BASE64BLOB)))
def _load(name):
global _TAR
return _TAR.extractfile(name).read()
def _all():
global _TAR
return _TAR.getnames()
"""
_ZIP_IMPORTS = "import zipfile"
_ZIP_API = """
_ZIP = zipfile.ZipFile(io.BytesIO(base64.b64decode(_BASE64BLOB)), "r")
def _load(name):
global _ZIP
return _ZIP.read(names)
def _all():
global _ZIP
return _ZIP.namelist()
"""
def write_py_wrapper(base_64_bytes_blob, files, filename, tar_xz):
base64_str = base_64_bytes_blob.decode("ascii")
with open(filename, "w") as f:
f.write(
_FILE_TEMPLATE.format(
extra_imports=_TAR_XZ_IMPORTS if tar_xz else _ZIP_IMPORTS,
base64_blob=base64_str,
compression_specific_api=_TAR_XZ_API if tar_xz else _ZIP_API,
)
)
for key, val in files.items():
f.write(f'{key} = _load("{os.path.basename(val)}")\n')
def main():
args = parse_args()
files = {
key_val[: key_val.find("=")]: key_val[key_val.find("=") + 1 :]
for key_val in args.inputs
}
base64_blob = compress_and_base_64(files.values(), args.tarxz)
write_py_wrapper(base64_blob, files, args.output, args.tarxz)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()