Skip to content

Commit

Permalink
reset scalar values all the time to avoid mutation by the kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
adriendelsalle committed May 24, 2024
1 parent deb3f7a commit 48cc54d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 4 additions & 4 deletions doc/source/examples/test_flow_kernel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,27 @@
" (\"elevation\", nb.float64[::1]),\n",
" (\"erosion\", nb.float64[::1]),\n",
" (\"drainage_area\", nb.float64[::1]),\n",
" (\"k_coef\", nb.float64),\n",
" # (\"k_coef\", nb.float64),\n",
" # (\"area_exp\", nb.float64),\n",
" # (\"slope_exp\", nb.float64),\n",
" # (\"dt\", nb.float64),\n",
" # (\"k_coef\", 2e-4),\n",
" (\"k_coef\", 2e-4),\n",
" (\"area_exp\", 0.4),\n",
" (\"slope_exp\", 1.),\n",
" (\"dt\", 2e4),\n",
" ],\n",
" outputs=[\"erosion\"],\n",
" max_receivers=1,\n",
" n_threads=1,\n",
" print_generated_code=True,\n",
" # print_generated_code=True,\n",
" application_order=fs.flow.KernelApplicationOrder.BREADTH_UPSTREAM\n",
")\n",
"\n",
"kernel.bind_data(\n",
" elevation=elevation,\n",
" erosion=erosion,\n",
" drainage_area=drainage_area,\n",
" k_coef=2e-4,\n",
" # k_coef=2e-4,\n",
" # area_exp=0.4,\n",
" # slope_exp=1.,\n",
" # dt=2e4\n",
Expand Down
4 changes: 3 additions & 1 deletion python/fastscapelib/flow/numba_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,9 @@ def _build_node_data_getter(self, max_receivers):
data_dtypes[value.dtype] = [name]

node_content = "\n".join(
[f"node_data.{name} = data.{name}[index]" for name in self._grid_data]
[f"node_data.{name} = data.{name}[index]" for name in self._grid_data] +
[f"node_data.{name} = data.{name}" for name, ty in self._constants.items() if issubclass(ty.__class__, nb.core.types.Type)] +
[f"node_data.{name} = {value}" for name, value in self._constants.items() if not issubclass(value.__class__, nb.core.types.Type)]
)

receivers_view_data = [
Expand Down

0 comments on commit 48cc54d

Please sign in to comment.