問題概要
問題ページ
-
D - Moving Piece
問題ページへ移動する
問題文
高橋君は \(1, 2, \cdots, N\) の番号のついた \(N\) マスから成るマス目の上で、コマを使ってゲームを行おうとしています。マス \(i\) には整数 \(C_i\) が書かれています。また、\(1, 2 …, N\) の順列 \(P_1, P_2, \cdots, P_N\) が与えられています。
これから高橋君は好きなマスを \(1\) つ選んでコマを \(1\) つ置き、\(1\) 回以上 \(K\) 回以下の好きな回数だけ、次のような方法でコマを移動させます。
- \(1\) 回の移動では、現在コマがマス \(i (1 \leq i \leq N)\) にあるなら、コマをマス \(P_i\) に移動させる。このとき、スコアに \(C_{P_i}\) が加算される。
高橋君のために、ゲーム終了時のスコアとしてあり得る値の最大値を求めてください。(ゲーム開始時のスコアは \(0\) です。)
制約
- \(2 \leq N \leq 5000\)
- \(1 \leq K \leq 10^9\)
- \(1 \leq P_i \leq N\)
- \(P_i \neq i\)
- \(P_1, P_2, \cdots, P_N\) は全て異なる
- \(-10^9 \leq C_i \leq 10^9\)
- 入力は全て整数である
問題の考察
ACコード
import sys
import collections
class AlgUnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
self.d = collections.defaultdict(list)
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
return {r: self.members(r) for r in self.roots()}
def set_data(self, x, amt, cnt):
self.d[self.find(x)] += [amt, cnt]
def __str__(self):
return '\n'.join('{}: {}'.format(r, self.members(r)) for r in self.roots())
def solve():
input = sys.stdin.readline
mod = 10 ** 9 + 7
n, k = list(map(int, input().rstrip('\n').split()))
p = [v - 1 for v in list(map(int, input().rstrip('\n').split()))]
c = list(map(int, input().rstrip('\n').split()))
mx = -(10 ** 20)
uf = AlgUnionFind(n)
for i in range(n):
if uf.find(i) == i:
pos = i
d = collections.defaultdict(list)
d[i] += [0, 0]
total = 0
for j in range(k):
if p[pos] not in d:
pos = p[pos]
total += c[pos]
d[pos] += [j, total]
mx = max(mx, total)
else:
pos = p[pos]
total += c[pos]
t_pos = pos
for l in range(10 ** 20):
uf.union(t_pos, p[t_pos])
t_pos = p[t_pos]
if t_pos == pos:
break
uf.set_data(pos, total - d[pos][1], j - d[pos][0] + 1)
if total > d[pos][1]:
cnt = uf.d[uf.find(pos)][1]
amt = uf.d[uf.find(pos)][0]
nokori = k - j - 1
loop = max(nokori // cnt - 1, 0)
total += amt * loop
for l in range(nokori - loop * cnt):
pos = p[pos]
total += c[pos]
mx = max(mx, total)
break
else:
cnt = uf.d[uf.find(i)][1]
amt = uf.d[uf.find(i)][0]
loop = max(k // cnt - 1, 0)
total = amt * loop if amt * loop > 0 else 0
pos = i
for l in range(k - loop * cnt):
pos = p[pos]
total += c[pos]
mx = max(mx, total)
print(mx)
if __name__ == '__main__':
solve()