Skip to content

Commit d60c13c

Browse files
AbishekSfacebook-github-bot
authored andcommitted
Allow aliases for run_opt (meta-pytorch#1141)
Summary: Lets allow aliases for a runopt. This will give downstream users to have multiple ways of accessing the same runopt. * Introduce new class for `RunOptAlias` which is used to expand on adding aliases to a runopt. * Add a new dict to maintain alias to key values that can be used by `opt.get(name)` * Modify add() to accept list as well, build out the aliases list and modify the previously created dict to fill in alias to primary_key values. * Modify resolve() to check if a different alias is already used in cfg i.e if the "jobPriority" and "job_priority" are aliases for the same one, we don't allow for both to be present in the cfg. * Modify get to look at the alias to primary_key dict as well. Differential Revision: D84157870
1 parent b376e8c commit d60c13c

File tree

2 files changed

+111
-6
lines changed

2 files changed

+111
-6
lines changed

torchx/specs/api.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -892,10 +892,14 @@ class runopt:
892892
Represents the metadata about the specific run option
893893
"""
894894

895+
class alias(str):
896+
pass
897+
895898
default: CfgVal
896899
opt_type: Type[CfgVal]
897900
is_required: bool
898901
help: str
902+
aliases: list[alias] | None = None
899903

900904
@property
901905
def is_type_list_of_str(self) -> bool:
@@ -987,6 +991,7 @@ class runopts:
987991

988992
def __init__(self) -> None:
989993
self._opts: Dict[str, runopt] = {}
994+
self._alias_to_key: dict[runopt.alias, str] = {}
990995

991996
def __iter__(self) -> Iterator[Tuple[str, runopt]]:
992997
return self._opts.items().__iter__()
@@ -1014,9 +1019,16 @@ def is_type(obj: CfgVal, tp: Type[CfgVal]) -> bool:
10141019

10151020
def get(self, name: str) -> Optional[runopt]:
10161021
"""
1017-
Returns option if any was registered, or None otherwise
1022+
Returns option if any was registered, or None otherwise.
1023+
First searches for the option by ``name``, then falls-back to matching ``name`` with any
1024+
registered aliases.
1025+
10181026
"""
1019-
return self._opts.get(name, None)
1027+
if name in self._opts:
1028+
return self._opts[name]
1029+
if name in self._alias_to_key:
1030+
return self._opts[self._alias_to_key[name]]
1031+
return None
10201032

10211033
def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10221034
"""
@@ -1031,6 +1043,24 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10311043

10321044
for cfg_key, runopt in self._opts.items():
10331045
val = resolved_cfg.get(cfg_key)
1046+
resolved_name = None
1047+
aliases = runopt.aliases or []
1048+
if val is None:
1049+
for alias in aliases:
1050+
val = resolved_cfg.get(alias)
1051+
if alias in cfg or val is not None:
1052+
resolved_name = alias
1053+
break
1054+
else:
1055+
resolved_name = cfg_key
1056+
for alias in aliases:
1057+
duplicate_val = resolved_cfg.get(alias)
1058+
if alias in cfg or duplicate_val is not None:
1059+
raise InvalidRunConfigException(
1060+
f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`",
1061+
resolved_name,
1062+
cfg,
1063+
)
10341064

10351065
# check required opt
10361066
if runopt.is_required and val is None:
@@ -1050,7 +1080,7 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10501080
)
10511081

10521082
# not required and not set, set to default
1053-
if val is None:
1083+
if val is None and resolved_name is None:
10541084
resolved_cfg[cfg_key] = runopt.default
10551085
return resolved_cfg
10561086

@@ -1143,9 +1173,38 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
11431173
cfg[key] = val
11441174
return cfg
11451175

1176+
def _get_primary_key_and_aliases(
1177+
self,
1178+
cfg_key: list[str] | str,
1179+
) -> tuple[str, list[runopt.alias]]:
1180+
"""
1181+
Returns the primary key and aliases for the given cfg_key.
1182+
"""
1183+
if isinstance(cfg_key, str):
1184+
return cfg_key, []
1185+
1186+
if len(cfg_key) == 0:
1187+
raise ValueError("cfg_key must be a non-empty list")
1188+
primary_key = None
1189+
aliases = list[runopt.alias]()
1190+
for name in cfg_key:
1191+
if isinstance(name, runopt.alias):
1192+
aliases.append(name)
1193+
else:
1194+
if primary_key is not None:
1195+
raise ValueError(
1196+
f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. "
1197+
)
1198+
primary_key = name
1199+
if primary_key is None or primary_key == "":
1200+
raise ValueError(
1201+
"Missing cfg_key. Please provide one other than the aliases."
1202+
)
1203+
return primary_key, aliases
1204+
11461205
def add(
11471206
self,
1148-
cfg_key: str,
1207+
cfg_key: str | list[str],
11491208
type_: Type[CfgVal],
11501209
help: str,
11511210
default: CfgVal = None,
@@ -1156,6 +1215,7 @@ def add(
11561215
value (if any). If the ``default`` is not specified then this option
11571216
is a required option.
11581217
"""
1218+
primary_key, aliases = self._get_primary_key_and_aliases(cfg_key)
11591219
if required and default is not None:
11601220
raise ValueError(
11611221
f"Required option: {cfg_key} must not specify default value. Given: {default}"
@@ -1166,8 +1226,10 @@ def add(
11661226
f"Option: {cfg_key}, must be of type: {type_}."
11671227
f" Given: {default} ({type(default).__name__})"
11681228
)
1169-
1170-
self._opts[cfg_key] = runopt(default, type_, required, help)
1229+
opt = runopt(default, type_, required, help, aliases)
1230+
for alias in aliases:
1231+
self._alias_to_key[alias] = primary_key
1232+
self._opts[primary_key] = opt
11711233

11721234
def update(self, other: "runopts") -> None:
11731235
self._opts.update(other._opts)

torchx/specs/test/api_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,49 @@ def test_runopts_add(self) -> None:
578578
# this print is intentional (demonstrates the intended usecase)
579579
print(opts)
580580

581+
def test_runopts_add_with_aliases(self) -> None:
582+
opts = runopts()
583+
opts.add(
584+
["job_priority", runopt.alias("jobPriority")],
585+
type_=str,
586+
help="priority for the job",
587+
)
588+
self.assertEqual(1, len(opts._opts))
589+
self.assertIsNotNone(opts.get("job_priority"))
590+
self.assertIsNotNone(opts.get("jobPriority"))
591+
592+
def test_runopts_resolve_with_aliases(self) -> None:
593+
opts = runopts()
594+
opts.add(
595+
["job_priority", runopt.alias("jobPriority")],
596+
type_=str,
597+
help="priority for the job",
598+
)
599+
opts.resolve({"job_priority": "high"})
600+
opts.resolve({"jobPriority": "low"})
601+
with self.assertRaises(InvalidRunConfigException):
602+
opts.resolve({"job_priority": "high", "jobPriority": "low"})
603+
604+
def test_runopts_resolve_with_none_valued_aliases(self) -> None:
605+
opts = runopts()
606+
opts.add(
607+
["job_priority", runopt.alias("jobPriority")],
608+
type_=str,
609+
help="priority for the job",
610+
)
611+
opts.add(
612+
["modelTypeName", runopt.alias("model_type_name")],
613+
type_=Union[str, None],
614+
help="ML Hub Model Type to attribute resource utilization for job",
615+
)
616+
resolved_opts = opts.resolve({"model_type_name": None, "jobPriority": "low"})
617+
self.assertEqual(resolved_opts.get("model_type_name"), None)
618+
self.assertEqual(resolved_opts.get("jobPriority"), "low")
619+
self.assertEqual(resolved_opts, {"model_type_name": None, "jobPriority": "low"})
620+
621+
with self.assertRaises(InvalidRunConfigException):
622+
opts.resolve({"model_type_name": None, "modelTypeName": "low"})
623+
581624
def get_runopts(self) -> runopts:
582625
opts = runopts()
583626
opts.add("run_as", type_=str, help="run as user", required=True)

0 commit comments

Comments
 (0)