-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
104 lines (80 loc) · 2.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import pygame as pg
import sys
from game_objects import BallObj, PaddleObj
import neat
import os
# config window
pg.init()
clock = pg.time.Clock()
size = (800, 600)
bg_color = pg.Color('grey12')
screen = pg.display.set_mode(size)
pg.display.set_caption('Pong')
def train_genome(genomes, config):
# start by creating lists holding the genome itself, the
# neural network associated with the genome and the
# paddle object that uses that network to play
nets = []
ge = []
paddles = []
balls = []
for genome_id, genome in genomes:
genome.fitness = 0 # start with fitness level of 0
net = neat.nn.FeedForwardNetwork.create(genome, config)
nets.append(net)
balls.append(BallObj.Ball(size, size[0] / 2, size[1] / 2))
paddles.append(PaddleObj.Paddle(size, size[0] / 2, size[1]))
ge.append(genome)
while True and len(paddles) > 0:
screen.fill(bg_color)
# Handle input
for event in pg.event.get():
if event.type == pg.QUIT:
pg.quit()
sys.exit()
for i, paddle in enumerate(paddles):
if balls[i].body.bottom == paddle.body.top and (
balls[i].body.right >= paddle.body.left and balls[i].body.left <= paddle.body.right):
ge[i].fitness += 1
if balls[i].body.top >= paddle.body.bottom:
ge[i].fitness -= 1
output = nets[i].activate((paddle.body.x, balls[i].body.x, balls[i].body.y, balls[i].speed_x))
if output[0] > 0.5:
paddle.movement(1)
if output[1] > 0.5:
paddle.movement(-1)
if output[2] > 0.5:
paddle.movement(0)
paddle.update()
balls[i].update(paddle)
paddle.draw(screen)
balls[i].draw(screen)
if not paddle.alive:
paddles.remove(paddle)
balls.remove(balls[i])
nets.remove(nets[i])
ge.remove(ge[i])
# Update the window
pg.display.flip()
clock.tick(60)
# print(clock.get_fps())
def run(config_file):
config = neat.config.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
config_file)
# Create the population, which is the top-level object for a NEAT run.
p = neat.Population(config)
# Add a stdout reporter to show progress in the terminal.
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
# p.add_reporter(neat.Checkpointer(5))
# Run for up to 10 generations.
winner = p.run(train_genome, 10)
# show final stats
print('\nBest genome:\n{!s}'.format(winner))
if __name__ == '__main__':
# set path to neat
local_dir = os.path.dirname(__file__)
config_path = local_dir + '/neat_config'
run(config_path)