IOI 2009 Bulgaria過去問 Regions

解法

まず木の節点の番号をpreorderで付け直す。すると、ある節点の子の節点の番号は連続する整数になるので、全ての節点についてその子の節点番号の区間を計算しておけば「節点bは節点aの子か?」という問にO(1)で答えることができて便利。これは木を辿るだけなのでO(N)。

次に、queryにおいてr1を固定したときにその答を効率良く計算することを考える。すなわち全ての地域 r について、 (r1, r) という形式のqueryに対する答を先に全て(効率良く)計算しておく。これは cnt1[v] に節点vから根までの最短経路に、地域 r1 から来る従業員が何人いるかを計算すれば求めることができる。cnt1[v]を全ての節点について計算するのはDFSなりを使ってO(N)。

さらに、queryにおいてr2を固定したときにその答を効率良く計算することも考える。これも先と同様に、cnt2[v] に節点vから根までの最短経路に、地域 r2 から来る従業員が何人いるかを計算して求められる。同じくDFSなどを使ってO(N)。

最後に任意のqueryを効率良く計算することを考える。ある節点vから根までの最短経路に地域 r1 からくる従業員が何人いるかを考えると、その人数によって[1, N] の区間がO(|r1|)個に分割されることがわかる(|r1| は地域 r1 から来る従業員の総数)。

たとえば図のような組織図の場合(赤は節点番号、青が出身地域)を考える。いま、地域Cについて考えると[1, N] の区間

[1, 1], [2, 3], [4, 6], [7, 7], [8, 8], [9, 9], [10, 11], [12, 13], [14, 14], [15, 15]

と分割されることがわかる。これらの区間に属する節点から根までの最短経路に居る地域Cからの従業員の数はそれぞれ

0, 1, 2, 3, 0, 1, 1, 2, 3, 1

となる。この分割は最初にpreorderで木を辿るときに計算することができるので、全ての地域についてこの分割をO(N)で計算することができる。(上の例で連続する区間において従業員の数が同じな部分があるが、それはpreorderで木をたどりながら計算するとそういう風に分割されてしまうということである)

さて、このような分割がわかれば、r2に属するすべての節点について、r1の分割のどの区間に属するかを調べればquery (r1, r2) に対する答を計算することができる。これを高速に行うにはちょうどSweep line のような感じでr1の分割とr2の節点をなぞっていってやるとよい。計算量はO(|r1| + |r2|) 。

上で述べた3つの技法を使って100点を得ることができる。まず、最初の2つの技法だけでは100点をとるには時間もメモリも足りない。また最後の技法だけでも100点をとるには時間が足りない。

そこで、一部の地域については最初の2つの技法を使って予め答を計算しておき、残りの地域については最後の技法を使ってそのつど計算させるようにすることを考える。ここで、sqrt(N)人以上の従業員をもつ地域は高々sqrt(N)個であることを使って、人数がsqrt(N)人以上の地域については予め答を計算し、そうでない地域についてはそのつど計算させることにする。

そうすると、最初の2つの技法は高々sqrt(N)回だけ使われるので計算量は O(N*sqrt(N)) であり、最後の技法は地域の大きさが sqrt(N) より小さい地域に対してだけ使われるので計算量は O(Q*sqrt(N)) である。よって合計すると時間計算量は O(N*sqrt(N) + Q*sqrt(N)) で、空間計算量は O(R*sqrt(N) + N) であり、100点を得ることができる。

ちなみに、最後の技法での計算を二分探索で行うと、O(|r1| lg |r2|) および O(|r2| lg |r1|) 時間で1つのqueryに答えることが可能になる。この2つと O(|r1| + |r2|) のアルゴリズムを併用し、かつ従業員が多い地域に関するqueryについてはメモ化するなどすると、空間計算量をぐんと抑えて100点がとれると思う(実装はしていない)。

ソース

/*
 * task: IOI2009 Regions
 * lang: C++
 * name: JAPLJ
 */
#include<cstdio>
#include<vector>
#include<algorithm>

using namespace std;

int N, Q, R;
vector<int> mregions;
int mregions_ans1[500][25000], mregions_ans2[25000][500]; // < 96MB

struct node {
  int region;
  int new_idx;
  vector<int> children;
} input[200050]; // < 5MB

bool operator < (const node& a, const node& b)
{
  return a.new_idx < b.new_idx;
}

struct region_info {
  int num, id, mreg_id;
  vector<int> indices;
  vector<pair<int, int> > intervals;
  region_info() : num(0), mreg_id(-1), indices(), intervals() { }
} regions[25050]; // < 2MB

// O(N)
void dfs(int p, int& num)
{
  int reg = input[p].region;
  input[p].new_idx = num++;
  regions[reg].indices.push_back(input[p].new_idx);
  regions[reg].num++;
  regions[reg].intervals.push_back(make_pair(num, regions[reg].num));
  for(int i=0; i<input[p].children.size(); ++i)
    dfs(input[p].children[i], num);
  regions[reg].num--;
  regions[reg].intervals.push_back(make_pair(num, regions[reg].num));
}

// O(N)
void precalc_fixr1_dfs(const region_info& r, int v, vector<int>& cnt, int pnum)
{
  cnt[v] = pnum;
  if(input[v].region == r.id) pnum++;
  for(int i=0; i<input[v].children.size(); ++i)
    precalc_fixr1_dfs(r, input[v].children[i], cnt, pnum);
}

// O(N)
void precalc_fixr1(const region_info& r)
{
  vector<int> cnt(N);
  if(r.intervals.size() == 0) return;
  precalc_fixr1_dfs(r, 0, cnt, 0);
  for(int i=0; i<N; ++i)
    mregions_ans1[r.mreg_id][input[i].region] += cnt[i];
}

// O(N)
void precalc_fixr2_dfs(const region_info& r, int v, vector<int>& cnt)
{
  int num = 0;
  for(int i=0; i<input[v].children.size(); ++i) {
    precalc_fixr2_dfs(r, input[v].children[i], cnt);
    num += cnt[input[v].children[i]];
  }
  if(input[v].region == r.id) num++;
  cnt[v] = num;
}

// O(N)
void precalc_fixr2(const region_info& r)
{
  vector<int> cnt(N);
  if(r.intervals.size() == 0) return;
  precalc_fixr2_dfs(r, 0, cnt);
  for(int i=0; i<N; ++i)
    mregions_ans2[input[i].region][r.mreg_id] += cnt[i];
}

// O(r1.indices.size() + r2.indices.size())
int solve(const region_info& r1, const region_info& r2)
{
  int ret = 0, pos = 0;
  if(r1.intervals.size() == 0) return 0;
  while(pos < r2.indices.size() && r2.indices[pos] < r1.intervals[0].first)
    pos++;
  for(int i=0; i+1<r1.intervals.size() && pos<r2.indices.size(); ++i) {
    int prev = pos;
    while(pos < r2.indices.size() && r2.indices[pos] < r1.intervals[i+1].first)
      pos++;
    ret += r1.intervals[i].second * (pos - prev);
  }
  return ret;
}

int main()
{
  scanf("%d%d%d", &N, &R, &Q);
  scanf("%d", &input[0].region);
  input[0].region--;
  for(int i=1; i<N; ++i) {
    int super, region;
    scanf("%d%d", &super, &region);
    input[i].region = region-1;
    input[super-1].children.push_back(i);
  }
  int temp = 0;
  dfs(0, temp);
  for(int i=0; i<N; ++i)
    for(int j=0; j<input[i].children.size(); ++j)
      input[i].children[j] = input[input[i].children[j]].new_idx;
  for(int i=0; i<R; ++i) {
    regions[i].id = i;
    if(regions[i].indices.size() > 447) { // 447 = floor(sqrt(200000))
      mregions.push_back(i);
      regions[i].mreg_id = mregions.size()-1;
    }
  }
  sort(input, input+N);
  for(int i=0; i<mregions.size(); ++i) {
    precalc_fixr1(regions[mregions[i]]);
    precalc_fixr2(regions[mregions[i]]);
  }
  for(int i=0; i<Q; ++i) {
    int r1, r2;
    scanf("%d%d", &r1, &r2);
    r1--; r2--;
    if(regions[r1].mreg_id != -1) {
      printf("%d\n", mregions_ans1[regions[r1].mreg_id][r2]);
    } else if(regions[r2].mreg_id != -1) {
      printf("%d\n", mregions_ans2[r1][regions[r2].mreg_id]);
    } else {
      printf("%d\n", solve(regions[r1], regions[r2]));
    }
    fflush(stdout);
  }
  return 0;
}