Edit

【pythonでABC183を解説】F - Confluence

問題概要

問題ページ

F - Confluence
F - Confluence

問題ページへ移動する

問題文

\(N\) 人の生徒が登校しようとしています。生徒 \(i\) はクラス \(C_i\) に属しています。

各生徒はそれぞれの家から出発したあと、他の生徒と合流を繰り返しながら学校へ向かいます。一度合流した生徒が分かれることはありません。

\(Q\) 個のクエリが与えられるので、順番に処理してください。クエリには \(2\) 種類あり、入力形式とクエリの内容は以下の通りです。

  • 1 a b : 生徒 \(a\) を含む集団と、生徒 \(b\) を含む集団が合流する (既に合流しているときは何も起こらない)
  • 2 x y : クエリの時点で既に生徒 \(x\) と合流している生徒(生徒 \(x\) を含む)のうち、クラス \(y\) に属している生徒の数を求める

制約

  • \(1 \leq N \leq 2 \times 10^5\)
  • \(1 \leq Q \leq 2 \times 10^5\)
  • \(1 \leq C_i,a,b,x,y \leq N\)
  • 1 a b のクエリにおいて、\(a \neq b\)
  • 入力はすべて整数

問題の考察

ACコード

import sys
import collections


class AlgUnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n
        self.d = []
        for i in range(n):
            self.d.append(collections.defaultdict(int))

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x
        for k, v in self.d[y].items():
            self.d[x][k] += v
            self.d[y][k] -= v

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        return {r: self.members(r) for r in self.roots()}

    def __str__(self):
        return '\n'.join('{}: {}'.format(r, self.members(r)) for r in self.roots())


def solve():
    input = sys.stdin.readline
    mod = 10 ** 9 + 7
    n, q = list(map(int, input().rstrip('\n').split()))
    uf = AlgUnionFind(n)
    for i, v in enumerate(list(map(int, input().rstrip('\n').split()))):
        uf.d[i][v - 1] += 1
    for i in range(q):
        a, b, c = list(map(int, input().rstrip('\n').split()))
        b, c = b - 1, c - 1
        if a == 1:
            uf.union(b, c)
        else:
            print(uf.d[uf.find(b)][c])


if __name__ == '__main__':
    solve()

-Edit
-