@@ -688,241 +688,3 @@ def test_dynamically_switch_inference_training_mode(self) -> None:
688
688
self .assertTrue (m ._is_inference )
689
689
self .assertTrue (m ._eviction_policy_name is None )
690
690
self .assertTrue (m ._eviction_module is None )
691
-
692
- # Pyre-ignore [56]: Pyre was not able to infer the type of argument `torch.cuda.device_count() < 1` to decorator factory `unittest.skipIf`
693
- @unittest .skipIf (
694
- torch .cuda .device_count () < 1 ,
695
- "Not enough GPUs, this test requires at least two GPUs" ,
696
- )
697
- def test_zch_hash_disable_fallback (self ) -> None :
698
- m = HashZchManagedCollisionModule (
699
- zch_size = 30 ,
700
- device = torch .device ("cuda" ),
701
- total_num_buckets = 2 ,
702
- eviction_policy_name = HashZchEvictionPolicyName .SINGLE_TTL_EVICTION ,
703
- eviction_config = HashZchEvictionConfig (
704
- features = [],
705
- single_ttl = 10 ,
706
- ),
707
- max_probe = 4 ,
708
- disable_fallback = True ,
709
- start_bucket = 1 ,
710
- output_segments = [0 , 10 , 20 ],
711
- )
712
- jt = JaggedTensor (
713
- values = torch .arange (0 , 4 , dtype = torch .int64 , device = "cuda" ),
714
- lengths = torch .tensor ([1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda" ),
715
- )
716
- # Run once to insert ids
717
- output0 = m .remap ({"test" : jt })
718
- self .assertTrue (
719
- torch .equal (
720
- output0 ["test" ].values (),
721
- torch .tensor ([8 , 15 , 11 ], dtype = torch .int64 , device = "cuda:0" ),
722
- )
723
- )
724
- self .assertTrue (
725
- torch .equal (
726
- output0 ["test" ].lengths (),
727
- torch .tensor ([1 , 1 , 0 , 1 ], dtype = torch .int64 , device = "cuda:0" ),
728
- )
729
- )
730
- m .reset_inference_mode ()
731
- jt = JaggedTensor (
732
- values = torch .tensor ([9 , 0 , 1 , 4 , 6 , 8 ], dtype = torch .int64 , device = "cuda" ),
733
- lengths = torch .tensor ([1 , 1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda" ),
734
- )
735
- # Run again in inference mode and only values 0 and 1 exist.
736
- output1 = m .remap ({"test" : jt })
737
- self .assertTrue (
738
- torch .equal (
739
- output1 ["test" ].values (),
740
- torch .tensor ([8 , 15 ], dtype = torch .int64 , device = "cuda:0" ),
741
- )
742
- )
743
- self .assertTrue (
744
- torch .equal (
745
- output1 ["test" ].lengths (),
746
- torch .tensor ([0 , 1 , 1 , 0 , 0 , 0 ], dtype = torch .int64 , device = "cuda:0" ),
747
- )
748
- )
749
-
750
- m = HashZchManagedCollisionModule (
751
- zch_size = 10 ,
752
- device = torch .device ("cuda" ),
753
- total_num_buckets = 2 ,
754
- eviction_policy_name = HashZchEvictionPolicyName .SINGLE_TTL_EVICTION ,
755
- eviction_config = HashZchEvictionConfig (
756
- features = [],
757
- single_ttl = 10 ,
758
- ),
759
- max_probe = 4 ,
760
- start_bucket = 0 ,
761
- output_segments = None ,
762
- disable_fallback = True ,
763
- )
764
- jt = JaggedTensor (
765
- values = torch .arange (0 , 4 , dtype = torch .int64 , device = "cuda" ),
766
- lengths = torch .tensor ([1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda" ),
767
- )
768
- # Run once to insert ids
769
- output0 = m .remap ({"test" : jt })
770
- self .assertTrue (
771
- torch .equal (
772
- output0 ["test" ].values (),
773
- torch .tensor ([3 , 5 , 4 , 6 ], dtype = torch .int64 , device = "cuda:0" ),
774
- )
775
- )
776
- self .assertTrue (
777
- torch .equal (
778
- output0 ["test" ].lengths (),
779
- torch .tensor ([1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda:0" ),
780
- )
781
- )
782
- m .reset_inference_mode ()
783
- jt = JaggedTensor (
784
- values = torch .tensor ([9 , 0 , 1 , 4 , 6 , 8 ], dtype = torch .int64 , device = "cuda" ),
785
- lengths = torch .tensor ([1 , 1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda" ),
786
- )
787
- # Run again in inference mode and only values 0 and 1 exist.
788
- output1 = m .remap ({"test" : jt })
789
- self .assertTrue (
790
- torch .equal (
791
- output1 ["test" ].values (),
792
- torch .tensor ([3 , 5 ], dtype = torch .int64 , device = "cuda:0" ),
793
- )
794
- )
795
- self .assertTrue (
796
- torch .equal (
797
- output1 ["test" ].lengths (),
798
- torch .tensor ([0 , 1 , 1 , 0 , 0 , 0 ], dtype = torch .int64 , device = "cuda:0" ),
799
- )
800
- )
801
-
802
- # Pyre-ignore [56]: Pyre was not able to infer the type of argument `torch.cuda.device_count() < 1` to decorator factory `unittest.skipIf`
803
- @unittest .skipIf (
804
- torch .cuda .device_count () < 1 ,
805
- "Not enough GPUs, this test requires at least two GPUs" ,
806
- )
807
- def test_zch_hash_zero_rows (self ) -> None :
808
- # When disabling fallback, for missed ids we should return zero rows in output embeddings.
809
- mc_emb_configs = [
810
- EmbeddingBagConfig (
811
- num_embeddings = 10 ,
812
- embedding_dim = 3 ,
813
- name = "table_0" ,
814
- data_type = DataType .FP32 ,
815
- feature_names = ["table_0" ],
816
- pooling = PoolingType .SUM ,
817
- weight_init_max = None ,
818
- weight_init_min = None ,
819
- init_fn = None ,
820
- use_virtual_table = False ,
821
- virtual_table_eviction_policy = None ,
822
- total_num_buckets = 1 ,
823
- )
824
- ]
825
- mc_modules : Dict [str , ManagedCollisionModule ] = {
826
- "table_0" : HashZchManagedCollisionModule (
827
- zch_size = 10 ,
828
- device = torch .device ("cuda" ),
829
- max_probe = 512 ,
830
- tb_logging_frequency = 100 ,
831
- name = "table_0" ,
832
- total_num_buckets = 1 ,
833
- eviction_config = None ,
834
- eviction_policy_name = None ,
835
- opt_in_prob = - 1 ,
836
- percent_reserved_slots = 0 ,
837
- disable_fallback = True ,
838
- )
839
- }
840
- mcebc = ManagedCollisionEmbeddingBagCollection (
841
- EmbeddingBagCollection (
842
- device = torch .device ("cuda" ),
843
- tables = mc_emb_configs ,
844
- is_weighted = False ,
845
- ),
846
- ManagedCollisionCollection (
847
- managed_collision_modules = mc_modules ,
848
- embedding_configs = mc_emb_configs ,
849
- ),
850
- return_remapped_features = True ,
851
- )
852
- lengths = torch .tensor (
853
- [1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = torch .device ("cuda" )
854
- )
855
- values = torch .tensor (
856
- [3 , 4 , 5 , 6 , 8 ],
857
- dtype = torch .int64 ,
858
- device = torch .device ("cuda" ),
859
- )
860
- features = KeyedJaggedTensor (
861
- keys = ["table_0" ],
862
- values = values ,
863
- lengths = lengths ,
864
- )
865
- # Run once to insert ids
866
- res = mcebc .forward (features )
867
- # Pyre-ignore [6]: In call `torch._C._VariableFunctions.abs`, for 1st positional argument, expected `Tensor` but got `Union[JaggedTensor, Tensor]`
868
- mask = torch .abs (res [0 ]["table_0" ]) == 0
869
- # For each row, check if all elements are True (i.e., close to zero)
870
- row_mask = mask .all (dim = 1 )
871
- # Get indices of zero rows
872
- self .assertEqual (torch .nonzero (row_mask , as_tuple = False ).squeeze ().numel (), 0 )
873
- self .assertIsNotNone (res [1 ])
874
- self .assertTrue (
875
- torch .equal (
876
- # Pyre-ignore [16]: Optional type has no attribute `__getitem__`.
877
- res [1 ]["table_0" ].values (),
878
- torch .tensor ([1 , 2 , 8 , 9 , 3 ], dtype = torch .int64 , device = "cuda:0" ),
879
- )
880
- )
881
- self .assertTrue (
882
- torch .equal (
883
- res [1 ]["table_0" ].lengths (),
884
- torch .tensor ([1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda:0" ),
885
- )
886
- )
887
- # Pyre-ignore [29]: `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function
888
- mcebc ._managed_collision_collection ._managed_collision_modules [
889
- "table_0"
890
- ].reset_inference_mode ()
891
- lengths = torch .tensor (
892
- [1 , 1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = torch .device ("cuda" )
893
- )
894
- values = torch .tensor (
895
- [0 , 4 , 5 , 1 , 2 , 8 ],
896
- dtype = torch .int64 ,
897
- device = torch .device ("cuda" ),
898
- )
899
- features = KeyedJaggedTensor (
900
- keys = ["table_0" ],
901
- values = values ,
902
- lengths = lengths ,
903
- )
904
- # Run once to insert ids.
905
- res = mcebc .forward (features )
906
- self .assertTrue (
907
- torch .equal (
908
- res [1 ]["table_0" ].values (),
909
- torch .tensor ([2 , 8 , 3 ], dtype = torch .int64 , device = "cuda:0" ),
910
- )
911
- )
912
- self .assertTrue (
913
- torch .equal (
914
- res [1 ]["table_0" ].lengths (),
915
- torch .tensor ([0 , 1 , 1 , 0 , 0 , 1 ], dtype = torch .int64 , device = "cuda:0" ),
916
- )
917
- )
918
- # Pyre-ignore [6]: In call `torch._C._VariableFunctions.abs`, for 1st positional argument, expected `Tensor` but got `Union[JaggedTensor, Tensor]`
919
- mask = torch .abs (res [0 ]["table_0" ]) == 0
920
- # For each row, check if all elements are True (i.e., close to zero)
921
- row_mask = mask .all (dim = 1 )
922
- # Get indices of zero rows
923
- self .assertTrue (
924
- torch .equal (
925
- torch .tensor ([0 , 3 , 4 ], device = "cuda:0" ),
926
- torch .nonzero (row_mask , as_tuple = False ).squeeze (),
927
- )
928
- )
0 commit comments