微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何计算一维CNN的FLOP和参数数量?使用pytorch平台

如何解决如何计算一维CNN的FLOP和参数数量?使用pytorch平台

我的网络是一维CNN,我想计算FLOP和参数的数量。我使用了公共方法“ flops_counter”,但不确定输入的大小。当我以size(128,1,50)运行它时,出现错误“对于3维权重[128,1,50]的预期3维输入,但是得到了大小[1,128,1]的4维输入,50]”。当我以size(128,50)运行它时,出现错误'RuntimeError:给定groups = 1,大小为[128,1,50]的权重,预期输入[1,128,50]具有1个通道,但是得到了改为使用128个频道。”

import torch
from models.cnn import net
from flops_counter import get_model_complexity_info

model = net()
# Flops&params
flops,params = get_model_complexity_info(model,(128,50),as_strings=True,print_per_layer_stat=True)
print('Flops:  ' + flops)
print('Params: ' + params)

这是我的一维CNN。

from __future__ import print_function
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F


# create cnn
class net(nn.Module):
    def __init__(self):
        super(net,self).__init__()
        self.conv1 = nn.Conv1d(1,128,50,stride=3)
        self.conv2 = nn.Conv1d(128,32,7,stride=1)
        self.conv3 = nn.Conv1d(32,9,stride=1)
        self.fc1 = nn.Linear(32,128)
        self.fc2 = nn.Linear(128,5)

        self.bn1 = nn.Batchnorm1d(128)
        self.bn2 = nn.Batchnorm1d(32)

        self.dropout = nn.Dropout2d(0.5)
        self.faltten = nn.Flatten()

    # forward propagation
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.bn1(x)
        x = F.max_pool1d(x,2,stride=3)

        x = self.dropout(F.relu(self.conv2(x)))
        # x = F.relu(self.conv2(x))
        x = self.bn2(x)
        x = F.max_pool1d(x,stride=2)

        x = self.dropout(F.relu(self.conv3(x)))

        x = self.faltten(x)
        x = self.dropout(self.fc1(x))
        output = self.fc2(x)

        return output

这是flops_counter的代码

'''
copyright (C) 2019 Sovrasov V. - All Rights Reserved
 * You may use,distribute and modify this code under the
 * terms of the MIT license.
 * You should have received a copy of the MIT license with
 * this file. If not visit https://opensource.org/licenses/MIT
'''

import sys
from functools import partial

import torch
import torch.nn as nn
import numpy as np


def get_model_complexity_info(model,input_res,print_per_layer_stat=True,input_constructor=None,ost=sys.stdout,verbose=False,ignore_modules=[],custom_modules_hooks={}):
    assert type(input_res) is tuple
    assert len(input_res) >= 1
    assert isinstance(model,nn.Module)
    global CUSTOM_MODULES_MAPPING
    CUSTOM_MODULES_MAPPING = custom_modules_hooks
    flops_model = add_flops_counting_methods(model)
    flops_model.eval()
    flops_model.start_flops_count(ost=ost,verbose=verbose,ignore_list=ignore_modules)
    if input_constructor:
        input = input_constructor(input_res)
        _ = flops_model(**input)
    else:
        try:
            batch = torch.ones(()).new_empty((1,*input_res),dtype=next(flops_model.parameters()).dtype,device=next(flops_model.parameters()).device)
        except stopiteration:
            batch = torch.ones(()).new_empty((1,*input_res))

        _ = flops_model(batch)

    flops_count,params_count = flops_model.compute_average_flops_cost()
    if print_per_layer_stat:
        print_model_with_flops(flops_model,flops_count,params_count,ost=ost)
    flops_model.stop_flops_count()
    CUSTOM_MODULES_MAPPING = {}

    if as_strings:
        return flops_to_string(flops_count),params_to_string(params_count)

    return flops_count,params_count


def flops_to_string(flops,units='GMac',precision=2):
    if units is None:
        if flops // 10**9 > 0:
            return str(round(flops / 10.**9,precision)) + ' GMac'
        elif flops // 10**6 > 0:
            return str(round(flops / 10.**6,precision)) + ' MMac'
        elif flops // 10**3 > 0:
            return str(round(flops / 10.**3,precision)) + ' KMac'
        else:
            return str(flops) + ' Mac'
    else:
        if units == 'GMac':
            return str(round(flops / 10.**9,precision)) + ' ' + units
        elif units == 'MMac':
            return str(round(flops / 10.**6,precision)) + ' ' + units
        elif units == 'KMac':
            return str(round(flops / 10.**3,precision)) + ' ' + units
        else:
            return str(flops) + ' Mac'


def params_to_string(params_num,units=None,precision=2):
    if units is None:
        if params_num // 10 ** 6 > 0:
            return str(round(params_num / 10 ** 6,2)) + ' M'
        elif params_num // 10 ** 3:
            return str(round(params_num / 10 ** 3,2)) + ' k'
        else:
            return str(params_num)
    else:
        if units == 'M':
            return str(round(params_num / 10.**6,precision)) + ' ' + units
        elif units == 'K':
            return str(round(params_num / 10.**3,precision)) + ' ' + units
        else:
            return str(params_num)


def print_model_with_flops(model,total_flops,total_params,precision=3,ost=sys.stdout):

    def accumulate_params(self):
        if is_supported_instance(self):
            return self.__params__
        else:
            sum = 0
            for m in self.children():
                sum += m.accumulate_params()
            return sum

    def accumulate_flops(self):
        if is_supported_instance(self):
            return self.__flops__ / model.__batch_counter__
        else:
            sum = 0
            for m in self.children():
                sum += m.accumulate_flops()
            return sum

    def flops_repr(self):
        accumulated_params_num = self.accumulate_params()
        accumulated_flops_cost = self.accumulate_flops()
        return ','.join([params_to_string(accumulated_params_num,units='M',precision=precision),'{:.3%} Params'.format(accumulated_params_num / total_params),flops_to_string(accumulated_flops_cost,units=units,'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),self.original_extra_repr()])

    def add_extra_repr(m):
        m.accumulate_flops = accumulate_flops.__get__(m)
        m.accumulate_params = accumulate_params.__get__(m)
        flops_extra_repr = flops_repr.__get__(m)
        if m.extra_repr != flops_extra_repr:
            m.original_extra_repr = m.extra_repr
            m.extra_repr = flops_extra_repr
            assert m.extra_repr != m.original_extra_repr

    def del_extra_repr(m):
        if hasattr(m,'original_extra_repr'):
            m.extra_repr = m.original_extra_repr
            del m.original_extra_repr
        if hasattr(m,'accumulate_flops'):
            del m.accumulate_flops

    model.apply(add_extra_repr)
    print(model,file=ost)
    model.apply(del_extra_repr)


def get_model_parameters_number(model):
    params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params_num


def add_flops_counting_methods(net_main_module):
    # adding additional methods to the existing module object,# this is done this way so that each function has access to self object
    net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
    net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
    net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
    net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)

    net_main_module.reset_flops_count()

    return net_main_module


def compute_average_flops_cost(self):
    """
    A method that will be available after add_flops_counting_methods() is called
    on a desired net object.
    Returns current mean flops consumption per image.
    """

    batches_count = self.__batch_counter__
    flops_sum = 0
    params_sum = 0
    for module in self.modules():
        if is_supported_instance(module):
            flops_sum += module.__flops__
    params_sum = get_model_parameters_number(self)
    return flops_sum / batches_count,params_sum


def start_flops_count(self,**kwargs):
    """
    A method that will be available after add_flops_counting_methods() is called
    on a desired net object.
    Activates the computation of mean flops consumption per image.
    Call it before you run the network.
    """
    add_batch_counter_hook_function(self)

    seen_types = set()
    def add_flops_counter_hook_function(module,ost,verbose,ignore_list):
        if type(module) in ignore_list:
            seen_types.add(type(module))
            if is_supported_instance(module):
                module.__params__ = 0
        elif is_supported_instance(module):
            if hasattr(module,'__flops_handle__'):
                return
            if type(module) in CUSTOM_MODULES_MAPPING:
                handle = module.register_forward_hook(CUSTOM_MODULES_MAPPING[type(module)])
            else:
                handle = module.register_forward_hook(MODULES_MAPPING[type(module)])
            module.__flops_handle__ = handle
            seen_types.add(type(module))
        else:
            if verbose and not type(module) in (nn.Sequential,nn.ModuleList) and not type(module) in seen_types:
                print('Warning: module ' + type(module).__name__ + ' is treated as a zero-op.',file=ost)
            seen_types.add(type(module))

    self.apply(partial(add_flops_counter_hook_function,**kwargs))


def stop_flops_count(self):
    """
    A method that will be available after add_flops_counting_methods() is called
    on a desired net object.
    Stops computing the mean flops consumption per image.
    Call whenever you want to pause the computation.
    """
    remove_batch_counter_hook_function(self)
    self.apply(remove_flops_counter_hook_function)


def reset_flops_count(self):
    """
    A method that will be available after add_flops_counting_methods() is called
    on a desired net object.
    Resets statistics computed so far.
    """
    add_batch_counter_variables_or_reset(self)
    self.apply(add_flops_counter_variable_or_reset)


# ---- Internal functions
def empty_flops_counter_hook(module,input,output):
    module.__flops__ += 0


def upsample_flops_counter_hook(module,output):
    output_size = output[0]
    batch_size = output_size.shape[0]
    output_elements_count = batch_size
    for val in output_size.shape[1:]:
        output_elements_count *= val
    module.__flops__ += int(output_elements_count)


def relu_flops_counter_hook(module,output):
    active_elements_count = output.numel()
    module.__flops__ += int(active_elements_count)


def linear_flops_counter_hook(module,output):
    input = input[0]
    output_last_dim = output.shape[-1]  # pytorch checks dimensions,so here we don't care much
    module.__flops__ += int(np.prod(input.shape) * output_last_dim)


def pool_flops_counter_hook(module,output):
    input = input[0]
    module.__flops__ += int(np.prod(input.shape))


def bn_flops_counter_hook(module,output):
    module.affine
    input = input[0]

    batch_flops = np.prod(input.shape)
    if module.affine:
        batch_flops *= 2
    module.__flops__ += int(batch_flops)


def deconv_flops_counter_hook(conv_module,output):
    # Can have multiple inputs,getting the first one
    input = input[0]

    batch_size = input.shape[0]
    input_height,input_width = input.shape[2:]

    kernel_height,kernel_width = conv_module.kernel_size
    in_channels = conv_module.in_channels
    out_channels = conv_module.out_channels
    groups = conv_module.groups

    filters_per_channel = out_channels // groups
    conv_per_position_flops = kernel_height * kernel_width * in_channels * filters_per_channel

    active_elements_count = batch_size * input_height * input_width
    overall_conv_flops = conv_per_position_flops * active_elements_count
    bias_flops = 0
    if conv_module.bias is not None:
        output_height,output_width = output.shape[2:]
        bias_flops = out_channels * batch_size * output_height * output_height
    overall_flops = overall_conv_flops + bias_flops

    conv_module.__flops__ += int(overall_flops)


def conv_flops_counter_hook(conv_module,getting the first one
    input = input[0]

    batch_size = input.shape[0]
    output_dims = list(output.shape[2:])

    kernel_dims = list(conv_module.kernel_size)
    in_channels = conv_module.in_channels
    out_channels = conv_module.out_channels
    groups = conv_module.groups

    filters_per_channel = out_channels // groups
    conv_per_position_flops = int(np.prod(kernel_dims)) * in_channels * filters_per_channel

    active_elements_count = batch_size * int(np.prod(output_dims))

    overall_conv_flops = conv_per_position_flops * active_elements_count

    bias_flops = 0

    if conv_module.bias is not None:

        bias_flops = out_channels * active_elements_count

    overall_flops = overall_conv_flops + bias_flops

    conv_module.__flops__ += int(overall_flops)


def batch_counter_hook(module,output):
    batch_size = 1
    if len(input) > 0:
        # Can have multiple inputs,getting the first one
        input = input[0]
        batch_size = len(input)
    else:
        pass
        print('Warning! No positional inputs found for a module,assuming batch size is 1.')
    module.__batch_counter__ += batch_size


def rnn_flops(flops,rnn_module,w_ih,w_hh,input_size):
    # matrix matrix mult ih state and internal state
    flops += w_ih.shape[0]*w_ih.shape[1]
    # matrix matrix mult hh state and internal state
    flops += w_hh.shape[0]*w_hh.shape[1]
    if isinstance(rnn_module,(nn.RNN,nn.RNNCell)):
        # add both operations
        flops += rnn_module.hidden_size
    elif isinstance(rnn_module,(nn.GRU,nn.GRUCell)):
        # hadamard of r
        flops += rnn_module.hidden_size
        # adding operations from both states
        flops += rnn_module.hidden_size*3
        # last two hadamard product and add
        flops += rnn_module.hidden_size*3
    elif isinstance(rnn_module,(nn.LSTM,nn.LSTMCell)):
        # adding operations from both states
        flops += rnn_module.hidden_size*4
        # two hadamard product and add for C state
        flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
        # final hadamard
        flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
    return flops


def rnn_flops_counter_hook(rnn_module,output):
    """
    Takes into account batch goes at first position,contrary
    to pytorch common rule (but actually it doesn't matter).
    IF sigmoid and tanh are made hard,only a comparison FLOPS should be accurate
    """
    flops = 0
    inp = input[0] # input is a tuble containing a sequence to process and (optionally) hidden state
    batch_size = inp.shape[0]
    seq_length = inp.shape[1]
    num_layers = rnn_module.num_layers

    for i in range(num_layers):
        w_ih = rnn_module.__getattr__('weight_ih_l' + str(i))
        w_hh = rnn_module.__getattr__('weight_hh_l' + str(i))
        if i == 0:
            input_size = rnn_module.input_size
        else:
            input_size = rnn_module.hidden_size
        flops = rnn_flops(flops,input_size)
        if rnn_module.bias:
            b_ih = rnn_module.__getattr__('bias_ih_l' + str(i))
            b_hh = rnn_module.__getattr__('bias_hh_l' + str(i))
            flops += b_ih.shape[0] + b_hh.shape[0]

    flops *= batch_size
    flops *= seq_length
    if rnn_module.bidirectional:
        flops *= 2
    rnn_module.__flops__ += int(flops)


def rnn_cell_flops_counter_hook(rnn_cell_module,output):
    flops = 0
    inp = input[0]
    batch_size = inp.shape[0]
    w_ih = rnn_cell_module.__getattr__('weight_ih')
    w_hh = rnn_cell_module.__getattr__('weight_hh')
    input_size = inp.shape[1]
    flops = rnn_flops(flops,rnn_cell_module,input_size)
    if rnn_cell_module.bias:
        b_ih = rnn_cell_module.__getattr__('bias_ih')
        b_hh = rnn_cell_module.__getattr__('bias_hh')
        flops += b_ih.shape[0] + b_hh.shape[0]

    flops *= batch_size
    rnn_cell_module.__flops__ += int(flops)


def add_batch_counter_variables_or_reset(module):

    module.__batch_counter__ = 0


def add_batch_counter_hook_function(module):
    if hasattr(module,'__batch_counter_handle__'):
        return

    handle = module.register_forward_hook(batch_counter_hook)
    module.__batch_counter_handle__ = handle


def remove_batch_counter_hook_function(module):
    if hasattr(module,'__batch_counter_handle__'):
        module.__batch_counter_handle__.remove()
        del module.__batch_counter_handle__


def add_flops_counter_variable_or_reset(module):
    if is_supported_instance(module):
        if hasattr(module,'__flops__') or hasattr(module,'__params__'):
            print('Warning: variables __flops__ or __params__ are already '
                    'defined for the module' + type(module).__name__ +
                    ' ptflops can affect your code!')
        module.__flops__ = 0
        module.__params__ = get_model_parameters_number(module)

CUSTOM_MODULES_MAPPING = {}

MODULES_MAPPING = {
    # convolutions
    nn.Conv1d: conv_flops_counter_hook,nn.Conv2d: conv_flops_counter_hook,nn.Conv3d: conv_flops_counter_hook,# activations
    nn.ReLU: relu_flops_counter_hook,nn.PReLU: relu_flops_counter_hook,nn.ELU: relu_flops_counter_hook,nn.LeakyReLU: relu_flops_counter_hook,nn.ReLU6: relu_flops_counter_hook,# poolings
    nn.MaxPool1d: pool_flops_counter_hook,nn.AvgPool1d: pool_flops_counter_hook,nn.AvgPool2d: pool_flops_counter_hook,nn.MaxPool2d: pool_flops_counter_hook,nn.MaxPool3d: pool_flops_counter_hook,nn.AvgPool3d: pool_flops_counter_hook,nn.AdaptiveMaxPool1d: pool_flops_counter_hook,nn.AdaptiveAvgPool1d: pool_flops_counter_hook,nn.AdaptiveMaxPool2d: pool_flops_counter_hook,nn.AdaptiveAvgPool2d: pool_flops_counter_hook,nn.AdaptiveMaxPool3d: pool_flops_counter_hook,nn.AdaptiveAvgPool3d: pool_flops_counter_hook,# BNs
    nn.Batchnorm1d: bn_flops_counter_hook,nn.Batchnorm2d: bn_flops_counter_hook,nn.Batchnorm3d: bn_flops_counter_hook,# FC
    nn.Linear: linear_flops_counter_hook,# Upscale
    nn.Upsample: upsample_flops_counter_hook,# Deconvolution
    nn.ConvTranspose2d: deconv_flops_counter_hook,# RNN
    nn.RNN: rnn_flops_counter_hook,nn.GRU: rnn_flops_counter_hook,nn.LSTM: rnn_flops_counter_hook,nn.RNNCell: rnn_cell_flops_counter_hook,nn.LSTMCell: rnn_cell_flops_counter_hook,nn.GRUCell: rnn_cell_flops_counter_hook
}


def is_supported_instance(module):
    if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING:
        return True
    return False


def remove_flops_counter_hook_function(module):
    if is_supported_instance(module):
        if hasattr(module,'__flops_handle__'):
            module.__flops_handle__.remove()
            del module.__flops_handle__

解决方法

以下是使用ptflops软件包的工作代码。您需要注意输入序列的长度。 Conv1d的pytorch文档显示为:Lout of Conv1d

这使您可以从第一个完全连接的层回溯所需的输入大小(请参阅模型定义中的注释)。

from ptflops import get_model_complexity_info
import torch.nn as nn
import torch.nn.functional as F


class net(nn.Module):
    def __init__(self):
        super(net,self).__init__()
        self.conv1 = nn.Conv1d(1,128,50,stride=3)  # Lin = 260
        # max_pool1d(x,2,stride=3)  # Lin = 71
        self.conv2 = nn.Conv1d(128,32,7,stride=1)  # Lin = 24
        # max_pool1d(x,stride=2)  # Lin = 18
        self.conv3 = nn.Conv1d(32,9,stride=1)  # Lin = 9
        self.fc1 = nn.Linear(32,128)
        self.fc2 = nn.Linear(128,5)

        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(32)

        self.dropout = nn.Dropout2d(0.5)
        self.flatten = nn.Flatten()

    # forward propagation
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.bn1(x)
        x = F.max_pool1d(x,stride=3)

        x = self.dropout(F.relu(self.conv2(x)))
        # x = F.relu(self.conv2(x))
        x = self.bn2(x)
        x = F.max_pool1d(x,stride=2)

        x = self.dropout(F.relu(self.conv3(x)))

        x = self.flatten(x)
        x = self.dropout(self.fc1(x))
        output = self.fc2(x)

        return output


macs,params = get_model_complexity_info(net(),(1,260),as_strings=False,print_per_layer_stat=True,verbose=True)
print('{:<30}  {:<8}'.format('Computational complexity: ',macs))
print('{:<30}  {:<8}'.format('Number of parameters: ',params))

输出:

net(
  0.05 M,100.000% Params,0.001 GMac,100.000% MACs,(conv1): Conv1d(0.007 M,13.143% Params,0.0 GMac,45.733% MACs,1,kernel_size=(50,),stride=(3,))
  (conv2): Conv1d(0.029 M,57.791% Params,50.980% MACs,kernel_size=(7,stride=(1,))
  (conv3): Conv1d(0.009 M,18.619% Params,0.913% MACs,kernel_size=(9,))
  (fc1): Linear(0.004 M,8.504% Params,0.404% MACs,in_features=32,out_features=128,bias=True)
  (fc2): Linear(0.001 M,1.299% Params,0.063% MACs,in_features=128,out_features=5,bias=True)
  (bn1): BatchNorm1d(0.0 M,0.515% Params,1.793% MACs,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)
  (bn2): BatchNorm1d(0.0 M,0.129% Params,0.114% MACs,track_running_stats=True)
  (dropout): Dropout2d(0.0 M,0.000% Params,0.000% MACs,p=0.5,inplace=False)
  (flatten): Flatten(0.0 M,)
)
Computational complexity:       1013472.0
Number of parameters:           49669  
```

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?