10
10
import copy
11
11
import enum
12
12
import unittest
13
+ from typing import Tuple , Union
13
14
from unittest .mock import MagicMock
14
15
15
16
import torch
16
17
17
18
from torchrec .distributed .embedding_types import EmbeddingComputeKernel
18
- from torchrec .distributed .test_utils .test_model import ModelInput , TestNegSamplingModule
19
+ from torchrec .distributed .test_utils .test_model import (
20
+ ModelInput ,
21
+ TestNegSamplingModule ,
22
+ TestSparseNN ,
23
+ )
19
24
from torchrec .distributed .train_pipeline .pipeline_context import TrainPipelineContext
20
25
from torchrec .distributed .train_pipeline .runtime_forwards import PipelinedForward
21
-
22
26
from torchrec .distributed .train_pipeline .tests .test_train_pipelines_base import (
23
27
TrainPipelineSparseDistTestBase ,
24
28
)
25
- from torchrec .distributed .train_pipeline .tracing import (
26
- ArgInfo ,
27
- ArgInfoStepFactory ,
28
- CallArgs ,
29
- NodeArgsHelper ,
30
- PipelinedPostproc ,
31
- )
29
+ from torchrec .distributed .train_pipeline .tracing import CallArgs , PipelinedPostproc
32
30
from torchrec .distributed .train_pipeline .utils import _rewrite_model
33
31
from torchrec .distributed .types import ShardingType
34
32
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
@@ -40,6 +38,15 @@ class ModelType(enum.Enum):
40
38
PIPELINED = "pipelined"
41
39
42
40
41
+ @torch .fx .wrap
42
+ def enrich_hstu_features (
43
+ kjt : KeyedJaggedTensor , hstu_factor : float
44
+ ) -> KeyedJaggedTensor :
45
+ if kjt ._weights is not None :
46
+ kjt ._weights *= hstu_factor
47
+ return kjt
48
+
49
+
43
50
class TrainPipelineUtilsTest (TrainPipelineSparseDistTestBase ):
44
51
# pyre-fixme[56]: Pyre was not able to infer the type of argument
45
52
@unittest .skipIf (
@@ -257,3 +264,115 @@ def test_restore_from_snapshot(self) -> None:
257
264
]
258
265
for source_model_type , recipient_model_type in variants :
259
266
self ._test_restore_from_snapshot (source_model_type , recipient_model_type )
267
+
268
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
269
+ @unittest .skipIf (
270
+ not torch .cuda .is_available (),
271
+ "Not enough GPUs, this test requires at least one GPU" ,
272
+ )
273
+ def test_rewrite_model_with_fx_wrap (self ) -> None :
274
+ sharding_type = ShardingType .TABLE_WISE .value
275
+ kernel_type = EmbeddingComputeKernel .FUSED .value
276
+ fused_params = {}
277
+
278
+ class TestPostProcModule (torch .nn .Module ):
279
+ def __init__ (self , f : float ):
280
+ super ().__init__ ()
281
+ self .f = f
282
+
283
+ def forward (self , x : KeyedJaggedTensor ) -> KeyedJaggedTensor :
284
+ return enrich_hstu_features (x , self .f )
285
+
286
+ postproc_module = TestPostProcModule (0.3 )
287
+
288
+ class TestModel (TestSparseNN ):
289
+ use_postproc_module : bool = False
290
+
291
+ def forward (
292
+ self ,
293
+ input : ModelInput ,
294
+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
295
+ if (type (self )).use_postproc_module :
296
+ input = self .postproc_module (input )
297
+ else :
298
+ input = enrich_hstu_features (input , 0.3 )
299
+ return self .dense_forward (input , self .sparse_forward (input ))
300
+
301
+ model = TestModel (
302
+ tables = self .tables ,
303
+ weighted_tables = self .weighted_tables ,
304
+ dense_device = self .device ,
305
+ sparse_device = torch .device ("meta" ),
306
+ postproc_module = postproc_module ,
307
+ )
308
+
309
+ sharded_model , optim = self ._generate_sharded_model_and_optimizer (
310
+ model , sharding_type , kernel_type , fused_params
311
+ )
312
+
313
+ # Try to rewrite model using a function for postproc
314
+ # EBC forwards not overwritten to PipelinedForward due to KJT modification
315
+ self .assertFalse (model .use_postproc_module )
316
+ _rewrite_model (
317
+ model = sharded_model ,
318
+ batch = None ,
319
+ context = TrainPipelineContext (),
320
+ dist_stream = None ,
321
+ pipeline_postproc = True ,
322
+ )
323
+ self .assertNotIsInstance (
324
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
325
+ # `sparse`.
326
+ sharded_model .module .sparse .ebc .forward ,
327
+ PipelinedForward ,
328
+ )
329
+ self .assertNotIsInstance (
330
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
331
+ # `sparse`.
332
+ sharded_model .module .sparse .weighted_ebc .forward ,
333
+ PipelinedForward ,
334
+ )
335
+
336
+ # Now use postproc module
337
+ TestModel .use_postproc_module = True
338
+ self .assertTrue (model .use_postproc_module )
339
+ _rewrite_model (
340
+ model = sharded_model ,
341
+ batch = None ,
342
+ context = TrainPipelineContext (),
343
+ dist_stream = None ,
344
+ pipeline_postproc = True ,
345
+ )
346
+
347
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`.
348
+ self .assertIsInstance (sharded_model .module .sparse .ebc .forward , PipelinedForward )
349
+ self .assertIsInstance (
350
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
351
+ # `sparse`.
352
+ sharded_model .module .sparse .weighted_ebc .forward ,
353
+ PipelinedForward ,
354
+ )
355
+ self .assertEqual (
356
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
357
+ # `sparse`.
358
+ sharded_model .module .sparse .ebc .forward ._args .args [0 ]
359
+ .steps [0 ]
360
+ .postproc_module ,
361
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
362
+ # `postproc_module`.
363
+ sharded_model .module .postproc_module ,
364
+ )
365
+ self .assertEqual (
366
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
367
+ # `sparse`.
368
+ sharded_model .module .sparse .weighted_ebc .forward ._args .args [0 ]
369
+ .steps [0 ]
370
+ .postproc_module ,
371
+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
372
+ # `postproc_module`.
373
+ sharded_model .module .postproc_module ,
374
+ )
375
+ state_dict = sharded_model .state_dict ()
376
+ missing_keys , unexpected_keys = sharded_model .load_state_dict (state_dict )
377
+ self .assertEqual (missing_keys , [])
378
+ self .assertEqual (unexpected_keys , [])
0 commit comments