如何解决断言失败:lhs.cols() == rhs.rows() &&“无效矩阵乘积”
嘿,我正在尝试将 2 个矩阵相乘,我知道我的问题是什么,我正在尝试将 2x1 矩阵乘以 NxN 矩阵 ,但 NxN 与 2x1 的大小相同!那么我该如何解决呢?
这是我的代码:
MatrixXf dmap_return(MatrixXf& in)
{
MatrixXf out(in.rows(),in.cols());
float val = 0;
for (int i = 0; i < in.rows(); i++)
{
for (int j = 0; j < in.cols(); j++)
{
val = in(i,j);
out(i,j) = dsigmoid(val);
}
}
return out;
}
MatrixXf output_errors = targets - outputs;
std::cout << output_errors << std::endl << std::endl;;
MatrixXf gradients = dmap_return(outputs);
std::cout << gradients << std::endl << std::endl;;
gradients *= output_errors; // it crushes here
output_errors
是一个 2x1 矩阵,gradients
这是我运行程序时得到的错误截图:
我尽量让这个例子保持简单,但如果您需要额外的代码,请告诉我。
解决方法
数学:
3 个矩阵 A、B、C。
A:n x m(行 x 列)
B: m x p
C: n x p
换句话说:结果矩阵的行数与 A 相同,B 的列数也相同,此外,A 的 num 行必须等于 B 的 num 列。 >
C 的每个字段都是 A 的相关行和 B 的列的点积的结果。
代码:
以下代码假设矩阵实现了 operator ()(size_t row,size_t col)
。
template <typename A,typename B,typename C>
C& mult_matrix(const A& a,const B& b,C& c) noexcept
{
assert(a.cols() == b.rows());
assert(c.rows() == a.rows());
assert(c.cols() == b.cols());
size_t n = a.rows();
size_t m = a.cols();
size_t p = b.cols();
for (size_t i = 0; i < n; ++i) {
for (size_t j = 0; j < p; ++j) {
c(i,j) = 0;
for (size_t k = 0; k < m; ++k) {
c(i,j) += (a(i,k) * b(k,j));
}
}
}
return c;
}
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。