【Pytorch API笔记3】用torch.numel()来统计网络的参数量

如何统计网络的大小,可以试一试torch.numel()函数
torch.numel()函数,可以计算出单个tensor元素的个数

一、对单个tensor使用,求tensor元素的个数

x = torch.randn((1, 3, 5, 7))
x.numel()
torch.numel()

输出105

二、求整个网络的参数

  n_p = sum(x.numel() for x in model.parameters())  # number parameters

如下示意图,可以计算网络的参数量
一个线性层,输入维度为1,输出维度为100
这个网络有200个参数,可以用x.numel() 巧妙计算出整个网络所需要的参数量

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()self.linear = torch.nn.Linear(1, 100) # 输入1、输出的维度都是100def forward(self, x):out = self.linear(x)return outnet = LinearModel()
n_p = sum(x.numel() for x in net.parameters())  # number parameters
print(n_p)  ##  ------>输出为200 

本文链接:https://my.lmcjl.com/post/3157.html

展开阅读全文

4 评论

留下您的评论.