|  | 
|  | 1 | +import logging | 
|  | 2 | + | 
|  | 3 | +import torch | 
|  | 4 | +from torch.utils._pytree import tree_any | 
|  | 5 | + | 
|  | 6 | + | 
|  | 7 | +log = logging.getLogger(__name__) | 
|  | 8 | + | 
|  | 9 | +from ._device_daemon import daemon | 
|  | 10 | +from ._meta_parser import prepare_for_sending, to_device_no_copy | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +_IMPL_REGISTRY = {} | 
|  | 14 | + | 
|  | 15 | + | 
|  | 16 | +# Define all the implementations in the registry | 
|  | 17 | +def _register_same_name(name, with_log=False): | 
|  | 18 | +    def _(*args, **kwargs): | 
|  | 19 | +        if with_log: | 
|  | 20 | +            log.info("Calling hook %s", name) | 
|  | 21 | +        return daemon.exec(name, *args, **kwargs) | 
|  | 22 | + | 
|  | 23 | +    _IMPL_REGISTRY[name] = _ | 
|  | 24 | + | 
|  | 25 | + | 
|  | 26 | +_register_same_name("deviceCount") | 
|  | 27 | +_register_same_name("getDevice") | 
|  | 28 | +_register_same_name("uncheckedSetDevice") | 
|  | 29 | +_register_same_name("exchangeDevice") | 
|  | 30 | +_register_same_name("malloc", True) | 
|  | 31 | +_register_same_name("free", True) | 
|  | 32 | + | 
|  | 33 | +_openreg_lib = torch.library.Library("_", "IMPL") | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +def _openreg_kernel_fallback(op, *args, **kwargs): | 
|  | 37 | +    log.info("Calling kernel %s", op) | 
|  | 38 | + | 
|  | 39 | +    # Special ops needed to avoid infinite recursion | 
|  | 40 | +    if op is torch.ops.aten._copy_from.default: | 
|  | 41 | +        from_, to_ = args | 
|  | 42 | +        if from_.device.type == to_.device.type: | 
|  | 43 | +            assert from_.device.type == "openreg" | 
|  | 44 | +            op = torch.ops.aten.copy_.default | 
|  | 45 | +            # handled below as a regular copy | 
|  | 46 | +        elif from_.device.type == "openreg": | 
|  | 47 | +            args, _ = prepare_for_sending((from_,), {}) | 
|  | 48 | +            host_mem = daemon.exec("send_data", *args) | 
|  | 49 | +            return to_.copy_(host_mem) | 
|  | 50 | +        elif to_.device.type == "openreg": | 
|  | 51 | +            args, _ = prepare_for_sending((to_,), {}) | 
|  | 52 | +            daemon.exec("recv_data", from_, *args) | 
|  | 53 | +            return to_ | 
|  | 54 | +        else: | 
|  | 55 | +            raise RuntimeError("Should not happen") | 
|  | 56 | +    elif op is torch.ops.aten.set_.source_Tensor: | 
|  | 57 | +        return torch.ops.aten.set_.source_Storage_storage_offset( | 
|  | 58 | +            args[0], | 
|  | 59 | +            args[1].untyped_storage(), | 
|  | 60 | +            args[1].storage_offset(), | 
|  | 61 | +            args[1].size(), | 
|  | 62 | +            args[1].stride(), | 
|  | 63 | +        ) | 
|  | 64 | +    elif op is torch.ops.aten._local_scalar_dense.default: | 
|  | 65 | +        args, _ = prepare_for_sending(args, {}) | 
|  | 66 | +        host_mem = daemon.exec("send_data", *args) | 
|  | 67 | +        return host_mem.item() | 
|  | 68 | + | 
|  | 69 | +    op_name = None | 
|  | 70 | +    post_process = None | 
|  | 71 | +    if "out" in op._overloadname: | 
|  | 72 | +        # Note that all structured native op will call here | 
|  | 73 | +        if isinstance(kwargs["out"], tuple): | 
|  | 74 | +            raise RuntimeError(f"out= variant {op} with tuple out= not supported") | 
|  | 75 | +        if kwargs["out"].nelement() == 0: | 
|  | 76 | +            # Out variant that needs a resize, convert to an out of place | 
|  | 77 | +            # and handle generically below | 
|  | 78 | +            orig_out = kwargs["out"] | 
|  | 79 | +            del kwargs["out"] | 
|  | 80 | +            if op._overloadname != "out": | 
|  | 81 | +                raise RuntimeError( | 
|  | 82 | +                    "Cannot retranslate non-default out= variant form 0 size" | 
|  | 83 | +                ) | 
|  | 84 | +            op = op.overloadpacket.default | 
|  | 85 | + | 
|  | 86 | +            def _post_process(): | 
|  | 87 | +                nonlocal real_res | 
|  | 88 | +                orig_out.set_(real_res) | 
|  | 89 | +                real_res = orig_out | 
|  | 90 | + | 
|  | 91 | +            post_process = _post_process | 
|  | 92 | + | 
|  | 93 | +        else: | 
|  | 94 | +            # No metadata update to do, just run the op on the device | 
|  | 95 | +            op_name = op.overloadpacket._qualified_op_name | 
|  | 96 | +            real_res = kwargs["out"] | 
|  | 97 | +    elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)): | 
|  | 98 | +        # No Tensor argument means factory function | 
|  | 99 | +        # They should decompose and be handled in our c++ side directly | 
|  | 100 | +        raise RuntimeError(f"{op} not handled yet.") | 
|  | 101 | +    elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: | 
|  | 102 | +        # Only handle inplace ops returning their first arg | 
|  | 103 | +        assert len(args) >= 1, f"Inplace {op} needs at least one arg" | 
|  | 104 | +        assert ( | 
|  | 105 | +            len(op._schema.returns) == 1 | 
|  | 106 | +        ), f"NYI Inplace {op} with more than one return" | 
|  | 107 | +        op_name = op.overloadpacket._qualified_op_name | 
|  | 108 | +        real_res = args[0] | 
|  | 109 | +    elif any(r.alias_info is not None for r in op._schema.returns): | 
|  | 110 | +        # View ops | 
|  | 111 | +        if op is torch.ops.aten.view.default: | 
|  | 112 | +            return torch.ops.aten._unsafe_view(*args, **kwargs) | 
|  | 113 | +        raise RuntimeError(f"{op} view op is not handled yet") | 
|  | 114 | + | 
|  | 115 | +    if op_name is None: | 
|  | 116 | +        # 1. Compute updated metadata | 
|  | 117 | +        if torch.Tag.dynamic_output_shape not in op.tags: | 
|  | 118 | +            # Usual case: run the meta op to see the output metadata | 
|  | 119 | +            meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs) | 
|  | 120 | +            meta_res = op(*meta_args, **meta_kwargs) | 
|  | 121 | + | 
|  | 122 | +            # 2. Allocate the output | 
|  | 123 | +            real_res, _ = to_device_no_copy("openreg", meta_res, {}) | 
|  | 124 | +        else: | 
|  | 125 | +            # Slow version for data-dependent functions: | 
|  | 126 | +            # Run the op on the device just to get the output shape | 
|  | 127 | +            args_, kwargs_ = prepare_for_sending(args, kwargs) | 
|  | 128 | +            shape = daemon.exec( | 
|  | 129 | +                "get_op_output_shape", | 
|  | 130 | +                op.overloadpacket._qualified_op_name, | 
|  | 131 | +                args_, | 
|  | 132 | +                kwargs_, | 
|  | 133 | +            ) | 
|  | 134 | + | 
|  | 135 | +            # 2. Allocate the output | 
|  | 136 | +            real_res = args[0].new(shape) | 
|  | 137 | + | 
|  | 138 | +        # 3. Move to out variant | 
|  | 139 | +        kwargs["out"] = real_res | 
|  | 140 | +        # Let overload resolution find the out= overload | 
|  | 141 | +        op_name = op.overloadpacket._qualified_op_name | 
|  | 142 | + | 
|  | 143 | +    # 4. Run the compute and populate the output on the device | 
|  | 144 | +    args, kwargs = prepare_for_sending(args, kwargs) | 
|  | 145 | +    daemon.exec("run_op", op_name, args, kwargs) | 
|  | 146 | + | 
|  | 147 | +    if post_process is not None: | 
|  | 148 | +        post_process() | 
|  | 149 | + | 
|  | 150 | +    return real_res | 
|  | 151 | + | 
|  | 152 | + | 
|  | 153 | +_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") | 
0 commit comments