自学围棋的AlphaGoZero,你也能用PyTorch造一个|附代码实现(7)
2023-05-04 来源:飞速影视
二是训练(Training) ,拿新鲜生成的数据,来改进当前的神经网络。
1def train():2 criterion = AlphaLoss() 3 dataset = SelfPlayDataset() 4 player, checkpoint = load_player(current_time, loaded_version) 5 optimizer = create_optimizer(player, lr, 6 param=checkpoint["optimizer"]) 7 best_player = deepcopy(player) 8 dataloader = DataLoader(dataset, collate_fn=collate_fn, 9 batch_size=BATCH_SIZE, shuffle=True)1011 while True:12 for batch_idx, (state, move, winner) in enumerate(dataloader):1314 ## Evaluate a copy of the current network15 if total_ite % TRAIN_STEPS == 0:16 pending_player = deepcopy(player)17 result = evaluate(pending_player, best_player)1819 if result:20 best_player = pending_player2122 example = {23 "state": state,24 "winner": winner,25 "move" : move26 }27 optimizer.zero_grad()28 winner, probas = pending_player.predict(example["state"])2930 loss = criterion(winner, example["winner"], 31 probas, example["move"])32 loss.backward()33 optimizer.step()3435 ## Fetch new games36 if total_ite % REFRESH_TICK == 0:37 last_id = fetch_new_games(collection, dataset, last_id)
本站仅为学习交流之用,所有视频和图片均来自互联网收集而来,版权归原创者所有,本网站只提供web页面服务,并不提供资源存储,也不参与录制、上传
若本站收录的节目无意侵犯了贵司版权,请发邮件(我们会在3个工作日内删除侵权内容,谢谢。)
www.fs94.org-飞速影视 粤ICP备74369512号