如何统计网络的大小,可以试一试torch.numel()函数
torch.numel()函数,可以计算出单个tensor元素的个数
一、对单个tensor使用,求tensor元素的个数
x = torch.randn((1, 3, 5, 7))
x.numel()
torch.numel()
输出105
二、求整个网络的参数
n_p = sum(x.numel() for x in model.parameters()) # number parameters
如下示意图,可以计算网络的参数量
一个线性层,输入维度为1,输出维度为100
这个网络有200个参数,可以用x.numel() 巧妙计算出整个网络所需要的参数量
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()self.linear = torch.nn.Linear(1, 100) # 输入1、输出的维度都是100def forward(self, x):out = self.linear(x)return outnet = LinearModel()
n_p = sum(x.numel() for x in net.parameters()) # number parameters
print(n_p) ## ------>输出为200
本文链接:https://my.lmcjl.com/post/3157.html
展开阅读全文
4 评论