Skip to content

Default tensor dtype can be configured on the device instead of by a type. #3642

@nathanielsimard

Description

@nathanielsimard

Instead of defining the default dtype on the backend implementation type, we could set it as a context.

type Backend = Cuda<bf16>;

to

let device: Device<Cuda> = Default::default();
device.dtype_global(DType::BF16);
let tensor = Tensor::random(...);
assert_eq!(tensor.dtype(), DType::BF16);

or

let device: Device<Cuda> = Default::default();
device.with_dtype(DType::BF16, || {
    let tensor = Tensor::random(...);
    assert_eq!(tensor.dtype(), DType::BF16);
});

The challenge would be to define the default behavior: does the default dtype impact the current device across all threads, which could have unintended side effects? It could also be based on the StreamId (so the thread ID on native). The scoped approach also fixes the problem of duplicating all tensor initialization APIs to allow for a custom default dtype.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

Status

Todo

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions