okayunのプログラミング日記

主に自分の解いた問題やCodinGameなどの解法などを載せていきます。

ARC044-B : 最短路問題

ARC044-B : 最短路問題

なんとなく解いた問題で, 難しくはないんですがちゃんと考察したのでまとめておきます.

問題の概要は省いて考察をここに書きます.
与えられる入力を満たすようなグラフの種類はいくつあるかを答えるのですが, 入力で与えられるのは, 各頂点の, 頂点 1 からの最短距離です. つまり, 各頂点の最短距離が変わらないようなグラフを考えればよいです(それはそう).
そこで次の 2 つの条件を考えることが出来ます.

1. 頂点 1 から同じ距離にある頂点同士は接続していても接続していなくてもよい

例えば, 頂点 1 から距離 n にある頂点 a, b があったとします. このとき a, b 同士が接続していようが接続していまいが, 各頂点の最短距離は変わらないので, 接続している場合と接続していない場合の 2 パターンあることがわかります.

2. 頂点 1 からの最短距離が n (n > 1) の頂点は, 最短距離が n - 1 の頂点のどれか 1 つに接続していれば良い

辺のコストはすべて 1 なので, 最短距離が n である頂点は, 最短距離が n - 1 の頂点のどれか一つに接続していればよいです. もちろん, 最短距離が n - 1 より小さい頂点に接続していた場合は, 最短距離が変わってしまうのでありえないことがわかります.

以上の 2 つを考えれば, あとは実装を頑張るだけ(なはず)です.
2 つ注意点を挙げるとすれば, 与えられる入力には条件が成り立たないようなものも存在するのでうまく弾く(例えば, 最短距離が x の頂点はあるが, 最短距離が x - 1 の頂点は存在しないなど)のと, 単純にオーバーフローする可能性があるので気をつけましょうという感じです.

以下が AC したソースコードになります.

#include <iostream>
#include <vector>
#include <algorithm>
#include <map>
#include <functional>
#include <random>
#include <string>
#include <stack>
#include <cassert>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include <cmath>

using std::cin;
using std::cout;
using std::cerr;
using std::endl;

using ll = long long;

const int MOD = int(1e9 + 7);

// x^n mod MOD
ll pow_mod(ll x, ll n) {
  if (n == 0) return 1LL;
  if (n == 1) return x;
  ll ret = pow_mod((x * x) % MOD, n / 2) % MOD;
  if (n & 1) ret = (x * ret) % MOD;
  return ret;
}

int main() {
  cin.tie(0);
  std::ios::sync_with_stdio(false);

  int N;

  cin >> N;

  std::vector<ll> a(N, 0LL);
  ll max_dist = 0, dist[100002]; // dist[i] := 頂点0から距離iにある頂点の数
  ll ans = 1;
  std::fill(dist, dist + 100001, 0);

  for (int i = 0; i < N; ++i) {
    cin >> a[i];
    max_dist = std::max(max_dist, a[i]);
    dist[a[i]]++;
  }

  if (a[0] != 0 || dist[0] != 1) {
    cout << 0 << endl;
    return 0;
  }

  for (int i = 1; i <= max_dist; ++i) {
    if (dist[i] == 0) {
      cout << 0 << endl;
      return 0;
    }

    ll cond1 = pow_mod(2, (dist[i] * (dist[i] - 1)) / 2);
    ll cond2 = (i == 1 ? 1 : (pow_mod((pow_mod(2, dist[i - 1]) - 1), dist[i])));

    ans = (((cond1 * cond2) % MOD) * ans) % MOD;
  }

  cout << (ans % MOD) << endl;

  return 0;
}