@@ -1915,8 +1915,21 @@ defmodule Axon.Layers do
1915
1915
must be at least rank 3, with fixed `batch` and `channel` dimensions.
1916
1916
Resizing will upsample or downsample using the given resize method.
1917
1917
1918
- Supported resize methods are `:nearest, :linear, :bilinear, :trilinear,
1919
- :cubic, :bicubic, :tricubic`.
1918
+ ## Options
1919
+
1920
+ * `:size` - a tuple specifying the resized spatial dimensions.
1921
+ Required.
1922
+
1923
+ * `:method` - the resizing method to use, either of `:nearest`,
1924
+ `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`. Defaults to
1925
+ `:nearest`.
1926
+
1927
+ * `:antialias` - whether an anti-aliasing filter should be used
1928
+ when downsampling. This has no effect with upsampling. Defaults
1929
+ to `true`.
1930
+
1931
+ * `:channels` - channels location, either `:first` or `:last`.
1932
+ Defaults to `:last`.
1920
1933
1921
1934
## Examples
1922
1935
@@ -1951,6 +1964,7 @@ defmodule Axon.Layers do
1951
1964
:size ,
1952
1965
method: :nearest ,
1953
1966
channels: :last ,
1967
+ antialias: true ,
1954
1968
mode: :inference
1955
1969
] )
1956
1970
@@ -1962,22 +1976,36 @@ defmodule Axon.Layers do
1962
1976
{ axis , put_elem ( out_shape , axis , out_size ) }
1963
1977
end )
1964
1978
1979
+ antialias = opts [ :antialias ]
1980
+
1965
1981
resized_input =
1966
1982
case opts [ :method ] do
1967
1983
:nearest ->
1968
1984
resize_nearest ( input , out_shape , spatial_axes )
1969
1985
1970
1986
:bilinear ->
1971
- resize_with_kernel ( input , out_shape , spatial_axes , & fill_linear_kernel / 1 )
1987
+ resize_with_kernel ( input , out_shape , spatial_axes , antialias , & fill_linear_kernel / 1 )
1972
1988
1973
1989
:bicubic ->
1974
- resize_with_kernel ( input , out_shape , spatial_axes , & fill_cubic_kernel / 1 )
1990
+ resize_with_kernel ( input , out_shape , spatial_axes , antialias , & fill_cubic_kernel / 1 )
1975
1991
1976
1992
:lanczos3 ->
1977
- resize_with_kernel ( input , out_shape , spatial_axes , & fill_lanczos_kernel ( 3 , & 1 ) )
1993
+ resize_with_kernel (
1994
+ input ,
1995
+ out_shape ,
1996
+ spatial_axes ,
1997
+ antialias ,
1998
+ & fill_lanczos_kernel ( 3 , & 1 )
1999
+ )
1978
2000
1979
2001
:lanczos5 ->
1980
- resize_with_kernel ( input , out_shape , spatial_axes , & fill_lanczos_kernel ( 5 , & 1 ) )
2002
+ resize_with_kernel (
2003
+ input ,
2004
+ out_shape ,
2005
+ spatial_axes ,
2006
+ antialias ,
2007
+ & fill_lanczos_kernel ( 5 , & 1 )
2008
+ )
1981
2009
1982
2010
method ->
1983
2011
raise ArgumentError ,
@@ -2038,12 +2066,13 @@ defmodule Axon.Layers do
2038
2066
2039
2067
@ f32_eps :math . pow ( 2 , - 23 )
2040
2068
2041
- deftransformp resize_with_kernel ( input , out_shape , spatial_axes , kernel_fun ) do
2069
+ deftransformp resize_with_kernel ( input , out_shape , spatial_axes , antialias , kernel_fun ) do
2042
2070
for axis <- spatial_axes , reduce: input do
2043
2071
input ->
2044
2072
resize_axis_with_kernel ( input ,
2045
2073
axis: axis ,
2046
2074
output_size: elem ( out_shape , axis ) ,
2075
+ antialias: antialias ,
2047
2076
kernel_fun: kernel_fun
2048
2077
)
2049
2078
end
@@ -2052,12 +2081,19 @@ defmodule Axon.Layers do
2052
2081
defnp resize_axis_with_kernel ( input , opts ) do
2053
2082
axis = opts [ :axis ]
2054
2083
output_size = opts [ :output_size ]
2084
+ antialias = opts [ :antialias ]
2055
2085
kernel_fun = opts [ :kernel_fun ]
2056
2086
2057
2087
input_size = Nx . axis_size ( input , axis )
2058
2088
2059
2089
inv_scale = input_size / output_size
2060
- kernel_scale = max ( 1 , inv_scale )
2090
+
2091
+ kernel_scale =
2092
+ if antialias do
2093
+ max ( 1 , inv_scale )
2094
+ else
2095
+ 1
2096
+ end
2061
2097
2062
2098
sample_f = ( Nx . iota ( { 1 , output_size } ) + 0.5 ) * inv_scale - 0.5
2063
2099
x = Nx . abs ( sample_f - Nx . iota ( { input_size , 1 } ) ) / kernel_scale
0 commit comments