为啥这玩意没在 OI 中普及啊?
Description
有 n 种物品,第 i 种物品的重量为 wi,价值为 vi,一共有 ci 个,求在总重量不超过 m 时的价值之和最大是多少。
形式化地,求:
maxi=1∑nxivis.t.{∀1≤i≤n,xi∈Z∧0≤xi≤ci∑i=1nxiwi≤m
其中 1≤wi≤W
我们有一个 O(nlogn+W3logW) 或 O(n+W3) 的做法。
Solution
考虑将所有物品按 vi/wi(“价值密度”)降序排序,然后贪心从前往后取,这样可以得到一个较优的解。
排除掉所有物品都能放进背包的平凡情况,在贪心从前往后取时,若某种物品不能全部放入背包则立刻退出,设这个解为 p,则其存在如下两个性质:
- 剩余空间 t−∑i=1npiwi 一定在 [0,W) 范围内。
- 存在一个分界点 i0,使得 ∀1≤i<i0,pi=ci 且 ∀i0<i≤n,pi=0
结论:存在一组最优解 o,使得:
i=1∑n∣oi−pi∣≤2W
证明:
考虑一组最优解 o 使得 ∑i=1n∣oi−pi∣ 最小,从 o 出发,每次在当前解 x 中添加一个物品或删除一个物品,逐步调整到 p。要求过程中剩余空间 m−∑i=1nxiwi 在 (−W,W] 中,并且调整要在 ∑i=1n∣oi−pi∣ 步内完成,即不存在冗余的操作。
一种调整方法是:若当前剩余空间 >0,则选择一个要加入的物品加进去,如果不存在这样的物品,则将所有要删除的物品删去,容易证明过程中剩余空间不会 >w。对于剩余空间 <0 的情况同理。
考虑反证,若 ∑i=1n∣oi−pi∣>2W,则调整过程中一定存在两个解 x,y 使得 ∑i=1nxiwi=∑i=1nyiwi,因此有 ∑i=1noiwi=∑i=1n(oi−xi+yi)wi。
考虑 o→x→y 的过程,由于没有冗余操作,有 ∀i,min{oi,yi}≤xi≤max{oi,yi},因此 ∀i,0≤oi−xi+yi≤ci,这说明 o−x+y 也是一组合法解。
考虑将 x→y 的过程视为 o→o−x+y 的调整过程,同样由于调整没有冗余过程,所有被加入的物品的价值密度都至少为 vi0/wi0,所有被删除的物品的价值密度都至多为 vi0/wi0,因此 o−x+y 是一个不劣于 o 的解,且距离 p 更近(∑i=1n∣(oi−xi+yi)−pi∣<∑i=1n∣oi−pi∣),与假设矛盾。
□
除了这个结论之外我们还需要一个前置算法:对于给定 L,我们可以在 O(nlogn+LWlogW) 的时间内对于每个 m∈[0,L] 求出背包问题的答案。
做法是对于每个重量 i,只需要保留价值前 iL 大的物品,然后对于每个重量就变成了任意函数和一个图函数的 (max,+) 卷积,可以使用决策单调性解决。
回到原问题,根据结论我们可以得到最优解一定可以被表示成 p−o−+o+ 的形式,其中 ∑i=1noi−wi≤2W2∧∀i,oi−≤pi,o+ 同理。
因此,我们可以令 L=2W2,ci′=pi,vi′=−vi 跑上面的算法,再令 L=2W2,ci′=ci−pi,vi′=vi 跑一遍,然后枚举 o− 和 o+ 的重量进行求解即可。
考虑优化到 O(n+W3)。首先 O(W3logW) 部分的 log 可以通过 SMAWK 去掉。
考虑 O(nlogn) 部分,主要是按价值密度从大到小排序。
我们实际需求其实只是找到一个排序后分界点 i0,因此可以考虑通过 nth_element
先找出第 n/2 大的价值密度,若太大则向右侧递归,否则向左侧递归,时间复杂度是 T(n)=T(n/2)+O(n)⟹T(n)=O(n)。
另外在「前置算法」中还有对物品按价值排序,但这部分的时间复杂度为 ∑i=1wiLlogiL=O(Llog2L) 的,当 L=2W2 时为 O(W2logW)。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
#include <vector> #include <cstdio> #include <string> #include <cstring> #include <iostream> #include <algorithm> #define eputchar(c) putc(c, stderr) #define eprintf(...) fprintf(stderr, __VA_ARGS__) #define eputs(str) fputs(str, stderr), putc('\n', stderr) using namespace std;
int n, W, p[3000005]; long long m, weight, value; struct node { int w, v, c;
node() = default; node(int _w, int _v, int _c): w(_w), v(_v), c(_c) {} } a[3000005]; long long f[180005], g[180005], h[180005]; long long ans, dp1[180005], dp2[180005];
void calc(int L, int R, int l, int r) { if (L > R) return; int mid = (L + R) / 2, k = l; for (int i = l; i <= min(r, mid - 1); i++) if (max<long long>(g[i] + f[mid - i], ~0x3f3f3f3f3f3f3f3f) >= max<long long>(g[k] + f[mid - k], ~0x3f3f3f3f3f3f3f3f)) k = i; h[mid] = g[k] + f[mid - k]; calc(L, mid - 1, l, k), calc(mid + 1, R, k, r); return; } void solve(vector<node> vec, long long *dp, int m) { sort(vec.begin(), vec.end(), [](const node &x, const node &y) { return x.w != y.w ? x.w < y.w : x.v > y.v; }); memset(dp, ~0x3f, sizeof(long long) * (m + 5)); dp[0] = 0; for (int l = 0, r; l < (int)vec.size(); l = r + 1) { r = l; while (r + 1 < (int)vec.size() && vec[r + 1].w == vec[l].w) r++; int e = vec[l].w; int idx = 0; for (int i = l; i <= r && idx < m / e; i++) for (int j = 1; j <= vec[i].c && idx < m / e; j++) idx++, g[idx] = g[idx - 1] + vec[i].v; if (!idx) continue; for (int i = 0; i < e; i++) { int cnt = 0; f[0] = ~0x3f3f3f3f3f3f3f3f; for (int j = i; j <= m; j += e) f[++cnt] = dp[j]; calc(1, cnt, 1, idx); cnt = 0; for (int j = i; j <= m; j += e) dp[j] = max(dp[j], h[++cnt]); } } return; }
int main() { scanf("%d%d%lld", &n, &W, &m); for (int i = 1; i <= n; i++) scanf("%d%d%d", &a[i].w, &a[i].v, &a[i].c); sort(a + 1, a + n + 1, [](const node &x, const node &y) { return (long long)x.v * y.w > (long long)y.v * x.w; }); for (int i = 1; i <= n; i++) if ((long long)a[i].w * a[i].c <= m - weight) p[i] = a[i].c, weight += (long long)a[i].w * a[i].c, value += (long long)a[i].v * a[i].c; else { int cnt = (m - weight) / a[i].w; p[i] = cnt, weight += (long long)a[i].w * cnt, value += (long long)a[i].v * cnt; break; } if (p[n] == a[n].c) {printf("%lld\n", value); return 0;} { vector<node> vec; for (int i = 1; i <= n; i++) vec.emplace_back(a[i].w, -a[i].v, p[i]); solve(vec, dp1, 2 * W * W); } { vector<node> vec; for (int i = 1; i <= n; i++) vec.emplace_back(a[i].w, a[i].v, a[i].c - p[i]); solve(vec, dp2, 2 * W * W); } for (int i = 1; i <= 2 * W * W; i++) dp2[i] = max(dp2[i - 1], dp2[i]); ans = value; for (int i = 0; i <= min(weight, 2ll * W * W); i++) ans = max(ans, value + dp1[i] + dp2[min(m - weight + i, 2ll * W * W)]); printf("%lld\n", ans); return 0; }
|
好厉害!