Skip to content

Commit 81eb6f3

Browse files
committed
add rl code
1 parent c3257c7 commit 81eb6f3

File tree

13 files changed

+1709
-88
lines changed

13 files changed

+1709
-88
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

docs/reinforcement_learning/code/algo/multi_armed_bandits.ipynb

Lines changed: 508 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
from collections import defaultdict
3+
from gym_env.envs.stochastic_grid_world import StochasticGridWorldEnv
4+
5+
6+
def value_iter(env, theta=0.001, discount_factor=1.0):
7+
def one_step_lookahead(state, V):
8+
A = np.zeros(len(env._action_to_direction))
9+
for a in env._action_to_direction:
10+
for prob, next_state, reward, done in env.P[state][a]:
11+
A[a] += prob * (reward + discount_factor * V[next_state])
12+
return A
13+
14+
V = defaultdict(int)
15+
while True:
16+
delta = 0
17+
for s in env.P:
18+
A = one_step_lookahead(s, V)
19+
best_action_value = np.max(A)
20+
delta = max(delta, np.abs(best_action_value - V[s]))
21+
V[s] = best_action_value
22+
23+
if delta < theta:
24+
break
25+
26+
policy = defaultdict(int)
27+
for s in env.P:
28+
A = one_step_lookahead(s, V)
29+
best_action = np.argmax(A)
30+
policy[s] = best_action
31+
32+
return policy, V
33+
34+
35+
if __name__ == "__main__":
36+
env = StochasticGridWorldEnv()
37+
policy, V = value_iter(env)
38+
print(policy)
39+
print(V)

docs/reinforcement_learning/code/gym_env/__init__.py

Whitespace-only changes.

docs/reinforcement_learning/code/gym_env/envs/__init__.py

Whitespace-only changes.

docs/reinforcement_learning/code/gym_env/envs/stochastic_grid_world.py

Lines changed: 223 additions & 85 deletions
Large diffs are not rendered by default.

docs/reinforcement_learning/code/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,8 @@ readme = "README.md"
66
requires-python = ">=3.13"
77
dependencies = [
88
"gymnasium>=1.1.1",
9+
"ipykernel>=6.29.5",
910
"pygame>=2.6.1",
11+
"torch>=2.7.0",
12+
"torchrl>=0.8.0",
1013
]
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 2,
6+
"id": "94e6bb6e",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"from torchrl.envs import GymEnv"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 3,
16+
"id": "beb5c210",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"env = GymEnv(\"Pendulum-v1\")"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": 4,
26+
"id": "aa974bfc",
27+
"metadata": {},
28+
"outputs": [
29+
{
30+
"data": {
31+
"text/plain": [
32+
"TensorDict(\n",
33+
" fields={\n",
34+
" done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
35+
" observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
36+
" terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
37+
" truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
38+
" batch_size=torch.Size([]),\n",
39+
" device=None,\n",
40+
" is_shared=False)"
41+
]
42+
},
43+
"execution_count": 4,
44+
"metadata": {},
45+
"output_type": "execute_result"
46+
}
47+
],
48+
"source": [
49+
"reset = env.reset()\n",
50+
"reset"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": 6,
56+
"id": "1a6b8ef6",
57+
"metadata": {},
58+
"outputs": [
59+
{
60+
"data": {
61+
"text/plain": [
62+
"TensorDict(\n",
63+
" fields={\n",
64+
" action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
65+
" done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
66+
" observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
67+
" terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
68+
" truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
69+
" batch_size=torch.Size([]),\n",
70+
" device=None,\n",
71+
" is_shared=False)"
72+
]
73+
},
74+
"execution_count": 6,
75+
"metadata": {},
76+
"output_type": "execute_result"
77+
}
78+
],
79+
"source": [
80+
"reset_with_action = env.rand_action(reset)\n",
81+
"reset_with_action"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": 7,
87+
"id": "65ff51e9",
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"from enum import IntEnum\n",
92+
"\n",
93+
"\n",
94+
"class ACTIONS(IntEnum):\n",
95+
" NORTH = 0\n",
96+
" SOUTH = 1\n",
97+
" EAST = 2\n",
98+
" WEST = 3"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": 8,
104+
"id": "fc0818d2",
105+
"metadata": {},
106+
"outputs": [
107+
{
108+
"data": {
109+
"text/plain": [
110+
"<ACTIONS.SOUTH: 1>"
111+
]
112+
},
113+
"execution_count": 8,
114+
"metadata": {},
115+
"output_type": "execute_result"
116+
}
117+
],
118+
"source": [
119+
"ACTIONS(1)"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"id": "e1dafbc4",
126+
"metadata": {},
127+
"outputs": [],
128+
"source": []
129+
}
130+
],
131+
"metadata": {
132+
"kernelspec": {
133+
"display_name": ".venv",
134+
"language": "python",
135+
"name": "python3"
136+
},
137+
"language_info": {
138+
"codemirror_mode": {
139+
"name": "ipython",
140+
"version": 3
141+
},
142+
"file_extension": ".py",
143+
"mimetype": "text/x-python",
144+
"name": "python",
145+
"nbconvert_exporter": "python",
146+
"pygments_lexer": "ipython3",
147+
"version": "3.13.3"
148+
}
149+
},
150+
"nbformat": 4,
151+
"nbformat_minor": 5
152+
}

0 commit comments

Comments
 (0)