File tree Expand file tree Collapse file tree 3 files changed +17
-4
lines changed
Expand file tree Collapse file tree 3 files changed +17
-4
lines changed Original file line number Diff line number Diff line change @@ -497,7 +497,7 @@ def get_visible_devices():
497497 )
498498 for v in device_vars :
499499 try :
500- return tuple (int (i ) for i in os .environ [v ].split (',' ))
500+ return v , tuple (int (i ) for i in os .environ [v ].split (',' ))
501501 except ValueError :
502502 # Visible devices set via UUIDs or other non-integer identifiers.
503503 warning ("Setting visible devices via UUIDs or other non-integer"
@@ -507,7 +507,7 @@ def get_visible_devices():
507507 # Environment variable not set
508508 continue
509509
510- return None
510+ return None , None
511511
512512
513513@memoized_func
Original file line number Diff line number Diff line change @@ -1398,11 +1398,18 @@ def _physical_deviceid(self):
13981398 rank = self .comm .Get_rank () if self .comm != MPI .COMM_NULL else 0
13991399 logical_deviceid = rank
14001400
1401- visible_devices = get_visible_devices ()
1401+ visible_device_var , visible_devices = get_visible_devices ()
14021402 if visible_devices is None :
14031403 return logical_deviceid
14041404 else :
1405- return visible_devices [logical_deviceid ]
1405+ try :
1406+ return visible_devices [logical_deviceid ]
1407+ except IndexError :
1408+ errmsg = (f"A deviceid value of { logical_deviceid } is not valid "
1409+ f"with { visible_device_var } ={ visible_devices } . Note that "
1410+ "deviceid corresponds to the logical index within the "
1411+ "visible devices, not the physical device index." )
1412+ raise ValueError (errmsg )
14061413 else :
14071414 return None
14081415
Original file line number Diff line number Diff line change @@ -149,6 +149,12 @@ def test_visible_devices_with_devito_deviceid(self):
149149 # So should be the second of the two visible devices specified (3)
150150 assert argmap ._physical_deviceid == 3
151151
152+ with switchenv ({'CUDA_VISIBLE_DEVICES' : "1" }), switchconfig (deviceid = 0 ):
153+ op1 = Operator (eq )
154+
155+ argmap1 = op1 .arguments ()
156+ assert argmap1 ._physical_deviceid == 1
157+
152158 @pytest .mark .parallel (mode = 2 )
153159 def test_deviceid_per_rank (self , mode ):
154160 """
You can’t perform that action at this time.
0 commit comments