使用强化学习AlphaZero算法训练五子棋AI
案例目标
通过本案例的学习和课后作业的练习:
- 了解强化学习AlphaZero算法
- 利用AlphaZero算法进行一次五子棋AI训练
你也可以将本案例相关的 ipynb 学习笔记分享到 AI Gallery Notebook 版块获得成长值,分享方法请查看此文档。
案例内容介绍
AlphaZero是一种强化学习算法,近期利用AlphaZero训练出的AI以绝对的优势战胜了多名围棋以及国际象棋冠军。AlphaZero创新点在于,它能够在不依赖于外部先验知识即专家知识、仅仅了解游戏规则的情况下,在棋盘类游戏中获得超越人类的表现。
本次案例将详细的介绍AlphaZero算法核心原理,包括神经网络构建、MCTS搜索、自博弈训练,以代码的形式加深算法理解,算法详情亦可见论文《Mastering the game of Go without human knowledge》。同时本案例提供五子棋强化学习环境,利用AlphaZero进行一次五子棋训练。最后可视化五子棋AI自博弈对局。
由于在标准棋盘下训练一个强力的五子棋AI需要大量的训练时间和资源,本案例将棋盘缩小到了6x6x4,且在运行过程中简化了训练过程,减少了自博弈次数和搜索次数。如果想要完整地训练一个五子棋AlphaZero AI,可在AI Gallery中订阅《Gomoku-训练五子棋小游戏》算法并在ModelArts中进行训练。
源码参考GitHub开源项目AlphaZero_Gomoku
注意事项
-
本案例运行环境为 Pytorch-1.0.0,且需使用 GPU 运行,请查看《ModelAtrs JupyterLab 硬件规格使用指南》了解切换硬件规格的方法;
-
如果您是第一次使用 JupyterLab,请查看《ModelAtrs JupyterLab使用指导》了解使用方法;
-
如果您在使用 JupyterLab 过程中碰到报错,请参考《ModelAtrs JupyterLab常见问题解决办法》尝试解决问题。
-
建议逐步运行
!pip install gym
第
2.进行训练参数配置
为简化训练过程,涉及到影响训练时长的参数都设置的较小,且棋盘大小也减小为6x6,棋子连线降低为4。
步:导入相关的库
import os
import copy
import random
import time
from operator import itemgetter
from collections import defaultdictdeque
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import gym
from gym.spaces import BoxDiscrete
import matplotlib.pyplot as plt
from IPython import display
2.进行训练参数配置
为简化训练过程,涉及到影响训练时长的参数都设置的较小,且棋盘大小也减小为6x6,棋子连线降低为4。
board_width = 6 # 棋盘宽
board_height = 6 # 棋盘高
n_in_row = 4 # 胜利需要连成线棋子
c_puct = 5 # 决定探索程度
n_playout = 100 # 每步模拟次数
learn_rate = 0.002 # 学习率
lr_multiplier = 1.0 # 基于KL的自适应学习率调整
temperature = 1.0 # 温度参数
noise_eps = 0.75 # 噪声参数
dirichlet_alpha = 0.3 # dirichlet系数
buffer_size = 5000 # buffer大小
train_batch_size = 128 # batchsize大小
update_epochs = 5 # 多少个epoch更新一次
kl_coeff = 0.02 # kl系数
checkpoint_freq = 20 # 模型保存频率
mcts_infer = 200 # 纯mcts推理时间
restore_model = None # 是否加载预训练模型
game_batch_num=40 # 训练步数
model_path="." # 模型保存路径
3.构建环境
五子棋的环境是按照标准gym环境构建的,棋盘宽x高,先在横线、直线或斜对角线上形成n子连线的玩家获胜。
状态空间为[4,棋盘宽,棋盘高],四个维度分别为当前视角下的位置,对手位置,上次位置以及轮次。
class GomokuEnv(gym.Env):
def __init__(selfstart_player=0):
self.start_player = start_player
self.action_space = Discrete((board_width * board_height))
self.observation_space = Box(01shape=(4board_widthboard_height))
self.reward = 0
self.info = {}
self.players = [12] # player1 and player2
def step(selfaction):
self.states[action] = self.current_player
if action in self.availables:
self.availables.remove(action)
self.last_move = action
donewinner = self.game_end()
reward = 0
if done:
if winner == self.current_player:
reward = 1
else:
reward = -1
self.current_player = (
self.players[0] if self.current_player == self.players[1]
else self.players[1]
)
# update state
obs = self.current_state()
return obsrewarddoneself.info
def reset(self):
if board_width < n_in_row or board_height < n_in_row:
raise Exception('board width and height can not be '
'less than {}'.format(n_in_row))
self.current_player = self.players[self.start_player] # start player
# keep available moves in a list
self.availables = list(range(board_width * board_height))
self.states = {}
self.last_move = -1
return self.current_state()
def render(selfmode='human'start_player=0):
width = board_width
height = board_height
p1p2 = self.players
print()
for x in range(width):
print("{0:8}".format(x)end='')
print('\r\n')
for i in range(height - 1-1-1):
print("{0:4d}".format(i)end='')
for j in range(width):
loc = i * width + j
p = self.states.get(loc-1)
if p == p1:
print('B'.center(8)end='')
elif p == p2:
print('W'.center(8)end='')
else:
print('_'.center(8)end='')
print('\r\n\r\n')
def has_a_winner(self):
states = self.states
moved = list(set(range(board_width * board_height)) - set(self.availables))
if len(moved) < n_in_row * 2 - 1:
return False-1
for m in moved:
h = m // board_width
w = m % board_width
player = states[m]
if (w in range(board_width - n_in_row + 1) and
len(set(states.get(i-1) for i in range(mm + n_in_row))) == 1):
return Trueplayer
if (h in range(board_height - n_in_row + 1) and
len(set(states.get(i-1) for i in range(mm + n_in_row * board_widthboard_width))) == 1):
return Trueplayer
if (w in range(board_width - n_in_row + 1) and h in range(board_height - n_in_row + 1) and
len(set(
states.get(i-1) for i in range(mm + n_in_row * (board_width + 1)board_width + 1))) == 1):
return Trueplayer
if (w in range(n_in_row - 1board_width) and h in range(board_height - n_in_row + 1) and
len(set(
states.get(i-1) for i in range(mm + n_in_row * (board_width - 1)board_width - 1))) == 1):
return Trueplayer
return False-1
def game_end(self):
"""Check whether the game is ended or not"""
winwinner = self.has_a_winner()
if win:
# print("winner is player{}".format(winner))
return Truewinner
elif not len(self.availables):
return True-1
return False-1
def current_state(self):
"""return the board state from the perspective of the current player.
state shape: 4*width*height
"""
square_state = np.zeros((4board_widthboard_height))
if self.states:
movesplayers = np.array(list(zip(*self.states.items())))
move_curr = moves[players == self.current_player]
move_oppo = moves[players != self.current_player]
square_state[0][move_curr // board_width,
move_curr % board_height] = 1.0
square_state[1][move_oppo // board_width,
move_oppo % board_height] = 1.0
# indicate the last move location
square_state[2][self.last_move // board_width,
self.last_move % board_height] = 1.0
if len(self.states) % 2 == 0:
square_state[3][::] = 1.0 # indicate the colour to play
return square_state[:::-1:]
def start_play(selfplayer1player2start_player=0):
"""start a game between two players"""
if start_player not in (01):
raise Exception('start_player should be either 0 (player1 first) '
'or 1 (player2 first)')
self.reset()
p1p2 = self.players
player1.set_player_ind(p1)
player2.set_player_ind(p2)
players = {p1: player1p2: player2}
while True:
player_in_turn = players[self.current_player]
move = player_in_turn.get_action(self)
self.step(move)
endwinner = self.game_end()
if end:
return winner
def start_self_play(selfplayer):
""" start a self-play game using a MCTS playerreuse the search tree,
and store the self-play data: (statemcts_probsz) for training
"""
self.reset()
statesmcts_probscurrent_players = [][][]
while True:
movemove_probs = player.get_action(selfreturn_prob=1)
# store the data
states.append(self.current_state())
mcts_probs.append(move_probs)
current_players.append(self.current_player)
# perform a move
self.step(move)
endwinner = self.game_end()
if end:
# winner from the perspective of the current player of each state
winners_z = np.zeros(len(current_players))
if winner != -1:
winners_z[np.array(current_players) == winner] = 1.0
winners_z[np.array(current_players) != winner] = -1.0
# reset MCTS root node
player.reset_player()
return winnerzip(statesmcts_probswinners_z)
def location_to_move(selflocation):
if (len(location) != 2):
return -1
h = location[0]
w = location[1]
move = h * board_width + w
if (move not in range(board_width * board_width)):
return -1
return move
def move_to_location(selfmove):
"""
3*3 board's moves like:
6 7 8
3 4 5
0 1 2
and move 5's location is (1,2)
"""
h = move // board_width
w = move % board_width
return [hw]
4.构建神经网络
网络结构较为简单,backbone部分是三层卷积神经网络,提取特征后分为两个分支。一个是价值分支,输出当前棋面价值。另一个是决策分支,输出神经网络计算得到的动作对应概率。
class Net(nn.Module):
"""policy-value network module"""
def __init__(self):
super(Netself).__init__()
# common layers
self.conv1 = nn.Conv2d(432kernel_size=3padding=1)
self.conv2 = nn.Conv2d(3264kernel_size=3padding=1)
self.conv3 = nn.Conv2d(64128kernel_size=3padding=1)
# action policy layers
self.act_conv1 = nn.Conv2d(1284kernel_size=1)
self.act_fc1 = nn.Linear(4 * board_width * board_height,
board_width * board_height)
# state value layers
self.val_conv1 = nn.Conv2d(1282kernel_size=1)
self.val_fc1 = nn.Linear(2 * board_width * board_height64)
self.val_fc2 = nn.Linear(641)
def forward(selfstate_input):
# common layers
x = F.relu(self.conv1(state_input))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# action policy layers
x_act = F.relu(self.act_conv1(x))
x_act = x_act.view(-14 * board_width * board_height)
x_act = F.log_softmax(self.act_fc1(x_act))
# state value layers
x_val = F.relu(self.val_conv1(x))
x_val = x_val.view(-12 * board_width * board_height)
x_val = F.relu(self.val_fc1(x_val))
x_val = F.tanh(self.val_fc2(x_val))
return x_actx_val
class PolicyValueNet:
"""policy-value network """
def __init__(selfmodel_file=None):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.l2_const = 1e-4 # coef of l2 penalty
# the policy value net module
self.policy_value_net = Net().to(self.device)
self.optimizer = optim.Adam(self.policy_value_net.parameters(),
weight_decay=self.l2_const)
if model_file:
net_params = torch.load(model_file)
self.policy_value_net.load_state_dict(net_params)
def policy_value(selfstate_batch):
"""
input: a batch of states
output: a batch of action probabilities and state values
"""
state_batch = Variable(torch.FloatTensor(state_batch).to(self.device))
log_act_probsvalue = self.policy_value_net(state_batch)
act_probs = np.exp(log_act_probs.data.cpu().numpy())
return act_probsvalue.data.cpu().numpy()
def policy_value_fn(selfboard):
"""
input: board
output: a list of (actionprobability) tuples for each available
action and the score of the board state
"""
legal_positions = board.availables
current_state = np.ascontiguousarray(board.current_state().reshape(
-14board_widthboard_height))
log_act_probsvalue = self.policy_value_net(
Variable(torch.from_numpy(current_state)).to(self.device).float())
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
act_probs = zip(legal_positionsact_probs[legal_positions])
value = value.data[0][0]
return act_probsvalue
def train_step(selfstate_batchmcts_probswinner_batchlr):
"""perform a training step"""
# wrap in Variable
state_batch = Variable(torch.FloatTensor(state_batch).to(self.device))
mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(self.device))
winner_batch = Variable(torch.FloatTensor(winner_batch).to(self.device))
# zero the parameter gradients
self.optimizer.zero_grad()
# set learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
# forward
log_act_probsvalue = self.policy_value_net(state_batch)
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
# Note: the L2 penalty is incorporated in optimizer
value_loss = F.mse_loss(value.view(-1)winner_batch)
policy_loss = -torch.mean(torch.sum(mcts_probs * log_act_probs1))
loss = value_loss + policy_loss
# backward and optimize
loss.backward()
self.optimizer.step()
# calc policy entropyfor monitoring only
entropy = -torch.mean(
torch.sum(torch.exp(log_act_probs) * log_act_probs1)
)
# return loss.dataentropy.data
# for pytorch version >= 0.5 please use the following line instead.
return loss.item()entropy.item()
def get_policy_param(self):
net_params = self.policy_value_net.state_dict()
return net_params
def save_model(selfmodel_file):
""" save model params to file """
net_params = self.get_policy_param() # get model params
torch.save(net_paramsmodel_file)
5.实现MCTS¶
AlphaZero利用MCTS来自博弈生成棋局,MCTS搜索原理简述如下:
-
每次模拟通过选择具有最大行动价值Q的边加上取决于所存储的先验概率P和该边的访问计数N(每次访问都被增加一次)的上限置信区间U来遍历树。
-
展开叶子节点,通过神经网络来评估局面s;向量P的值存储在叶子结点扩展的边上。
-
更新行动价值Q等于在该行动下的子树中的所有评估值V的均值。
-
一旦MCTS搜索完成,返回局面s下的落子概率π。
def softmax(x):
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
def rollout_policy_fn(board):
"""a coarsefast version of policy_fn used in the rollout phase."""
# rollout randomly
action_probs = np.random.rand(len(board.availables))
return zip(board.availablesaction_probs)
def policy_value_fn(board):
"""a function that takes in a state and outputs a list of (actionprobability)
tuples and a score for the state"""
# return uniform probabilities and 0 score for pure MCTS
action_probs = np.ones(len(board.availables)) / len(board.availables)
return zip(board.availablesaction_probs)0
class TreeNode:
"""A node in the MCTS tree.
Each node keeps track of its own value Qprior probability Pand
its visit-count-adjusted prior score u.
"""
def __init__(selfparentprior_p):
self._parent = parent
self._children = {} # a map from action to TreeNode
self._n_visits = 0
self._Q = 0
self._u = 0
self._P = prior_p
def expand(selfaction_priors):
"""Expand tree by creating new children.
action_priors: a list of tuples of actions and their prior probability
according to the policy function.
"""
for actionprob in action_priors:
if action not in self._children:
self._children[action] = TreeNode(selfprob)
def select(selfc_puct):
"""Select action among children that gives maximum action value Q
plus bonus u(P).
Return: A tuple of (actionnext_node)
"""
return max(self._children.items(),
key=lambda act_node: act_node[1].get_value(c_puct))
def update(selfleaf_value):
"""Update node values from leaf evaluation.
leaf_value: the value of subtree evaluation from the current player's
perspective.
"""
# Count visit.
self._n_visits += 1
# Update Qa running average of values for all visits.
self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits
def update_recursive(selfleaf_value):
"""Like a call to update()but applied recursively for all ancestors.
"""
# If it is not rootthis node's parent should be updated first.
if self._parent:
self._parent.update_recursive(-leaf_value)
self.update(leaf_value)
def get_value(selfc_puct):
"""Calculate and return the value for this node.
It is a combination of leaf evaluations Qand this node's prior
adjusted for its visit countu.
c_puct: a number in (0inf) controlling the relative impact of
value Qand prior probability Pon this node's score.
"""
self._u = (c_puct * self._P *
np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
return self._Q + self._u
def is_leaf(self):
"""Check if leaf node (i.e. no nodes below this have been expanded)."""
return self._children == {}
def is_root(self):
return self._parent is None
class MCTS:
"""An implementation of Monte Carlo Tree Search."""
def __init__(selfpolicy_value_fnc_puct=5):
"""
policy_value_fn: a function that takes in a board state and outputs
a list of (actionprobability) tuples and also a score in [-11]
(i.e. the expected value of the end game score from the current
player's perspective) for the current player.
c_puct: a number in (0inf) that controls how quickly exploration
converges to the maximum-value policy. A higher value means
relying on the prior more.
"""
self._root = TreeNode(None1.0)
self._policy = policy_value_fn
self._c_puct = c_puct
def _playout(selfstate):
"""Run a single playout from the root to the leafgetting a value at
the leaf and propagating it back through its parents.
State is modified in-placeso a copy must be provided.
"""
node = self._root
while (1):
if node.is_leaf():
break
# Greedily select next move.
actionnode = node.select(self._c_puct)
state.step(action)
# Evaluate the leaf using a network which outputs a list of
# (actionprobability) tuples p and also a score v in [-11]
# for the current player.
action_probsleaf_value = self._policy(state)
# Check for end of game.
endwinner = state.game_end()
if not end:
node.expand(action_probs)
else:
# for end statereturn the true leaf_value
if winner == -1: # tie
leaf_value = 0.0
else:
leaf_value = (
1.0 if winner == state.current_player else -1.0
)
# Update value and visit count of nodes in this traversal.
node.update_recursive(-leaf_value)
def _playout_p(selfstate):
"""Run a single playout from the root to the leafgetting a value at
the leaf and propagating it back through its parents.
State is modified in-placeso a copy must be provided.
"""
node = self._root
while (1):
if node.is_leaf():
break
# Greedily select next move.
actionnode = node.select(self._c_puct)
state.step(action)
action_probs_ = self._policy(state)
# Check for end of game
endwinner = state.game_end()
if not end:
node.expand(action_probs)
# Evaluate the leaf node by random rollout
leaf_value = self._evaluate_rollout(state)
# Update value and visit count of nodes in this traversal.
node.update_recursive(-leaf_value)
def _evaluate_rollout(selfenvlimit=1000):
"""Use the rollout policy to play until the end of the game,
returning +1 if the current player wins-1 if the opponent wins,
and 0 if it is a tie.
"""
player = env.current_player
for i in range(limit):
endwinner = env.game_end()
if end:
break
action_probs = rollout_policy_fn(env)
max_action = max(action_probskey=itemgetter(1))[0]
env.step(max_action)
else:
# If no break from the loopissue a warning.
print("WARNING: rollout reached move limit")
if winner == -1: # tie
return 0
else:
return 1 if winner == player else -1
def get_move_probs(selfstatetemp=1e-3):
"""Run all playouts sequentially and return the available actions and
their corresponding probabilities.
state: the current game state
temp: temperature parameter in (01] controls the level of exploration
"""
for n in range(n_playout):
state_copy = copy.deepcopy(state)
self._playout(state_copy)
# calc the move probabilities based on visit counts at the root node
act_visits = [(actnode._n_visits)
for actnode in self._root._children.items()]
actsvisits = zip(*act_visits)
act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
return actsact_probs
def get_move(selfstate):
"""Runs all playouts sequentially and returns the most visited action.
state: the current game state
Return: the selected action
"""
for n in range(n_playout):
state_copy = copy.deepcopy(state)
self._playout_p(state_copy)
return max(self._root._children.items(),
key=lambda act_node: act_node[1]._n_visits)[0]
def update_with_move(selflast_move):
"""Step forward in the treekeeping everything we already know
about the subtree.
"""
if last_move in self._root._children:
self._root = self._root._children[last_move]
self._root._parent = None
else:
self._root = TreeNode(None1.0)
def __str__(self):
return "MCTS"
6.实现自博弈过程
实现自博弈训练,此处博弈双方分别为基于MCTS的神经网络和纯MCTS,对弈过程中,前者基于神经网络和MCTS获取最优下子策略,而后者则仅根据MCTS搜索下子策略。保存对局数据
class MCTS_Pure:
"""AI player based on MCTS"""
def __init__(self):
self.mcts = MCTS(policy_value_fnc_puct)
def set_player_ind(selfp):
self.player = p
def reset_player(self):
self.mcts.update_with_move(-1)
def get_action(selfboard):
sensible_moves = board.availables
if len(sensible_moves) > 0:
move = self.mcts.get_move(board)
self.mcts.update_with_move(-1)
return move
else:
print("WARNING: the board is full")
def __str__(self):
return "MCTS {}".format(self.player)
class MCTSPlayer(MCTS_Pure):
"""AI player based on MCTS"""
def __init__(selfpolicy_value_functionis_selfplay=0):
super(MCTS_Pureself).__init__()
self.mcts = MCTS(policy_value_functionc_puct)
self._is_selfplay = is_selfplay
def get_action(selfenvreturn_prob=0):
sensible_moves = env.availables
# the pi vector returned by MCTS as in the alphaGo Zero paper
move_probs = np.zeros(board_width * board_width)
if len(sensible_moves) > 0:
actsprobs = self.mcts.get_move_probs(envtemperature)
move_probs[list(acts)] = probs
if self._is_selfplay:
# add Dirichlet Noise for exploration (needed for
# self-play training)
move = np.random.choice(
acts,
p=noise_eps * probs + (1 - noise_eps) * np.random.dirichlet(
dirichlet_alpha * np.ones(len(probs))))
# update the root node and reuse the search tree
self.mcts.update_with_move(move)
else:
# with the default temp=1e-3it is almost equivalent
# to choosing the move with the highest prob
move = np.random.choice(actsp=probs)
# reset the root node
self.mcts.update_with_move(-1)
if return_prob:
return movemove_probs
else:
return move
else:
print("WARNING: the board is full")
7.训练主函数
训练过程包括自我对局,数据生成,模型更新和保存
class TrainPipeline:
def __init__(self):
# params of the board and the game
self.env = GomokuEnv()
# training params
self.data_buffer = deque(maxlen=buffer_size)
self.play_batch_size = 1
self.best_win_ratio = 0.0
# start training from an initial policy-value net
self.policy_value_net = PolicyValueNet(model_file=restore_model)
self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
is_selfplay=1)
self.mcts_infer = mcts_infer
self.lr_multiplier = lr_multiplier
def get_equi_data(selfplay_data):
"""augment the data set by rotation and flipping
play_data: [(statemcts_probwinner_z)......]
"""
extend_data = []
for statemcts_porbwinner in play_data:
for i in [1234]:
# rotate counterclockwise
equi_state = np.array([np.rot90(si) for s in state])
equi_mcts_prob = np.rot90(np.flipud(
mcts_porb.reshape(board_heightboard_width))i)
extend_data.append((equi_state,
np.flipud(equi_mcts_prob).flatten(),
winner))
# flip horizontally
equi_state = np.array([np.fliplr(s) for s in equi_state])
equi_mcts_prob = np.fliplr(equi_mcts_prob)
extend_data.append((equi_state,
np.flipud(equi_mcts_prob).flatten(),
winner))
return extend_data
def collect_selfplay_data(selfn_games=1):
"""collect self-play data for training"""
for i in range(n_games):
winnerplay_data = self.env.start_self_play(self.mcts_player)
play_data = list(play_data)[:]
self.episode_len = len(play_data)
# augment the data
play_data = self.get_equi_data(play_data)
self.data_buffer.extend(play_data)
def policy_update(self):
"""update the policy-value net"""
mini_batch = random.sample(self.data_buffertrain_batch_size)
state_batch = [data[0] for data in mini_batch]
mcts_probs_batch = [data[1] for data in mini_batch]
winner_batch = [data[2] for data in mini_batch]
old_probsold_v = self.policy_value_net.policy_value(state_batch)
for i in range(update_epochs):
lossentropy = self.policy_value_net.train_step(
state_batch,
mcts_probs_batch,
winner_batch,
learn_rate * self.lr_multiplier)
new_probsnew_v = self.policy_value_net.policy_value(state_batch)
kl = np.mean(np.sum(old_probs * (
np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
axis=1)
)
if kl > kl_coeff * 4: # early stopping if D_KL diverges badly
break
# adaptively adjust the learning rate
if kl > kl_coeff * 2 and self.lr_multiplier > 0.1:
self.lr_multiplier /= 1.5
elif kl < kl_coeff / 2 and self.lr_multiplier < 10:
self.lr_multiplier *= 1.5
return lossentropy
def policy_evaluate(selfn_games=10):
"""
Evaluate the trained policy by playing against the pure MCTS player
Note: this is only for monitoring the progress of training
"""
current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn)
pure_mcts_player = MCTS_Pure()
win_cnt = defaultdict(int)
for i in range(n_games):
winner = self.env.start_play(current_mcts_player,
pure_mcts_player,
start_player=i % 2)
win_cnt[winner] += 1
win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
print("num_playouts:{}win: {}lose: {}tie:{}".format(self.mcts_infer,
win_cnt[1]win_cnt[2]win_cnt[-1]))
return win_ratio
def run(self):
"""run the training pipeline"""
win_num = 0
try:
for i_step in range(game_batch_num):
self.collect_selfplay_data(self.play_batch_size)
print("batch i:{}episode_len:{}".format(
i_step + 1self.episode_len))
if len(self.data_buffer) > train_batch_size:
lossentropy = self.policy_update()
# check the performance of the current model,
# and save the model params
if (i_step + 1) % checkpoint_freq == 0:
print("current self-play batch: {}".format(i_step + 1))
win_ratio = self.policy_evaluate()
self.policy_value_net.save_model(os.path.join(model_path"newest_model.pt"))
if win_ratio > self.best_win_ratio:
win_num += 1
# print("New best policy!!!!!!!!")
self.best_win_ratio = win_ratio
# update the best_policy
self.policy_value_net.save_model(os.path.join(model_path"best_model.pt"))
if self.best_win_ratio == 1.0 and self.mcts_infer < 5000:
self.mcts_infer += 1000
self.best_win_ratio = 0.0
except KeyboardInterrupt:
print('\n\rquit')
return win_num
8.开始自博弈训练,并保存模型
GPU训练耗时约4分钟
start_t = time.time()
training_pipeline = TrainPipeline()
training_pipeline.run()
print("time cost is {}".format(time.time()-start_t))
batch i:1episode_len:13
batch i:2episode_len:16
batch i:3episode_len:13
batch i:4episode_len:11
batch i:5episode_len:11
batch i:6episode_len:15
batch i:7episode_len:13
batch i:8episode_len:13
batch i:9episode_len:19
batch i:10episode_len:14
batch i:11episode_len:11
batch i:12episode_len:19
batch i:13episode_len:15
batch i:14episode_len:22
batch i:15episode_len:8
batch i:16episode_len:16
batch i:17episode_len:15
batch i:18episode_len:13
batch i:19episode_len:15
batch i:20episode_len:17
current self-play batch: 20
num_playouts:200win: 2lose: 8tie:0
batch i:21episode_len:12
batch i:22episode_len:14
batch i:23episode_len:12
batch i:24episode_len:11
batch i:25episode_len:15
batch i:26episode_len:7
batch i:27episode_len:13
batch i:28episode_len:10
batch i:29episode_len:10
batch i:30episode_len:13
batch i:31episode_len:17
batch i:32episode_len:12
batch i:33episode_len:11
batch i:34episode_len:11
batch i:35episode_len:9
batch i:36episode_len:14
batch i:37episode_len:17
batch i:38episode_len:12
batch i:39episode_len:13
batch i:40episode_len:18
current self-play batch: 40
num_playouts:200win: 3lose: 7tie:0
time cost is 250.64277577400208
9.AI对战
(等待第8步运行结束后再运行此步)
加载模型,进行人机对战
# 定义当前玩家
class CurPlayer:
player_id = 0
# 可视化部分
class Game(object):
def __init__(selfboard):
self.board = board
self.cell_size = board_width - 1
self.chess_size = 50 * self.cell_size
self.whitex = []
self.whitey = []
self.blackx = []
self.blacky = []
# 棋盘背景色
self.color = "#e4ce9f"
self.colors = [[self.color] * self.cell_size for _ in range(self.cell_size)]
def graphic(selfboardplayer1player2):
"""Draw the board and show game info"""
plt_figax = plt.subplots(facecolor=self.color)
ax.set_facecolor(self.color)
# 制作棋盘
# mytable = ax.table(cellColours=self.colorsloc='center')
mytable = plt.table(cellColours=self.colors,
colWidths=[1 / board_width] * self.cell_size,
loc='center'
)
ax.set_aspect('equal')
# 网格大小
cell_height = 1 / board_width
for poscell in mytable.get_celld().items():
cell.set_height(cell_height)
mytable.auto_set_font_size(False)
mytable.set_fontsize(self.cell_size)
ax.set_xlim([1board_width * 2 + 1])
ax.set_ylim([board_height * 2 + 11])
plt.title("Gomoku")
plt.axis('off')
cur_player = CurPlayer()
while True:
# left down of mouse
try:
if cur_player.player_id == 1:
move = player1.get_action(self.board)
self.board.step(move)
xy = self.board.move_to_location(move)
plt.scatter((y + 1) * 2(x + 1) * 2s=self.chess_sizec='white')
cur_player.player_id = 0
elif cur_player.player_id == 0:
move = player2.get_action(self.board)
self.board.step(move)
xy = self.board.move_to_location(move)
plt.scatter((y + 1) * 2(x + 1) * 2s=self.chess_sizec='black')
cur_player.player_id = 1
endwinner = self.board.game_end()
if end:
if winner != -1:
ax.text(x=board_widthy=(board_height + 1) * 2 + 0.1,
s="Game end. Winner is player {}".format(cur_player.player_id)fontsize=10,
color='red'weight='bold',
horizontalalignment='center')
else:
ax.text(x=board_widthy=(board_height + 1) * 2 + 0.1,
s="Game end. Tie Round".format(cur_player.player_id)fontsize=10color='red',
weight='bold',
horizontalalignment='center')
return winner
display.display(plt.gcf())
display.clear_output(wait=True)
except:
pass
def start_play(selfplayer1player2start_player=0):
"""start a game between two players"""
if start_player not in (01):
raise Exception('start_player should be either 0 (player1 first) '
'or 1 (player2 first)')
self.board.reset()
p1p2 = self.board.players
player1.set_player_ind(p1)
player2.set_player_ind(p2)
self.graphic(self.boardplayer1player2)
# 初始化棋盘
board = GomokuEnv()
game = Game(board)
# 加载模型
best_policy = PolicyValueNet(model_file="best_model.pt")
# 两个AI对打
mcts_player = MCTSPlayer(best_policy.policy_value_fn)
#开始对打
game.start_play(mcts_playermcts_playerstart_player=0

至此,本案例结束,如果想要完整地训练一个五子棋AlphaZero AI,可在AI Gallery中订阅《Gomoku-训练五子棋小游戏》算法并在ModelArts中进行训练。
华为开发者空间发布
让每位开发者拥有一台云主机
- 点赞
- 收藏
- 关注作者
评论(0)