Edit

【pythonでABC190を解説】E - Magical Ornament

問題概要

問題ページ

https://atcoder.jp/contests/abc190/tasks/abc190_e
https://atcoder.jp/contests/abc190/tasks/abc190_e

問題ページへ移動する

問題文

AtCoder 王国には \(1, 2, \dots, N\) の番号がついた \(N\) 種類の魔法石が流通しています。
高橋くんは、魔法石を一列に並べて飾りを作ろうとしています。
魔法石には隣り合わせにできる組とできない組があります。
隣り合わせにできる組は \((\)魔法石 \(A_1,\) 魔法石 \(B_1), (\)魔法石 \(A_2,\) 魔法石 \(B_2), \dots, (\)魔法石 \(A_M,\) 魔法石 \(B_M)\) の \(M\) 組で、それ以外の組は隣り合わせることができません。(これらの組において、石の順序は不問です。)
魔法石 \(C_1, C_2, \dots, C_K\) をそれぞれ \(1\) 個以上含む魔法石の列を作ることができるか判定し、作れる場合はそのような列を作るのに必要な魔法石の個数の最小値を求めてください。

制約

  • 入力は全て整数
  • \(1 ≤ N ≤ 10^5\)
  • \(0 ≤ M ≤ 10^5\)
  • \(1 ≤ A_i < B_i ≤ N\)
  • \(i ≠ j\) ならば \((A_i, B_i) ≠ (A_j, B_j)\)
  • \(1 ≤ K ≤ 17\)
  • \(1 ≤ C_1 < C_2 < \dots < C_K ≤ N\)

問題の考察

ACコード

import sys
import collections


def solve():
    input = sys.stdin.readline
    mod = 10 ** 9 + 7
    n, m = list(map(int, input().rstrip('\n').split()))
    edges = [[] for _ in range(n)]
    for i in range(m):
        edge_a, edge_b = list(map(int, input().rstrip('\n').split()))
        edges[edge_a-1].append(edge_b-1)
        edges[edge_b-1].append(edge_a-1)
    k = int(input().rstrip('\n'))
    c = list(map(int, input().rstrip('\n').split()))
    d = collections.defaultdict(int)
    for i in range(len(c)):
        c[i] -= 1
        d[c[i]] = i

    dist = [[10 ** 20] * k for _ in range(k)]
    fq = [-1] * n
    for i in range(k):
        ql = [[0, c[i]]]
        ql = collections.deque(ql)
        while ql:
            cost, now_edge = ql.popleft()
            if now_edge in d:
                dist[i][d[now_edge]] = min(dist[i][d[now_edge]], cost)
                dist[d[now_edge]][i] = dist[i][d[now_edge]]
            if fq[now_edge] < i:
                fq[now_edge] = i
                for next_edge in edges[now_edge]:
                    ql.append((cost + 1, next_edge))

    f = 10 ** 20
    dp = [[f] * (1 << k) for _ in range(k)]
    for i in range(k):
        dp[i][1 << i] = 1
    for i in range(1 << k):
        for j in range(k):
            if dp[j][i] != f:
                for l in range(k):
                    if j != l:
                        n_bit = 1 << l
                        if i & n_bit == 0:
                            dp[l][i|n_bit] = min(dp[l][i|n_bit], dp[j][i] + dist[j][l])
    ans = f
    for i in range(k):
        ans = min(ans, dp[i][-1])
    print(ans if ans != f else -1)


if __name__ == '__main__':
    solve()

-Edit
-