Skip to content

Commit

Permalink
(Aegis) move parameter-list to agrf
Browse files Browse the repository at this point in the history
  • Loading branch information
ahyangyi committed Oct 27, 2023
1 parent 469aedb commit db1a09a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 58 deletions.
58 changes: 58 additions & 0 deletions agrf/parameters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import copy


class Parameter:
def __init__(self, name, default, enum):
self.name = name
self.default = default
self.enum = enum

def add(self, g, s):
g.add_int_parameter(
name=s[f"STR_PARAM_{self.name}"],
description=s[f"STR_PARAM_{self.name}_DESC"],
default=self.default,
limits=(min(self.enum.keys()), max(self.enum.keys())),
enum={k: s[f"STR_PARAM_{self.name}_{v}"] for k, v in self.enum.items()},
)


class ParameterList:
def __init__(self, parameters):
self.parameters = parameters

def add(self, g, s):
for p in self.parameters:
p.add(g, s)

def index(self, name):
return [i for i, p in enumerate(self.parameters) if p.name == name][0]


class SearchSpace:
def __init__(self, choices, parameter_list):
self.choices = choices
self.parameter_list = parameter_list

def copy(self):
return SearchSpace(copy.deepcopy(self.choices), self.parameter_list)

def fix_docs_params(self, cat, options):
[(idx, all_options)] = [
(i, the_options) for i, (the_cat, the_options) in enumerate(self.choices) if the_cat == cat
]
assert all(o in all_options for o in options)
self.choices[idx] = (cat, options)

def iterate_variations(self, i=0, params={}):
if i == len(self.choices):
yield params
else:
for j in self.choices[i][1]:
new_params = params.copy()
new_params[self.choices[i][0]] = j
for variation in self.iterate_variations(i + 1, new_params):
yield variation

def desc(self, params):
return "".join(str(options.index(params[i])) for i, options in self.choices)
59 changes: 1 addition & 58 deletions industry/lib/parameters.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,4 @@
import copy


class Parameter:
def __init__(self, name, default, enum):
self.name = name
self.default = default
self.enum = enum

def add(self, g, s):
g.add_int_parameter(
name=s[f"STR_PARAM_{self.name}"],
description=s[f"STR_PARAM_{self.name}_DESC"],
default=self.default,
limits=(min(self.enum.keys()), max(self.enum.keys())),
enum={k: s[f"STR_PARAM_{self.name}_{v}"] for k, v in self.enum.items()},
)


class ParameterList:
def __init__(self, parameters):
self.parameters = parameters

def add(self, g, s):
for p in self.parameters:
p.add(g, s)

def index(self, name):
return [i for i, p in enumerate(self.parameters) if p.name == name][0]
from agrf.parameters import Parameter, ParameterList, SearchSpace


parameter_list = ParameterList(
Expand Down Expand Up @@ -177,35 +149,6 @@ def index(self, name):
)


class SearchSpace:
def __init__(self, choices, parameter_list):
self.choices = choices
self.parameter_list = parameter_list

def copy(self):
return SearchSpace(copy.deepcopy(self.choices), parameter_list)

def fix_docs_params(self, cat, options):
[(idx, all_options)] = [
(i, the_options) for i, (the_cat, the_options) in enumerate(self.choices) if the_cat == cat
]
assert all(o in all_options for o in options)
self.choices[idx] = (cat, options)

def iterate_variations(self, i=0, params={}):
if i == len(self.choices):
yield params
else:
for j in self.choices[i][1]:
new_params = params.copy()
new_params[self.choices[i][0]] = j
for variation in self.iterate_variations(i + 1, new_params):
yield variation

def desc(self, params):
return "".join(str(options.index(params[i])) for i, options in self.choices)


parameter_choices = SearchSpace(
[
("POLICY", ["AUTARKY", "SELF_SUFFICIENT", "FREE_TRADE", "EXPORT"]),
Expand Down

0 comments on commit db1a09a

Please sign in to comment.