-
Notifications
You must be signed in to change notification settings - Fork 2
Improve testing coverage for hydrological functions with multi-backend support #173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
ef315cb
a7a4cd3
9b30537
c4c0a08
4f152ae
e2157e4
0af4ead
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from _test_inputs.catchment import * | ||
| from _test_inputs.accumulation import input_field_1c | ||
| from _test_inputs.readers import * | ||
|
|
||
| import earthkit.hydro as ekh | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "river_network, field, locations, expected", | ||
| [ | ||
| ( | ||
| ("cama_nextxy", cama_nextxy_1), | ||
| input_field_1c, | ||
| catchment_query_field_1, | ||
| catchment_max_1c, | ||
| ), | ||
| ], | ||
| indirect=["river_network"], | ||
| ) | ||
| @pytest.mark.parametrize("array_backend", ["numpy", "torch"]) | ||
| def test_catchments_max(river_network, field, locations, expected, array_backend): | ||
| """Test catchment max aggregation.""" | ||
| river_network = river_network.to_device("cpu", array_backend) | ||
| xp = ekh._backends.find.get_array_backend(array_backend) | ||
| result = ekh.catchments.array.max(river_network, xp.asarray(field), locations=locations) | ||
| result = np.asarray(result) | ||
| print("Result:", result) | ||
| print("Expected:", expected) | ||
| np.testing.assert_allclose(result, expected, rtol=1e-6) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from _test_inputs.catchment import * | ||
| from _test_inputs.accumulation import input_field_1c | ||
| from _test_inputs.readers import * | ||
|
|
||
| import earthkit.hydro as ekh | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "river_network, field, locations, expected", | ||
| [ | ||
| ( | ||
| ("cama_nextxy", cama_nextxy_1), | ||
| input_field_1c, | ||
| catchment_query_field_1, | ||
| catchment_mean_1c, | ||
| ), | ||
| ], | ||
| indirect=["river_network"], | ||
| ) | ||
| @pytest.mark.parametrize("array_backend", ["numpy", "torch", "jax"]) | ||
| def test_catchments_mean(river_network, field, locations, expected, array_backend): | ||
| """Test catchment mean aggregation.""" | ||
| river_network = river_network.to_device("cpu", array_backend) | ||
| xp = ekh._backends.find.get_array_backend(array_backend) | ||
| result = ekh.catchments.array.mean(river_network, xp.asarray(field), locations=locations) | ||
| result = np.asarray(result) | ||
| print("Result:", result) | ||
| print("Expected:", expected) | ||
| np.testing.assert_allclose(result, expected, rtol=1e-6) | ||
|
Comment on lines
+10
to
+31
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should test other array backends other than numpy |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from _test_inputs.catchment import * | ||
| from _test_inputs.accumulation import input_field_1c | ||
| from _test_inputs.readers import * | ||
|
|
||
| import earthkit.hydro as ekh | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "river_network, field, locations, expected", | ||
| [ | ||
| ( | ||
| ("cama_nextxy", cama_nextxy_1), | ||
| input_field_1c, | ||
| catchment_query_field_1, | ||
| catchment_min_1c, | ||
| ), | ||
| ], | ||
| indirect=["river_network"], | ||
| ) | ||
| @pytest.mark.parametrize("array_backend", ["numpy", "torch"]) | ||
| def test_catchments_min(river_network, field, locations, expected, array_backend): | ||
| """Test catchment min aggregation.""" | ||
| river_network = river_network.to_device("cpu", array_backend) | ||
| xp = ekh._backends.find.get_array_backend(array_backend) | ||
| result = ekh.catchments.array.min(river_network, xp.asarray(field), locations=locations) | ||
| result = np.asarray(result) | ||
| print("Result:", result) | ||
| print("Expected:", expected) | ||
| np.testing.assert_allclose(result, expected, rtol=1e-6) | ||
|
Comment on lines
+10
to
+31
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should test other array backends other than numpy |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from _test_inputs.catchment import * | ||
| from _test_inputs.accumulation import input_field_1c | ||
| from _test_inputs.readers import * | ||
|
|
||
| import earthkit.hydro as ekh | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "river_network, field, locations, expected", | ||
| [ | ||
| ( | ||
| ("cama_nextxy", cama_nextxy_1), | ||
| input_field_1c, | ||
| catchment_query_field_1, | ||
| catchment_sum_1c, | ||
| ), | ||
| ], | ||
| indirect=["river_network"], | ||
| ) | ||
| @pytest.mark.parametrize("array_backend", ["numpy", "torch", "jax"]) | ||
| def test_catchments_sum(river_network, field, locations, expected, array_backend): | ||
| """Test catchment sum aggregation.""" | ||
| river_network = river_network.to_device("cpu", array_backend) | ||
| xp = ekh._backends.find.get_array_backend(array_backend) | ||
| result = ekh.catchments.array.sum(river_network, xp.asarray(field), locations=locations) | ||
| result = np.asarray(result) | ||
| print("Result:", result) | ||
| print("Expected:", expected) | ||
| np.testing.assert_allclose(result, expected, rtol=1e-6) | ||
|
Comment on lines
+10
to
+31
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should test other array backends other than numpy |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from _test_inputs.distance import * | ||
| from _test_inputs.readers import * | ||
|
|
||
| import earthkit.hydro as ekh | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "river_network, field, expected", | ||
| [ | ||
| ( | ||
| ("cama_nextxy", cama_nextxy_1), | ||
| None, | ||
| distance_1_to_sink_shortest, | ||
| ), | ||
| ], | ||
| indirect=["river_network"], | ||
| ) | ||
| def test_distance_to_sink(river_network, field, expected): | ||
| """Test distance to sink computation.""" | ||
| result = ekh.distance.array.to_sink( | ||
| river_network, field=field, path="shortest", return_type="masked" | ||
| ) | ||
| print("Result:", result) | ||
| print("Expected:", expected) | ||
| np.testing.assert_array_equal(result, expected) | ||
|
Comment on lines
+9
to
+27
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should test other array backends other than numpy |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from _test_inputs.distance import * | ||
| from _test_inputs.readers import * | ||
|
|
||
| import earthkit.hydro as ekh | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "river_network, field, expected", | ||
| [ | ||
| ( | ||
| ("cama_nextxy", cama_nextxy_1), | ||
| None, | ||
| distance_1_to_source_shortest, | ||
| ), | ||
| ], | ||
| indirect=["river_network"], | ||
| ) | ||
| def test_distance_to_source(river_network, field, expected): | ||
| """Test distance to source computation.""" | ||
| result = ekh.distance.array.to_source( | ||
| river_network, field=field, path="shortest", return_type="masked" | ||
| ) | ||
| print("Result:", result) | ||
| print("Expected:", expected) | ||
| np.testing.assert_allclose(result, expected, rtol=1e-6) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from _test_inputs.accumulation import * | ||
| from _test_inputs.readers import * | ||
|
|
||
| import earthkit.hydro as ekh | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "river_network, input_field, flow_downstream, mv", | ||
| [ | ||
| ( | ||
| ("cama_nextxy", cama_nextxy_1), | ||
| input_field_1c, | ||
| downstream_metric_max_1c, | ||
| mv_1c, | ||
| ), | ||
| ( | ||
| ("cama_nextxy", cama_nextxy_1), | ||
| input_field_1e, | ||
| downstream_metric_max_1e, | ||
| mv_1e, | ||
| ), | ||
| ], | ||
| indirect=["river_network"], | ||
| ) | ||
| @pytest.mark.parametrize("array_backend", ["numpy", "torch", "jax"]) | ||
| def test_downstream_metric_max(river_network, input_field, flow_downstream, mv, array_backend): | ||
| river_network = river_network.to_device("cpu", array_backend) | ||
| xp = ekh._backends.find.get_array_backend(array_backend) | ||
| output_field = ekh.downstream.array.max( | ||
| river_network, xp.asarray(input_field), node_weights=None, return_type="masked" | ||
| ) | ||
| output_field = np.asarray(output_field) | ||
| flow_downstream_out = np.asarray(xp.asarray(flow_downstream)) | ||
| print(output_field) | ||
| print(flow_downstream_out) | ||
| assert output_field.dtype == flow_downstream_out.dtype | ||
| np.testing.assert_allclose(output_field, flow_downstream, rtol=1e-6, equal_nan=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should test other array backends other than numpy