Skip to content

Commit bd5eea5

Browse files
committed
updates for linting
1 parent 5fe237f commit bd5eea5

File tree

4 files changed

+62
-139
lines changed

4 files changed

+62
-139
lines changed

python-xpress-facility-location/input.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"seed": 1,
2+
"seed": 10,
33
"num_parks": 4,
44
"num_schools": 9,
55
"num_sites": 11

python-xpress-facility-location/main.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import time
1+
2+
import json
23

34
import nextmv
4-
from visuals import draw_sol
55
import numpy as np
6-
import json
6+
from visuals import draw_sol
77

88
try:
99
import xpress as xp
@@ -25,18 +25,29 @@ def main() -> None:
2525
options = nextmv.Options(
2626
nextmv.Option("input", str, "", "Path to input file. Default is stdin.", False),
2727
nextmv.Option("output", str, "", "Path to output file. Default is stdout.", False),
28-
nextmv.Option("duration", int, 30, "Max runtime duration (in seconds).", False),
28+
nextmv.Option("objective", str, "average_distance", "minimizes for average_distance, total_distance, or max_distance", False),
29+
nextmv.Option("parks_override", int, 4, "number of parks to build", False),
2930
)
3031

3132
input = nextmv.load(options=options, path=options.input)
33+
if options.objective not in ["average_distance", "total_distance", "max_distance"]:
34+
raise ValueError("Invalid objective. Must be either 'average_distance', 'total_distance', or 'max_distance'.")
3235

3336
nextmv.log("Solving facility location problem:")
34-
nextmv.log(f" - schools: {input.data.get('num_schools', [])}")
35-
nextmv.log(f" - sites: {input.data.get('num_sites', 0)}")
36-
nextmv.log(f" - parks: {input.data.get('num_parks', 0)}")
37+
nextmv.log(f" - objective: {options.objective}")
38+
nextmv.log(f" - parks_override: {options.parks_override}")
39+
nextmv.log(f" - schools: {input.data.get('num_schools')}")
40+
nextmv.log(f" - sites: {input.data.get('num_sites')}")
41+
nextmv.log(f" - parks: {input.data.get('num_parks')}")
42+
43+
np.random.seed(input.data.get('seed'))
3744

3845
SCHOOLS = range(input.data.get('num_schools')) # set of schools
3946
SITES = range(input.data.get('num_sites')) # set of candidate sites
47+
if options.parks_override is not None:
48+
num_parks = options.parks_override
49+
else:
50+
num_parks = input.data.get('num_parks')
4051

4152
coord_schools = 10 * np.random.random((input.data.get('num_schools'), 2)) # x-y coordinates between 0 and 10 (in km)
4253
coord_sites = 10 * np.random.random((input.data.get('num_sites'), 2))
@@ -51,13 +62,20 @@ def main() -> None:
5162
build = prob.addVariables(SITES, vartype=xp.binary)
5263

5364
# Objective function and constraints
54-
prob.setObjective(xp.Sum(dist[i,j] * serves[i,j] for i in SCHOOLS for j in SITES))
65+
if options.objective == "average_distance":
66+
prob.setObjective(xp.Sum(dist[i,j] * serves[i,j] for i in SCHOOLS for j in SITES) / input.data.get('num_schools'))
67+
elif options.objective == "total_distance":
68+
prob.setObjective(xp.Sum(dist[i,j] * serves[i,j] for i in SCHOOLS for j in SITES))
69+
elif options.objective == "max_distance":
70+
z = prob.addVariable() # add auxiliary variable to the problem
71+
prob.addConstraint(z >= xp.Sum(dist[i,j] * serves[i,j] for j in SITES) for i in SCHOOLS)
72+
prob.setObjective(z) # replaces the old objective function
5573

5674
# Every school must be served by one park
5775
prob.addConstraint(xp.Sum(serves[i,j] for j in SITES) == 1 for i in SCHOOLS)
5876

5977
# Exactly n parks are built:
60-
prob.addConstraint(xp.Sum(build[j] for j in SITES) == input.data.get('num_parks'))
78+
prob.addConstraint(xp.Sum(build[j] for j in SITES) == num_parks)
6179

6280
# Only parks that are built can serve schools
6381
prob.addConstraint(xp.Sum(serves[i,j] for i in SCHOOLS) <= input.data.get('num_schools') * build[j] for j in SITES)
@@ -71,15 +89,30 @@ def main() -> None:
7189

7290
input.options.provider = "xpress"
7391

74-
input_charts = draw_sol(n=input.data.get('num_schools'),m=input.data.get('num_sites'), label="Input Chart", coord_schools=coord_schools, coord_sites=coord_sites, SCHOOLS=SCHOOLS, SITES=SITES)
75-
output_charts = draw_sol(input.data.get('num_schools'),input.data.get('num_sites'),prob,serves,build, "Output Chart", coord_schools=coord_schools, coord_sites=coord_sites, SCHOOLS=SCHOOLS, SITES=SITES)
92+
input_charts = draw_sol(n=input.data.get('num_schools'),
93+
m=input.data.get('num_sites'),
94+
label="Input Chart",
95+
coord_schools=coord_schools,
96+
coord_sites=coord_sites,
97+
SCHOOLS=SCHOOLS,
98+
SITES=SITES)
99+
output_charts = draw_sol(input.data.get('num_schools'),
100+
input.data.get('num_sites'),
101+
prob,
102+
serves,
103+
build,
104+
label="Output Chart",
105+
coord_schools=coord_schools,
106+
coord_sites=coord_sites,
107+
SCHOOLS=SCHOOLS,
108+
SITES=SITES)
76109

77110
output = nextmv.Output(
78111
solution={"solution": solution},
79112
statistics={"result": {"value": value}, "schema": "v1"},
80113
assets=[input_charts, output_charts]
81114
)
82-
115+
83116
nextmv.write(output, path=options.output)
84117

85118

python-xpress-facility-location/problem.lp

Lines changed: 0 additions & 114 deletions
This file was deleted.

python-xpress-facility-location/visuals.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
import plotly.graph_objects as go
21
import json
2+
33
import nextmv
4+
import plotly.graph_objects as go
5+
46

5-
def draw_sol(n, m, prob=None, x=None, y=None, label="Chart", coord_schools=None, coord_sites=None, SCHOOLS=None, SITES=None) -> nextmv.Asset:
6-
V = [i for i in range(n + m)]
7+
def draw_sol(n, m, prob=None, x=None, y=None,
8+
label="Chart", coord_schools=None, coord_sites=None,
9+
SCHOOLS=None, SITES=None) -> nextmv.Asset:
10+
V = list(range(n + m))
711
E = []
812

913
# Get coordinates
@@ -18,7 +22,7 @@ def draw_sol(n, m, prob=None, x=None, y=None, label="Chart", coord_schools=None,
1822
E = [(i, n + j) for i in SCHOOLS for j in SITES if xsol[i, j] > 0.5]
1923

2024
# Node colors
21-
node_colS = {i: '#5555ff' for i in SCHOOLS}
25+
node_colS = dict.fromkeys(SCHOOLS, '#5555ff')
2226
node_colA1 = {n + j: '#ff5555' for j in SITES if y and ysol[j] > 0.5}
2327
node_colA0 = {n + j: '#a0a0a0' for j in SITES if not y or ysol[j] < 0.5}
2428
node_col = {**node_colS, **node_colA1, **node_colA0}
@@ -28,10 +32,10 @@ def draw_sol(n, m, prob=None, x=None, y=None, label="Chart", coord_schools=None,
2832
x=[coord[i][0] for i in V],
2933
y=[coord[i][1] for i in V],
3034
mode='markers+text',
31-
marker=dict(
32-
size=10,
33-
color=[node_col[i] for i in V]
34-
),
35+
marker={
36+
"size": 10,
37+
"color": [node_col[i] for i in V]
38+
},
3539
text=[str(i) for i in V],
3640
textposition="top center",
3741
hoverinfo='text'
@@ -46,7 +50,7 @@ def draw_sol(n, m, prob=None, x=None, y=None, label="Chart", coord_schools=None,
4650
x=[x0, x1],
4751
y=[y0, y1],
4852
mode='lines',
49-
line=dict(width=1, color='#888'),
53+
line={"width": 1, "color": '#888'},
5054
hoverinfo='none'
5155
)
5256
edge_traces.append(edge_trace)
@@ -61,9 +65,9 @@ def draw_sol(n, m, prob=None, x=None, y=None, label="Chart", coord_schools=None,
6165
title='School-Area Assignment Network',
6266
showlegend=False,
6367
hovermode='closest',
64-
margin=dict(b=20, l=5, r=5, t=40),
65-
xaxis=dict(showgrid=False, zeroline=False),
66-
yaxis=dict(showgrid=False, zeroline=False),
68+
margin={"b": 20, "l": 5, "r": 5, "t": 40},
69+
xaxis={"showgrid": False, "zeroline": False},
70+
yaxis={"showgrid": False, "zeroline": False},
6771
width=700,
6872
height=700
6973
)

0 commit comments

Comments
 (0)