AtCoder Beginner Contest

【pythonでABC185を解説】F - Range Xor Query

問題概要

問題ページ

問題文

長さ \(N\) の整数列 \(A\) があります。
あなたは今からこの数列について \(Q\) 個のクエリを処理します。\(i\) 番目のクエリでは、 \(T_i, X_i, Y_i\) が与えられるので、以下の処理をしてください。

  • \(T_i = 1\) のとき
    \(A_{X_i}\) を \(A_{X_i} \oplus Y_i\) で置き換える
  • \(T_i = 2\) のとき
    \(A_{X_i} \oplus A_{X_i + 1} \oplus A_{X_i + 2} \oplus \dots \oplus A_{Y_i}\) を出力する

ただし \(a \oplus b\) は \(a\) と \(b\) のビット単位 xor を表します。

ビット単位 xor とは

整数 \(A, B\) のビット単位 xor 、\(A \oplus B\) は、以下のように定義されます。

  • \(A \oplus B\) を二進表記した際の \(2^k\) (\(k \geq 0\)) の位の数は、\(A, B\) を二進表記した際の \(2^k\) の位の数のうち一方のみが \(1\) であれば \(1\)、そうでなければ \(0\) である。

例えば、\(3 \oplus 5 = 6\) となります (二進表記すると: \(011 \oplus 101 = 110\))。

制約

  • \(1 \le N \le 300000\)
  • \(1 \le Q \le 300000\)
  • \(0 \le A_i \lt 2^{30}\)
  • \(T_i\) は \(1\) または \(2\)
  • \(T_i = 1\) なら \(1 \le X_i \le N\) かつ \(0 \le Y_i \lt 2^{30}\)
  • \(T_i = 2\) なら \(1 \le X_i \le Y_i \le N\)
  • 入力は全て整数

問題の考察

範囲の排他的論理和(XOR)を求める問題。

排他的論理和は交換法則が成り立つのでセグメントツリーが使える。

セグメントツリーのテンプレのモノイドを変更するだけで対応できる。

def __init__(self, size, f=lambda x, y: x ^ y, default=0):

こんな感じに変更すれば良い。

セグメントツリーの基本的な使い方は以下参照。

たびすけ
セグメントツリーは競プロ頻出です。
この問題はF問題ですが、セグメントツリーの典型的な問題なので難易度自体は高くないです!

ACコード

import sys


class SegmentTree:
    def __init__(self, size, f=lambda x, y: x ^ y, default=0):
        self.size = 2 ** (size - 1).bit_length()
        self.default = default
        self.dat = [default] * (self.size * 2)
        self.f = f

    def update(self, i, x):
        i += self.size
        self.dat[i] = x
        while i > 0:
            i >>= 1
            self.dat[i] = self.f(self.dat[i * 2], self.dat[i * 2 + 1])

    def query(self, left, right):
        left += self.size
        right += self.size
        left_res, right_res = self.default, self.default
        while left < right:
            if left & 1:
                left_res = self.f(left_res, self.dat[left])
                left += 1

            if right & 1:
                right -= 1
                right_res = self.f(self.dat[right], right_res)
            left >>= 1
            right >>= 1
        res = self.f(left_res, right_res)
        return res


def solve():
    input = sys.stdin.readline
    mod = 10 ** 9 + 7
    n, q = list(map(int, input().rstrip('\n').split()))
    a = list(map(int, input().rstrip('\n').split()))
    st = SegmentTree(n)
    for i in range(n):
        st.update(i, a[i])
    for i in range(q):
        t, x, y = list(map(int, input().rstrip('\n').split()))
        if t == 1:
            st.update(x - 1, st.query(x - 1, x) ^ y)
        else:
            print(st.query(x - 1, y))


if __name__ == '__main__':
    solve()

-AtCoder Beginner Contest
-, ,