题解 CF1705E【Mark and Professor Koro】

2022-07-18 14:55:35


题意

给定一个数列 $a$,可以进行若干次操作,每次操作可以选择 $a$ 中出现过至少两次的 $x$,删除掉这两个 $x$,再插入新元素 $x+1$。

有 $q$ 次询问,每次询问都永久地将 $a_x$ 修改为 $y$,即强制在线。求:每次询问后,进行若干次操作,可以得到的最大的数值是多少。

分析

将两个相同的数字 $x$ 替换为一个 $x+1$,不难想到转化为二进制数的进位:$2^x + 2^x = 2^{x + 1}$。

而将 $a_x$ 修改为 $y$,可以类似地转化为,先给二进制数减去 $2^{a_x}$ 再加上一个 $2^y$。

将所有可以进位的位置全部进位,直到不能继续操作,即得到了最大的数值。

于是这道题目就转化成了一个高精度计算二进制加减法的题目。当然,由于数据范围较大,python 等语言自带的高精度运算无法通过此题。

可以使用一段伪代码来表示这个过程。

read n, q, a
for i in range(n):
    ans += 2 ** a[i]
for i in range(q):
    read k, l
    k = k - 1
    ans -= 2 ** a[k]
    a[k] = l
    ans += 2 ** a[k]
    print log2(ans)

注意到二进制数的特殊性:每一位上的值,要么是 $0$,要么是 $1$。同时,这道题目也具有特殊性:每次进行加减法的数字,都是 $2$ 的幂,即某个确定的二进制位。

那么我们只需要在程序中模拟二进制数加减法的进位、退位即可。

考虑用树状数组维护该二进制数的一段区间内,位为 $1$ 的个数。

如果一次加法操作之后,第 $x$ 位的值是 $2$,意味着需要进位。可以二分查找第 $x$ 位之后的第一个 $0$ 的位置 $y$,将第 $y$ 位进到 $1$,而 $x \sim y-1$ 位全部修改为 $0$。

如果一次减法操作之后,第 $x$ 位的值是 $-1$,意味着需要借位退位。可以二分查找第 $x$ 位之后第一个 $1$ 的位置 $y$,将第 $y$ 位的值修改为 $0$,而 $x \sim y-1$ 位全部修改为 $1$。

可以在树状数组当中同时维护差分数组 $d_i$ 和 $(i-1)\times d_i$ 的值,实现区间修改和区间查询。

总时间复杂度是 $\mathcal{O}((n+q)\log^2S)$,其中 $S = \sum a_i$。树状数组常数较小,可以轻松通过。

代码

主函数部分代码较长,主要是因为进位、退位部分代码写的比较臭,可重用性比较差,复制粘贴了三遍(

struct FenwickRange {
  int n;
  std::vector<ll> a, b;
  FenwickRange() {}
  FenwickRange(int n) : n(n) { a.resize(n + 1, 0), b.resize(n + 1, 0); }
  int lowbit(int x) { return x & -x; }
  void rangeAdd(int l, int r, ll val) {
    if (l > r) return;
    add(a, l, val);
    add(b, l, (l - 1) * val);
    add(a, r + 1, -val);
    add(b, r + 1, -r * val);
  }
  ll rangeSum(int l, int r) {
    if (l > r) return 0;
    if (l == 0) return rangeSum(1, r);
    ll R = r * sum(a, r) - sum(b, r);
    ll L = (l - 1) * sum(a, l - 1) - sum(b, l - 1);
    return R - L;
  }

 private:
  void add(std::vector<ll>& a, int x, ll val) {
    while (x <= n) {
      a[x] += val;
      x += lowbit(x);
    }
  }
  ll sum(std::vector<ll>& a, int pos) {
    ll res = 0;
    while (pos) {
      res += a[pos];
      pos -= lowbit(pos);
    }
    return res;
  }
};

void solution() {
  int n, q;
  read(n, q);
  std::vector<int> a(n + 1);
  for (int i = 1; i <= n; ++i) read(a[i]);
  const int maxBit = 2e5 + 50;
  FenwickRange fwk(maxBit);
  for (int i = 1; i <= n; ++i) {
    fwk.rangeAdd(a[i], a[i], 1);
    if (fwk.rangeSum(a[i], a[i]) >= 2) {  // 进位
      assert(fwk.rangeSum(a[i], a[i]) == 2);
      // 二分 a[i] 这一位之后的第一个 0 的位置
      int l = 1, r = maxBit - a[i], mid = (l + r) / 2;
      while (l < r) {
        int cnt1 = fwk.rangeSum(a[i] + 1, a[i] + mid);
        int need = mid;
        if (cnt1 >= need) {
          assert(cnt1 == need);
          l = mid + 1;
        } else {
          r = mid;
        }
        mid = (l + r) / 2;
      }
      fwk.rangeAdd(a[i], a[i], -2);
      assert(fwk.rangeSum(a[i] + 1, a[i] + mid - 1) == mid - 1);
      fwk.rangeAdd(a[i] + 1, a[i] + mid - 1, -1);
      fwk.rangeAdd(a[i] + mid, a[i] + mid, 1);
    }
  }

  for (int i = 1; i <= q; ++i) {
    int x, y;
    read(x, y);  // a[x] = y
    int del = a[x];
    fwk.rangeAdd(del, del, -1);
    if (fwk.rangeSum(del, del) <= -1) {  // 退位
      assert(fwk.rangeSum(del, del) == -1);
      // 二分 del 这一位之后的第一个 1 的位置
      int l = 1, r = maxBit - del, mid = (l + r) / 2;
      while (l < r) {
        int cnt1 = fwk.rangeSum(del + 1, del + mid);
        int cnt0 = mid - cnt1;
        int need = mid;
        if (cnt0 >= need) {
          assert(cnt0 == need);
          l = mid + 1;
        } else {
          r = mid;
        }
        mid = (l + r) / 2;
      }
      fwk.rangeAdd(del, del, 2);
      assert(mid - 1 - fwk.rangeSum(del + 1, mid - 1) == mid - 1);
      fwk.rangeAdd(del + 1, del + mid - 1, 1);
      fwk.rangeAdd(del + mid, del + mid, -1);
    }
    a[x] = y;
    fwk.rangeAdd(y, y, 1);
    if (fwk.rangeSum(y, y) >= 2) {  // 进位
      assert(fwk.rangeSum(y, y) == 2);
      // 二分 y 这一位之后的第一个 0 的位置
      int l = 1, r = maxBit - y, mid = (l + r) / 2;
      while (l < r) {
        int cnt1 = fwk.rangeSum(y + 1, y + mid);
        int need = mid;
        if (cnt1 >= need) {
          assert(cnt1 == need);
          l = mid + 1;
        } else {
          r = mid;
        }
        mid = (l + r) / 2;
      }
      fwk.rangeAdd(y, y, -2);
      assert(fwk.rangeSum(y + 1, y + mid - 1) == mid - 1);
      fwk.rangeAdd(y + 1, y + mid - 1, -1);
      fwk.rangeAdd(y + mid, y + mid, 1);
    }

    // 查找最高位
    int l = 1, r = maxBit, mid = (l + r + 1) / 2;
    while (l < r) {
      int cnt1 = fwk.rangeSum(mid, r);
      if (cnt1 >= 1) {
        l = mid;
      } else {
        r = mid - 1;
      }
      mid = (l + r + 1) / 2;
    }
    print(mid);
  }
}