You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Instead of defining the default dtype on the backend implementation type, we could set it as a context.
typeBackend = Cuda<bf16>;
to
let device:Device<Cuda> = Default::default();
device.dtype_global(DType::BF16);let tensor = Tensor::random(...);assert_eq!(tensor.dtype(),DType::BF16);
or
let device:Device<Cuda> = Default::default();
device.with_dtype(DType::BF16, || {let tensor = Tensor::random(...);assert_eq!(tensor.dtype(),DType::BF16);});
The challenge would be to define the default behavior: does the default dtype impact the current device across all threads, which could have unintended side effects? It could also be based on the StreamId (so the thread ID on native). The scoped approach also fixes the problem of duplicating all tensor initialization APIs to allow for a custom default dtype.