Skip to content

Commit 40aede6

Browse files
Add option to control anti-aliasing in the resize layer (#555)
1 parent cfa77a3 commit 40aede6

File tree

3 files changed

+90
-10
lines changed

3 files changed

+90
-10
lines changed

lib/axon.ex

+6-2
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ defmodule Axon do
390390
391391
You may specify the parameter shape as either a static shape or
392392
as function of the inputs to the given layer. If you specify the
393-
parameter shape as a function, it will be given the
393+
parameter shape as a function, it will be given the
394394
395395
## Options
396396
@@ -2122,18 +2122,22 @@ defmodule Axon do
21222122
21232123
* `:method` - resize method. Defaults to `:nearest`.
21242124
2125+
* `:antialias` - whether an anti-aliasing filter should be used
2126+
when downsampling. Defaults to `true`.
2127+
21252128
* `:channels` - channel configuration. One of `:first` or
21262129
`:last`. Defaults to `:last`.
21272130
21282131
"""
21292132
@doc type: :shape
21302133
def resize(%Axon{} = x, resize_shape, opts \\ []) do
2131-
opts = Keyword.validate!(opts, [:name, method: :nearest, channels: :last])
2134+
opts = Keyword.validate!(opts, [:name, method: :nearest, antialias: true, channels: :last])
21322135
channels = opts[:channels]
21332136

21342137
layer(:resize, [x],
21352138
name: opts[:name],
21362139
method: opts[:method],
2140+
antialias: opts[:antialias],
21372141
channels: channels,
21382142
size: resize_shape,
21392143
op_name: :resize

lib/axon/layers.ex

+44-8
Original file line numberDiff line numberDiff line change
@@ -1915,8 +1915,21 @@ defmodule Axon.Layers do
19151915
must be at least rank 3, with fixed `batch` and `channel` dimensions.
19161916
Resizing will upsample or downsample using the given resize method.
19171917
1918-
Supported resize methods are `:nearest, :linear, :bilinear, :trilinear,
1919-
:cubic, :bicubic, :tricubic`.
1918+
## Options
1919+
1920+
* `:size` - a tuple specifying the resized spatial dimensions.
1921+
Required.
1922+
1923+
* `:method` - the resizing method to use, either of `:nearest`,
1924+
`:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`. Defaults to
1925+
`:nearest`.
1926+
1927+
* `:antialias` - whether an anti-aliasing filter should be used
1928+
when downsampling. This has no effect with upsampling. Defaults
1929+
to `true`.
1930+
1931+
* `:channels` - channels location, either `:first` or `:last`.
1932+
Defaults to `:last`.
19201933
19211934
## Examples
19221935
@@ -1951,6 +1964,7 @@ defmodule Axon.Layers do
19511964
:size,
19521965
method: :nearest,
19531966
channels: :last,
1967+
antialias: true,
19541968
mode: :inference
19551969
])
19561970

@@ -1962,22 +1976,36 @@ defmodule Axon.Layers do
19621976
{axis, put_elem(out_shape, axis, out_size)}
19631977
end)
19641978

1979+
antialias = opts[:antialias]
1980+
19651981
resized_input =
19661982
case opts[:method] do
19671983
:nearest ->
19681984
resize_nearest(input, out_shape, spatial_axes)
19691985

19701986
:bilinear ->
1971-
resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1)
1987+
resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_linear_kernel/1)
19721988

19731989
:bicubic ->
1974-
resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1)
1990+
resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_cubic_kernel/1)
19751991

19761992
:lanczos3 ->
1977-
resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1))
1993+
resize_with_kernel(
1994+
input,
1995+
out_shape,
1996+
spatial_axes,
1997+
antialias,
1998+
&fill_lanczos_kernel(3, &1)
1999+
)
19782000

19792001
:lanczos5 ->
1980-
resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1))
2002+
resize_with_kernel(
2003+
input,
2004+
out_shape,
2005+
spatial_axes,
2006+
antialias,
2007+
&fill_lanczos_kernel(5, &1)
2008+
)
19812009

19822010
method ->
19832011
raise ArgumentError,
@@ -2038,12 +2066,13 @@ defmodule Axon.Layers do
20382066

20392067
@f32_eps :math.pow(2, -23)
20402068

2041-
deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do
2069+
deftransformp resize_with_kernel(input, out_shape, spatial_axes, antialias, kernel_fun) do
20422070
for axis <- spatial_axes, reduce: input do
20432071
input ->
20442072
resize_axis_with_kernel(input,
20452073
axis: axis,
20462074
output_size: elem(out_shape, axis),
2075+
antialias: antialias,
20472076
kernel_fun: kernel_fun
20482077
)
20492078
end
@@ -2052,12 +2081,19 @@ defmodule Axon.Layers do
20522081
defnp resize_axis_with_kernel(input, opts) do
20532082
axis = opts[:axis]
20542083
output_size = opts[:output_size]
2084+
antialias = opts[:antialias]
20552085
kernel_fun = opts[:kernel_fun]
20562086

20572087
input_size = Nx.axis_size(input, axis)
20582088

20592089
inv_scale = input_size / output_size
2060-
kernel_scale = max(1, inv_scale)
2090+
2091+
kernel_scale =
2092+
if antialias do
2093+
max(1, inv_scale)
2094+
else
2095+
1
2096+
end
20612097

20622098
sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5
20632099
x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale

test/axon/layers_test.exs

+40
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,46 @@ defmodule Axon.LayersTest do
10091009
atol: 1.0e-4
10101010
)
10111011
end
1012+
1013+
test "without anti-aliasing" do
1014+
# Upscaling
1015+
1016+
image = Nx.iota({1, 4, 4, 3}, type: :f32)
1017+
1018+
assert_all_close(
1019+
Axon.Layers.resize(image, size: {3, 3}, method: :bicubic, antialias: false),
1020+
Nx.tensor([
1021+
[
1022+
[
1023+
[[1.5427, 2.5427, 3.5427], [5.7341, 6.7341, 7.7341], [9.9256, 10.9256, 11.9256]],
1024+
[[18.3085, 19.3085, 20.3085], [22.5, 23.5, 24.5], [26.6915, 27.6915, 28.6915]],
1025+
[
1026+
[35.0744, 36.0744, 37.0744],
1027+
[39.2659, 40.2659, 41.2659],
1028+
[43.4573, 44.4573, 45.4573]
1029+
]
1030+
]
1031+
]
1032+
]),
1033+
atol: 1.0e-4
1034+
)
1035+
1036+
# Downscaling (no effect)
1037+
1038+
image = Nx.iota({1, 2, 2, 3}, type: :f32)
1039+
1040+
assert_all_close(
1041+
Axon.Layers.resize(image, size: {3, 3}, method: :bicubic, antialias: false),
1042+
Nx.tensor([
1043+
[
1044+
[[-0.5921, 0.4079, 1.4079], [1.1053, 2.1053, 3.1053], [2.8026, 3.8026, 4.8026]],
1045+
[[2.8026, 3.8026, 4.8026], [4.5, 5.5, 6.5], [6.1974, 7.1974, 8.1974]],
1046+
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
1047+
]
1048+
]),
1049+
atol: 1.0e-4
1050+
)
1051+
end
10121052
end
10131053

10141054
describe "lstm_cell" do

0 commit comments

Comments
 (0)