如何解决优化自定义 2d 卷积,它可以在 pytorch 中的卷积过程中量化 MAC 结果
我是一名学习低功耗存储器电路/内存计算的学生,我正在尝试将乘法累加 (MAC) 操作引入 SRAM 阵列,以便大部分/所有卷积都可以在内存阵列内完成。
由于 SRAM 内部的计算无法像 GPU 那样处理全精度操作,我试图将 MAC 结果的位精度降低到某个点,而不会遭受严重的精度损失。
所以我试图用pytorch平台验证它。
但是,conv2d 似乎不支持这样的功能,所以我正在努力从地面自定义构建卷积。
我正在使用定制的迷你 vgg:128C3-128C3-MP2-256C3-256C3-MP2-512C3-512C3-MP2-1024FC-1024FC-10FC
我仍在编写代码,但我遇到了“CUDA 内存不足”的错误。
我知道为什么我收到错误但无法想出解决方案。
有没有什么办法可以在不改变批量大小或网络的情况下优化代码?
另外,如果有人知道我可以参考的自定义 2d 卷积,我会非常感谢它。
感谢您阅读我的问题!
(请原谅我的写作和代码可能难以阅读。:'(我对 python/pytorch 完全陌生,只知道神经网络的基础知识。)
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torch.autograd import Function
import pdb
class Conv2DFunctionCustom(Function):
@staticmethod
def forward(ctx,input,weight,bias=None,stride=1,padding=0,dilation=1,groups=1):
ctx.save_for_backward(input,bias)
ctx.stride,ctx.padding,ctx.dilation,ctx.groups = stride,padding,dilation,groups
zeropad = nn.ZeroPad2d(ctx.padding[0])
batch_size = len(input)
input_channel = len(input[0])
input_size = len(input[0][0])
ochannel = len(weight)
ichannel = len(weight[0])
kernel_size = len(weight[0][0])
if padding:
inp_pad = zeropad(input)
else:
inp_pad = input
input_un_tensor = inp_pad.unfold(1,input_channel,input_channel)
input_reshape = input_un_tensor.transpose(1,4).reshape(len(inp_pad),len(inp_pad[0]),len(inp_pad[0][0]),len(inp_pad[0][0]))
weight_un_tensor = weight.unfold(1,ichannel,ichannel)
weight_reshape = weight_un_tensor.transpose(1,4).reshape(len(weight),len(weight[0]),len(weight[0][0]),len(weight[0][0]))
input_unfold = torch.nn.functional.unfold(input_reshape,(kernel_size,kernel_size))
print('unfoldinput:',input_unfold.size())
mult = input_unfold.transpose(1,2).unfold(2,16,16)[None,:] * weight_reshape.view(weight_reshape.size(0),-1).unfold(1,16)[:,None,None]
psum = torch.sum(mult,dim=4)
psum = torch.sum(psum,dim=3)
out = torch.nn.functional.fold(psum,(input_size,input_size),(1,1)).transpose(0,1)
return out
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。