diff --git a/python/hidet/option.py b/python/hidet/option.py index 473bebdd4..fbfc80dba 100644 --- a/python/hidet/option.py +++ b/python/hidet/option.py @@ -10,8 +10,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import Dict, Any, List, Optional, Callable, Iterable, Tuple +from typing import Dict, Any, List, Optional, Callable, Iterable, Tuple, Union import os +import tomlkit class OptionRegistry: @@ -36,6 +37,81 @@ def __init__( self.checker = checker +def create_toml_doc() -> tomlkit.TOMLDocument: + def nest_flattened_dict(d: Dict[str, Any]) -> Dict[str, Any]: + new_dict = {} + for k, v in d.items(): + if '.' in k: + prefix, suffix = k.split('.', 1) + if prefix not in new_dict: + new_dict[prefix] = {suffix: v} + else: + new_dict[prefix][suffix] = v + else: + new_dict[k] = v + for k, v in new_dict.items(): + if isinstance(v, dict): + new_dict[k] = nest_flattened_dict(v) + return new_dict + + def gen_doc(d: Dict[str, Any], toml_doc: tomlkit.TOMLDocument): + for k, v in d.items(): + if isinstance(v, dict): + table = tomlkit.table() + gen_doc(v, table) + toml_doc.add(k, table) + elif isinstance(v, OptionRegistry): + toml_doc.add(tomlkit.comment(v.description)) + if v.choices is not None: + toml_doc.add(tomlkit.comment(f' choices: {v.choices}')) + if isinstance(v.default_value, (bool, int, float, str)): + toml_doc.add(k, v.default_value) + elif isinstance(v.default_value, Tuple): + # represent tuples are toml arrays, do not allow python lists are default values to avoid ambiguity + val = list(v.default_value) + arr = tomlkit.array() + arr.extend(val) + toml_doc.add(k, arr) + else: + raise ValueError(f'Invalid type of default value for option {k}: {type(v.default_value)}') + toml_doc.add(tomlkit.nl()) + else: + raise ValueError(f'Invalid type of default value for option {k}: {type(v)}') + + fd = nest_flattened_dict(OptionRegistry.registered_options) + doc = tomlkit.document() + gen_doc(fd, doc) + return doc + + +def _load_config(config_file_path: str): + def collapse_nested_dict(d: Dict[str, Any]) -> Dict[str, Union[str, int, float, bool, Tuple]]: + # {"cuda": {"arch": "hopper", "cc": [9, 0]}} -> {"cuda.arch": 90, "cuda.cc": (9, 0)} + ret = {} + for k, v in d.items(): + if isinstance(v, dict): + v = collapse_nested_dict(v) + for k1, v1 in v.items(): + ret[f'{k}.{k1}'] = v1 + continue + if isinstance(v, list): + v = tuple(v) + ret[k] = v + return ret + + with open(config_file_path, 'r') as f: + config_doc = tomlkit.parse(f.read()) + for k, v in collapse_nested_dict(config_doc).items(): + if k not in OptionRegistry.registered_options: + raise KeyError(f'Option {k} found in config file {config_file_path} is not registered.') + OptionRegistry.registered_options[k].default_value = v + + +def _write_default_config(config_file_path: str, config_doc: tomlkit.TOMLDocument): + with open(config_file_path, 'w') as f: + tomlkit.dump(config_doc, f) + + def register_option( name: str, type_hint: str, @@ -177,11 +253,20 @@ def register_hidet_options(): ) register_option( name='cuda.arch', - type_hint='Optional[str]', - default_value=None, - description='The CUDA architecture to compile the kernels for (e.g., "sm_70"). None for auto-detect.', + type_hint='str', + default_value='auto', + description='The CUDA architecture to compile the kernels for (e.g., "sm_70"). "auto" for auto-detect.', ) + config_file_path = os.path.join(os.path.expanduser('~'), '.config', 'hidet') + if not os.path.exists(config_file_path): + os.makedirs(config_file_path) + config_file_path = os.path.join(config_file_path, 'hidet.toml') + if not os.path.exists(config_file_path): + _write_default_config(config_file_path, create_toml_doc()) + else: + _load_config(config_file_path) + register_hidet_options() @@ -662,15 +747,15 @@ def debug_show_verbose_flow_graph(enable: bool = True): class cuda: @staticmethod - def arch(arch: Optional[str] = None): + def arch(arch: str = 'auto'): """ Set the CUDA architecture to use when building CUDA kernels. Parameters ---------- arch: Optional[str] - The CUDA architecture, e.g., 'sm_35', 'sm_70', 'sm_80', etc. None means using the architecture of the first - CUDA GPU on the current machine. Default None. + The CUDA architecture, e.g., 'sm_35', 'sm_70', 'sm_80', etc. "auto" means + using the architecture of the first CUDA GPU on the current machine. Default "auto". """ OptionContext.current().set_option('cuda.arch', arch) @@ -685,7 +770,7 @@ def get_arch() -> str: The CUDA architecture, e.g., 'sm_35', 'sm_70', 'sm_80', etc. """ arch: Optional[str] = OptionContext.current().get_option('cuda.arch') - if arch is None: + if arch == "auto": import hidet.cuda # get the architecture of the first CUDA GPU diff --git a/requirements.txt b/requirements.txt index 17a8871ac..bb21b1f04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,5 +37,8 @@ filelock requests +# for configuration +tomlkit + # for parser -lark \ No newline at end of file +lark diff --git a/setup.py b/setup.py index 4252457b6..dda947129 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,8 @@ "requests", "filelock", "cuda-python>=11.6.1; platform_system=='Linux'", - "lark" + "lark", + "tomlkit" ], platforms=["linux"], entry_points={