diff --git a/agrf/parameters/__init__.py b/agrf/parameters/__init__.py new file mode 100644 index 00000000..0fa8958c --- /dev/null +++ b/agrf/parameters/__init__.py @@ -0,0 +1,58 @@ +import copy + + +class Parameter: + def __init__(self, name, default, enum): + self.name = name + self.default = default + self.enum = enum + + def add(self, g, s): + g.add_int_parameter( + name=s[f"STR_PARAM_{self.name}"], + description=s[f"STR_PARAM_{self.name}_DESC"], + default=self.default, + limits=(min(self.enum.keys()), max(self.enum.keys())), + enum={k: s[f"STR_PARAM_{self.name}_{v}"] for k, v in self.enum.items()}, + ) + + +class ParameterList: + def __init__(self, parameters): + self.parameters = parameters + + def add(self, g, s): + for p in self.parameters: + p.add(g, s) + + def index(self, name): + return [i for i, p in enumerate(self.parameters) if p.name == name][0] + + +class SearchSpace: + def __init__(self, choices, parameter_list): + self.choices = choices + self.parameter_list = parameter_list + + def copy(self): + return SearchSpace(copy.deepcopy(self.choices), self.parameter_list) + + def fix_docs_params(self, cat, options): + [(idx, all_options)] = [ + (i, the_options) for i, (the_cat, the_options) in enumerate(self.choices) if the_cat == cat + ] + assert all(o in all_options for o in options) + self.choices[idx] = (cat, options) + + def iterate_variations(self, i=0, params={}): + if i == len(self.choices): + yield params + else: + for j in self.choices[i][1]: + new_params = params.copy() + new_params[self.choices[i][0]] = j + for variation in self.iterate_variations(i + 1, new_params): + yield variation + + def desc(self, params): + return "".join(str(options.index(params[i])) for i, options in self.choices) diff --git a/industry/lib/parameters.py b/industry/lib/parameters.py index 257fa2ef..416cb362 100644 --- a/industry/lib/parameters.py +++ b/industry/lib/parameters.py @@ -1,32 +1,4 @@ -import copy - - -class Parameter: - def __init__(self, name, default, enum): - self.name = name - self.default = default - self.enum = enum - - def add(self, g, s): - g.add_int_parameter( - name=s[f"STR_PARAM_{self.name}"], - description=s[f"STR_PARAM_{self.name}_DESC"], - default=self.default, - limits=(min(self.enum.keys()), max(self.enum.keys())), - enum={k: s[f"STR_PARAM_{self.name}_{v}"] for k, v in self.enum.items()}, - ) - - -class ParameterList: - def __init__(self, parameters): - self.parameters = parameters - - def add(self, g, s): - for p in self.parameters: - p.add(g, s) - - def index(self, name): - return [i for i, p in enumerate(self.parameters) if p.name == name][0] +from agrf.parameters import Parameter, ParameterList, SearchSpace parameter_list = ParameterList( @@ -177,35 +149,6 @@ def index(self, name): ) -class SearchSpace: - def __init__(self, choices, parameter_list): - self.choices = choices - self.parameter_list = parameter_list - - def copy(self): - return SearchSpace(copy.deepcopy(self.choices), parameter_list) - - def fix_docs_params(self, cat, options): - [(idx, all_options)] = [ - (i, the_options) for i, (the_cat, the_options) in enumerate(self.choices) if the_cat == cat - ] - assert all(o in all_options for o in options) - self.choices[idx] = (cat, options) - - def iterate_variations(self, i=0, params={}): - if i == len(self.choices): - yield params - else: - for j in self.choices[i][1]: - new_params = params.copy() - new_params[self.choices[i][0]] = j - for variation in self.iterate_variations(i + 1, new_params): - yield variation - - def desc(self, params): - return "".join(str(options.index(params[i])) for i, options in self.choices) - - parameter_choices = SearchSpace( [ ("POLICY", ["AUTARKY", "SELF_SUFFICIENT", "FREE_TRADE", "EXPORT"]),