如何解决如何提高大整数的乘法效率?
这个周末我跟着 wiki 实现了基本的大整数乘法。我使用Toom-3算法来实现。但是一开始花的时间竟然比长乘法(小学乘法)慢,一去不复返了。我希望程序能在500位以内超过小学乘法,请问怎么办?
我尝试优化,我保留了向量容量并删除了多余的代码。但不是很有效。
我应该使用 vector<long long>
作为我的基数吗?
typedef long long BigIntBase;
typedef vector<BigIntBase> BigIntDigits;
// ceil(numeric_limits<BigIntBase>::digits10 / 2.0) - 1;
static const int digit_base_len = 9;
// b
static const BigIntBase digit_base = 1000000000;
class BigInt {
public:
BigInt(int digits_capacity = 0,bool nega = false) {
negative = nega;
digits.reserve(digits_capacity);
}
BigInt(BigIntDigits _digits,bool nega = false) {
negative = nega;
digits = _digits;
}
BigInt(const span<const BigIntBase> &range,bool nega = false) {
negative = nega;
digits = BigIntDigits(range.begin(),range.end());
}
BigInt operator+(const BigInt &rhs) {
if ((*this).negative == rhs.negative)
return BigInt(plus((*this).digits,rhs.digits),(*this).negative);
if (greater((*this).digits,rhs.digits))
return BigInt(minus((*this).digits,(*this).negative);
return BigInt(minus(rhs.digits,(*this).digits),rhs.negative);
}
BigInt operator-(const BigInt &rhs) { return *this + BigInt(rhs.digits,!rhs.negative); }
BigInt operator*(const BigInt &rhs) {
if ((*this).digits.empty() || rhs.digits.empty()) {
return BigInt();
} else if ((*this).digits.size() == 1 && rhs.digits.size() == 1) {
BigIntBase val = (*this).digits[0] * rhs.digits[0];
return BigInt(val < digit_base ? BigIntDigits{val} : BigIntDigits{val % digit_base,val / digit_base},(*this).negative ^ rhs.negative);
} else if ((*this).digits.size() == 1)
return BigInt(multiply(rhs,(*this).digits[0]).digits,(*this).negative ^ rhs.negative);
else if (rhs.digits.size() == 1)
return BigInt(multiply((*this),rhs.digits[0]).digits,(*this).negative ^ rhs.negative);
return BigInt(toom3(span((*this).digits),span(rhs.digits)),(*this).negative ^ rhs.negative);
}
string to_string() {
if (this->digits.empty())
return "0";
stringstream ss;
if (this->negative)
ss << "-";
ss << std::to_string(this->digits.back());
for (auto it = this->digits.rbegin() + 1; it != this->digits.rend(); ++it)
ss << setw(digit_base_len) << setfill('0') << std::to_string(*it);
return ss.str();
}
BigInt from_string(string s) {
digits.clear();
negative = s[0] == '-';
for (int pos = max(0,(int)s.size() - digit_base_len); pos >= 0; pos -= digit_base_len)
digits.push_back(stoll(s.substr(pos,digit_base_len)));
if (s.size() % digit_base_len)
digits.push_back(stoll(s.substr(0,s.size() % digit_base_len)));
return *this;
}
private:
bool negative;
BigIntDigits digits;
const span<const BigIntBase> toom3_slice_num(const span<const BigIntBase> &num,const int &n,const int &i) {
int begin = n * i;
if (begin < num.size()) {
const span<const BigIntBase> result = num.subspan(begin,min((int)num.size() - begin,i));
return result;
}
return span<const BigIntBase>();
}
BigIntDigits toom3(const span<const BigIntBase> &num1,const span<const BigIntBase> &num2) {
int i = ceil(max(num1.size() / 3.0,num2.size() / 3.0));
const span<const BigIntBase> m0 = toom3_slice_num(num1,i);
const span<const BigIntBase> m1 = toom3_slice_num(num1,1,i);
const span<const BigIntBase> m2 = toom3_slice_num(num1,2,i);
const span<const BigIntBase> n0 = toom3_slice_num(num2,i);
const span<const BigIntBase> n1 = toom3_slice_num(num2,i);
const span<const BigIntBase> n2 = toom3_slice_num(num2,i);
BigInt pt0 = plus(m0,m2);
BigInt pp0 = m0;
BigInt pp1 = plus(pt0.digits,m1);
BigInt pn1 = pt0 - m1;
BigInt pn2 = multiply(pn1 + m2,2) - m0;
BigInt pin = m2;
BigInt qt0 = plus(n0,n2);
BigInt qp0 = n0;
BigInt qp1 = plus(qt0.digits,n1);
BigInt qn1 = qt0 - n1;
BigInt qn2 = multiply(qn1 + n2,2) - n0;
BigInt qin = n2;
BigInt rp0 = pp0 * qp0;
BigInt rp1 = pp1 * qp1;
BigInt rn1 = pn1 * qn1;
BigInt rn2 = pn2 * qn2;
BigInt rin = pin * qin;
BigInt r0 = rp0;
BigInt r4 = rin;
BigInt r3 = divide(rn2 - rp1,3);
BigInt r1 = divide(rp1 - rn1,2);
BigInt r2 = rn1 - rp0;
r3 = divide(r2 - r3,2) + multiply(rin,2);
r2 = r2 + r1 - r4;
r1 = r1 - r3;
BigIntDigits result = r0.digits;
if (!r1.digits.empty()) {
shift_left(r1.digits,i);
result = plus(result,r1.digits);
}
if (!r2.digits.empty()) {
shift_left(r2.digits,i << 1);
result = plus(result,r2.digits);
}
if (!r3.digits.empty()) {
shift_left(r3.digits,i * 3);
result = plus(result,r3.digits);
}
if (!r4.digits.empty()) {
shift_left(r4.digits,i << 2);
result = plus(result,r4.digits);
}
return result;
}
BigIntDigits plus(const span<const BigIntBase> &lhs,const span<const BigIntBase> &rhs) {
if (lhs.empty())
return BigIntDigits(rhs.begin(),rhs.end());
if (rhs.empty())
return BigIntDigits(lhs.begin(),lhs.end());
int max_length = max(lhs.size(),rhs.size());
BigIntDigits result;
result.reserve(max_length + 1);
for (int w = 0; w < max_length; ++w)
result.push_back((lhs.size() > w ? lhs[w] : 0) + (rhs.size() > w ? rhs[w] : 0));
for (int w = 0; w < result.size() - 1; ++w) {
result[w + 1] += result[w] / digit_base;
result[w] %= digit_base;
}
if (result.back() >= digit_base) {
result.push_back(result.back() / digit_base);
result[result.size() - 2] %= digit_base;
}
return result;
}
BigIntDigits minus(const span<const BigIntBase> &lhs,lhs.end());
BigIntDigits result;
result.reserve(lhs.size() + 1);
for (int w = 0; w < lhs.size(); ++w)
result.push_back((lhs.size() > w ? lhs[w] : 0) - (rhs.size() > w ? rhs[w] : 0));
for (int w = 0; w < result.size() - 1; ++w)
if (result[w] < 0) {
result[w + 1] -= 1;
result[w] += digit_base;
}
while (!result.empty() && !result.back())
result.pop_back();
return result;
}
void shift_left(BigIntDigits &lhs,const int n) {
if (!lhs.empty()) {
BigIntDigits zeros(n,0);
lhs.insert(lhs.begin(),zeros.begin(),zeros.end());
}
}
BigInt divide(const BigInt &lhs,const int divisor) {
BigIntDigits reminder(lhs.digits);
BigInt result(lhs.digits.capacity(),lhs.negative);
for (int w = reminder.size() - 1; w >= 0; --w) {
result.digits.insert(result.digits.begin(),reminder[w] / divisor);
reminder[w - 1] += (reminder[w] % divisor) * digit_base;
}
while (!result.digits.empty() && !result.digits.back())
result.digits.pop_back();
return result;
}
BigInt multiply(const BigInt &lhs,const int multiplier) {
BigInt result(lhs.digits,lhs.negative);
for (int w = 0; w < result.digits.size(); ++w)
result.digits[w] *= multiplier;
for (int w = 0; w < result.digits.size(); ++w)
if (result.digits[w] >= digit_base) {
if (w + 1 == result.digits.size())
result.digits.push_back(result.digits[w] / digit_base);
else
result.digits[w + 1] += result.digits[w] / digit_base;
result.digits[w] %= digit_base;
}
return result;
}
bool greater(const BigIntDigits &lhs,const BigIntDigits &rhs) {
if (lhs.size() == rhs.size()) {
int w = lhs.size() - 1;
while (w >= 0 && lhs[w] == rhs[w])
--w;
return w >= 0 && lhs[w] > rhs[w];
} else
return lhs.size() > rhs.size();
}
};
数字 | 小学 | Toom-3 |
---|---|---|
10 | 4588 | 10003 |
50 | 24147 | 109084 |
100 | 52165 | 286535 |
150 | 92405 | 476275 |
200 | 172156 | 1076570 |
250 | 219599 | 1135946 |
300 | 320939 | 1530747 |
350 | 415655 | 1689745 |
400 | 498172 | 1937327 |
450 | 614467 | 2629886 |
500 | 863116 | 3184277 |
解决方法
问题是你在 toom3_slice_num
中进行了一百万次分配,在这里你可以使用一个 std::span
(或一个 std::pair 迭代器到实际部分)作为你给出的数字是一个常量。 toom3
也是分配器地狱。
multiply
可能会再分配 1 个时间。计算所需的位数或只是将大小加 1。
对于几乎无锁的分配,vector
应该是 pmr
(使用适当的分配器)。
如果不使用 -O2
或 -O3
编译,所有这些都将被浪费。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。