Skip to content

Commit 791bf3f

Browse files
authored
Merge pull request #407 from xiongjyu/dev-jericho-llm-prior
fix(xjy): add qwen performance test on atari
2 parents ee61b12 + 9838468 commit 791bf3f

File tree

1 file changed

+352
-0
lines changed

1 file changed

+352
-0
lines changed

zoo/atari/envs/test_qwen_arati_env.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
# run_pong_qwen_ddp.py
2+
import os, re, json, random
3+
from dataclasses import dataclass
4+
from collections import deque, namedtuple
5+
from typing import List, Tuple, Union
6+
import numpy as np
7+
import shutil
8+
from PIL import Image
9+
import torch
10+
import torch.distributed as dist
11+
12+
from transformers import AutoProcessor
13+
from transformers import Qwen2_5_VLForConditionalGeneration
14+
15+
from easydict import EasyDict
16+
from zoo.atari.envs.atari_lightzero_env import AtariEnvLightZero
17+
18+
19+
def to_model_image(arr: Union[np.ndarray, torch.Tensor], channel_last: bool, use_pil: bool):
20+
"""
21+
返回:
22+
- use_pil=True -> PIL.Image(RGB)
23+
- use_pil=False -> numpy HWC uint8
24+
"""
25+
if isinstance(arr, torch.Tensor):
26+
arr = arr.detach().cpu().numpy()
27+
arr = np.asarray(arr)
28+
29+
# 2D 灰度 -> HWC
30+
if arr.ndim == 2:
31+
arr = arr[:, :, None]
32+
33+
# 统一到 HWC
34+
if channel_last:
35+
hwc = arr
36+
else:
37+
assert arr.ndim == 3 and arr.shape[0] in (1, 3), f"Expect (C,H,W) or (H,W,C), got {arr.shape}"
38+
hwc = np.transpose(arr, (1, 2, 0))
39+
40+
# 灰度扩 3 通道
41+
if hwc.shape[-1] == 1:
42+
hwc = np.repeat(hwc, 3, axis=-1)
43+
44+
# 归一到 uint8
45+
if hwc.dtype != np.uint8:
46+
if hwc.max() <= 1.0:
47+
hwc = hwc * 255.0
48+
hwc = np.clip(hwc, 0, 255).astype(np.uint8)
49+
50+
if use_pil:
51+
return Image.fromarray(hwc, mode="RGB")
52+
else:
53+
return hwc
54+
55+
56+
57+
def init_distributed():
58+
backend = "nccl" if torch.cuda.is_available() else "gloo"
59+
if not dist.is_initialized():
60+
dist.init_process_group(backend=backend, init_method="env://")
61+
rank = dist.get_rank()
62+
world_size = dist.get_world_size()
63+
64+
# 设定 device
65+
local_rank = int(os.getenv("LOCAL_RANK", rank % max(1, torch.cuda.device_count())))
66+
if torch.cuda.is_available():
67+
torch.cuda.set_device(local_rank)
68+
69+
return rank, world_size, local_rank
70+
71+
72+
Transition = namedtuple("Transition", ["step", "image", "action_str"])
73+
74+
class QwenPongPolicy:
75+
"""
76+
- 历史 n 帧(仅包含:图像 + 我们当时的动作字符串)
77+
- 指令结构(中文提示语义一致,英文更利于指令稳定):
78+
环境描述 + 任务描述 + 当前图片 + <image> + 可选动作(字符串列表)
79+
+ 历史轨迹(只含 历史图片 + 历史动作字符串)
80+
要求模型输出:单行 纯动作字符串(如 RIGHTFIRE)
81+
- 解析失败则从 allowed 随机抽取一个字符串,再映射回动作 id
82+
- 支持 FlashAttention-2(若不可用自动回退)
83+
"""
84+
# 6 个官方动作名
85+
ID2NAME = {
86+
0: "NOOP",
87+
1: "FIRE",
88+
2: "RIGHT",
89+
3: "LEFT",
90+
4: "RIGHTFIRE",
91+
5: "LEFTFIRE",
92+
}
93+
NAME2ID = {v: k for k, v in ID2NAME.items()}
94+
95+
ACTION_EXPLAIN = {
96+
"NOOP": "Do nothing (stay still).",
97+
"FIRE": "Serve a new point(use only at the start of a rally).",
98+
"RIGHT": "Move your RIGHT paddle UP in this Pong port.",
99+
"LEFT": "Move your RIGHT paddle DOWN in this Pong port.",
100+
"RIGHTFIRE": "Move UP and SERVE simultaneously (use only to start a rally).",
101+
"LEFTFIRE": "Move DOWN and SERVE simultaneously (use only to start a rally).",
102+
}
103+
104+
105+
def __init__(self, model_name: str, dtype: torch.dtype, history_n: int,
106+
use_pil: bool, channel_last: bool, device: torch.device, save_dir: str = "pong_ddp_frames", save_image=False, rank: int = 0):
107+
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
108+
109+
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
110+
model_name,
111+
torch_dtype=dtype,
112+
device_map={"": device.index},
113+
attn_implementation="flash_attention_2",
114+
trust_remote_code=True,
115+
)
116+
117+
self.model.eval()
118+
119+
self.history_n = history_n
120+
self.buffer: deque[Transition] = deque(maxlen=history_n)
121+
self.use_pil = use_pil
122+
self.channel_last = channel_last
123+
self.device = device
124+
self.save_image = save_image
125+
self.save_dir = save_dir
126+
self.rank = rank
127+
self.rank_dir = os.path.join(self.save_dir, f"rank{rank:02d}")
128+
if os.path.exists(self.rank_dir):
129+
shutil.rmtree(self.rank_dir)
130+
131+
os.makedirs(self.rank_dir, exist_ok=True)
132+
self.meta_path = os.path.join(self.rank_dir, "trajectory.jsonl")
133+
134+
def save_pil_if_enabled(self, img: Image.Image, save_root: str, step: int):
135+
d = os.path.join(save_root, f"rank{self.rank:02d}")
136+
os.makedirs(d, exist_ok=True)
137+
img.save(os.path.join(d, f"frame_{step:06d}.png"))
138+
139+
def log_step(self, step: int, action_id: int, action_str: str, reward: float):
140+
"""
141+
Append one record to a single per-rank trajectory file in the same directory as frames.
142+
- If meta_format == 'jsonl': one JSON object per line
143+
- If meta_format == 'csv': a single CSV with header 'step,action_id,action,reward'
144+
"""
145+
rec = {
146+
"step": int(step),
147+
"action_id": int(action_id),
148+
"action": str(action_str),
149+
"reward": float(reward),
150+
}
151+
with open(self.meta_path, "a") as f:
152+
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
153+
154+
155+
def _build_messages_and_images(self, cur_img, allowed_names: List[str]):
156+
"""
157+
user.content 顺序(按你的要求):
158+
1) 环境描述 + 任务描述(文本)
159+
2) 当前图片 <image>
160+
3) 可选动作(字符串列表)+ 对这 6 个动作的清晰解释
161+
4) 历史轨迹(只包含:历史图片 + 对应动作字符串)
162+
5) 输出格式要求:只返回一行 {ACTION: <action_str>}
163+
"""
164+
content = []
165+
images_for_processor = []
166+
167+
# 1) 环境 + 任务
168+
content.append({
169+
"type": "text",
170+
"text": (
171+
"Environment: Atari Pong (ALE) — two paddles rally a ball.\n"
172+
"Task: You control the RIGHT paddle. Keep your paddle vertically aligned with the ball to return it and avoid conceding.\n"
173+
"Serving rule: when a new point starts and the ball is not yet in play, you must SERVE using FIRE or *_FIRE; "
174+
"during an active rally, do NOT use FIRE actions and instead move appropriately."
175+
)
176+
})
177+
178+
# 2) 当前图片
179+
content.append({"type": "text", "text": "Current state image:"})
180+
content.append({"type": "image", "image": cur_img})
181+
images_for_processor.append(cur_img)
182+
183+
# 3) 可选动作 + 解释
184+
allowed_str = ", ".join(allowed_names)
185+
# 解释文本(只针对当前允许的动作给出说明)
186+
explain_lines = []
187+
for name in allowed_names:
188+
desc = self.ACTION_EXPLAIN.get(name, "")
189+
if desc:
190+
explain_lines.append(f"- {name}: {desc}")
191+
explain_text = "\n".join(explain_lines)
192+
193+
content.append({
194+
"type": "text",
195+
"text": (
196+
f"Available actions (choose exactly one string): {allowed_str}\n"
197+
"Action semantics:\n"
198+
f"{explain_text}\n"
199+
"Heuristic (to guide your choice): if the ball is above your paddle, choose an UP action (RIGHT/RIGHTFIRE when serving); "
200+
"if the ball is below, choose a DOWN action (LEFT/LEFTFIRE when serving); if perfectly aligned and rally is active, NOOP briefly is acceptable."
201+
)
202+
})
203+
204+
# 4) 历史交互轨迹(只包含:历史图片 + 当时选择的动作字符串)
205+
if len(self.buffer) > 0:
206+
content.append({"type": "text", "text": "Recent interaction history (most recent first):"})
207+
for tr in list(self.buffer)[::-1]: # 近 -> 远
208+
content.append({"type": "image", "image": tr.image})
209+
images_for_processor.append(tr.image)
210+
# 再给该状态下我们选过的动作(仅动作字符串)
211+
content.append({
212+
"type": "text",
213+
"text": f"You chose the action: {tr.action_str}"
214+
})
215+
216+
# 5) 输出格式要求(只返回一行 {ACTION: <action_str>})
217+
content.append({
218+
"type": "text",
219+
"text": (
220+
"\nOutput requirement:\n"
221+
"- Return EXACTLY ONE line in the form: {ACTION: <action_str>}\n"
222+
f"- <action_str> MUST be one of: {allowed_str}\n"
223+
)
224+
})
225+
226+
messages = [
227+
{"role": "system", "content": "You are a precise action selector for Atari Pong. Always follow the requested output format."},
228+
{"role": "user", "content": content},
229+
]
230+
return messages, images_for_processor
231+
232+
def _parse_action_string(self, text: str, allowed_names: List[str]) -> str:
233+
# 为避免 RIGHTFIRE 被 RIGHT 抢先匹配,按长度降序
234+
names_sorted = sorted(allowed_names, key=len, reverse=True)
235+
236+
alt = "|".join(map(re.escape, names_sorted))
237+
pattern = rf"""\{{\s*"?ACTION"?\s*[::]\s*"?\s*({alt})\s*"?\s*\}}"""
238+
239+
m = re.search(pattern, text, flags=re.IGNORECASE)
240+
if m:
241+
return m.group(1).upper()
242+
243+
return random.choice(allowed_names)
244+
245+
@torch.inference_mode()
246+
def decide(self, obs_dict: dict, step: int) -> Tuple[int, str, str]:
247+
allowed_ids = [i for i, v in enumerate(obs_dict.get("action_mask", [1]*6)) if int(v) == 1]
248+
allowed_names = [self.ID2NAME[i] for i in allowed_ids]
249+
250+
cur_img = to_model_image(obs_dict["observation"], channel_last=False, use_pil=self.use_pil)
251+
252+
messages, images_for_processor = self._build_messages_and_images(cur_img, allowed_names)
253+
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
254+
255+
inputs = self.processor(
256+
text=prompt,
257+
images=images_for_processor,
258+
return_tensors="pt"
259+
).to(self.device)
260+
261+
out_ids = self.model.generate(
262+
**inputs,
263+
max_new_tokens=16,
264+
temperature=0.0,
265+
do_sample=False,
266+
top_p=1.0,
267+
)
268+
input_len = int(inputs["input_ids"].shape[1])
269+
gen_only = out_ids[:, input_len:]
270+
271+
out_text = self.processor.batch_decode(gen_only, skip_special_tokens=True)[0]
272+
273+
action_str = self._parse_action_string(out_text, allowed_names)
274+
action_id = self.NAME2ID[action_str]
275+
276+
if self.use_pil and self.save_image:
277+
self.save_pil_if_enabled(cur_img, self.save_dir, step)
278+
279+
return action_id, action_str, out_text
280+
281+
def record(self, prev_obs: dict, action_id: int, step: int):
282+
img = to_model_image(prev_obs["observation"], channel_last=False, use_pil=self.use_pil)
283+
action_str = self.ID2NAME[action_id]
284+
self.buffer.append(Transition(step=step, image=img, action_str=action_str))
285+
286+
287+
if __name__ == "__main__":
288+
rank, world_size, local_rank = init_distributed()
289+
device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu")
290+
291+
base_seed = 12345
292+
random.seed(base_seed + rank)
293+
np.random.seed(base_seed + rank)
294+
torch.manual_seed(base_seed + rank)
295+
296+
config = EasyDict(dict(
297+
collector_env_num=8,
298+
evaluator_env_num=3,
299+
n_evaluator_episode=3,
300+
env_id='PongNoFrameskip-v4',
301+
env_type='Atari',
302+
observation_shape=[3, 64, 64],
303+
collect_max_episode_steps=int(1.08e5),
304+
eval_max_episode_steps=int(1.08e5),
305+
gray_scale=False,
306+
frame_skip=4,
307+
frame_stack_num=1,
308+
episode_life=True,
309+
clip_rewards=True,
310+
channel_last=False,
311+
render_mode_human=False,
312+
scale=True,
313+
warp_frame=True,
314+
save_video=False,
315+
transform2string=False,
316+
game_wrapper=True,
317+
stop_value=int(1e6),
318+
save_replay=False,
319+
replay_path=None,
320+
))
321+
config.max_episode_steps = config.eval_max_episode_steps
322+
env = AtariEnvLightZero(config)
323+
324+
policy = QwenPongPolicy(
325+
model_name="/fs-computility/niuyazhe/shared/xiongjyu/model/Qwen2.5-VL-3B-Instruct",
326+
dtype=torch.bfloat16,
327+
history_n=5,
328+
use_pil=False,
329+
channel_last=config.channel_last,
330+
device=device,
331+
save_dir="/fs-computility/niuyazhe/shared/xiongjyu/jericho/LightZero/pong_ddp_frames",
332+
save_image=True,
333+
rank=rank
334+
)
335+
336+
obs = env.reset()
337+
episode_return, steps = 0.0, 0
338+
339+
while True:
340+
action_id, action_str, raw = policy.decide(obs, step=steps)
341+
prev_obs = obs
342+
obs, reward, done, info = env.step(action_id)
343+
policy.log_step(steps, action_id, action_str, reward)
344+
345+
policy.record(prev_obs, action_id, step=steps)
346+
347+
episode_return += float(reward)
348+
steps += 1
349+
350+
if done or steps >= config.max_episode_steps:
351+
print(f"[RANK {rank}/{world_size}] return={episode_return}, steps={steps}, info={info}")
352+
break

0 commit comments

Comments
 (0)