Skip to content

Commit

Permalink
Merge pull request #22 from NVlabs/nb/gpu-reorder-wo-cuhpx
Browse files Browse the repository at this point in the history
Avoid using cuhpx for nest and XY reordering
  • Loading branch information
nbren12 authored Jan 23, 2025
2 parents efb891c + 1d661f6 commit a5d5a3f
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 42 deletions.
51 changes: 51 additions & 0 deletions earth2grid/_bit_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def compact_bits(bits):
# Remove interleaved 0 bits
bits = bits & 0x5555555555555555 # Mask: 01010101...
# example implementation for 1 byte
# 0a0b0c0d
# 00ab00cd # (x | x >> 1) & 00110011 = 0x33
# 0000abcd # (x | x >> 2) & 00001111 = 0x0F
# --------
# abc0d
bits = (bits | (bits >> 1)) & 0x3333333333333333 # noqa
bits = (bits | (bits >> 2)) & 0x0F0F0F0F0F0F0F0F # noqa
bits = (bits | (bits >> 4)) & 0x00FF00FF00FF00FF # noqa
bits = (bits | (bits >> 8)) & 0x0000FFFF0000FFFF # noqa
bits = (bits | (bits >> 16)) & 0x00000000FFFFFFFF # noqa
return bits


def spread_bits(bits):
"""
bits is a 32 bit number (stored in int64)
algorithm starts by moving the first 16 bits to the left by 16
and proceeding recursively
"""
# example implementation for a byte
# 0000abcd
# 00ab00cd # (x | x <<2) & 00110011 = 0x33
# 0a0b0c0d # (x | x <<1) & 01010100 = 0x55
# --------
# abc0d
bits = (bits | (bits << 16)) & 0x0000FFFF0000FFFF # noqa
bits = (bits | (bits << 8)) & 0x00FF00FF00FF00FF # noqa
bits = (bits | (bits << 4)) & 0x0F0F0F0F0F0F0F0F # noqa
bits = (bits | (bits << 2)) & 0x3333333333333333 # noqa
bits = (bits | (bits << 1)) & 0x5555555555555555 # noqa
return bits
77 changes: 35 additions & 42 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import numpy as np
import torch

from earth2grid import healpix_bare
from earth2grid import _bit_ops, healpix_bare
from earth2grid._regrid import Regridder
from earth2grid.healpix_bare import ang2pix

Expand Down Expand Up @@ -155,7 +155,10 @@ def to_xy_cuda(self, x: torch.Tensor, dest: "XY"):
if self == PixelOrder.RING:
return _apply_cuhpx_remap(cuhpx.ring2flat, x, clockwise=dest.clockwise, origin=dest.origin.name)
elif self == PixelOrder.NEST:
return _apply_cuhpx_remap(cuhpx.nest2flat, x, clockwise=dest.clockwise, origin=dest.origin.name)
nside = npix2nside(x.size(-1))
i_dest = torch.arange(x.shape[-1], dtype=torch.int64, device=x.device)
i = xy2nest(nside, i_dest, dest)
return x[..., i]


class Compass(Enum):
Expand Down Expand Up @@ -192,14 +195,10 @@ def reorder_from_cuda(self, x, src: "PixelOrderT"):
return src.to_xy_cuda(x, self)

def to_xy_cuda(self, x: torch.Tensor, dest: "XY"):
return _apply_cuhpx_remap(
cuhpx.flat2flat,
x,
src_origin=self.origin.name,
src_clockwise=self.clockwise,
dest_origin=dest.origin.name,
dest_clockwise=dest.clockwise,
)
nside = npix2nside(x.size(-1))
i_dest = torch.arange(x.shape[-1], dtype=torch.int64, device=x.device)
i = xy2xy(nside, src=dest, dest=self, i=i_dest)
return x[..., i]

def to_ring_cuda(self, x: torch.Tensor):
return _apply_cuhpx_remap(
Expand All @@ -210,7 +209,10 @@ def to_ring_cuda(self, x: torch.Tensor):
)

def to_nest_cuda(self, x: torch.Tensor):
return _apply_cuhpx_remap(cuhpx.flat2nest, x, origin=self.origin.name, clockwise=self.clockwise)
nside = npix2nside(x.size(-1))
i_dest = torch.arange(x.shape[-1], dtype=torch.int64, device=x.device)
i = nest2xy(nside, i_dest, self)
return x[..., i]


PixelOrderT = Union[PixelOrder, XY]
Expand All @@ -224,7 +226,16 @@ def reorder(x: torch.Tensor, src_pixel_order: PixelOrderT, dest_pixel_order: Pix
return grid.reorder(dest_pixel_order, x)


def _convert_xyindex(nside: int, src: XY, dest: XY, i):
def xy2xy(nside: int, src: XY, dest: XY, i: torch.Tensor):
"""Convert flat index between pixel ordering conventions`
Args:
i: int64
"""
if src == dest:
return i

if src.clockwise != dest.clockwise:
i = _flip_xy(nside, i)

Expand Down Expand Up @@ -259,7 +270,7 @@ def _nest_ipix(self):
"""convert to nested index number"""
i = torch.arange(self._npix(), device="cpu")
if isinstance(self.pixel_order, XY):
i_xy = _convert_xyindex(nside=self._nside(), src=self.pixel_order, dest=XY(), i=i)
i_xy = xy2xy(nside=self._nside(), src=self.pixel_order, dest=XY(), i=i)
i = xy2nest(self._nside(), i_xy)
elif self.pixel_order == PixelOrder.RING:
i = healpix_bare.ring2nest(self._nside(), i)
Expand All @@ -273,7 +284,7 @@ def _nest2me(self, ipix: torch.Tensor) -> torch.Tensor:
"""return the index in my PIXELORDER corresponding to ipix in NEST ordering"""
if isinstance(self.pixel_order, XY):
i_xy = nest2xy(self._nside(), ipix)
i_me = _convert_xyindex(nside=self._nside(), src=XY(), dest=self.pixel_order, i=i_xy)
i_me = xy2xy(nside=self._nside(), src=XY(), dest=self.pixel_order, i=i_xy)
elif self.pixel_order == PixelOrder.RING:
i_me = healpix_bare.nest2ring(self._nside(), ipix)
elif self.pixel_order == PixelOrder.NEST:
Expand Down Expand Up @@ -451,22 +462,6 @@ def backward(ctx, grad):
ZOOM_LEVELS = 20


def _extract_every_other_bit(binary_number):
result = 0
shift_count = 0

for i in range(ZOOM_LEVELS):
# Check if the least significant bit is 1
# Set the corresponding bit in the result
result |= (binary_number & 1) << shift_count

# Shift to the next bit to check
binary_number = binary_number >> 2
shift_count += 1

return result


def _flip_xy(nside: int, i):
n2 = nside * nside
f = i // n2
Expand Down Expand Up @@ -507,29 +502,27 @@ def _rotate_index(nside: int, rotations: int, i):
return n2 * f + nside * new_y + new_x


def nest2xy(nside, i):
def nest2xy(nside, i, pixel_order: XY = XY()):
"""convert NEST to XY index"""
tile = i // nside**2
j = i % (nside**2)
x = _extract_every_other_bit(j)
y = _extract_every_other_bit(j >> 1)
return tile * nside**2 + y * nside + x
x = _bit_ops.compact_bits(j)
y = _bit_ops.compact_bits(j >> 1)
xy = tile * nside**2 + y * nside + x
xy = xy2xy(nside, XY(), pixel_order, xy)
return xy


def xy2nest(nside, i):
def xy2nest(nside, i, pixel_order: XY = XY()):
"""convert XY index to NEST"""
i = xy2xy(nside, pixel_order, XY(), i)
tile = i // (nside**2)
y = (i % (nside**2)) // nside
x = i % nside

result = 0
for i in range(ZOOM_LEVELS):
# Extract the ith bit from the number
extracted_bit = (x >> i) & 1
result |= extracted_bit << (2 * i)

extracted_bit = (y >> i) & 1
result |= extracted_bit << (2 * i + 1)
result |= _bit_ops.spread_bits(x)
result |= _bit_ops.spread_bits(y) << 1
return result | (tile * nside**2)


Expand Down
21 changes: 21 additions & 0 deletions tests/test__bit_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from earth2grid._bit_ops import compact_bits, spread_bits


def test_bit_ops():
assert spread_bits(0b11) == 0b101
i = 99
assert compact_bits(spread_bits(i)) == i

0 comments on commit a5d5a3f

Please sign in to comment.