自学围棋的AlphaGoZero,你也能用PyTorch造一个|附代码实现(2)
2023-05-04 来源:飞速影视
跳跃的样子,写成代码就是:
1class BasicBlock(nn.Module):2 """ 3 Basic residual block with 2 convolutions and a skip connection 4 before the last ReLU activation. 5 """ 6 7 def __init__(self, inplanes, planes, stride=1, downsample=None): 8 super(BasicBlock, self).__init__() 910 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,11 stride=stride, padding=1, bias=False)12 self.bn1 = nn.BatchNorm2d(planes)1314 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,15 stride=stride, padding=1, bias=False)16 self.bn2 = nn.BatchNorm2d(planes)171819 def forward(self, x):20 residual = x2122 out = self.conv1(x)23 out = F.relu(self.bn1(out))2425 out = self.conv2(out)26 out = self.bn2(out)2728 out = residual29 out = F.relu(out)3031 return out
然后,把它加到特征提取模型里面去:
1class Extractor(nn.Module):2 def __init__(self, inplanes, outplanes): 3 super(Extractor, self).__init__() 4 self.conv1 = nn.Conv2d(inplanes, outplanes, stride=1, 5 kernel_size=3, padding=1, bias=False) 6 self.bn1 = nn.BatchNorm2d(outplanes) 7 8 for block in range(BLOCKS): 9 setattr(self, "res{}".format(block), 10 BasicBlock(outplanes, outplanes))111213 def forward(self, x):14 x = F.relu(self.bn1(self.conv1(x)))15 for block in range(BLOCKS - 1):16 x = getattr(self, "res{}".format(block))(x)1718 feature_maps = getattr(self, "res{}".format(BLOCKS - 1))(x)19 return feature_maps
本站仅为学习交流之用,所有视频和图片均来自互联网收集而来,版权归原创者所有,本网站只提供web页面服务,并不提供资源存储,也不参与录制、上传
若本站收录的节目无意侵犯了贵司版权,请发邮件(我们会在3个工作日内删除侵权内容,谢谢。)
www.fs94.org-飞速影视 粤ICP备74369512号