Skip to content

Commit 386bd17

Browse files
authored
feat: Using args as arguments (#230)
1 parent b35afcd commit 386bd17

File tree

5 files changed

+596
-52
lines changed

5 files changed

+596
-52
lines changed

examples/common/functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import logging
23
import os
34
from pathlib import Path
@@ -32,6 +33,15 @@ class ComplexParams(BaseModel):
3233
foo: str
3334

3435

36+
def function_using_argparse(args: argparse.Namespace):
37+
assert args.integer == 1
38+
assert args.floater == 3.14
39+
assert args.stringer == "hello"
40+
assert args.envvar == "from env"
41+
assert args.pydantic_param["x"] == 10
42+
assert args.pydantic_param["foo"] == "bar"
43+
44+
3545
def read_initial_params_as_pydantic(
3646
integer: int,
3747
floater: float,

examples/torch/single_cpu_args.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# single_cpu_train_with_args.py
2+
3+
import argparse # New: for command-line arguments
4+
import time
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.optim as optim
9+
from torch.utils.data import DataLoader, TensorDataset
10+
11+
12+
def run_single_cpu_training(args: argparse.Namespace):
13+
"""
14+
Runs a simple training loop on a single CPU core.
15+
Accepts parsed arguments for hyperparameters.
16+
"""
17+
print(
18+
f"Parameters: learning_rate={args.learning_rate}, num_epochs={args.num_epochs}, batch_size={args.batch_size}"
19+
)
20+
print("--- Starting Single-CPU Training ---")
21+
print(f"Learning Rate: {args.learning_rate}, Epochs: {args.num_epochs}")
22+
print(f"Batch Size: {args. batch_size}")
23+
24+
# 1. Define a simple model
25+
class SimpleModel(nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.linear = nn.Linear(10, 1) # Input 10 features, output 1
29+
30+
def forward(self, x):
31+
return self.linear(x)
32+
33+
model = SimpleModel()
34+
device = torch.device("cpu") # Explicitly set device to CPU
35+
model.to(device)
36+
37+
# 2. Create a dummy dataset
38+
num_samples = 1000 # Larger dataset to see the difference in speed later
39+
num_features = 10
40+
X = torch.randn(num_samples, num_features)
41+
y = (
42+
torch.sum(X * torch.arange(1, num_features + 1).float(), dim=1, keepdim=True)
43+
+ torch.randn(num_samples, 1) * 0.1
44+
)
45+
46+
dataset = TensorDataset(X, y)
47+
dataloader = DataLoader(
48+
dataset, batch_size=args.batch_size, shuffle=True
49+
) # Use batch_size parameter
50+
51+
# 3. Define optimizer and loss function
52+
optimizer = optim.SGD(
53+
model.parameters(), lr=args.learning_rate
54+
) # Use learning_rate parameter
55+
criterion = nn.MSELoss()
56+
57+
start_time = time.time()
58+
59+
# 4. Training loop
60+
for epoch in range(args.num_epochs): # Use num_epochs parameter
61+
model.train()
62+
total_loss = 0
63+
for batch_idx, (inputs, targets) in enumerate(dataloader):
64+
inputs, targets = inputs.to(device), targets.to(device)
65+
66+
optimizer.zero_grad()
67+
outputs = model(inputs)
68+
loss = criterion(outputs, targets)
69+
loss.backward()
70+
optimizer.step()
71+
total_loss += loss.item()
72+
73+
avg_loss = total_loss / len(dataloader)
74+
print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {avg_loss:.4f}")
75+
76+
end_time = time.time()
77+
print(f"\nSingle-CPU Training complete in {end_time - start_time:.2f} seconds!")
78+
79+
# Save the model
80+
model_save_path = (
81+
f"single_cpu_model_lr{args.learning_rate}_epochs{args.num_epochs}.pth"
82+
)
83+
torch.save(model.state_dict(), model_save_path)
84+
print(f"Model saved to {model_save_path}")
85+
86+
87+
if __name__ == "__main__":
88+
parser = argparse.ArgumentParser(description="Single-CPU PyTorch Training Example.")
89+
parser.add_argument(
90+
"--learning_rate",
91+
type=float,
92+
default=0.01,
93+
help="Learning rate for the optimizer (default: 0.01)",
94+
)
95+
parser.add_argument(
96+
"--num_epochs",
97+
type=int,
98+
default=50,
99+
help="Number of training epochs (default: 50)",
100+
)
101+
parser.add_argument(
102+
"--batch_size",
103+
type=int,
104+
default=32,
105+
help="Batch size for training (default: 32)",
106+
)
107+
108+
args = parser.parse_args()
109+
run_single_cpu_training(args)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ file-system = "extensions.run_log_store.file_system:FileSystemRunLogstore"
151151

152152
# Release configuration
153153
[tool.semantic_release]
154-
commit_parser = "angular"
154+
commit_parser = "conventional"
155155
major_on_zero = true
156156
allow_zero_version = true
157157
tag_format = "{version}"

runnable/parameters.py

Lines changed: 119 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import argparse
12
import inspect
23
import json
34
import logging
45
import os
56
from typing import Any, Dict, Type
67

7-
import pydantic
88
from pydantic import BaseModel, ConfigDict
99
from typing_extensions import Callable
1010

@@ -48,15 +48,40 @@ def get_user_set_parameters(remove: bool = False) -> Dict[str, JsonParameter]:
4848
return parameters
4949

5050

51+
def return_json_parameters(params: Dict[str, Any]) -> Dict[str, Any]:
52+
"""
53+
Returns the parameters as a JSON serializable dictionary.
54+
55+
Args:
56+
params (dict): The parameters to serialize.
57+
58+
Returns:
59+
dict: The JSON serializable dictionary.
60+
"""
61+
return_params = {}
62+
for key, value in params.items():
63+
if isinstance(value, ObjectParameter):
64+
continue
65+
66+
return_params[key] = value.get_value()
67+
return return_params
68+
69+
5170
def filter_arguments_for_func(
5271
func: Callable[..., Any],
5372
params: Dict[str, Any],
5473
map_variable: MapVariableType = None,
5574
) -> Dict[str, Any]:
5675
"""
5776
Inspects the function to be called as part of the pipeline to find the arguments of the function.
58-
Matches the function arguments to the parameters available either by command line or by up stream steps.
77+
Matches the function arguments to the parameters available either by static parameters or by up stream steps.
5978
79+
The function "func" signature could be:
80+
- def my_function(arg1: int, arg2: str, arg3: float):
81+
- def my_function(arg1: int, arg2: str, arg3: float, **kwargs):
82+
in this case, we would need to send in remaining keyword arguments as a dictionary.
83+
- def my_function(arg1: int, arg2: str, arg3: float, args: argparse.Namespace):
84+
In this case, we need to send the rest of the parameters as attributes of the args object.
6085
6186
Args:
6287
func (Callable): The function to inspect
@@ -72,63 +97,109 @@ def filter_arguments_for_func(
7297
params[key] = JsonParameter(kind="json", value=v)
7398

7499
bound_args = {}
75-
unassigned_params = set(params.keys())
76-
# Check if VAR_KEYWORD is used, it is we send back everything
77-
# If **kwargs is present in the function signature, we send back everything
78-
for name, value in function_args.items():
79-
if value.kind != inspect.Parameter.VAR_KEYWORD:
80-
continue
81-
# Found VAR_KEYWORD, we send back everything as found
82-
for key, value in params.items():
83-
bound_args[key] = params[key].get_value()
100+
missing_required_args: list[str] = []
101+
var_keyword_param = None
102+
namespace_param = None
84103

85-
return bound_args
86-
87-
# Lets return what is asked for then!!
104+
# First pass: Handle regular parameters and identify special parameters
88105
for name, value in function_args.items():
89106
# Ignore any *args
90107
if value.kind == inspect.Parameter.VAR_POSITIONAL:
91108
logger.warning(f"Ignoring parameter {name} as it is VAR_POSITIONAL")
92109
continue
93110

94-
if name not in params:
95-
# No parameter of this name was provided
96-
if value.default == inspect.Parameter.empty:
97-
# No default value is given in the function signature. error as parameter is required.
98-
raise ValueError(
99-
f"Parameter {name} is required for {func.__name__} but not provided"
100-
)
101-
# default value is given in the function signature, nothing further to do.
111+
# Check for **kwargs parameter
112+
if value.kind == inspect.Parameter.VAR_KEYWORD:
113+
var_keyword_param = name
102114
continue
103115

104-
param_value = params[name]
105-
106-
if type(value.annotation) in [
107-
BaseModel,
108-
pydantic._internal._model_construction.ModelMetaclass,
109-
] and not isinstance(param_value, ObjectParameter):
110-
# Even if the annotation is a pydantic model, it can be passed as an object parameter
111-
# We try to cast it as a pydantic model if asked
112-
named_param = params[name].get_value()
113-
114-
if not isinstance(named_param, dict):
115-
# A case where the parameter is a one attribute model
116-
named_param = {name: named_param}
117-
118-
bound_model = bind_args_for_pydantic_model(named_param, value.annotation)
119-
bound_args[name] = bound_model
116+
# Check for argparse.Namespace parameter
117+
if value.annotation == argparse.Namespace:
118+
namespace_param = name
119+
continue
120120

121-
elif value.annotation in [str, int, float, bool]:
122-
# Cast it if its a primitive type. Ensure the type matches the annotation.
123-
bound_args[name] = value.annotation(params[name].get_value())
121+
# Handle regular parameters
122+
if name not in params:
123+
if value.default != inspect.Parameter.empty:
124+
# Default value is given in the function signature, we can use it
125+
bound_args[name] = value.default
126+
else:
127+
# This is a required parameter that's missing
128+
missing_required_args.append(name)
124129
else:
125-
bound_args[name] = params[name].get_value()
126-
127-
unassigned_params.remove(name)
128-
129-
params = {
130-
key: params[key] for key in unassigned_params
131-
} # remove keys from params if they are assigned
130+
# We have a parameter of this name, lets bind it
131+
param_value = params[name]
132+
133+
if (
134+
inspect.isclass(value.annotation)
135+
and issubclass(value.annotation, BaseModel)
136+
) and not isinstance(param_value, ObjectParameter):
137+
# Even if the annotation is a pydantic model, it can be passed as an object parameter
138+
# We try to cast it as a pydantic model if asked
139+
named_param = params[name].get_value()
140+
141+
if not isinstance(named_param, dict):
142+
# A case where the parameter is a one attribute model
143+
named_param = {name: named_param}
144+
145+
bound_model = bind_args_for_pydantic_model(
146+
named_param, value.annotation
147+
)
148+
bound_args[name] = bound_model
149+
150+
elif value.annotation in [str, int, float, bool] and callable(
151+
value.annotation
152+
):
153+
# Cast it if its a primitive type. Ensure the type matches the annotation.
154+
try:
155+
bound_args[name] = value.annotation(params[name].get_value())
156+
except (ValueError, TypeError) as e:
157+
raise ValueError(
158+
f"Cannot cast parameter '{name}' to {value.annotation.__name__}: {e}"
159+
)
160+
else:
161+
# We do not know type of parameter, we send the value as found
162+
bound_args[name] = params[name].get_value()
163+
164+
# Find extra parameters (parameters in params but not consumed by regular function parameters)
165+
consumed_param_names = set(bound_args.keys()) | set(missing_required_args)
166+
extra_params = {k: v for k, v in params.items() if k not in consumed_param_names}
167+
168+
# Second pass: Handle **kwargs and argparse.Namespace parameters
169+
if var_keyword_param is not None:
170+
# Function accepts **kwargs - add all extra parameters directly to bound_args
171+
for param_name, param_value in extra_params.items():
172+
bound_args[param_name] = param_value.get_value()
173+
elif namespace_param is not None:
174+
# Function accepts argparse.Namespace - create namespace with extra parameters
175+
args_namespace = argparse.Namespace()
176+
for param_name, param_value in extra_params.items():
177+
setattr(args_namespace, param_name, param_value.get_value())
178+
bound_args[namespace_param] = args_namespace
179+
elif extra_params:
180+
# Function doesn't accept **kwargs or namespace, but we have extra parameters
181+
# This should only be an error if we also have missing required parameters
182+
# or if the function truly can't handle the extra parameters
183+
if missing_required_args:
184+
# We have both missing required and extra parameters - this is an error
185+
raise ValueError(
186+
f"Function {func.__name__} has parameters {missing_required_args} that are not present in the parameters"
187+
)
188+
# If we only have extra parameters and no missing required ones, we just ignore the extras
189+
# This allows for more flexible parameter passing
190+
191+
# Check for missing required parameters
192+
if missing_required_args:
193+
if var_keyword_param is None and namespace_param is None:
194+
# No way to handle missing parameters
195+
raise ValueError(
196+
f"Function {func.__name__} has parameters {missing_required_args} that are not present in the parameters"
197+
)
198+
# If we have **kwargs or namespace, missing parameters might be handled there
199+
# But if they're truly required (no default), we should still error
200+
raise ValueError(
201+
f"Function {func.__name__} has parameters {missing_required_args} that are not present in the parameters"
202+
)
132203

133204
return bound_args
134205

0 commit comments

Comments
 (0)