-
Notifications
You must be signed in to change notification settings - Fork 373
/
Copy pathpolicy.py
125 lines (97 loc) · 4.44 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import random
from abc import abstractmethod
import numpy as np
class Policy:
@abstractmethod
def __init__(self):
pass
@abstractmethod
def get_action(self, observation, info):
pass
def _get_stock_size_(self, stock):
stock_w = np.sum(np.any(stock != -2, axis=1))
stock_h = np.sum(np.any(stock != -2, axis=0))
return stock_w, stock_h
def _can_place_(self, stock, position, prod_size):
pos_x, pos_y = position
prod_w, prod_h = prod_size
return np.all(stock[pos_x : pos_x + prod_w, pos_y : pos_y + prod_h] == -1)
class RandomPolicy(Policy):
def __init__(self):
pass
def get_action(self, observation, info):
list_prods = observation["products"]
prod_size = [0, 0]
stock_idx = -1
pos_x, pos_y = 0, 0
# Pick a product that has quality > 0
for prod in list_prods:
if prod["quantity"] > 0:
prod_size = prod["size"]
# Random choice a stock idx
pos_x, pos_y = None, None
for _ in range(100):
# random choice a stock
stock_idx = random.randint(0, len(observation["stocks"]) - 1)
stock = observation["stocks"][stock_idx]
# Random choice a position
stock_w, stock_h = self._get_stock_size_(stock)
prod_w, prod_h = prod_size
if stock_w >= prod_w and stock_h >= prod_h:
pos_x = random.randint(0, stock_w - prod_w)
pos_y = random.randint(0, stock_h - prod_h)
if self._can_place_(stock, (pos_x, pos_y), prod_size):
break
if stock_w >= prod_h and stock_h >= prod_w:
pos_x = random.randint(0, stock_w - prod_h)
pos_y = random.randint(0, stock_h - prod_w)
if self._can_place_(stock, (pos_x, pos_y), prod_size[::-1]):
prod_size = prod_size[::-1]
break
if pos_x is not None and pos_y is not None:
break
return {"stock_idx": stock_idx, "size": prod_size, "position": (pos_x, pos_y)}
class GreedyPolicy(Policy):
def __init__(self):
pass
def get_action(self, observation, info):
list_prods = observation["products"]
prod_size = [0, 0]
stock_idx = -1
pos_x, pos_y = 0, 0
# Pick a product that has quality > 0
for prod in list_prods:
if prod["quantity"] > 0:
prod_size = prod["size"]
# Loop through all stocks
for i, stock in enumerate(observation["stocks"]):
stock_w, stock_h = self._get_stock_size_(stock)
prod_w, prod_h = prod_size
if stock_w >= prod_w and stock_h >= prod_h:
pos_x, pos_y = None, None
for x in range(stock_w - prod_w + 1):
for y in range(stock_h - prod_h + 1):
if self._can_place_(stock, (x, y), prod_size):
pos_x, pos_y = x, y
break
if pos_x is not None and pos_y is not None:
break
if pos_x is not None and pos_y is not None:
stock_idx = i
break
if stock_w >= prod_h and stock_h >= prod_w:
pos_x, pos_y = None, None
for x in range(stock_w - prod_h + 1):
for y in range(stock_h - prod_w + 1):
if self._can_place_(stock, (x, y), prod_size[::-1]):
prod_size = prod_size[::-1]
pos_x, pos_y = x, y
break
if pos_x is not None and pos_y is not None:
break
if pos_x is not None and pos_y is not None:
stock_idx = i
break
if pos_x is not None and pos_y is not None:
break
return {"stock_idx": stock_idx, "size": prod_size, "position": (pos_x, pos_y)}