18
18
from collections .abc import Mapping
19
19
from collections .abc import Sequence as abcSequence
20
20
from itertools import combinations
21
- from typing import TYPE_CHECKING , Optional
21
+ from typing import TYPE_CHECKING , Any , Optional , Tuple , cast
22
22
23
23
import matplotlib .pyplot as plt
24
24
import numpy as np
@@ -48,15 +48,15 @@ def default_qubit_color() -> str:
48
48
49
49
@staticmethod
50
50
def _draw_2D (
51
- ax : plt .axes . _subplots . AxesSubplot ,
51
+ ax : plt .Axes ,
52
52
pos : np .ndarray ,
53
53
ids : abcSequence [QubitId ],
54
54
plane : tuple = (0 , 1 ),
55
55
with_labels : bool = True ,
56
56
blockade_radius : Optional [float ] = None ,
57
57
draw_graph : bool = True ,
58
58
draw_half_radius : bool = False ,
59
- qubit_colors : Mapping [QubitId , str ] = dict (),
59
+ qubit_colors : Mapping [QubitId , Any ] = dict (),
60
60
masked_qubits : set [QubitId ] = set (),
61
61
are_traps : bool = False ,
62
62
dmm_qubits : Mapping [QubitId , float ] = {},
@@ -68,6 +68,7 @@ def _draw_2D(
68
68
69
69
ix , iy = plane
70
70
71
+ params : dict [str , Any ]
71
72
if are_traps :
72
73
params = dict (
73
74
s = 50 ,
@@ -107,7 +108,7 @@ def _draw_2D(
107
108
dmm_arr [:, iy ],
108
109
marker = "s" ,
109
110
s = 1200 ,
110
- alpha = alpha ,
111
+ alpha = alpha , # type: ignore[arg-type]
111
112
c = "black" if not qubit_colors else ordered_qubit_colors ,
112
113
)
113
114
axes = "xyz"
@@ -194,12 +195,12 @@ def _draw_2D(
194
195
fontsize = 12 if i not in final_plot_det_map else 8.3 ,
195
196
multialignment = "right" ,
196
197
)
197
- txt ._get_wrap_line_width = lambda : 50.0
198
+ txt ._get_wrap_line_width = lambda : 50.0 # type: ignore
198
199
199
200
if draw_half_radius and blockade_radius is not None :
200
201
for p , color in zip (pos , ordered_qubit_colors ):
201
202
circle = plt .Circle (
202
- tuple (p [[ ix , iy ] ]),
203
+ (p [ix ], p [ iy ]),
203
204
blockade_radius / 2 ,
204
205
alpha = 0.1 ,
205
206
color = color ,
@@ -214,7 +215,11 @@ def _draw_2D(
214
215
lines = bonds [:, :, (ix , iy )]
215
216
else :
216
217
lines = np .array ([])
217
- lc = mc .LineCollection (lines , linewidths = 0.6 , colors = "grey" )
218
+ lc = mc .LineCollection (
219
+ cast (abcSequence [np .ndarray ], lines ),
220
+ linewidths = 0.6 ,
221
+ colors = "grey" ,
222
+ )
218
223
ax .add_collection (lc )
219
224
220
225
else :
@@ -258,9 +263,14 @@ def _draw_3D(
258
263
blockade_radius = blockade_radius ,
259
264
draw_half_radius = draw_half_radius ,
260
265
)
261
- fig .get_layout_engine ().set (w_pad = 6.5 )
262
-
263
- for ax , (ix , iy ) in zip (axes , combinations (np .arange (3 ), 2 )):
266
+ _layout_engine = fig .get_layout_engine ()
267
+ assert _layout_engine is not None
268
+ _layout_engine .set (w_pad = 6.5 ) # type: ignore[call-arg]
269
+
270
+ for ax , (ix , iy ) in zip (
271
+ cast (abcSequence [plt .Axes ], axes ),
272
+ combinations (np .arange (3 ), 2 ),
273
+ ):
264
274
RegDrawer ._draw_2D (
265
275
ax ,
266
276
pos ,
@@ -284,7 +294,12 @@ def _draw_3D(
284
294
)
285
295
286
296
else :
287
- fig = plt .figure (figsize = 2 * plt .figaspect (0.5 ))
297
+ fig = plt .figure (
298
+ figsize = cast (
299
+ Tuple [float , float ],
300
+ tuple (2 * np .array (plt .figaspect (0.5 ))),
301
+ )
302
+ )
288
303
289
304
if draw_graph and blockade_radius is not None :
290
305
bonds = {}
@@ -293,6 +308,7 @@ def _draw_3D(
293
308
xj , yj , zj = pos [j ]
294
309
bonds [(i , j )] = [[xi , xj ], [yi , yj ], [zi , zj ]]
295
310
311
+ params : dict [str , Any ]
296
312
if are_traps :
297
313
params = dict (s = 50 , c = "white" , edgecolors = "black" )
298
314
else :
@@ -313,7 +329,7 @@ def _draw_3D(
313
329
coords [0 ],
314
330
coords [1 ],
315
331
coords [2 ],
316
- q ,
332
+ q , # type: ignore[arg-type]
317
333
fontsize = 12 ,
318
334
ha = "left" ,
319
335
va = "bottom" ,
@@ -336,15 +352,21 @@ def _draw_3D(
336
352
y = radius * np .sin (u ) * np .sin (v ) + y0
337
353
z = radius * np .cos (v ) + z0
338
354
# alpha controls opacity
339
- ax .plot_surface (x , y , z , color = color , alpha = 0.1 )
355
+ ax .plot_surface ( # type: ignore[attr-defined]
356
+ x ,
357
+ y ,
358
+ z ,
359
+ color = color ,
360
+ alpha = 0.1 ,
361
+ )
340
362
341
363
if draw_graph and blockade_radius is not None :
342
364
for x , y , z in bonds .values ():
343
365
ax .plot (x , y , z , linewidth = 1.5 , color = "grey" )
344
366
345
367
ax .set_xlabel ("x (µm)" )
346
368
ax .set_ylabel ("y (µm)" )
347
- ax .set_zlabel ("z (µm)" )
369
+ ax .set_zlabel ("z (µm)" ) # type: ignore[attr-defined]
348
370
349
371
@staticmethod
350
372
def _register_dims (
@@ -367,7 +389,7 @@ def _initialize_fig_axes(
367
389
blockade_radius : Optional [float ] = None ,
368
390
draw_half_radius : bool = False ,
369
391
nregisters : int = 1 ,
370
- ) -> tuple [plt .figure . Figure , plt .axes .Axes ]:
392
+ ) -> tuple [plt .Figure , plt .Axes | abcSequence [ plt .Axes ] ]:
371
393
"""Creates the Figure and Axes for drawing the register."""
372
394
diffs = RegDrawer ._register_dims (
373
395
pos ,
@@ -394,7 +416,9 @@ def _initialize_fig_axes_projection(
394
416
blockade_radius : Optional [float ] = None ,
395
417
draw_half_radius : bool = False ,
396
418
nregisters : int = 1 ,
397
- ) -> tuple [plt .figure .Figure , plt .axes .Axes ]:
419
+ ) -> tuple [
420
+ plt .Figure , abcSequence [plt .Axes ] | abcSequence [abcSequence [plt .Axes ]]
421
+ ]:
398
422
"""Creates the Figure and Axes for drawing the register projections."""
399
423
diffs = RegDrawer ._register_dims (
400
424
pos ,
0 commit comments