Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove numpy<2 pin #984

Open
janosh opened this issue Jan 28, 2025 · 5 comments
Open

Remove numpy<2 pin #984

janosh opened this issue Jan 28, 2025 · 5 comments
Assignees

Comments

@janosh
Copy link

janosh commented Jan 28, 2025

_balanced_partition appears to be the only function that uses numba for jitting.

@numba.njit
def _balanced_partition(sizes: NDArray[np.int_], num_parts: int):
"""
Greedily partition the given set by always inserting
the largest element into the smallest partition.
"""
sort_idx = np.argsort(-sizes) # Sort in descending order
heap = [(sizes[idx], [idx]) for idx in sort_idx[:num_parts]]
heapq.heapify(heap)
for idx in sort_idx[num_parts:]:
smallest_part = heapq.heappop(heap)
new_size = smallest_part[0] + sizes[idx]
new_idx = smallest_part[1] + [
idx
] # TODO should this be append to save time/space
heapq.heappush(heap, (new_size, new_idx))
return [part[1] for part in heap]

@zulissi this seems like a low-cost function that only mildly benefits from JIT (didn't test, could be wrong, just judging from reading the function). unless there are additional reasons for down-pinning numpy that i'm unaware of, maybe it makes sense to remove that decorator and the version pin on numpy<2 entirely? the benefits of numpy v2 compatibility might be worth more than mildly faster _balanced_partition

@lbluque
Copy link
Collaborator

lbluque commented Jan 30, 2025

Hi @janosh 👋

Indeed I think we may be able to get away with just removing it. I haven't profiled the difference though. I do agree numpy 2 compatibility would be great, but since a lot of our ongoing training runs are using it, it would be easier to keep it for now.

I haven't had time to delve into the numba docs, but my guess is that up to date versions should work with numpy 2. I'll look into this as soon as I get a chance, or just profile the difference without numba jit.

@bkmi
Copy link

bkmi commented Feb 6, 2025

numpy 2.1 is supported in the newest version of numba
https://github.com/numba/numba/releases/tag/0.61.0

@lbluque
Copy link
Collaborator

lbluque commented Feb 7, 2025

Thanks @bkmi! I opened a #1003. However this is still blocked by torch_geometric 2.6.1.

@bkmi
Copy link

bkmi commented Feb 11, 2025

It does not look to me that torch_geometric 2.6.1 requires numpy<2

https://github.com/pyg-team/pytorch_geometric/blob/90bb1397f0313adbe69f90b6ed8bdf123eeea120/pyproject.toml#L35C6-L35C11

although perhaps they do not admit they depend a lower version.

@lbluque
Copy link
Collaborator

lbluque commented Feb 12, 2025

It's not explicitly set as a dependency, but you can see in these failing tests the usage of the now removed numpy.math in dimenet_utils.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants