-
Notifications
You must be signed in to change notification settings - Fork 108
Add tests for NumPy language context and fix import path #2690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
88b16f3
a9a0400
55ec948
d2f213c
fc912bd
5efab78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| from thunder.numpy import size as np_size | ||
| from thunder.core.langctxs import langctx, Languages, resolve_language | ||
| from thunder.core.proxies import TensorProxy | ||
| from thunder.core.trace import detached_trace | ||
| from thunder.core.devices import cpu | ||
| from thunder.core.dtypes import float32 | ||
|
|
||
|
|
||
| def test_numpy_langctx_registration_and_len_size(): | ||
| with detached_trace(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice trick here |
||
| t = TensorProxy(shape=(2, 3), device=cpu, dtype=float32) | ||
|
|
||
| with langctx(Languages.NUMPY): | ||
| assert len(t) == 2 # axis 0 length | ||
| assert t.size() == 6 # total elements | ||
| assert np_size(t) == 6 | ||
|
Comment on lines
+14
to
+16
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you choose to test these?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I needed to check that any numpy function is actually working correctly under the context manager. If you have any better ideas for testing, please, suggest one, I'm ready to fix |
||
|
|
||
|
|
||
| def test_numpy_langctx_resolve_language(): | ||
| numpy_ctx_by_enum = resolve_language(Languages.NUMPY) | ||
| numpy_ctx_by_name = resolve_language("numpy") | ||
|
|
||
| assert numpy_ctx_by_enum is numpy_ctx_by_name | ||
| assert numpy_ctx_by_enum.name == "numpy" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch!