Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non markov exps #22

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8cdd934
dummy stash commit
devrz45 Jul 18, 2023
bb4944d
changed rules not working
devrz45 Jul 20, 2023
c060595
Non markovian movement test failing
devrz45 Jul 20, 2023
e66cacf
changed rules.yaml for health clause to be [0.1, 1] instead of [1,1]
dyumanaditya Jul 25, 2023
9b3532a
updated pyreason version
dyumanaditya Jul 25, 2023
fde26dd
updated pyreason version
dyumanaditya Jul 26, 2023
57e8b72
Setup for non-markov training trials
devrz45 Jul 31, 2023
d93a073
Merge branch 'main' into non-markov-exps
devrz45 Aug 4, 2023
036a489
fix bug where bullet moves two steps
dyumanaditya Aug 7, 2023
1dc627f
merge main and non-markov-exps
dyumanaditya Aug 7, 2023
213a56c
added who shot bullet info
dyumanaditya Aug 7, 2023
8213da2
Updated rules to handle the bullet persisting
devrz45 Aug 7, 2023
19d8ffd
non markov setup with multi-agent shooting
devrz45 Aug 8, 2023
70ed5e8
updated rules and pyreason version
dyumanaditya Aug 9, 2023
485d361
Non markov multi-agent test case fail
devrz45 Aug 9, 2023
7ded834
updated rules to fix assertion error
dyumanaditya Aug 10, 2023
a681504
updated rules to fix assertion error
dyumanaditya Aug 10, 2023
3602849
updated rules to fix assertion error
dyumanaditya Aug 10, 2023
3a72b7e
Merge branch 'main' into non-markov-exps
dyumanaditya Aug 10, 2023
850bf62
Added new observation space to check who killed whom
dyumanaditya Aug 10, 2023
1d13609
observations only contain info about current killings not past killings
dyumanaditya Aug 10, 2023
46e5454
new non markov rules and test
devrz45 Aug 10, 2023
75b83f3
Merge branch 'non-markov-exps' of github.com:lab-v2/pyreason-gym into…
devrz45 Aug 10, 2023
07100c3
persistent bullet fix again
dyumanaditya Aug 11, 2023
0013b91
changed rules to fix issue where bullet disappears in spot where anot…
dyumanaditya Aug 12, 2023
6e1d639
Non markov specific test added
devrz45 Aug 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions pyreason_gym/pyreason_grid_world/graph/game_graph.graphml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
<key id="down" for="edge" attr.name="down" attr.type="long" />
<key id="up" for="edge" attr.name="up" attr.type="long" />
<key id="right" for="edge" attr.name="right" attr.type="long" />
<key id="blue-soldier-2" for="node" attr.name="blue-soldier-2" attr.type="long" />
<key id="red-soldier-2" for="node" attr.name="red-soldier-2" attr.type="long" />
<key id="bullet" for="node" attr.name="bullet" attr.type="long" />
<key id="blue-soldier-1" for="node" attr.name="blue-soldier-1" attr.type="long" />
<key id="teamBlue" for="node" attr.name="teamBlue" attr.type="long" />
<key id="shootRightBlue" for="node" attr.name="shootRightBlue" attr.type="long" />
<key id="shootLeftBlue" for="node" attr.name="shootLeftBlue" attr.type="long" />
<key id="shootDownBlue" for="node" attr.name="shootDownBlue" attr.type="long" />
<key id="shootUpBlue" for="node" attr.name="shootUpBlue" attr.type="long" />
<key id="red-soldier-1" for="node" attr.name="red-soldier-1" attr.type="long" />
<key id="justDied" for="node" attr.name="justDied" attr.type="string" />
<key id="teamRed" for="node" attr.name="teamRed" attr.type="long" />
<key id="shootRightRed" for="node" attr.name="shootRightRed" attr.type="long" />
Expand Down Expand Up @@ -240,6 +244,7 @@
<data key="shootRightRed">0</data>
<data key="teamRed">1</data>
<data key="justDied">0,0</data>
<data key="red-soldier-1">1</data>
</node>
<node id="blue-soldier-1">
<data key="health">1</data>
Expand All @@ -253,6 +258,7 @@
<data key="shootRightBlue">0</data>
<data key="teamBlue">1</data>
<data key="justDied">0,0</data>
<data key="blue-soldier-1">1</data>
</node>
<node id="red-bullet-1">
<data key="teamRed">1</data>
Expand All @@ -262,6 +268,42 @@
<data key="teamBlue">1</data>
<data key="bullet">1</data>
</node>
<node id="red-soldier-2">
<data key="health">1</data>
<data key="moveUp">0</data>
<data key="moveDown">0</data>
<data key="moveLeft">0</data>
<data key="moveRight">0</data>
<data key="shootUpRed">0</data>
<data key="shootDownRed">0</data>
<data key="shootLeftRed">0</data>
<data key="shootRightRed">0</data>
<data key="teamRed">1</data>
<data key="justDied">0,0</data>
<data key="red-soldier-2">1</data>
</node>
<node id="blue-soldier-2">
<data key="health">1</data>
<data key="moveUp">0</data>
<data key="moveDown">0</data>
<data key="moveLeft">0</data>
<data key="moveRight">0</data>
<data key="shootUpBlue">0</data>
<data key="shootDownBlue">0</data>
<data key="shootLeftBlue">0</data>
<data key="shootRightBlue">0</data>
<data key="teamBlue">1</data>
<data key="justDied">0,0</data>
<data key="blue-soldier-2">1</data>
</node>
<node id="red-bullet-2">
<data key="teamRed">1</data>
<data key="bullet">1</data>
</node>
<node id="blue-bullet-2">
<data key="teamBlue">1</data>
<data key="bullet">1</data>
</node>
<edge source="0" target="1">
<data key="right">1</data>
</edge>
Expand Down Expand Up @@ -357,6 +399,9 @@
<edge source="7" target="red-soldier-1">
<data key="atLoc">1</data>
</edge>
<edge source="7" target="red-soldier-2">
<data key="atLoc">1</data>
</edge>
<edge source="8" target="9">
<data key="right">1</data>
</edge>
Expand Down Expand Up @@ -964,6 +1009,9 @@
<edge source="56" target="blue-soldier-1">
<data key="atLoc">1</data>
</edge>
<edge source="56" target="blue-soldier-2">
<data key="atLoc">1</data>
</edge>
<edge source="57" target="58">
<data key="right">1</data>
</edge>
Expand Down Expand Up @@ -1070,5 +1118,23 @@
<edge source="blue-soldier-1" target="blue-bullet-1">
<data key="bullet">1</data>
</edge>
<edge source="red-soldier-2" target="red-base">
<data key="team">1</data>
</edge>
<edge source="red-soldier-2" target="7">
<data key="atLoc">1</data>
</edge>
<edge source="red-soldier-2" target="red-bullet-2">
<data key="bullet">1</data>
</edge>
<edge source="blue-soldier-2" target="blue-base">
<data key="team">1</data>
</edge>
<edge source="blue-soldier-2" target="56">
<data key="atLoc">1</data>
</edge>
<edge source="blue-soldier-2" target="blue-bullet-2">
<data key="bullet">1</data>
</edge>
</graph>
</graphml>
18 changes: 10 additions & 8 deletions pyreason_gym/pyreason_grid_world/pyreason_grid_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ def __init__(self, grid_size, num_agents_per_team):

# Keep track of the next timestep to start
self.next_time = 0

# Pyreason settings
pr.settings.verbose = False
pr.settings.atom_trace = False
pr.settings.canonical = True
pr.settings.inconsistency_check = False
pr.settings.static_graph_facts = False
pr.settings.store_interpretation_changes = False
# pr.settings.store_interpretation_changes = True
current_path = os.path.abspath(os.path.dirname(__file__))

# Load the graph
Expand Down Expand Up @@ -49,20 +49,21 @@ def move(self, action):
fact_off = pr.fact_node.Fact(f'red_action_{i+1}_off', f'red-soldier-{i+1}', pr.label.Label(red_available_actions[a]), pr.interval.closed(0,0), self.next_time+1, self.next_time+1)
facts.append(fact_on)
facts.append(fact_off)

for i, a in enumerate(blue_team_actions):
if a != 8:
fact_on = pr.fact_node.Fact(f'blue_action_{i+1}', f'blue-soldier-{i+1}', pr.label.Label(blue_available_actions[a]), pr.interval.closed(1,1), self.next_time, self.next_time)
fact_off = pr.fact_node.Fact(f'blue_action_{i+1}_off', f'blue-soldier-{i+1}', pr.label.Label(blue_available_actions[a]), pr.interval.closed(0,0), self.next_time+1, self.next_time+1)
facts.append(fact_on)
facts.append(fact_off)

self.interpretation = pr.reason(1, again=True, node_facts=facts)
# pr.save_rule_trace(self.interpretation)
self.next_time = self.interpretation.time + 1

def get_obs(self):
observation = {'red_team': [], 'blue_team': [], 'red_bullets': [], 'blue_bullets': []}

# Gather bullet info for red and blue bullets
(red_bullet_positions, blue_bullet_positions), (red_bullet_directions, blue_bullet_directions), (red_killed_who, blue_killed_who) = self._get_bullet_info()
for red_pos, red_dir in zip(red_bullet_positions, red_bullet_directions):
Expand Down Expand Up @@ -99,14 +100,14 @@ def get_obs(self):
observation['blue_team'].append({'pos': np.array(blue_pos_coords, dtype=np.int32), 'health': np.array([blue_health], dtype=np.float32), 'killed': list(blue_killed_who[i-1])})

return observation

def get_obstacle_locations(self):
# Return the coordinates of all the mountains in the grid to be able to draw them
relevant_edges = [edge for edge in self.interpretation.edges if edge[1]=='mountain']
obstacle_positions = [int(edge[0]) for edge in relevant_edges]
obstacle_positions_coords = np.array([[pos%self.grid_size, pos//self.grid_size] for pos in obstacle_positions])
return obstacle_positions_coords

def get_base_locations(self):
# Return the locations of the two bases
relevant_edges = [edge for edge in self.interpretation.edges if 'base' in edge[0]]
Expand All @@ -127,7 +128,8 @@ def _get_bullet_info(self):
positions = (red_bullet_positions_coords, blue_bullet_positions_coords)

# Get info about who killed whom. Stored in the form a list for every agent: (red-killer: [blue-casualties]) or (blue-killer: [red-casualties])
kill_info_edges = [edge for edge in self.interpretation.edges if pr.label.Label('killed') in self.interpretation.interpretations_edge[edge].world]
kill_info_edges = [edge for edge in self.interpretation.edges if pr.label.Label('killed') in self.interpretation.interpretations_edge[edge].world
and self.interpretation.interpretations_edge[edge].world[pr.label.Label('killed')] == pr.interval.closed(1, 1)]
kill_info_edges = sorted(kill_info_edges, key=lambda x: int(x[0][-1]))
red_killed_who_tuple = [(int(edge[0][-1]), int(edge[1][-1])) for edge in kill_info_edges if 'red' in edge[0]]
blue_killed_who_tuple = [(int(edge[0][-1]), int(edge[1][-1])) for edge in kill_info_edges if 'blue' in edge[0]]
Expand Down
Loading