Skip to content

Nomination-NRB/RL-snack

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RL-snack

强化学习贪吃蛇

概览

image

流程

强化学习贪吃蛇训练流程:

  1. 首先,定义一个Q网络(QNetwork),它将贪吃蛇游戏的状态(包括蛇的位置和苹果的位置等)作为输入,输出每个动作的Q值(预测的奖励)。
  2. 初始化贪吃蛇游戏环境(GameEnvironment)、经验回放(ReplayMemory)、优化器(Adam)和损失函数(MSELoss)。
  3. 开始训练循环,进行多个回合(episode)的训练。
  4. 在每个回合中,调用 run_episode 函数进行游戏回合,模拟贪吃蛇的行为。在每局游戏中,贪吃蛇根据当前状态(蛇的位置和苹果的位置)通过Q网络预测每个动作的Q值,并根据 ε-greedy策略(以一定概率随机探索,否则选择预测Q值最大的动作)来选择下一步的动作。然后执行动作,更新游戏状态,计算奖励和是否结束游戏,并将样本数据(状态、动作、奖励、下一个状态、是否结束)存储到经验回放中。
  5. 在每个回合结束后,通过经验回放(Replay Memory)中的样本数据,调用 learn 函数来更新Q网络的权重,以逼近最优的Q值。learn 函数会使用批量的样本数据来计算损失函数,然后通过优化器来更新Q网络的权重。
  6. 记录每个回合的得分和蛇的长度,并计算滑动窗口的平均得分和平均蛇的长度,用于评估训练的效果。
  7. 循环训练直到达到预定的训练回合数(NUM_EPISODES)为止。
  8. 在训练过程中,可以保存一些训练中间结果,如每隔一定回合保存一次模型的权重。
  9. 训练完成后,获得训练过程中的得分、平均得分、平均蛇的长度和最大蛇的长度等数据,用于分析和可视化训练结果。

目录

RL-snack
├─ Game.py                                # 贪吃蛇,食物,游戏环境初始化类
├─ README.md
├─ WatchAgent.py                          # ai玩游戏/玩家玩游戏
├─ config
│    └─ config.py                         # 基础配置信息
├─ dir_chk                                # 训练好的模型,_5表示训练了5个episode,一个episode的信息在config.py查看
│    ├─ Snake_5
│    ├─ Snake_10
│    ├─ Snake_100
│    ├─ Snake_5000
│    └─ Snake_60000
├─ model.py                               # 神经网络定义
├─ npy                                    # 保存下来的numpy数据
│    ├─ avg_len_of_snake.npy
│    ├─ avg_scores.npy
│    ├─ max_len_of_snake.npy
│    └─ scores.npy
├─ outputImage                            # 保存下来的图片
│    ├─ drawAvgAndMaxLen.png
│    ├─ drawMaxHist.png
│    └─ drawScores.png
├─ replay_buffer.py                       # 经验回放机制类
└─ train.py                               # 执行训练

文件

train.py

代码使用了深度神经网络(Q网络)和经验回放(Replay Memory)来实现深度Q学习(Deep Q Learning)算法。

  1. run_episode 函数:这个函数表示一个完整的游戏回合。在每个回合中,游戏会进行多局游戏(num_games),并且记录每局游戏的奖励(reward)、蛇的平均长度(avg_len_of_snake)以及蛇的最大长度(max_len_of_snake)。
  2. learn 函数:这个函数用于更新神经网络模型的权重,以使其逼近贪吃蛇的最优策略。它使用深度Q学习算法中的样本数据进行训练。首先,从经验回放(memory)中随机采样一批样本数据(BATCH_SIZE),然后计算目标Q值(Q_targets)和预测Q值(Q_expected),并通过均方误差(criterion)计算损失函数,最后使用优化器(optimizer)来更新模型的权重。
  3. train 函数:这个函数是整个训练过程的主函数。它会进行多个回合(NUM_EPISODES)的训练。在每个回合中,调用 run_episode 函数来进行一定数量的游戏回合,并记录每个回合的奖励和蛇的长度。然后,通过 learn 函数来更新神经网络的权重。训练过程会记录游戏回合的得分和蛇的长度,并且每隔一定回合打印出当前的训练进度。
  4. 主程序:初始化模型、游戏环境、经验回放、优化器和损失函数等,然后调用 train 函数进行训练。

model.py

代码定义了一个简单的神经网络模型 QNetwork 和一个函数 get_network_input 来获取游戏状态信息并将其转换成神经网络的输入向量。

  1. QNetwork 类:这是一个继承自 nn.Module 的神经网络模型类,用于近似Q函数。Q函数是强化学习中表示动作价值的函数,输入是状态信息,输出是对应每个动作的Q值。模型使用全连接神经网络结构,并使用ReLU作为激活函数。
    • __init__ 方法:模型的初始化函数,定义神经网络的结构。接收三个参数:input_dim 表示输入向量的维度,hidden_dim 表示隐藏层的维度,output_dim 表示输出向量的维度(代表动作的数量)。
    • forward 方法:前向传播函数,定义了输入向量如何经过各层神经元的计算得到输出。首先通过三个隐藏层(fc1fc2fc3)进行线性变换并使用ReLU激活函数,然后通过最后一个全连接层 fc4 得到输出 l4
  2. get_network_input 函数:这个函数用于获取游戏状态信息并将其转换为神经网络的输入向量。它接收两个参数:player 表示贪吃蛇的对象,apple 表示苹果的对象。
    • 函数逻辑:首先调用 player 对象的 getproximity() 方法,获取贪吃蛇头部周围的环境信息,得到一个代表环境的数组 proximity。然后将贪吃蛇的头部位置、苹果的位置、贪吃蛇的头部方向和环境信息拼接成一个向量 x。最后将向量 x 转换为 PyTorch 的张量(torch.tensor)并返回。
  3. if __name__ == '__main__': 部分:这是主程序部分,用于测试上述代码的功能。
    • 创建了一个 QNetwork 的实例 model,并指定输入向量维度为 10、隐藏层维度为 20 和输出向量维度为 5
    • 创建了一个贪吃蛇游戏环境 env,并初始化了贪吃蛇和苹果的对象 playerapple
    • 调用 envresetgame() 方法来重置游戏环境,然后调用 get_network_input 函数,传入 playerapple 对象,获取游戏状态的输入向量 Input
    • 打印模型结构和获取的输入向量 Input

Game.py

代码实现了一个简单的贪吃蛇游戏环境,包含了三个类:SnakeClassAppleClassGameEnvironment

  1. SnakeClass 类:
    • __init__ 方法:初始化贪吃蛇对象,指定初始位置和方向,记录贪吃蛇的长度(len)。
    • __len__ 方法:返回贪吃蛇的长度。
    • move 方法:更新贪吃蛇的位置,将新的头部位置加入到 prevpos 中,并保持 prevpos 长度为 len + 1,以记录贪吃蛇的轨迹。
    • checkdead 方法:检查贪吃蛇是否死亡,判断是否碰到游戏边界或者碰撞到自己。
    • getproximity 方法:获取贪吃蛇头部周围环境的状态,返回四个方向移动是否会导致贪吃蛇死亡的列表。
  2. AppleClass 类:
    • __init__ 方法:初始化苹果对象,随机生成苹果的初始位置和初始得分。
    • eaten 方法:苹果被吃后,重新生成一个新的苹果位置,并增加得分。
  3. GameEnvironment 类:
    • __init__ 方法:初始化游戏环境,包括贪吃蛇对象(snake)、苹果对象(apple)、游戏结束标志(game_over)、网格大小(gridsize)、奖励设置(reward_nothingreward_deadreward_apple)等。
    • resetgame 方法:重置游戏环境,重新初始化贪吃蛇位置、苹果位置和得分,以及其他游戏状态变量。
    • get_boardstate 方法:获取游戏环境的状态信息,包括贪吃蛇位置、方向、轨迹、苹果位置、得分和游戏结束标志。
    • update_boardstate 方法:根据传入的动作 move 来更新游戏环境状态,包括更新贪吃蛇的方向和位置,判断是否吃到苹果,是否撞到边界或者撞到自己,以及计算奖励。
    • showState 方法:打印游戏环境的状态信息,包括贪吃蛇的位置、方向、轨迹、长度、苹果的位置、得分和游戏结束标志。

在主程序部分,对上述三个类进行了测试:

  • 创建了一个 SnakeClass 的实例 snake,并展示了其初始状态和经过一次移动后的状态。
  • 创建了一个 AppleClass 的实例 apple,并展示了其初始状态和被吃后的状态。
  • 创建了一个 GameEnvironment 的实例 env,指定了网格大小、奖励和惩罚值,展示了其初始状态和重置游戏后的状态。然后通过调用 update_boardstate 方法来模拟游戏环境的更新,并展示更新后的状态。

replay_buffer.py

代码实现了一个经验回放(Replay Memory)缓冲区,用于存储和采样训练数据,用于深度强化学习算法的训练。经验回放缓冲区是深度Q学习算法中的一个重要组成部分,它用于解决样本之间的相关性问题,提高训练的稳定性和效率。

  1. ReplayMemory 类:
    • __init__ 方法:初始化经验回放缓冲区,指定最大容量 max_size,并创建一个空的 buffer 列表用于存储经验元组。
    • push 方法:将一个完整的经验元组(state, action, reward, next_state, done)添加到缓冲区中。
    • sample 方法:从缓冲区中随机采样指定大小的批次数据。采样过程是随机的,用于打破样本之间的相关性。
    • truncate 方法:如果缓冲区的大小超过了最大容量,就会从头部截断多余的数据,保持缓冲区大小不超过 max_size
    • __len__ 方法:返回当前缓冲区中存储的经验元组数量。
  2. if __name__ == '__main__': 部分:这是主程序部分,用于测试 ReplayMemory 类的功能。
    • 创建了一个 ReplayMemory 的实例 replay_memory,并指定经验回放缓冲区的最大容量为 1000
    • 添加了 10 个随机生成的经验元组到缓冲区中,每个经验元组包含了 state(状态)、action(动作)、reward(奖励)、next_state(下一个状态)和 done(是否结束)信息。
    • 输出了缓冲区中存储的经验个数,即 replay_memory 的长度。
    • 从缓冲区中采样了一个大小为 5 的批次的经验数据,并输出采样得到的状态批次、动作批次、奖励批次、下一个状态批次和是否结束批次。

这样,经过以上操作,我们就能够将多个游戏状态信息存储到经验回放缓冲区中,并随机采样这些状态信息的批次,用于训练深度强化学习算法,提高训练效率和稳定性

WatchAgent.py

代码实现了一个使用训练好的深度强化学习模型玩贪吃蛇游戏的功能。

  1. drawboard 函数:用于绘制贪吃蛇游戏界面。根据贪吃蛇的位置和苹果的位置,绘制蛇身和苹果,并以渐变颜色表示蛇身的长度。同时,绘制背景网格用于美化界面。
  2. run_snake_game 函数:主游戏循环函数,用于实现游戏的运行和渲染。
    • 初始化游戏环境和相关参数。
    • 加载预训练的深度强化学习模型 model,用于决策贪吃蛇的动作。
    • 创建 Pygame 窗口和相关设置,设置游戏窗口标题。
    • 在游戏循环中,不断获取当前游戏状态,传入模型,获取模型的输出作为动作,然后更新游戏状态,绘制游戏界面,显示贪吃蛇长度和奖励等信息。
    • 检测游戏是否结束,若结束则重置游戏状态。
  3. 主程序部分,通过调用 run_snake_game 函数来运行贪吃蛇游戏。

训练结果

image

image

image

参考

https://github.com/Rafael1s/Deep-Reinforcement-Learning-Algorithms/tree/master/Snake-Pygame-DQN

About

强化学习贪吃蛇

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages