2929 Type ,
3030 TypeVar ,
3131 TYPE_CHECKING ,
32+ Union ,
3233)
3334
3435import numpy as np
@@ -93,21 +94,27 @@ def __init__(
9394 * ,
9495 dtype : Type [np .complexfloating ] = np .complex64 ,
9596 noise : 'cirq.NOISE_MODEL_LIKE' = None ,
96- seed : 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None ,
97+ seed : Optional [ Union [ int , np . random . Generator , np . random . RandomState ]] = None ,
9798 split_untangled_states : bool = False ,
9899 ):
99100 """Initializes the simulator.
100101
101102 Args:
102103 dtype: The `numpy.dtype` used by the simulation.
103104 noise: A noise model to apply while simulating.
104- seed: The random seed to use for this simulator.
105+ seed: The random seed or generator to use for this simulator.
105106 split_untangled_states: If True, optimizes simulation by running
106107 unentangled qubit sets independently and merging those states
107108 at the end.
108109 """
109110 self ._dtype = dtype
110- self ._prng = value .parse_random_state (seed )
111+ if isinstance (seed , np .random .RandomState ):
112+ # Convert RandomState to Generator for backward compatibility
113+ self ._prng = np .random .default_rng (seed .get_state ()[1 ][0 ])
114+ elif isinstance (seed , np .random .Generator ):
115+ self ._prng = seed
116+ else :
117+ self ._prng = np .random .default_rng (seed )
111118 self ._noise = devices .NoiseModel .from_noise_model_like (noise )
112119 self ._split_untangled_states = split_untangled_states
113120
@@ -228,6 +235,7 @@ def _run(
228235 circuit : 'cirq.AbstractCircuit' ,
229236 param_resolver : 'cirq.ParamResolver' ,
230237 repetitions : int ,
238+ rng : Optional [np .random .Generator ] = None ,
231239 ) -> Dict [str , np .ndarray ]:
232240 """See definition in `cirq.SimulatesSamples`."""
233241 param_resolver = param_resolver or study .ParamResolver ({})
@@ -254,7 +262,10 @@ def _run(
254262 assert step_result is not None
255263 measurement_ops = [cast (ops .GateOperation , op ) for op in general_ops ]
256264 return step_result .sample_measurement_ops (
257- measurement_ops , repetitions , seed = self ._prng , _allow_repeated = True
265+ measurement_ops ,
266+ repetitions ,
267+ seed = rng if rng is not None else self ._prng ,
268+ _allow_repeated = True ,
258269 )
259270
260271 records : Dict ['cirq.MeasurementKey' , List [Sequence [Sequence [int ]]]] = {}
@@ -395,9 +406,15 @@ def sample(
395406 self ,
396407 qubits : List ['cirq.Qid' ],
397408 repetitions : int = 1 ,
398- seed : 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None ,
409+ seed : Optional [ Union [ int , np . random . Generator , np . random . RandomState ]] = None ,
399410 ) -> np .ndarray :
400- return self ._sim_state .sample (qubits , repetitions , seed )
411+ if isinstance (seed , np .random .RandomState ):
412+ rng = np .random .default_rng (seed .get_state ()[1 ][0 ])
413+ elif isinstance (seed , np .random .Generator ):
414+ rng = seed
415+ else :
416+ rng = np .random .default_rng (seed )
417+ return self ._sim_state .sample (qubits , repetitions , rng )
401418
402419
403420class SimulationTrialResultBase (
0 commit comments