Skip to content

Commit

Permalink
Add missing schema check. Adjust test criteria in autoaugment and fri…
Browse files Browse the repository at this point in the history
…ends.

Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Jun 10, 2024
1 parent cf103f1 commit 999c1cc
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 15 deletions.
1 change: 1 addition & 0 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,7 @@ def visit_op(op):
if id(op) in visited:
return
visited.add(id(op))
op.check_args()
# visit conttributing inputs
for edge in get_op_input_edges(op):
visit_op(get_source_op(edge))
Expand Down
19 changes: 8 additions & 11 deletions dali/test/python/auto_aug/test_auto_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -210,8 +210,8 @@ def pipeline():
@params(
(False, "cpu", 256),
(False, "gpu", 512),
(True, "cpu", 400),
(True, "gpu", 348),
(True, "cpu", 2000),
(True, "gpu", 2000),
)
def test_sub_policy(randomly_negate, dev, batch_size):
num_magnitude_bins = 10
Expand Down Expand Up @@ -305,9 +305,8 @@ def third(data, op_id_mag_id):
expected_counts.append(expected)
stat = chisquare(counts, expected_counts)
# assert that the magnitudes negation looks independently enough
# (0.05 <=), but also that it is not too ideal (i.e. like all
# cases happening exactly the expected number of times)
assert 0.05 <= stat.pvalue <= 0.95, f"{stat}"
# (0.01 <=)
assert 0.01 <= stat.pvalue, f"{stat}"


@params(("cpu",), ("gpu",))
Expand Down Expand Up @@ -397,7 +396,7 @@ def second_stage_only(data, op_id_mag_id):
)

policy = Policy("MyPolicy", num_magnitude_bins=num_magnitude_bins, sub_policies=sub_policies)
p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy)
p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy, seed=1234)
p.build()

for _ in range(5):
Expand All @@ -415,10 +414,8 @@ def second_stage_only(data, op_id_mag_id):
actual.append(actual_counts[mags])
expected.append(expected_counts[mags])
stat = chisquare(actual, expected)
# assert that the magnitudes negation looks independently enough
# (0.05 <=), but also that it is not too ideal (i.e. like all
# cases happening exactly the expected number of times)
assert 0.05 <= stat.pvalue <= 0.95, f"{stat}"
# assert that the magnitudes negation looks independently enough (0.01 <=)
assert 0.01 <= stat.pvalue, f"{stat}"


def test_policy_presentation():
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/auto_aug/test_rand_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -312,7 +312,7 @@ def pipeline():
actual.append(actual_count[out])
expected.append(expected_counts[out])
stat = chisquare(actual, expected)
assert 0.01 <= stat.pvalue <= 0.99, f"{stat} {actual} {expected}"
assert 0.01 <= stat.pvalue, f"{stat} {actual} {expected}"


def test_wrong_params_fail():
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/auto_aug/test_trivial_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -236,4 +236,4 @@ def pipeline():
stat = chisquare(actual, expected)
stats.append(stat)
mean_p_val = sum(stat.pvalue for stat in stats) / len(stats)
assert 0.05 <= mean_p_val <= 0.95, f"{mean_p_val} {stat} {actual} {expected}"
assert 0.01 <= mean_p_val, f"{mean_p_val} {stat} {actual} {expected}"

0 comments on commit 999c1cc

Please sign in to comment.