@@ -177,6 +177,10 @@ def compute(
177
177
ctx : ITEPEmbeddingBagCollectionContext ,
178
178
dist_input : KJTList ,
179
179
) -> List [torch .Tensor ]:
180
+ # We need to explicitly move iter to CPU since it might be moved to GPU
181
+ # after __init__. This should be done once.
182
+ self ._iter = self ._iter .cpu ()
183
+
180
184
if not ctx .is_reindexed :
181
185
dist_input = self ._reindex (dist_input )
182
186
ctx .is_reindexed = True
@@ -196,6 +200,10 @@ def output_dist(
196
200
def compute_and_output_dist (
197
201
self , ctx : ITEPEmbeddingBagCollectionContext , input : KJTList
198
202
) -> LazyAwaitable [KeyedTensor ]:
203
+ # We need to explicitly move iter to CPU since it might be moved to GPU
204
+ # after __init__. This should be done once.
205
+ self ._iter = self ._iter .cpu ()
206
+
199
207
# Insert forward() function of GenericITEPModule into compute_and_output_dist()
200
208
for i , (sharding , features ) in enumerate (
201
209
zip (
@@ -424,6 +432,10 @@ def compute(
424
432
ctx : ITEPEmbeddingCollectionContext ,
425
433
dist_input : KJTList ,
426
434
) -> List [torch .Tensor ]:
435
+ # We need to explicitly move iter to CPU since it might be moved to GPU
436
+ # after __init__. This should be done once.
437
+ self ._iter = self ._iter .cpu ()
438
+
427
439
for i , (sharding , features ) in enumerate (
428
440
zip (
429
441
self ._embedding_collection ._sharding_type_to_sharding .keys (),
@@ -450,6 +462,10 @@ def output_dist(
450
462
def compute_and_output_dist (
451
463
self , ctx : ITEPEmbeddingCollectionContext , input : KJTList
452
464
) -> LazyAwaitable [Dict [str , JaggedTensor ]]:
465
+ # We need to explicitly move iter to CPU since it might be moved to GPU
466
+ # after __init__. This should be done once.
467
+ self ._iter = self ._iter .cpu ()
468
+
453
469
# Insert forward() function of GenericITEPModule into compute_and_output_dist()
454
470
""" """
455
471
for i , (sharding , features ) in enumerate (
0 commit comments