Skip to content

Commit d4bcff2

Browse files
AbishekSfacebook-github-bot
authored andcommitted
Allow aliases for run_opt (meta-pytorch#1141)
Summary: Lets allow aliases for a runopt. This will give downstream users to have multiple ways of accessing the same runopt. * Introduce new class for `RunOptAlias` which is used to expand on adding aliases to a runopt. * Add a new dict to maintain alias to key values that can be used by `opt.get(name)` * Modify add() to accept list as well, build out the aliases list and modify the previously created dict to fill in alias to primary_key values. * Modify resolve() to check if a different alias is already used in cfg i.e if the "jobPriority" and "job_priority" are aliases for the same one, we don't allow for both to be present in the cfg. * Modify get to look at the alias to primary_key dict as well. Differential Revision: D84157870
1 parent b376e8c commit d4bcff2

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

torchx/specs/api.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,10 @@ def get_type_name(tp: Type[CfgVal]) -> str:
886886
return str(tp)
887887

888888

889+
class RunOptAlias(str):
890+
pass
891+
892+
889893
@dataclass
890894
class runopt:
891895
"""
@@ -896,6 +900,7 @@ class runopt:
896900
opt_type: Type[CfgVal]
897901
is_required: bool
898902
help: str
903+
aliases: list[RunOptAlias] | None = None
899904

900905
@property
901906
def is_type_list_of_str(self) -> bool:
@@ -987,6 +992,7 @@ class runopts:
987992

988993
def __init__(self) -> None:
989994
self._opts: Dict[str, runopt] = {}
995+
self._alias_to_key: Dict[RunOptAlias, str] = {}
990996

991997
def __iter__(self) -> Iterator[Tuple[str, runopt]]:
992998
return self._opts.items().__iter__()
@@ -1016,7 +1022,11 @@ def get(self, name: str) -> Optional[runopt]:
10161022
"""
10171023
Returns option if any was registered, or None otherwise
10181024
"""
1019-
return self._opts.get(name, None)
1025+
if name in self._opts:
1026+
return self._opts[name]
1027+
if name in self._alias_to_key:
1028+
return self._opts[self._alias_to_key[name]]
1029+
return None
10201030

10211031
def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10221032
"""
@@ -1031,6 +1041,24 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10311041

10321042
for cfg_key, runopt in self._opts.items():
10331043
val = resolved_cfg.get(cfg_key)
1044+
resolved_name = None
1045+
aliases = runopt.aliases or []
1046+
if val is None:
1047+
for alias in aliases:
1048+
val = resolved_cfg.get(alias)
1049+
if val is not None:
1050+
resolved_name = alias
1051+
break
1052+
else:
1053+
resolved_name = cfg_key
1054+
for alias in aliases:
1055+
duplicate_val = resolved_cfg.get(alias)
1056+
if duplicate_val is not None:
1057+
raise InvalidRunConfigException(
1058+
f"Run option: {resolved_name}, is an alias of another run option already used in the cfg.",
1059+
resolved_name,
1060+
cfg,
1061+
)
10341062

10351063
# check required opt
10361064
if runopt.is_required and val is None:
@@ -1143,9 +1171,38 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
11431171
cfg[key] = val
11441172
return cfg
11451173

1174+
def _get_primary_key_and_aliases(
1175+
self,
1176+
cfg_key: list[str] | str,
1177+
) -> tuple[str, list[RunOptAlias]]:
1178+
"""
1179+
Returns the primary key and aliases for the given cfg_key.
1180+
"""
1181+
if isinstance(cfg_key, str):
1182+
return cfg_key, []
1183+
1184+
if len(cfg_key) == 0:
1185+
raise ValueError("cfg_key must be a non-empty list")
1186+
primary_key = None
1187+
aliases = list[RunOptAlias]()
1188+
for name in cfg_key:
1189+
if isinstance(name, RunOptAlias):
1190+
aliases.append(name)
1191+
else:
1192+
if primary_key is not None:
1193+
raise ValueError(
1194+
f"cfg_key must contain a single primary key. Given more than one primary keys: {primary_key}, {name}. If one of them is an alias please use RunOptAlias type instead of str. "
1195+
)
1196+
primary_key = name
1197+
if primary_key is None or primary_key == "":
1198+
raise ValueError(
1199+
"Missing cfg_key. Please provide one other than the aliases."
1200+
)
1201+
return primary_key, aliases
1202+
11461203
def add(
11471204
self,
1148-
cfg_key: str,
1205+
cfg_key: str | list[str],
11491206
type_: Type[CfgVal],
11501207
help: str,
11511208
default: CfgVal = None,
@@ -1156,6 +1213,7 @@ def add(
11561213
value (if any). If the ``default`` is not specified then this option
11571214
is a required option.
11581215
"""
1216+
primary_key, aliases = self._get_primary_key_and_aliases(cfg_key)
11591217
if required and default is not None:
11601218
raise ValueError(
11611219
f"Required option: {cfg_key} must not specify default value. Given: {default}"
@@ -1166,8 +1224,10 @@ def add(
11661224
f"Option: {cfg_key}, must be of type: {type_}."
11671225
f" Given: {default} ({type(default).__name__})"
11681226
)
1169-
1170-
self._opts[cfg_key] = runopt(default, type_, required, help)
1227+
opt = runopt(default, type_, required, help, aliases)
1228+
for alias in aliases:
1229+
self._alias_to_key[alias] = primary_key
1230+
self._opts[primary_key] = opt
11711231

11721232
def update(self, other: "runopts") -> None:
11731233
self._opts.update(other._opts)

torchx/specs/test/api_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Role,
4242
RoleStatus,
4343
runopt,
44+
RunOptAlias,
4445
runopts,
4546
TORCHX_HOME,
4647
Workspace,
@@ -578,6 +579,29 @@ def test_runopts_add(self) -> None:
578579
# this print is intentional (demonstrates the intended usecase)
579580
print(opts)
580581

582+
def test_runopts_add_with_aliases(self) -> None:
583+
opts = runopts()
584+
opts.add(
585+
["job_priority", RunOptAlias("jobPriority")],
586+
type_=str,
587+
help="priority for the job",
588+
)
589+
self.assertEqual(1, len(opts._opts))
590+
self.assertIsNotNone(opts.get("job_priority"))
591+
self.assertIsNotNone(opts.get("jobPriority"))
592+
593+
def test_runopts_resolve_with_aliases(self) -> None:
594+
opts = runopts()
595+
opts.add(
596+
["job_priority", RunOptAlias("jobPriority")],
597+
type_=str,
598+
help="priority for the job",
599+
)
600+
opts.resolve({"job_priority": "high"})
601+
opts.resolve({"jobPriority": "low"})
602+
with self.assertRaises(InvalidRunConfigException):
603+
opts.resolve({"job_priority": "high", "jobPriority": "low"})
604+
581605
def get_runopts(self) -> runopts:
582606
opts = runopts()
583607
opts.add("run_as", type_=str, help="run as user", required=True)

0 commit comments

Comments
 (0)