Skip to content

Commit

Permalink
Merge pull request #16 from cognitiveailab/serialization
Browse files Browse the repository at this point in the history
Add serialization to TextWorldExpressEnv
  • Loading branch information
MarcCote authored Dec 12, 2024
2 parents b3e4df1 + 788f3e3 commit a132be0
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 1 deletion.
69 changes: 69 additions & 0 deletions tests/test_textworld_express.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from textworld_express import TextWorldExpressEnv


Expand Down Expand Up @@ -119,3 +120,71 @@ def _extract_contents(root):
objs = _extract_contents(obj_tree["locations"]["kitchen"])
expected_nested_objs_in_kitchen = ['cookbook', 'counter', 'cutlery drawer', 'dining chair', 'dining table', 'dishwasher', 'fridge', 'green bell pepper', 'knife', 'oven', 'stove', 'trash can']
assert sorted(objs.keys()) == expected_nested_objs_in_kitchen


def test_reset_with_seed():
env = TextWorldExpressEnv()
obs1, _ = env.reset(seed=42, gameName="cookingworld")
obs2, _ = env.reset(seed=42, gameName="cookingworld")
assert obs1 == obs2


def test_reset_without_seed():
env = TextWorldExpressEnv()
obs1, _ = env.reset(gameName="cookingworld")
obs2, _ = env.reset(gameName="cookingworld")
assert obs1 != obs2


def test_step():
env = TextWorldExpressEnv()
env.reset(gameName="cookingworld")
obs, reward, done, infos = env.step("look around")
assert isinstance(obs, str)
assert isinstance(reward, float)
assert isinstance(done, bool)
assert isinstance(infos, dict)


def test_clone():
env = TextWorldExpressEnv()
infos = env.reset(gameName="cookingworld", seed=42, generateGoldPath=True)
env.step("look around")
solution = env.getGoldActionSequence()
for action in solution[:len(solution) // 2]:
env.step(action)

clone_env = env.clone()
assert env.getRunHistory() == clone_env.getRunHistory()

# Continue the game in both env.
for action in solution[len(solution) // 2:]:
env.step(action)
clone_env.step(action)

assert env.getRunHistory() == clone_env.getRunHistory()


def test_serialize_deserialize():
env = TextWorldExpressEnv()
env.reset(gameName="cookingworld", seed=42)
env.step("look around")
state = env.serialize()
new_env = TextWorldExpressEnv.deserialize(state)
assert env.getRunHistory() == new_env.getRunHistory()


def test_close():
env = TextWorldExpressEnv()
env.reset(gameName="cookingworld")
assert env._gateway.java_process.poll() is None
env.close()
time.sleep(1)
assert env._gateway.java_process.poll() is not None


def test_get_game_names():
env = TextWorldExpressEnv()
game_names = env.getGameNames()
assert isinstance(game_names, list)
assert "cookingworld" in game_names
32 changes: 31 additions & 1 deletion textworld_express/textworld_express.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, serverPath=None, envStepLimit=100):
self.seed = None
self.gameName = None
self.gameParams = ""
self.gameFold = None
self.gameFold = "train"
self.generateGoldPath = False

self._obj_tree_tempfile = tempfile.NamedTemporaryFile()
Expand Down Expand Up @@ -266,3 +266,33 @@ def step(self, inputStr:str):
self.addStepToHistory(infos)

return observation, reward, isCompleted, infos

def serialize(self):
state = {
"gameName": self.gameName,
"gameParams": self.gameParams,
"seed": self.seed,
"gameFold": self.gameFold,
"envStepLimit": self.envStepLimit,
"generateGoldPath": self.generateGoldPath,
"actions": [info["lastActionStr"] for info in self.runHistory[1:]],
}
return state

@classmethod
def deserialize(cls, state):
env = cls(envStepLimit=state["envStepLimit"])
env.reset(
seed=state["seed"],
gameFold=state["gameFold"],
gameName=state["gameName"],
gameParams=state["gameParams"],
generateGoldPath=state["generateGoldPath"]
)
for action in state["actions"]:
env.step(action)

return env

def clone(self):
return self.deserialize(self.serialize())

0 comments on commit a132be0

Please sign in to comment.