Skip to content

Commit 1c626c1

Browse files
Yixin Baofacebook-github-bot
authored andcommitted
Remove flaky test in oss test_zch_hash_disable_fallback
Summary: as titled. Differential Revision: D83786896
1 parent d958b4d commit 1c626c1

File tree

1 file changed

+0
-238
lines changed

1 file changed

+0
-238
lines changed

torchrec/modules/tests/test_hash_mc_modules.py

Lines changed: 0 additions & 238 deletions
Original file line numberDiff line numberDiff line change
@@ -688,241 +688,3 @@ def test_dynamically_switch_inference_training_mode(self) -> None:
688688
self.assertTrue(m._is_inference)
689689
self.assertTrue(m._eviction_policy_name is None)
690690
self.assertTrue(m._eviction_module is None)
691-
692-
# Pyre-ignore [56]: Pyre was not able to infer the type of argument `torch.cuda.device_count() < 1` to decorator factory `unittest.skipIf`
693-
@unittest.skipIf(
694-
torch.cuda.device_count() < 1,
695-
"Not enough GPUs, this test requires at least two GPUs",
696-
)
697-
def test_zch_hash_disable_fallback(self) -> None:
698-
m = HashZchManagedCollisionModule(
699-
zch_size=30,
700-
device=torch.device("cuda"),
701-
total_num_buckets=2,
702-
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
703-
eviction_config=HashZchEvictionConfig(
704-
features=[],
705-
single_ttl=10,
706-
),
707-
max_probe=4,
708-
disable_fallback=True,
709-
start_bucket=1,
710-
output_segments=[0, 10, 20],
711-
)
712-
jt = JaggedTensor(
713-
values=torch.arange(0, 4, dtype=torch.int64, device="cuda"),
714-
lengths=torch.tensor([1, 1, 1, 1], dtype=torch.int64, device="cuda"),
715-
)
716-
# Run once to insert ids
717-
output0 = m.remap({"test": jt})
718-
self.assertTrue(
719-
torch.equal(
720-
output0["test"].values(),
721-
torch.tensor([8, 15, 11], dtype=torch.int64, device="cuda:0"),
722-
)
723-
)
724-
self.assertTrue(
725-
torch.equal(
726-
output0["test"].lengths(),
727-
torch.tensor([1, 1, 0, 1], dtype=torch.int64, device="cuda:0"),
728-
)
729-
)
730-
m.reset_inference_mode()
731-
jt = JaggedTensor(
732-
values=torch.tensor([9, 0, 1, 4, 6, 8], dtype=torch.int64, device="cuda"),
733-
lengths=torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.int64, device="cuda"),
734-
)
735-
# Run again in inference mode and only values 0 and 1 exist.
736-
output1 = m.remap({"test": jt})
737-
self.assertTrue(
738-
torch.equal(
739-
output1["test"].values(),
740-
torch.tensor([8, 15], dtype=torch.int64, device="cuda:0"),
741-
)
742-
)
743-
self.assertTrue(
744-
torch.equal(
745-
output1["test"].lengths(),
746-
torch.tensor([0, 1, 1, 0, 0, 0], dtype=torch.int64, device="cuda:0"),
747-
)
748-
)
749-
750-
m = HashZchManagedCollisionModule(
751-
zch_size=10,
752-
device=torch.device("cuda"),
753-
total_num_buckets=2,
754-
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
755-
eviction_config=HashZchEvictionConfig(
756-
features=[],
757-
single_ttl=10,
758-
),
759-
max_probe=4,
760-
start_bucket=0,
761-
output_segments=None,
762-
disable_fallback=True,
763-
)
764-
jt = JaggedTensor(
765-
values=torch.arange(0, 4, dtype=torch.int64, device="cuda"),
766-
lengths=torch.tensor([1, 1, 1, 1], dtype=torch.int64, device="cuda"),
767-
)
768-
# Run once to insert ids
769-
output0 = m.remap({"test": jt})
770-
self.assertTrue(
771-
torch.equal(
772-
output0["test"].values(),
773-
torch.tensor([3, 5, 4, 6], dtype=torch.int64, device="cuda:0"),
774-
)
775-
)
776-
self.assertTrue(
777-
torch.equal(
778-
output0["test"].lengths(),
779-
torch.tensor([1, 1, 1, 1], dtype=torch.int64, device="cuda:0"),
780-
)
781-
)
782-
m.reset_inference_mode()
783-
jt = JaggedTensor(
784-
values=torch.tensor([9, 0, 1, 4, 6, 8], dtype=torch.int64, device="cuda"),
785-
lengths=torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.int64, device="cuda"),
786-
)
787-
# Run again in inference mode and only values 0 and 1 exist.
788-
output1 = m.remap({"test": jt})
789-
self.assertTrue(
790-
torch.equal(
791-
output1["test"].values(),
792-
torch.tensor([3, 5], dtype=torch.int64, device="cuda:0"),
793-
)
794-
)
795-
self.assertTrue(
796-
torch.equal(
797-
output1["test"].lengths(),
798-
torch.tensor([0, 1, 1, 0, 0, 0], dtype=torch.int64, device="cuda:0"),
799-
)
800-
)
801-
802-
# Pyre-ignore [56]: Pyre was not able to infer the type of argument `torch.cuda.device_count() < 1` to decorator factory `unittest.skipIf`
803-
@unittest.skipIf(
804-
torch.cuda.device_count() < 1,
805-
"Not enough GPUs, this test requires at least two GPUs",
806-
)
807-
def test_zch_hash_zero_rows(self) -> None:
808-
# When disabling fallback, for missed ids we should return zero rows in output embeddings.
809-
mc_emb_configs = [
810-
EmbeddingBagConfig(
811-
num_embeddings=10,
812-
embedding_dim=3,
813-
name="table_0",
814-
data_type=DataType.FP32,
815-
feature_names=["table_0"],
816-
pooling=PoolingType.SUM,
817-
weight_init_max=None,
818-
weight_init_min=None,
819-
init_fn=None,
820-
use_virtual_table=False,
821-
virtual_table_eviction_policy=None,
822-
total_num_buckets=1,
823-
)
824-
]
825-
mc_modules: Dict[str, ManagedCollisionModule] = {
826-
"table_0": HashZchManagedCollisionModule(
827-
zch_size=10,
828-
device=torch.device("cuda"),
829-
max_probe=512,
830-
tb_logging_frequency=100,
831-
name="table_0",
832-
total_num_buckets=1,
833-
eviction_config=None,
834-
eviction_policy_name=None,
835-
opt_in_prob=-1,
836-
percent_reserved_slots=0,
837-
disable_fallback=True,
838-
)
839-
}
840-
mcebc = ManagedCollisionEmbeddingBagCollection(
841-
EmbeddingBagCollection(
842-
device=torch.device("cuda"),
843-
tables=mc_emb_configs,
844-
is_weighted=False,
845-
),
846-
ManagedCollisionCollection(
847-
managed_collision_modules=mc_modules,
848-
embedding_configs=mc_emb_configs,
849-
),
850-
return_remapped_features=True,
851-
)
852-
lengths = torch.tensor(
853-
[1, 1, 1, 1, 1], dtype=torch.int64, device=torch.device("cuda")
854-
)
855-
values = torch.tensor(
856-
[3, 4, 5, 6, 8],
857-
dtype=torch.int64,
858-
device=torch.device("cuda"),
859-
)
860-
features = KeyedJaggedTensor(
861-
keys=["table_0"],
862-
values=values,
863-
lengths=lengths,
864-
)
865-
# Run once to insert ids
866-
res = mcebc.forward(features)
867-
# Pyre-ignore [6]: In call `torch._C._VariableFunctions.abs`, for 1st positional argument, expected `Tensor` but got `Union[JaggedTensor, Tensor]`
868-
mask = torch.abs(res[0]["table_0"]) == 0
869-
# For each row, check if all elements are True (i.e., close to zero)
870-
row_mask = mask.all(dim=1)
871-
# Get indices of zero rows
872-
self.assertEqual(torch.nonzero(row_mask, as_tuple=False).squeeze().numel(), 0)
873-
self.assertIsNotNone(res[1])
874-
self.assertTrue(
875-
torch.equal(
876-
# Pyre-ignore [16]: Optional type has no attribute `__getitem__`.
877-
res[1]["table_0"].values(),
878-
torch.tensor([1, 2, 8, 9, 3], dtype=torch.int64, device="cuda:0"),
879-
)
880-
)
881-
self.assertTrue(
882-
torch.equal(
883-
res[1]["table_0"].lengths(),
884-
torch.tensor([1, 1, 1, 1, 1], dtype=torch.int64, device="cuda:0"),
885-
)
886-
)
887-
# Pyre-ignore [29]: `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function
888-
mcebc._managed_collision_collection._managed_collision_modules[
889-
"table_0"
890-
].reset_inference_mode()
891-
lengths = torch.tensor(
892-
[1, 1, 1, 1, 1, 1], dtype=torch.int64, device=torch.device("cuda")
893-
)
894-
values = torch.tensor(
895-
[0, 4, 5, 1, 2, 8],
896-
dtype=torch.int64,
897-
device=torch.device("cuda"),
898-
)
899-
features = KeyedJaggedTensor(
900-
keys=["table_0"],
901-
values=values,
902-
lengths=lengths,
903-
)
904-
# Run once to insert ids.
905-
res = mcebc.forward(features)
906-
self.assertTrue(
907-
torch.equal(
908-
res[1]["table_0"].values(),
909-
torch.tensor([2, 8, 3], dtype=torch.int64, device="cuda:0"),
910-
)
911-
)
912-
self.assertTrue(
913-
torch.equal(
914-
res[1]["table_0"].lengths(),
915-
torch.tensor([0, 1, 1, 0, 0, 1], dtype=torch.int64, device="cuda:0"),
916-
)
917-
)
918-
# Pyre-ignore [6]: In call `torch._C._VariableFunctions.abs`, for 1st positional argument, expected `Tensor` but got `Union[JaggedTensor, Tensor]`
919-
mask = torch.abs(res[0]["table_0"]) == 0
920-
# For each row, check if all elements are True (i.e., close to zero)
921-
row_mask = mask.all(dim=1)
922-
# Get indices of zero rows
923-
self.assertTrue(
924-
torch.equal(
925-
torch.tensor([0, 3, 4], device="cuda:0"),
926-
torch.nonzero(row_mask, as_tuple=False).squeeze(),
927-
)
928-
)

0 commit comments

Comments
 (0)