IOI 2008 Egypt過去問 PYRAMID BASE

問題概要

横M、縦Nの二次元フィールドに辺がx軸y軸とそれぞれ平行な長方形の障害物がP個ある。i個目の障害物はCiの金で撤去可能であり、与えられた予算はBである。このフィールド上にできるだけ広い正方形の土地を確保せよ。

  • 1 <= M <= 1,000,000
  • 1 <= N <= 1,000,000
  • 1 <= Ci <= 7,000

次の3つのテストグループがある。

35点分
  • B = 0 (障害物は撤去できない)
  • P <= 1,000
35点分
  • 0 < B <= 2,000,000,000
  • P <= 30,000
30点分
  • B = 0 (障害物は撤去できない)
  • P <= 400,000

解法

まず最初に共通の性質として、答となる最大の正方形の辺は必ずどれかの長方形に接している。よって長方形に接するところだけを調べれば良い、というのが基本。

2つ目の35点分の解法

つまり障害物を撤去できるとき。基本的には正方形の一辺の長さについて二分探索を行う。ある長さ L の正方形がフィールド上に確保できるかどうかを現実的な時間で判定できればこのケースは解けることになる。

答となる正方形の辺が長方形に接することを利用すると、全ての長方形について、それに接する一辺の長さ L の正方形を置いたときに、その正方形と重なる長方形を撤去するためのコストが B 以下ならばよい。しかし、これをそのまま実装すると時間がかかりすぎる。

そこでまず問題を変形する。長さ L の正方形をフィールド上にそのまま確保するのではなくて、全ての長方形の横の長さを左に、縦の長さを下に、それぞれ L だけ伸ばしたフィールド上において一辺の長さが 1 の正方形を確保すると考える。(このときの一辺の長さが 1 の正方形は、目的とする一辺の長さ L の正方形の左下端になっている)

するとこの問題は次のようにして解ける。最初に要素数 N の配列 S を用意し、全要素を 0 で初期化する。そして走査線を左から右に走らせつつ、次の操作を行う。

  • 走査線が長方形の左端に達した場合、その長方形の下y座標(y1とする)から上y座標(y2とする)まで、S[y1..y2] にそれぞれこの長方形を撤去するためのコスト(cとする)を加算する。
  • 走査線が長方形の右端に達した場合、S[y1..y2] からそれぞれ c を引く。
  • いずれの場合でも、上記の2操作が終わった後に S[1..N-L+1] から最小の要素を見つける。その最小の要素の値が B 以下であれば、フィールド上に一辺の長さ L の正方形を確保することは可能であると判定する。そうでない場合は走査を続ける。

これをそのまま実装すると当然遅いので、Sをsegment treeで実装してやれば一回の操作(範囲に対する加算、減算、最小値)はそれぞれ O(log N) で済む。また、走査線は長方形の左右両端のみに興味があるので、そこだけ操作するようにすればこの判定は全体で O(P log N) で動く。

あとはこれを使って二分探索をするだけで、その範囲は 1 から min(M, N) までなので、全体の計算量は O(P log N log min(M, N)) となる。segment treeの実装は慣れていれば優しいし、他の実装も軽めなので辛くはない。

1つ目の35点分および3つ目の30点分の解法

つまり障害物を撤去できないとき。ここで考えるのは、「正方形の左端のx座標を固定したとき、とりうる最大の右端のx座標は何か?」ということである。この値が左端のx座標について広義単調増加であることから、次のアルゴリズムを開発できる。

最初に要素数 N の配列 S を用意し、全要素を 0 で初期化する。また、答となる値を sol とし、最初は0にしておく。そして走査線を2本、左から右に走らせつつ、次の操作を行う。ここで2本の走査線のうち1本は、上記の問における「左端のx座標」に対応し(左走査線とする)、もう1本は「とりうる最大の右端のx座標」に対応する(右走査線とする)。

  • 右走査線が長方形の左端に達した場合、S[y1..y2] にそれぞれ(正の値ならなんでもよいが) 1 を加算する。
  • 左走査線が長方形の右端に達した場合、S[y1..y2] からそれぞれ 1 を引く。
  • いずれの場合でも、min(右走査線のx座標 - 左走査線のx座標, S[1..N]の中で0が連続する最大の区間の長さ)と sol を比較し、その値がsolより多ければ sol をその値にする。

ここで「右走査線のx座標 - 左走査線のx座標」がその時にとりうる正方形の一辺の最大の長さの上界になってることは言うまでもないとして、「S[1..N]の中で0が連続する最大の区間の長さ」というのが「走査線に挟まれた部分において、左走査線から右走査線までの間に、障害物が占めるマスがひとつもないようなy座標が連続する最大の長さ」であることを考えれば、これで最適解が出せることに納得がいくはずである。

この場合もそのままの実装では当然遅いので、Sをsegment treeで実装してやれば一回の操作が O(log N) で済む。この走査も同じく長方形の両端だけを走査するので時間計算量は全体で O(P log N) である。

実装に関しては、segment treeの実装はそこまで難しくない(とはいっても、やや応用的な範囲になるのかな?)。ただ走査の部分が(少なくとも僕は)バグバグでどうしようもなく手間取った。まあこれは単に僕がSweep-line系のアルゴリズムの実装が死ぬほど苦手だというだけかもしれない。

ソース

MacBook Pro上のVirtual BoxにIOI2010の競技環境を作った上で、最大ケースが4秒ちょっとぐらい。時間制限は5秒になってるけど、本番に間に合うかどうかは非常に怪しいのでもっと賢い実装が必要かも。そもそもいくら走査の実装が苦手だからといってもあまりにもひどいと思う……。

segtree, solveB0 が B = 0 のときの解法で、 segtree2, solve が B > 0 のときの解法。

#include<cstdio>
#include<vector>
#include<algorithm>

using namespace std;

struct segtree {
  struct seg {
    int add;
    int zleft, zright, zmax;
  } *S;
  int n;
  void init(int x, int l, int r) {
    S[x].add = 0;
    S[x].zleft = S[x].zright = S[x].zmax = r-l;
    if(r-l <= 1) return;
    int mid = (l+r)/2;
    init(x*2, l, mid);
    init(x*2+1, mid, r);
  }
  segtree(int siz) {
    n = siz;
    S = new seg[4*siz+50];
    init(1, 1, siz+1);
  }
  inline int get_r(int x, int cl, int cr, int l, int r) {
    if(r-l <= 0 || S[x].add > 0) return 0;
    return min(S[x].zright, cr-l);
  }
  inline int get_l(int x, int cl, int cr, int l, int r) {
    if(r-l <= 0 || S[x].add > 0) return 0;
    return min(S[x].zleft, r-cl);
  }
  int query(int x, int cl, int cr, int l, int r, int psum) {
    if(r-l <= 0) return 0;
    if(psum+S[x].add > 0) return 0;
    if(cl == l && cr == r) return psum == 0 ? S[x].zmax : 0;
    int ret = 0, mid = (cl+cr)/2;
    if(r<=mid) {
      ret = query(2*x, cl, mid, l, r, psum+S[x].add);
    } else if(mid<=l) {
      ret = query(2*x+1, mid, cr, l, r, psum+S[x].add);
    } else {
      ret = query(2*x, cl, mid, l, mid, psum+S[x].add);
      ret = max(ret, query(2*x+1, mid, cr, mid, r, psum+S[x].add));
      ret = max(ret, get_r(2*x, cl, mid, l, mid) + get_l(2*x+1, mid, cr, mid, r));
    }
    return ret;
  }
  int query(int l, int r) {
    return query(1, 1, n+1, l, r, 0);
  }
  void incr(int x, int cl, int cr, int l, int r, int v, int psum) {
    int mid = (cl+cr)/2;
    if(cl == l && cr == r) {
      S[x].add += v;
      if(S[x].add > 0) S[x].zleft = S[x].zright = S[x].zmax = 0;
      else {
        if(r-l == 1) S[x].zleft = S[x].zright = S[x].zmax = 1;
        else {
          S[x].zleft = get_l(2*x, cl, mid, cl, mid);
          if(S[x].zleft == mid-cl) S[x].zleft += get_l(2*x+1, mid, cr, mid, cr);
          S[x].zright = get_r(2*x+1, mid, cr, mid, cr);
          if(S[x].zright == cr-mid) S[x].zright += get_r(2*x, cl, mid, cl, mid);
          S[x].zmax = max(S[2*x].zmax, S[2*x+1].zmax);
          S[x].zmax = max(S[x].zmax, get_r(2*x, cl, mid, cl, mid) + get_l(2*x+1, mid, cr, mid, cr));
        }
      }
      return;
    }
    if(l<mid) incr(2*x, cl, mid, l, min(r, mid), v, psum+S[x].add);
    if(mid<r) incr(2*x+1, mid, cr, max(l, mid), r, v, psum+S[x].add);
    if(S[x].add > 0) {
      S[x].zleft = S[x].zright = S[x].zmax = 0;
      return;
    }
    S[x].zleft = get_l(2*x, cl, mid, cl, mid);
    if(S[x].zleft == mid-cl) S[x].zleft += get_l(2*x+1, mid, cr, mid, cr);
    S[x].zright = get_r(2*x+1, mid, cr, mid, cr);
    if(S[x].zright == cr-mid) S[x].zright += get_r(2*x, cl, mid, cl, mid);
    S[x].zmax = max(S[2*x].zmax, S[2*x+1].zmax);
    S[x].zmax = max(S[x].zmax, get_r(2*x, cl, mid, cl, mid) + get_l(2*x+1, mid, cr, mid, cr));  
  }
  void insert(int l, int r) {
    incr(1, 1, n+1, l, r, 1, 0);
  }
  void remove(int l, int r) {
    incr(1, 1, n+1, l, r, -1, 0);
  }
};

struct side {
  int x, cost;
  int y1, y2;
  bool start;
  side(int a, int b, int c, int d=0, bool e=false)
    : x(a), y1(b), y2(c), cost(d), start(e) { }
};
struct side_cmp {
  bool operator ()(const side& a, const side& b) {
    return a.x == b.x ? a.y1 < b.y1 : a.x < b.x;
  }
};
struct side_cmp_b {
  bool operator ()(const side& a, const side& b) {
    return a.x == b.x ? (a.start != b.start ? b.start : false) : a.x < b.x;
  }
};
int solveB0(int M, int N, int P,
            vector<int> x1, vector<int> y1, 
            vector<int> x2, vector<int> y2, vector<int> c)
{
  vector<side> E1, E2;
  for(int i=0; i<P; ++i) {
    E1.push_back(side(x1[i], y1[i], y2[i]));
    E2.push_back(side(x2[i], y1[i], y2[i]));
  }
  E1.push_back(side(N+1, 1000000000, 1));
  E2.push_back(side(1, -1000000000, 1));
  sort(E1.begin(), E1.end(), side_cmp());
  sort(E2.begin(), E2.end(), side_cmp());
  int left = 0, right = 0, sol = 0;
  segtree S(M);
  while(right < E1.size()) {
    int ptr;
    bool fail = false;
    while(right < E1.size()) {
      sol = max(sol, min(E1[right].x-E2[left].x, S.query(1, M+1)));
      ptr = right;
      if(S.query(1, M+1) < E1[right].x-E2[left].x) {  
        fail = true;
        break;
      }
      while(ptr < E1.size() && E1[right].x==E1[ptr].x) {
        if(E1[ptr].x <= N) { 
          S.insert(E1[ptr].y1, E1[ptr].y2);
        }
        ptr++;
      }
      if(S.query(1, M+1) < E1[right].x-E2[left].x) {
        fail = true;
        break;
      }
      right = ptr;
      if(right >= E1.size()) break;
      sol = max(sol, min(E1[right].x-E2[left].x, S.query(1, M+1)));
    }
    if(right >= E1.size()) break;
    sol = max(sol, min(E1[right].x-E2[left].x, S.query(1, M+1)));
    if(fail) {
      while(left+1 < E2.size() && S.query(1, M+1) < E1[right].x-E2[left].x) {
        int ptr2 = left+1;
        while(ptr2 < E2.size() && E2[left+1].x == E2[ptr2].x) {
          if(ptr2 > 0) {
            S.remove(E2[ptr2].y1, E2[ptr2].y2);
          }
          ptr2++;      
        }
        left = ptr2-1;
        fail = false;
      }
      for(int p=right; p<ptr; ++p)
        S.remove(E1[p].y1, E1[p].y2);
      ptr = right;
    }
    right = ptr;
    sol = max(sol, min(E1[right].x-E2[left].x, S.query(1, M+1)));
    if(N+1-E2[left].x <= sol) break;
  }
  return min(sol, min(M, N));
}

struct segtree2 {
  struct seg {
    int val, add;
  } *S;
  int n;
  void init(int x, int l, int r) {
    S[x].val = S[x].add = 0;
    if(r-l <= 1) return;
    int mid = (l+r)/2;
    init(2*x, l, mid);
    init(2*x+1, mid, r);
  }
  segtree2(int siz) {
    n = siz;
    S = new seg[4*n+50];
    init(1, 1, n+1);
  }
  ~segtree2() { delete[] S; }
  void incr(int x, int cl, int cr, int l, int r, int v) {
    if(cl == l && cr == r) {
      S[x].add += v;
      return;
    }
    int mid = (cl+cr)/2;
    if(l<mid) incr(2*x, cl, mid, l, min(mid, r), v);
    if(mid<r) incr(2*x+1, mid, cr, max(mid, l), r, v);
    S[x].val = min(S[2*x].val+S[2*x].add, S[2*x+1].val+S[2*x+1].add);
  }
  void incr(int l, int r, int v) {
    incr(1, 1, n+1, l, r, v);
  }
  int query(int x, int cl, int cr, int l, int r) {
    if(cl == l && cr == r) {
      return S[x].val + S[x].add;
    }
    int mid = (cl+cr)/2;
    if(r<=mid) {
      return query(2*x, cl, mid, l, r) + S[x].add;
    } else if(mid<=l) {
      return query(2*x+1, mid, cr, l, r) + S[x].add;
    } else {
      return min(query(2*x, cl, mid, l, mid), query(2*x+1, mid, cr, mid, r)) + S[x].add;
    }
  }
  int query(int l, int r) {
    return query(1, 1, n+1, l, r);
  }
};

bool check(int M, int N, int B, int L, int P,
           vector<int> x1, vector<int> y1, 
           vector<int> x2, vector<int> y2, vector<int> c)
{
  segtree2 S(M);
  vector<side> E;
  for(int i=0; i<P; ++i) {
    x1[i] = max(1, x1[i]-L+1);
    y1[i] = max(1, y1[i]-L+1);
  }
  for(int i=0; i<P; ++i) {
    E.push_back(side(x1[i], y1[i], y2[i], c[i], true));
    E.push_back(side(x2[i], y1[i], y2[i], -c[i], false));
  }
  E.push_back(side(1, -1, -1, true));
  E.push_back(side(N+1-L, -1, -1, false));
  sort(E.begin(), E.end(), side_cmp_b());
  for(int i=0; i<E.size(); ) {
    int p;
    if(E[i].x > N+1-L) break;
    for(p=i; p<E.size() && E[i].x==E[p].x; ++p) 
      if(E[p].y1 >= 1)
        S.incr(E[p].y1, E[p].y2, E[p].cost);
    i = p;
    if(S.query(1, M-L+2) <= B) return true;
  }
  return false;
}

int solve(int M, int N, int B, int P,
          vector<int> x1, vector<int> y1, 
          vector<int> x2, vector<int> y2, vector<int> c)
{
  int lo = 1, hi = min(M, N);
  if(!check(M, N, B, lo, P, x1, y1, x2, y2, c)) return 0;
  while(hi-lo>1) {
    int mid = (hi+lo)/2;
    if(check(M, N, B, mid, P, x1, y1, x2, y2, c))
      lo = mid;
    else
      hi = mid;
  }
  if(check(M, N, B, hi, P, x1, y1, x2, y2, c)) return hi;
  return lo;
}

int main()
{
  int M, N, B, P;
  scanf("%d%d%d%d", &N, &M, &B, &P);
  vector<int> x1(P), y1(P), x2(P), y2(P), c(P);
  for(int i=0; i<P; ++i) {
    scanf("%d%d%d%d%d", &x1[i], &y1[i], &x2[i], &y2[i], &c[i]);
    x2[i]++; y2[i]++;
  }
  if(B == 0) printf("%d\n", solveB0(M, N, P, x1, y1, x2, y2, c));
  else printf("%d\n", solve(M, N, B, P, x1, y1, x2, y2, c));
  return 0;
}