如何解决浮点数组中的 k-diff 序列
寻找一种算法来查找最长序列(对、三元组、最多四元组),这些序列在浮点排序数组 k
中被常量非整数差 arr
分隔。有 O(n) 或更好的解决方案吗?
find_sequences(arr=[1.20,2.00,2.20,2.31,3.09,3.43,4.20,5.30],k=1.10,tol=0.01)
# with tolerance of 1% of k,or 0.011,first sequence includes 2.31 but not 3.43
# [[1.20,2.31],[2.00,5.30]]
find_sequences(arr=[1.20,3.00,3.10,tol=0.02)
# tolerance of 2% allows in 3.43
# [[1.20,3.43],5.30]]
# alternatively,return indices - as you can see they're overlapping:
# [[0,3,6],[1,5,7,8]]
通过带有 __eq__
的 np.isclose()
构造函数似乎很容易实现 Tolerance,对此不必太担心。主要想知道是否有一次性解决方案。
与 Leetcode 的 #532 (K-diff Pairs in an Array) 有很大的相似性 https://leetcode.com/problems/k-diff-pairs-in-an-array/
到目前为止,我想出了这个非常慢的熊猫解决方案。
def find_series(s,delta,btol,utol):
"""Finds delta-diff sequences in a float array.
Algorithm:
1) find all matching pairs (M0,M1)
2) recursively find longer sequences.
"""
# step 1: find all matching pairs
m01 = []
for idx,val in s.items():
lower,upper = val + delta - btol,val + delta + utol
is_match = s[idx:].between(lower,upper)
if sum(is_match) == 1:
m01.append([idx,is_match.idxmax()])
elif sum(is_match) > 1: # starting series and tolerances are picked to not allow this to happen
print(f'multiple matches for {idx}:{val}')
m01 = np.array(m01) # np.append / np.vstack are slower
res = pd.DataFrame(data={
'M0': s[m01[:,0]].values,'M1': s[m01[:,1]].values,})
# check if M1 values are found in M0 column
next_ = res['M0'].isin(res['M1'])
n_matches = sum(next_)
if n_matches == 0:
return
# step 2: recursion
next_map = res[next_].set_index('M0')['M1'].to_dict()
i = 2
while True:
next_col = res[f'M{i-1}'].map(next_map)
n_matches = next_col.notna().sum()
if n_matches > 0:
res[f'M{i}'] = next_col
i += 1
else:
break
return res[~next_].to_numpy()
find_series(a,1.1,0.02,0.02)
返回:
array([[1.2,nan],[2.,4.2,5.3 ]])
在更大的数据集上计时
| n | time(ms) |
|-----:|-----------:|
| 200 | 82 |
| 400 | 169 |
| 800 | 391 |
| 1600 | 917 |
| 3200 | 2500 |
解决方法
是的,这可以通过 O(nlog(n)) 中的扫描线技术来完成。假设从一个数字 x,如果 x + a
想法是这样的:为每个数字 x 创建类型 1、2 和 3 的事件。 x 的类型 1 事件发生在位置 x,表明我们应该根据当前可用的数字处理 x。 x 的类型 2 事件发生在位置 x + a,它表明我们现在应该将 x 包含在当前可用的数字集中。正如您所怀疑的,x 的类型 3 事件发生在位置 x + b,它表明我们应该从当前可用的数字中删除 x。
在处理 x 时,当前可用的数字都将小于 x。所以关键是当前可用的每个数字都可以从它自身到 x。当我们处理一个数字时,我们也会给出确定可以导致该数字的最大长度链。因此,对于 x 之前的每个数字,我们知道有多少个数字最适合它,这意味着对于当前可用集合中的所有内容,我们也知道答案。因此,我们在当前可用集合中的所有内容中取最大答案,向其中添加一个,并将其设置为 x 的答案。
日志因素来自我们必须对事件进行排序的事实。以下代码适用于您的示例。
int main() {
vector<double> arr{1.20,2.00,2.20,2.31,3.00,3.10,3.43,4.20,5.30};
int n = arr.size();
double tol = 0.01;
double a = 1.1 * (1 - tol),b = 1.1 * (1 + tol);
vector<pair<double,pair<double,int>>> events;
for (int i = 0; i < n; ++i) {
double x = arr[i];
events.push_back({x,{1,i}});
events.push_back({x + a,{2,i}});
events.push_back({x + b,{3,i}});
}
sort(events.begin(),events.end());
multiset<pair<int,int>> avail; // set of pairs of answer,and index,for each currently available element
vector<int> ans(n,0),prev(n,-1);
for (auto ev : events) {
int type = ev.second.first,idx = ev.second.second;
if (type == 1) { // process x
if (avail.size()) {
ans[idx] = 1 + avail.rbegin()->first; // largest currently available answer
prev[idx] = avail.rbegin()->second;
} else ans[idx] = 1;
} else if (type == 2) { // add in x
avail.insert({ans[idx],idx});
} else if (type == 3) { // remove x
avail.erase(avail.lower_bound({ans[idx],idx}));
}
}
int best = 0,pos = -1;
for (int i = 0; i < n; ++i)
if (ans[i] > best) {
best = max(ans[i],best);
pos = i;
}
vector<double> vals;
while (pos != -1) {
vals.push_back(arr[pos]);
pos = prev[pos];
}
sort(vals.begin(),vals.end());
for (auto x : vals) cout << x << ",";
cout << endl;
}
请注意,这只是找到满足约束的最长序列。当看到你提到的对、三胞胎和四胞胎时,我有点困惑,因为如果你想找到所有四胞胎,可能有 O(n^4) 个具有足够大的 tol 和 close arr 中有足够的值。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。