Skip to content

Commit 9238a04

Browse files
authored
Merge branch 'development' into reg_cocktails
2 parents 33b2223 + 40a3987 commit 9238a04

File tree

1 file changed

+76
-73
lines changed

1 file changed

+76
-73
lines changed

examples/40_advanced/example_custom_configuration_space.py

Lines changed: 76 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
The following example shows how adjust the configuration space of
77
the search. Currently, there are two changes that can be made to the space:-
8+
89
1. Adjust individual hyperparameters in the pipeline
910
2. Include or exclude components:
1011
a) include: Dictionary containing components to include. Key is the node
@@ -66,76 +67,78 @@ def get_search_space_updates():
6667
return updates
6768

6869

69-
if __name__ == '__main__':
70-
71-
############################################################################
72-
# Data Loading
73-
# ============
74-
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
75-
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
76-
X,
77-
y,
78-
random_state=1,
79-
)
80-
81-
############################################################################
82-
# Build and fit a classifier with include components
83-
# ==================================================
84-
api = TabularClassificationTask(
85-
search_space_updates=get_search_space_updates(),
86-
include_components={'network_backbone': ['ResNetBackbone'],
87-
'encoder': ['OneHotEncoder']}
88-
)
89-
90-
############################################################################
91-
# Search for an ensemble of machine learning algorithms
92-
# =====================================================
93-
api.search(
94-
X_train=X_train.copy(),
95-
y_train=y_train.copy(),
96-
X_test=X_test.copy(),
97-
y_test=y_test.copy(),
98-
optimize_metric='accuracy',
99-
total_walltime_limit=300,
100-
func_eval_time_limit_secs=50
101-
)
102-
103-
############################################################################
104-
# Print the final ensemble performance
105-
# ====================================
106-
print(api.run_history, api.trajectory)
107-
y_pred = api.predict(X_test)
108-
score = api.score(y_pred, y_test)
109-
print(score)
110-
print(api.show_models())
111-
112-
############################################################################
113-
# Build and fit a classifier with exclude components
114-
# ==================================================
115-
api = TabularClassificationTask(
116-
search_space_updates=get_search_space_updates(),
117-
exclude_components={'network_backbone': ['MLPBackbone'],
118-
'encoder': ['OneHotEncoder']}
119-
)
120-
121-
############################################################################
122-
# Search for an ensemble of machine learning algorithms
123-
# =====================================================
124-
api.search(
125-
X_train=X_train,
126-
y_train=y_train,
127-
X_test=X_test.copy(),
128-
y_test=y_test.copy(),
129-
optimize_metric='accuracy',
130-
total_walltime_limit=300,
131-
func_eval_time_limit_secs=50
132-
)
133-
134-
############################################################################
135-
# Print the final ensemble performance
136-
# ====================================
137-
print(api.run_history, api.trajectory)
138-
y_pred = api.predict(X_test)
139-
score = api.score(y_pred, y_test)
140-
print(score)
141-
print(api.show_models())
70+
############################################################################
71+
# Data Loading
72+
# ============
73+
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
74+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
75+
X,
76+
y,
77+
random_state=1,
78+
)
79+
80+
############################################################################
81+
# Build and fit a classifier with include components
82+
# ==================================================
83+
api = TabularClassificationTask(
84+
search_space_updates=get_search_space_updates(),
85+
include_components={'network_backbone': ['MLPBackbone', 'ResNetBackbone'],
86+
'encoder': ['OneHotEncoder']}
87+
)
88+
89+
############################################################################
90+
# Search for an ensemble of machine learning algorithms
91+
# =====================================================
92+
api.search(
93+
X_train=X_train.copy(),
94+
y_train=y_train.copy(),
95+
X_test=X_test.copy(),
96+
y_test=y_test.copy(),
97+
optimize_metric='accuracy',
98+
total_walltime_limit=150,
99+
func_eval_time_limit_secs=30
100+
)
101+
102+
############################################################################
103+
# Print the final ensemble performance
104+
# ====================================
105+
y_pred = api.predict(X_test)
106+
score = api.score(y_pred, y_test)
107+
print(score)
108+
print(api.show_models())
109+
110+
# Print statistics from search
111+
print(api.sprint_statistics())
112+
113+
############################################################################
114+
# Build and fit a classifier with exclude components
115+
# ==================================================
116+
api = TabularClassificationTask(
117+
search_space_updates=get_search_space_updates(),
118+
exclude_components={'network_backbone': ['MLPBackbone'],
119+
'encoder': ['OneHotEncoder']}
120+
)
121+
122+
############################################################################
123+
# Search for an ensemble of machine learning algorithms
124+
# =====================================================
125+
api.search(
126+
X_train=X_train,
127+
y_train=y_train,
128+
X_test=X_test.copy(),
129+
y_test=y_test.copy(),
130+
optimize_metric='accuracy',
131+
total_walltime_limit=150,
132+
func_eval_time_limit_secs=30
133+
)
134+
135+
############################################################################
136+
# Print the final ensemble performance
137+
# ====================================
138+
y_pred = api.predict(X_test)
139+
score = api.score(y_pred, y_test)
140+
print(score)
141+
print(api.show_models())
142+
143+
# Print statistics from search
144+
print(api.sprint_statistics())

0 commit comments

Comments
 (0)