Edit

【pythonでABC175を解説】D - Moving Piece

問題概要

問題ページ

問題文

高橋君は \(1, 2, \cdots, N\) の番号のついた \(N\) マスから成るマス目の上で、コマを使ってゲームを行おうとしています。マス \(i\) には整数 \(C_i\) が書かれています。また、\(1, 2 …, N\) の順列 \(P_1, P_2, \cdots, P_N\) が与えられています。

これから高橋君は好きなマスを \(1\) つ選んでコマを \(1\) つ置き、\(1\) 回以上 \(K\) 回以下の好きな回数だけ、次のような方法でコマを移動させます。

  • \(1\) 回の移動では、現在コマがマス \(i (1 \leq i \leq N)\) にあるなら、コマをマス \(P_i\) に移動させる。このとき、スコアに \(C_{P_i}\) が加算される。

高橋君のために、ゲーム終了時のスコアとしてあり得る値の最大値を求めてください。(ゲーム開始時のスコアは \(0\) です。)

制約

  • \(2 \leq N \leq 5000\)
  • \(1 \leq K \leq 10^9\)
  • \(1 \leq P_i \leq N\)
  • \(P_i \neq i\)
  • \(P_1, P_2, \cdots, P_N\) は全て異なる
  • \(-10^9 \leq C_i \leq 10^9\)
  • 入力は全て整数である

問題の考察

ACコード

import sys
import collections


class AlgUnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n
        self.d = collections.defaultdict(list)

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        return {r: self.members(r) for r in self.roots()}

    def set_data(self, x, amt, cnt):
        self.d[self.find(x)] += [amt, cnt]

    def __str__(self):
        return '\n'.join('{}: {}'.format(r, self.members(r)) for r in self.roots())


def solve():
    input = sys.stdin.readline
    mod = 10 ** 9 + 7
    n, k = list(map(int, input().rstrip('\n').split()))
    p = [v - 1 for v in list(map(int, input().rstrip('\n').split()))]
    c = list(map(int, input().rstrip('\n').split()))
    mx = -(10 ** 20)
    uf = AlgUnionFind(n)
    for i in range(n):
        if uf.find(i) == i:
            pos = i
            d = collections.defaultdict(list)
            d[i] += [0, 0]
            total = 0
            for j in range(k):
                if p[pos] not in d:
                    pos = p[pos]
                    total += c[pos]
                    d[pos] += [j, total]
                    mx = max(mx, total)
                else:
                    pos = p[pos]
                    total += c[pos]
                    t_pos = pos
                    for l in range(10 ** 20):
                        uf.union(t_pos, p[t_pos])
                        t_pos = p[t_pos]
                        if t_pos == pos:
                            break
                    uf.set_data(pos, total - d[pos][1], j - d[pos][0] + 1)
                    if total > d[pos][1]:
                        cnt = uf.d[uf.find(pos)][1]
                        amt = uf.d[uf.find(pos)][0]
                        nokori = k - j - 1
                        loop = max(nokori // cnt - 1, 0)
                        total += amt * loop
                        for l in range(nokori - loop * cnt):
                            pos = p[pos]
                            total += c[pos]
                            mx = max(mx, total)
                    break
        else:
            cnt = uf.d[uf.find(i)][1]
            amt = uf.d[uf.find(i)][0]
            loop = max(k // cnt - 1, 0)
            total = amt * loop if amt * loop > 0 else 0
            pos = i
            for l in range(k - loop * cnt):
                pos = p[pos]
                total += c[pos]
                mx = max(mx, total)
    print(mx)


if __name__ == '__main__':
    solve()

-Edit
-