如何打印Pytorch在网络中的梯度值 您所在的位置:网站首页 python怎么打印变量的值 如何打印Pytorch在网络中的梯度值

如何打印Pytorch在网络中的梯度值

2023-09-14 03:07| 来源: 网络整理| 查看: 265

一、神经网络初始化

我喜欢在网络的构造函数中进行。比如

import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(3, 64, 3, padding = 1), nn.BatchNorm2d(64), nn.ReLU(True), nn.AvgPool2d(2, 2) ) self.conv21 = nn.Conv2d(64, 64*2, 3, padding = 1 ) self.pool2 = nn.AvgPool2d(2, 2) self.conv31 = nn.Conv2d(64*2, 10, 1) self.pool3 = nn.AvgPool2d(8, 8) self.line = nn.Linear(10,100) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight,1.0,0.02) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight,.0, 0.05) nn.init.zeros_(m.bias) def forward(self, x): x = self.conv1(x) x = F.relu(self.conv21(x)) x = self.pool2(x) x = self.conv31(x) x = self.pool3(x) x = x.view(-1, 10) x = self.line(x) return x

      上面代码中的__init__函数中的for m in self.modules(): 循环语句就是对网络结构中的结点进行赋值,那么这段代码是如何进行的呢?它是从外到内依次进行的。       第一次进入循环,m的值会是整个网络,则是

Net( (conv1): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): AvgPool2d(kernel_size=2, stride=2, padding=0) ) (conv21): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0) (conv31): Conv2d(128, 10, kernel_size=(1, 1), stride=(1, 1)) (pool3): AvgPool2d(kernel_size=8, stride=8, padding=0) (line): Linear(in_features=10, out_features=100, bias=True) )

这很明显不会进入任何分支,第二次进入循环,m的值是self.conv1右边的值,即为

Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): AvgPool2d(kernel_size=2, stride=2, padding=0) )

这也很明显不会进入任何分支,第三次进入循环,m的值是上一次self.conv1的序列内部分的第一个结构,即

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

这会进入分支中的第一个if语句。接下来的循环m的值为

BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

这会进入elif isinstance(m, nn.BatchNorm2d)分支结构,如此这样直到退出整个self.conv1序列,然后进入接下来的self.conv21,self.pool2等结点,直到整个网络赋值完成。

二、打印w,b和梯度

我们接着使用下面的例子查看初始化的值和权重

net = Net() print(net.conv21.bias) print(net.conv21.bias.grad) print(net.conv21.weight) print(net.conv21.weight.grad)

输出结果

conv21.bias = Parameter containing: tensor([0., 0.省略中间的0., 0., 0., 0., 0.], requires_grad=True) conv21.bias.grad = None conv21.weight = Parameter containing: tensor([[[[ 0.0026, 0.0542, -0.0444],省略中间的值, [-0.0391, -0.0124, 0.0965]]]],requires_grad=True) conv21.weight.grad = None

能看到偏置项确实为0了,权重也是我们想要的结果,但是为什么所有的偏导都为None呢?那是因为我们并没有进行反向传播,采用下面的代码进行一次反向传播后打印的值就不为None了。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net.to(device) inputs = torch.rand(4,3,32,32) labels = torch.rand(4)*10//5 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) inputs = inputs.to(device) labels = labels.to(device) outputs = net(inputs) loss = criterion(outputs, labels.long()) loss.backward() optimizer.step() print("conv21.bias = ",net.conv21.bias) print("conv21.bias.grad = ",net.conv21.bias.grad) print("conv21.weight = ",net.conv21.weight) print("conv21.weight.grad = ",net.conv21.weight.grad)

小TIP: 在调试代码的时候有时候想打印整个Tensor的值,这时我们可以使用下面的写法临时改变一下torch.set_printoptions(profile="full")将full改为default即是默认的省略输出

那么如何打印Sequential序列中的值呢

for i,m in enumerate(net.conv1.children()): if isinstance(m, nn.Conv2d): print("net.conv1."+str(i)+"(Conv2d).weight = ",m.weight) print("net.conv1."+str(i)+"(Conv2d).weight.grad = ",m.weight.grad) elif isinstance(m, nn.BatchNorm2d): print("net.conv1."+str(i)+"(BatchNorm2d).weight = ",m.weight) print("net.conv1."+str(i)+"(BatchNorm2d).weight.grad = ",m.weight.grad)

即循环遍历序列下的子项,使用isinstance进行判断目标对象并打印。 部分结果展示 在这里插入图片描述



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有