目录
- pytorch查看模型model参数parameters
- pytorch查看模型参数总结
- 1:DNN_printer
- 2:parameters
- 3:get_model_complexity_info()
- 4:twww.devze.comorchstat
pytorch查看模型model参数parameters
示例1:pytorch自带的faster r-cnn模型
import torch
import torchvision
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
示例2:自定义网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.features = self._vgg_layers(cfg)
def _vgg_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
nn.BATchNorm2d(x),
nn.ReLU(inplace=True)
]
in_channels = x
return nn.Sequential(*layers)
def forward(self, data):
out_map = self.features(data)
return out_map
Model = Net()
for name, p in model.named_parameters():
print(name)
编程客栈 print(p.requires_grad)
print(...)
#或者
for p in model.parameters():开发者_JS开发
print(p)
print(...)
在自定义网络中,model.parameters()方法继承自nn.Module
pytorch查看模型参数总结
1:DNN_printer
其中(3, 32, 32)是输入的大小,其他方法中的参数同理
from DNN_printer import DNN_printer
batch_size = 512
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
// put the code here and you can get the result
DNN_printer(net, (3, 32, 32),batch_size)
结果

2:parameters
def cnn_paras_count(net):
"""cnn参数量统计, 使用方式cnn_paras_count(net)"""
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_parajavascriptms = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
return total_params, total_trainable_params
cnn_paras_count(net)
直接输出参数量,然后自己计算
需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。
例如:
- 44426个参数*4 / 1024 ≈ 174KB
3:get_model_complexity_info()
from ptflops import get_model_complexity_info from torchvision import models net = models.mobilenet_v2() ops, params = get_model_complexity_info(net, (3, 224, 224), as编程客栈_strings=True, print_per_layer_stat=True, verbose=True)

4:torchstat
from torchstat import stat import torchvision.models as models model = models.resnet152(android) stat(model, (3, 224, 224))
输出

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。
加载中,请稍侯......
精彩评论