1
1
import configparser
2
+ import hashlib
2
3
import re
3
4
import sys
4
5
from io import BytesIO
5
6
from pathlib import Path
6
7
from types import ModuleType
7
- from typing import Callable , List , Optional , Union
8
+ from typing import Any , Callable , List , Optional , Union
8
9
from zipfile import ZipFile
9
- import hashlib
10
10
11
11
from vcap import BaseCapsule , BaseBackend , BaseStreamState , NodeDescription
12
12
from vcap .loading .errors import IncompatibleCapsuleError , InvalidCapsuleError
@@ -200,19 +200,21 @@ def check_arg_names(func: Callable,
200
200
capsule .name )
201
201
202
202
# Validate the capsule class attributes
203
- capsule_assertions = [
204
- isinstance (capsule .name , str ),
205
- callable (capsule .backend_loader ),
206
- isinstance (capsule .version , int ),
207
- isinstance (capsule .input_type , NodeDescription ),
208
- isinstance (capsule .output_type , NodeDescription ),
209
- isinstance (capsule .options , dict )]
210
-
211
- if not all (capsule_assertions ):
212
- raise InvalidCapsuleError (
213
- f"The capsule has an invalid internal configuration!\n " +
214
- f"Capsule Assertions: { capsule_assertions } " ,
215
- capsule .name )
203
+ _validate_capsule_field (capsule , "name" , capsule .name , str )
204
+ _validate_capsule_field (capsule ,
205
+ "backend_loader" ,
206
+ capsule .backend_loader ,
207
+ callable )
208
+ _validate_capsule_field (capsule , "version" , capsule .version , int )
209
+ _validate_capsule_field (capsule ,
210
+ "input_type" ,
211
+ capsule .input_type ,
212
+ NodeDescription )
213
+ _validate_capsule_field (capsule ,
214
+ "output_type" ,
215
+ capsule .input_type ,
216
+ NodeDescription )
217
+ _validate_capsule_field (capsule , "options" , capsule .options , dict )
216
218
217
219
# Make sure that certain things are NOT attributes (we don't want
218
220
# accidental leftover code from previous capsule versions)
@@ -230,40 +232,42 @@ def check_arg_names(func: Callable,
230
232
pass
231
233
232
234
# Validate the capsule's backend_loader takes the right args
233
- loader = capsule .backend_loader
234
- loader_assertions = [
235
- callable (loader ),
236
- check_arg_names (func = loader , correct = ["capsule_files" , "device" ])]
237
- if not all (loader_assertions ):
235
+ correct_args = ["capsule_files" , "device" ]
236
+ is_correct = check_arg_names (func = capsule .backend_loader ,
237
+ correct = correct_args )
238
+ if not is_correct :
238
239
raise InvalidCapsuleError (
239
- f"The capsule's backend_loader has an invalid configuration !\n "
240
- f"Loader Assertions: { loader_assertions } " ,
240
+ f"The capsule's backend_loader has invalid arguments !\n "
241
+ f"The arguments must be { correct_args } " ,
241
242
capsule .name )
242
243
243
244
# Validate the backend class attributes
244
245
if capsule .backends is not None :
245
246
backend = capsule .backends [0 ]
246
- backend_assertions = [
247
- callable (backend .batch_predict ),
248
- callable (backend .process_frame ),
249
- callable (backend .close ),
250
- isinstance (capsule .backends [0 ], BaseBackend )]
251
- if not all (backend_assertions ):
247
+ _validate_backend_field (capsule ,
248
+ "batch_predict" ,
249
+ backend .batch_predict ,
250
+ callable )
251
+ _validate_backend_field (capsule ,
252
+ "process_frame" ,
253
+ backend .process_frame ,
254
+ callable )
255
+ _validate_backend_field (capsule ,
256
+ "close" ,
257
+ backend .close ,
258
+ callable )
259
+ if not isinstance (capsule .backends [0 ], BaseBackend ):
252
260
raise InvalidCapsuleError (
253
- f"The capsule's backend has an invalid configuration! \n "
254
- f"Backend Assertions: { backend_assertions } " ,
261
+ f"The capsule's backend field must be a class that "
262
+ f"subclasses { BaseBackend . __name__ } " ,
255
263
capsule .name )
256
264
257
265
# Validate the stream state
258
266
stream_state = capsule .stream_state
259
- stream_state_assertions = [
260
- (stream_state is BaseStreamState or
261
- BaseStreamState in stream_state .__bases__ )]
262
- if not all (stream_state_assertions ):
267
+ if not issubclass (stream_state , BaseStreamState ):
263
268
raise InvalidCapsuleError (
264
- "The capsule's stream_state has an invalid configuration!\n "
265
- f"Stream State Assertions: { stream_state_assertions } " ,
266
- capsule .name )
269
+ f"The capsule's stream_state field must be a subclass of "
270
+ f"{ BaseStreamState .__name__ } , got { stream_state } " )
267
271
268
272
# Validate that if the capsule is an encoder, it has a threshold option
269
273
if capsule .capability .encoded :
@@ -288,3 +292,40 @@ def capsule_module_name(data: bytes) -> str:
288
292
"""Creates a unique module name for the given capsule bytes"""
289
293
hash_ = hashlib .sha256 (data ).hexdigest ()
290
294
return f"capsule_{ hash_ } "
295
+
296
+
297
+ _TYPE_CALLABLE = Union [type , callable ]
298
+
299
+
300
+ def _validate_capsule_field (capsule : BaseCapsule ,
301
+ name : str ,
302
+ value : Any ,
303
+ type_ : _TYPE_CALLABLE ) -> None :
304
+ if type_ is callable :
305
+ if not callable (value ):
306
+ raise InvalidCapsuleError (
307
+ f"The capsule has an invalid internal configuration!\n "
308
+ f"Capsule field { name } must be callable, got { type (value )} " ,
309
+ capsule .name )
310
+ elif not isinstance (value , type_ ):
311
+ raise InvalidCapsuleError (
312
+ f"The capsule has an invalid internal configuration!\n "
313
+ f"Capsule field { name } must be of type { type_ } , got { type (value )} " ,
314
+ capsule .name )
315
+
316
+
317
+ def _validate_backend_field (capsule : BaseCapsule ,
318
+ name : str ,
319
+ value : Any ,
320
+ type_ : _TYPE_CALLABLE ) -> None :
321
+ if type_ is callable :
322
+ if not callable (value ):
323
+ raise InvalidCapsuleError (
324
+ f"The capsule's backend has an invalid configuration!\n "
325
+ f"Backend field { name } must be callable, got { type (value )} " ,
326
+ capsule .name )
327
+ elif not isinstance (value , type_ ):
328
+ raise InvalidCapsuleError (
329
+ f"The capsule's backend has an invalid configuration!\n "
330
+ f"Backend field { name } must be of type { type_ } , got { type (value )} " ,
331
+ capsule .name )
0 commit comments