题解 CF1668D【Optimal Partition】

2022-04-20 16:04:25


题意

给定一个数组 $a$,长度为 $n (1 \leq n \leq 5 \times 10^5)$,你需要将其分割为若干个连续的子数组,使所有子数组的价值总和最大。

定义 $\texttt{s(l, r)} = a_l + a_{l+1} + a_{l+2} + \cdots + a_r$,子数组 $a[l, r]$ 的价值是:

  • $r-l+1$,$\texttt{s(l, r)} > 0$
  • 0,$\texttt{s(l, r)} = 0$
  • $-(r-l+1)$,$\texttt{s(l, r)} < 0$

思路

有意思的一道题目。

暴力

若忽略掉数据范围的限制,不难想到一种 $O(n^2)$ 的 dp 方案:

  • 先无脑地把前缀和求出来;
  • 设 $dp[i]$ 表示前 $i$ 个元素构成的子数组,划分后可产生的最大价值
  • 显然 $dp[0] = 0$,即空数组不产生价值
  • 考虑转移:枚举 $i, j (j < i)$,则 $dp[i] = \max\left( dp[j] + \texttt{calc(j + 1, i)} \right)$;
  • 其中 $\texttt{calc(j + 1, i)}$ 表示子数组 $a[j + 1, r]$ 的价值,可以由定义计算;
  • $dp[n]$ 就是答案。
dp[0] = 0;
for (int i = 1; i <= n; ++i)
  for (int j = 0; j < i; ++j)
    dp[i] = std::max(dp[i], dp[j] + calc(j + 1, i));
print(dp[n]);

优化

上面的转移方程 $dp[i] = \max\left( dp[j] + \texttt{calc(j + 1, i)} \right)$ 包含了函数调用,有点麻烦,不妨把它拆开,于是产生了三个新的方程:

  • $dp[i] = \max\left( dp[j] + (i-j) \right)$,$\texttt{sum(j + 1, i)} > 0$
  • $dp[i] = \max\left( dp[j] \right)$,$\texttt{sum(j + 1, i)} = 0$
  • $dp[i] = \max\left( dp[j] + (j-i) \right)$,$\texttt{sum(j + 1, i)} < 0$

而刚刚频繁出现的 $\texttt{sum(j + 1, i)}$ 又可以拆成 $s[i] - s[(j+1)-1] = s[i] - s[j]$。于是上面三个式子又可以移项、变形为:

  • $dp[i] = \max\left( dp[j] - j \right) + i$,$s[i] > s[j]$
  • $dp[i] = \max\left( dp[j] \right)$,$s[i] = s[j]$
  • $dp[i] = \max\left( dp[j] + j \right) - i$,$s[i] < s[j]$

于是只需要动态维护 $dp[i] + i$,$dp[i]$,$dp[i] - i$ 的区间最大值即可。

首先对前缀和数组 $s$ 进行离散化,然后把 $s[j]$ 当成下标。每次查询坐标 $s[j]$ 之前(或之后)的区间最大值。当然,对于 $s[i] = s[j]$ 的情况,相当于查询 $s[j]$ 坐标的最大值。

建立三棵线段树,分别维护即可:第零棵维护 $dp[i] + i$,第一棵维护 $dp[i]$,第二棵维护 $dp[i] - i$。dp 数组更新完之后,再进入线段树修改区间最大值(单点修改)。

(又是三棵树,仿佛嗅到了 CF1660F2 的味道)

代码

struct SegTree {
  std::vector<ll> a;
  SegTree(int n) : a((n + 1) * 4, 0) { this->build(1, 1, n); }
#define lson ((o) << 1)
#define rson ((o) << 1 | 1)
  void build(int o, int l, int r) {
    if (l == r) return void(a[o] = -1e18);
    int mid = (l + r) / 2;
    build(lson, l, mid);
    build(rson, mid + 1, r);
    a[o] = std::max(a[lson], a[rson]);
  }
  ll query(int o, int l, int r, int L, int R) {
    if (l >= L && r <= R) return a[o];
    int mid = (l + r) / 2;
    ll res = -1e18;
    if (L <= mid) res = std::max(res, query(lson, l, mid, L, R));
    if (R > mid) res = std::max(res, query(rson, mid + 1, r, L, R));
    return res;
  }
  void change(int o, int l, int r, int x, ll val) {
    if (l == r) return void(a[o] = std::max(a[o], val));
    int mid = (l + r) / 2;
    if (x <= mid)
      change(lson, l, mid, x, val);
    else
      change(rson, mid + 1, r, x, val);
    a[o] = std::max(a[lson], a[rson]);
  }
#undef lson
#undef rson
};

void solution() {
  int n;
  read(n);
  std::vector<int> a(n + 1);
  for (int i = 1; i <= n; ++i) read(a[i]);
  std::vector<ll> s(n + 1);
  for (int i = 1; i <= n; ++i) s[i] = s[i - 1] + a[i];
  // 离散化前缀和数组 s
  std::vector<ll> vs(s.begin(), s.end());
  std::sort(vs.begin(), vs.end());
  std::map<ll, int> belong;
  int tot = 0;
  for (auto i : vs)
    if (!belong.count(i)) belong[i] = ++tot;
  // s2 是离散化后的 s
  std::vector<int> s2(n + 1);
  for (int i = 1; i <= n; ++i) s2[i] = belong[s[i]];

  auto chmax = [](auto& x, auto y) { x = std::max(x, y); };
  std::vector<SegTree> seg(3, SegTree(tot));
  // 下面这行相当于暴力代码的 dp[0] = 0
  for (int i = 0; i < 3; ++i) seg[i].change(1, 1, tot, belong[0], 0);
  for (int i = 1; i <= n; i++) {
    // 对应上述第一个转移方程
    if (s2[i] > 1) chmax(dp[i], seg[2].query(1, 1, tot, 1, s2[i] - 1) + i);
    // 第二个转移方程
    chmax(dp[i], seg[1].query(1, 1, tot, s2[i], s2[i]));
    // 第三个转移方程
    if (s2[i] < tot) chmax(dp[i], seg[0].query(1, 1, tot, s2[i] + 1, tot) - i);
    // 单点修改
    seg[0].change(1, 1, tot, s2[i], dp[i] + i);
    seg[1].change(1, 1, tot, s2[i], dp[i]);
    seg[2].change(1, 1, tot, s2[i], dp[i] - i);
  }
  print(dp[n]);
}