diff --git a/earth2grid/_bit_ops.py b/earth2grid/_bit_ops.py new file mode 100644 index 0000000..6cf6d72 --- /dev/null +++ b/earth2grid/_bit_ops.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def compact_bits(bits): + # Remove interleaved 0 bits + bits = bits & 0x5555555555555555 # Mask: 01010101... + # example implementation for 1 byte + # 0a0b0c0d + # 00ab00cd # (x | x >> 1) & 00110011 = 0x33 + # 0000abcd # (x | x >> 2) & 00001111 = 0x0F + # -------- + # abc0d + bits = (bits | (bits >> 1)) & 0x3333333333333333 # noqa + bits = (bits | (bits >> 2)) & 0x0F0F0F0F0F0F0F0F # noqa + bits = (bits | (bits >> 4)) & 0x00FF00FF00FF00FF # noqa + bits = (bits | (bits >> 8)) & 0x0000FFFF0000FFFF # noqa + bits = (bits | (bits >> 16)) & 0x00000000FFFFFFFF # noqa + return bits + + +def spread_bits(bits): + """ + bits is a 32 bit number (stored in int64) + algorithm starts by moving the first 16 bits to the left by 16 + and proceeding recursively + """ + # example implementation for a byte + # 0000abcd + # 00ab00cd # (x | x <<2) & 00110011 = 0x33 + # 0a0b0c0d # (x | x <<1) & 01010100 = 0x55 + # -------- + # abc0d + bits = (bits | (bits << 16)) & 0x0000FFFF0000FFFF # noqa + bits = (bits | (bits << 8)) & 0x00FF00FF00FF00FF # noqa + bits = (bits | (bits << 4)) & 0x0F0F0F0F0F0F0F0F # noqa + bits = (bits | (bits << 2)) & 0x3333333333333333 # noqa + bits = (bits | (bits << 1)) & 0x5555555555555555 # noqa + return bits diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py index 477785f..32db194 100644 --- a/earth2grid/healpix.py +++ b/earth2grid/healpix.py @@ -42,7 +42,7 @@ import numpy as np import torch -from earth2grid import healpix_bare +from earth2grid import _bit_ops, healpix_bare from earth2grid._regrid import Regridder from earth2grid.healpix_bare import ang2pix @@ -155,7 +155,10 @@ def to_xy_cuda(self, x: torch.Tensor, dest: "XY"): if self == PixelOrder.RING: return _apply_cuhpx_remap(cuhpx.ring2flat, x, clockwise=dest.clockwise, origin=dest.origin.name) elif self == PixelOrder.NEST: - return _apply_cuhpx_remap(cuhpx.nest2flat, x, clockwise=dest.clockwise, origin=dest.origin.name) + nside = npix2nside(x.size(-1)) + i_dest = torch.arange(x.shape[-1], dtype=torch.int64, device=x.device) + i = xy2nest(nside, i_dest, dest) + return x[..., i] class Compass(Enum): @@ -192,14 +195,10 @@ def reorder_from_cuda(self, x, src: "PixelOrderT"): return src.to_xy_cuda(x, self) def to_xy_cuda(self, x: torch.Tensor, dest: "XY"): - return _apply_cuhpx_remap( - cuhpx.flat2flat, - x, - src_origin=self.origin.name, - src_clockwise=self.clockwise, - dest_origin=dest.origin.name, - dest_clockwise=dest.clockwise, - ) + nside = npix2nside(x.size(-1)) + i_dest = torch.arange(x.shape[-1], dtype=torch.int64, device=x.device) + i = xy2xy(nside, src=dest, dest=self, i=i_dest) + return x[..., i] def to_ring_cuda(self, x: torch.Tensor): return _apply_cuhpx_remap( @@ -210,7 +209,10 @@ def to_ring_cuda(self, x: torch.Tensor): ) def to_nest_cuda(self, x: torch.Tensor): - return _apply_cuhpx_remap(cuhpx.flat2nest, x, origin=self.origin.name, clockwise=self.clockwise) + nside = npix2nside(x.size(-1)) + i_dest = torch.arange(x.shape[-1], dtype=torch.int64, device=x.device) + i = nest2xy(nside, i_dest, self) + return x[..., i] PixelOrderT = Union[PixelOrder, XY] @@ -224,7 +226,16 @@ def reorder(x: torch.Tensor, src_pixel_order: PixelOrderT, dest_pixel_order: Pix return grid.reorder(dest_pixel_order, x) -def _convert_xyindex(nside: int, src: XY, dest: XY, i): +def xy2xy(nside: int, src: XY, dest: XY, i: torch.Tensor): + """Convert flat index between pixel ordering conventions` + + Args: + i: int64 + + """ + if src == dest: + return i + if src.clockwise != dest.clockwise: i = _flip_xy(nside, i) @@ -259,7 +270,7 @@ def _nest_ipix(self): """convert to nested index number""" i = torch.arange(self._npix(), device="cpu") if isinstance(self.pixel_order, XY): - i_xy = _convert_xyindex(nside=self._nside(), src=self.pixel_order, dest=XY(), i=i) + i_xy = xy2xy(nside=self._nside(), src=self.pixel_order, dest=XY(), i=i) i = xy2nest(self._nside(), i_xy) elif self.pixel_order == PixelOrder.RING: i = healpix_bare.ring2nest(self._nside(), i) @@ -273,7 +284,7 @@ def _nest2me(self, ipix: torch.Tensor) -> torch.Tensor: """return the index in my PIXELORDER corresponding to ipix in NEST ordering""" if isinstance(self.pixel_order, XY): i_xy = nest2xy(self._nside(), ipix) - i_me = _convert_xyindex(nside=self._nside(), src=XY(), dest=self.pixel_order, i=i_xy) + i_me = xy2xy(nside=self._nside(), src=XY(), dest=self.pixel_order, i=i_xy) elif self.pixel_order == PixelOrder.RING: i_me = healpix_bare.nest2ring(self._nside(), ipix) elif self.pixel_order == PixelOrder.NEST: @@ -451,22 +462,6 @@ def backward(ctx, grad): ZOOM_LEVELS = 20 -def _extract_every_other_bit(binary_number): - result = 0 - shift_count = 0 - - for i in range(ZOOM_LEVELS): - # Check if the least significant bit is 1 - # Set the corresponding bit in the result - result |= (binary_number & 1) << shift_count - - # Shift to the next bit to check - binary_number = binary_number >> 2 - shift_count += 1 - - return result - - def _flip_xy(nside: int, i): n2 = nside * nside f = i // n2 @@ -507,29 +502,27 @@ def _rotate_index(nside: int, rotations: int, i): return n2 * f + nside * new_y + new_x -def nest2xy(nside, i): +def nest2xy(nside, i, pixel_order: XY = XY()): """convert NEST to XY index""" tile = i // nside**2 j = i % (nside**2) - x = _extract_every_other_bit(j) - y = _extract_every_other_bit(j >> 1) - return tile * nside**2 + y * nside + x + x = _bit_ops.compact_bits(j) + y = _bit_ops.compact_bits(j >> 1) + xy = tile * nside**2 + y * nside + x + xy = xy2xy(nside, XY(), pixel_order, xy) + return xy -def xy2nest(nside, i): +def xy2nest(nside, i, pixel_order: XY = XY()): """convert XY index to NEST""" + i = xy2xy(nside, pixel_order, XY(), i) tile = i // (nside**2) y = (i % (nside**2)) // nside x = i % nside result = 0 - for i in range(ZOOM_LEVELS): - # Extract the ith bit from the number - extracted_bit = (x >> i) & 1 - result |= extracted_bit << (2 * i) - - extracted_bit = (y >> i) & 1 - result |= extracted_bit << (2 * i + 1) + result |= _bit_ops.spread_bits(x) + result |= _bit_ops.spread_bits(y) << 1 return result | (tile * nside**2) diff --git a/tests/test__bit_ops.py b/tests/test__bit_ops.py new file mode 100644 index 0000000..23f1ac1 --- /dev/null +++ b/tests/test__bit_ops.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from earth2grid._bit_ops import compact_bits, spread_bits + + +def test_bit_ops(): + assert spread_bits(0b11) == 0b101 + i = 99 + assert compact_bits(spread_bits(i)) == i