Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict stopping criterion parameter usage in command line #174

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ function (nvbench_add_examples_target target_prefix cuda_std)
target_include_directories(${example_name} PRIVATE "${CMAKE_CURRENT_LIST_DIR}")
target_link_libraries(${example_name} PRIVATE nvbench::main)
set_target_properties(${example_name} PROPERTIES COMPILE_FEATURES cuda_std_${cuda_std})
add_test(NAME ${example_name}
COMMAND "$<TARGET_FILE:${example_name}>" --timeout 0.1 --min-time 1e-5
)
if ("${example_src}" STREQUAL "custom_criterion.cu")
add_test(NAME ${example_name}
COMMAND "$<TARGET_FILE:${example_name}>" --timeout 0.1)
else()
add_test(NAME ${example_name}
COMMAND "$<TARGET_FILE:${example_name}>" --timeout 0.1 --min-time 1e-5)
endif()

# These should not deadlock. If they do, it may be that the CUDA context was created before
# setting CUDA_MODULE_LOAD=EAGER in main, see NVIDIA/nvbench#136.
Expand Down
6 changes: 1 addition & 5 deletions nvbench/benchmark_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,7 @@ struct benchmark_base
/// Control the stopping criterion for the measurement loop.
/// @{
[[nodiscard]] const std::string& get_stopping_criterion() const { return m_stopping_criterion; }
benchmark_base &set_stopping_criterion(std::string criterion)
{
m_stopping_criterion = std::move(criterion);
return *this;
}
benchmark_base &set_stopping_criterion(std::string criterion);
/// @}

protected:
Expand Down
10 changes: 10 additions & 0 deletions nvbench/benchmark_base.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include <nvbench/benchmark_base.cuh>

#include <nvbench/criterion_manager.cuh>

#include <nvbench/detail/transform_reduce.cuh>

namespace nvbench
Expand Down Expand Up @@ -80,4 +82,12 @@ std::size_t benchmark_base::get_config_count() const
return per_device_count * m_devices.size();
}

benchmark_base &benchmark_base::set_stopping_criterion(std::string criterion)
{
m_stopping_criterion = std::move(criterion);
m_criterion_params = criterion_manager::get().get_criterion(m_stopping_criterion).get_params();
return *this;
}


} // namespace nvbench
3 changes: 3 additions & 0 deletions nvbench/criterion_manager.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ public:

using params_description = std::vector<std::pair<std::string, nvbench::named_values::type>>;
params_description get_params_description() const;

using params_map = std::unordered_map<std::string, params_description>;
params_map get_params_description_map() const;
};

/**
Expand Down
19 changes: 19 additions & 0 deletions nvbench/criterion_manager.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,23 @@ nvbench::criterion_manager::params_description criterion_manager::get_params_des
return desc;
}

criterion_manager::params_map criterion_manager::get_params_description_map() const
{
params_map result;

for (auto &[criterion_name, criterion] : m_map)
{
params_description &desc = result[criterion_name];
nvbench::criterion_params params = criterion->get_params();

for (auto param : params.get_names())
{
nvbench::named_values::type type = params.get_type(param);
desc.emplace_back(param, type);
}
}

return result;
}

} // namespace nvbench
24 changes: 18 additions & 6 deletions nvbench/detail/measure_cold.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <fmt/format.h>

#include <optional>

namespace nvbench::detail
{

Expand Down Expand Up @@ -176,8 +178,18 @@ void measure_cold_base::generate_summaries()
mean_cuda_time);
const auto cuda_rel_stdev = cuda_stdev / mean_cuda_time;
const auto noise = cuda_rel_stdev;
const auto max_noise = m_criterion_params.get_float64("max-noise");
const auto min_time = m_criterion_params.get_float64("min-time");

auto get_param = [this](std::optional<nvbench::float64_t> &param, const std::string &name)
{
if (m_criterion_params.has_value(name))
param = m_criterion_params.get_float64(name);
};

std::optional<nvbench::float64_t> max_noise;
get_param(max_noise, "max-noise");

std::optional<nvbench::float64_t> min_time;
get_param(max_noise, "min-time");

{
auto &summ = m_state.add_summary("nv/cold/time/gpu/stdev/relative");
Expand Down Expand Up @@ -241,15 +253,15 @@ void measure_cold_base::generate_summaries()
{
const auto timeout = m_walltime_timer.get_duration();

if (noise > max_noise)
if (max_noise && noise > *max_noise)
{
printer.log(nvbench::log_level::warn,
fmt::format("Current measurement timed out ({:0.2f}s) "
"while over noise threshold ({:0.2f}% > "
"{:0.2f}%)",
timeout,
noise * 100,
max_noise * 100));
*max_noise * 100));
}
if (m_total_samples < m_min_samples)
{
Expand All @@ -260,15 +272,15 @@ void measure_cold_base::generate_summaries()
m_total_samples,
m_min_samples));
}
if (m_total_cuda_time < min_time)
if (min_time && m_total_cuda_time < *min_time)
{
printer.log(nvbench::log_level::warn,
fmt::format("Current measurement timed out ({:0.2f}s) "
"before accumulating min_time ({:0.2f}s < "
"{:0.2f}s)",
timeout,
m_total_cuda_time,
min_time));
*min_time));
}
}

Expand Down
86 changes: 76 additions & 10 deletions nvbench/option_parser.cu
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ void option_parser::parse_impl()
}
}

this->apply_criterion_props();
this->check_criterion_props();

// Make sure there's a default printer if needed:
if (!m_have_stdout_printer)
{
Expand Down Expand Up @@ -536,7 +539,8 @@ void option_parser::parse_range(option_parser::arg_iterator_t first,
if (it != criterion_params.end())
{
check_params(1);
this->update_criterion_prop(first[0], first[1], it->second);
m_stopping_criterion_properties.push_back(
{int(m_benchmarks.size()) - 1, first[0], first[1], it->second});
first += 2;
}
else
Expand Down Expand Up @@ -977,20 +981,13 @@ catch (std::exception &e)
}

void option_parser::update_criterion_prop(
int benchmark_idx,
const std::string &prop_arg,
const std::string &prop_val,
const nvbench::named_values::type type)
try
{
// If no active benchmark, save args as global.
if (m_benchmarks.empty())
{
m_global_benchmark_args.push_back(prop_arg);
m_global_benchmark_args.push_back(prop_val);
return;
}

benchmark_base &bench = *m_benchmarks.back();
benchmark_base &bench = *m_benchmarks.at(benchmark_idx);
nvbench::criterion_params& criterion_params = bench.get_criterion_params();
std::string name(prop_arg.begin() + 2, prop_arg.end());
if (type == nvbench::named_values::type::float64)
Expand Down Expand Up @@ -1028,6 +1025,75 @@ catch (std::exception& e)
e.what());
}

void option_parser::check_criterion_props()
{
const nvbench::criterion_manager::params_map params_map =
nvbench::criterion_manager::get().get_params_description_map();

for (const auto& bench_ptr : m_benchmarks)
{
const std::string &stopping_criterion = bench_ptr->get_stopping_criterion();
auto it_criterion = params_map.find(stopping_criterion);

if (it_criterion == params_map.end())
{
NVBENCH_THROW(std::runtime_error,
"Unknown benchmark stopping criterion `{}`",
stopping_criterion);
}

const nvbench::criterion_manager::params_description &params_desc = it_criterion->second;
const nvbench::criterion_params &params = bench_ptr->get_criterion_params();

std::vector<std::string> param_names = params.get_names();

for (const std::string &name : param_names)
{
auto it_params = std::find_if(params_desc.begin(),
params_desc.end(),
[&name](const auto &param) { return param.first == name; });

if (it_params == params_desc.end())
{
NVBENCH_THROW(std::runtime_error,
"Unknown stopping criterion parameter:\nBenchmark: `{}`\nCriterion: `{}`\nParameter: `{}`",
bench_ptr->get_name(),
it_criterion->first,
name);
}
}

for (const auto& pair : params_desc)
{
auto it_params = std::find(param_names.begin(), param_names.end(), pair.first);

if (it_params == param_names.end())
{
NVBENCH_THROW(std::runtime_error,
"A stopping criterion parameter isn't set:\nBenchmark: `{}`\nCriterion: `{}`\nParameter: `{}`",
bench_ptr->get_name(),
it_criterion->first,
pair.first);
}
}

}
}

void option_parser::apply_criterion_props()
{
for (const stopping_criterion_property& prop : m_stopping_criterion_properties)
{
int beg = (prop.benchmark_idx == -1 ? 0 : prop.benchmark_idx);
int end = (prop.benchmark_idx == -1 ? int(m_benchmarks.size()) : prop.benchmark_idx + 1);

for (int i = beg; i < end; i++)
{
update_criterion_prop(i, prop.arg, prop.val, prop.type);
}
}
}

void option_parser::update_float64_prop(const std::string &prop_arg, const std::string &prop_val)
try
{
Expand Down
15 changes: 14 additions & 1 deletion nvbench/option_parser.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,14 @@ private:
void update_int64_prop(const std::string &prop_arg, const std::string &prop_val);
void update_float64_prop(const std::string &prop_arg, const std::string &prop_val);

void update_criterion_prop(const std::string &prop_arg,
void update_criterion_prop(int benchmark_idx,
const std::string &prop_arg,
const std::string &prop_val,
const nvbench::named_values::type type);

void update_used_device_state() const;
void check_criterion_props();
void apply_criterion_props();

// Command line args
std::vector<std::string> m_args;
Expand Down Expand Up @@ -146,6 +149,16 @@ private:

// Used for device modification commands like --log-gpu-clocks
bool m_exit_after_parsing{false};

struct stopping_criterion_property
{
int benchmark_idx;
std::string arg;
std::string val;
nvbench::named_values::type type;
};

std::vector<stopping_criterion_property> m_stopping_criterion_properties;
};

} // namespace nvbench
67 changes: 50 additions & 17 deletions testing/option_parser.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1231,23 +1231,56 @@ void test_timeout()

void test_stopping_criterion()
{
nvbench::option_parser parser;
parser.parse(
{"--benchmark", "DummyBench",
"--stopping-criterion", "entropy",
"--max-angle", "0.42",
"--min-r2", "0.6"});
const auto& states = parser_to_states(parser);

ASSERT(states.size() == 1);
ASSERT(states[0].get_stopping_criterion() == "entropy");

const nvbench::criterion_params &criterion_params = states[0].get_criterion_params();
ASSERT(criterion_params.has_value("max-angle"));
ASSERT(criterion_params.has_value("min-r2"));

ASSERT(criterion_params.get_float64("max-angle") == 0.42);
ASSERT(criterion_params.get_float64("min-r2") == 0.6);
{
nvbench::option_parser parser;
parser.parse({"--benchmark",
"DummyBench",
"--stopping-criterion",
"entropy",
"--max-angle",
"0.42",
"--min-r2",
"0.6"});
const auto &states = parser_to_states(parser);

ASSERT(states.size() == 1);
ASSERT(states[0].get_stopping_criterion() == "entropy");

const nvbench::criterion_params &criterion_params = states[0].get_criterion_params();
ASSERT(criterion_params.has_value("max-angle"));
ASSERT(criterion_params.has_value("min-r2"));

ASSERT(criterion_params.get_float64("max-angle") == 0.42);
ASSERT(criterion_params.get_float64("min-r2") == 0.6);
}
{
try
{
nvbench::option_parser parser;
parser.parse({
"--max-angle",
"0.42",
"--benchmark",
"DummyBench",
"--stopping-criterion",
"entropy",
"--min-r2",
"0.6",
"--max-angle",
"0.42",
"--benchmark",
"TestBench",
"--stopping-criterion",
"stdrel",
});
ASSERT(false);
}
catch (const std::runtime_error & /*ex*/)
{
// `max-angle` isn't applicable to `stdrel`
// fmt::print(stderr, "{}", ex.what());
}
}
}

int main()
Expand Down