問題概要
問題ページ
-
F - Range Xor Query
問題ページへ移動する
問題文
長さ \(N\) の整数列 \(A\) があります。
あなたは今からこの数列について \(Q\) 個のクエリを処理します。\(i\) 番目のクエリでは、 \(T_i, X_i, Y_i\) が与えられるので、以下の処理をしてください。
- \(T_i = 1\) のとき
\(A_{X_i}\) を \(A_{X_i} \oplus Y_i\) で置き換える - \(T_i = 2\) のとき
\(A_{X_i} \oplus A_{X_i + 1} \oplus A_{X_i + 2} \oplus \dots \oplus A_{Y_i}\) を出力する
ただし \(a \oplus b\) は \(a\) と \(b\) のビット単位 xor を表します。
ビット単位 xor とは
整数 \(A, B\) のビット単位 xor 、\(A \oplus B\) は、以下のように定義されます。
- \(A \oplus B\) を二進表記した際の \(2^k\) (\(k \geq 0\)) の位の数は、\(A, B\) を二進表記した際の \(2^k\) の位の数のうち一方のみが \(1\) であれば \(1\)、そうでなければ \(0\) である。
例えば、\(3 \oplus 5 = 6\) となります (二進表記すると: \(011 \oplus 101 = 110\))。
制約
- \(1 \le N \le 300000\)
- \(1 \le Q \le 300000\)
- \(0 \le A_i \lt 2^{30}\)
- \(T_i\) は \(1\) または \(2\)
- \(T_i = 1\) なら \(1 \le X_i \le N\) かつ \(0 \le Y_i \lt 2^{30}\)
- \(T_i = 2\) なら \(1 \le X_i \le Y_i \le N\)
- 入力は全て整数
問題の考察
範囲の排他的論理和(XOR)を求める問題。
排他的論理和は交換法則が成り立つのでセグメントツリーが使える。
セグメントツリーのテンプレのモノイドを変更するだけで対応できる。
def __init__(self, size, f=lambda x, y: x ^ y, default=0):
こんな感じに変更すれば良い。
セグメントツリーの基本的な使い方は以下参照。
たびすけ
セグメントツリーは競プロ頻出です。
この問題はF問題ですが、セグメントツリーの典型的な問題なので難易度自体は高くないです!
この問題はF問題ですが、セグメントツリーの典型的な問題なので難易度自体は高くないです!
ACコード
import sys
class SegmentTree:
def __init__(self, size, f=lambda x, y: x ^ y, default=0):
self.size = 2 ** (size - 1).bit_length()
self.default = default
self.dat = [default] * (self.size * 2)
self.f = f
def update(self, i, x):
i += self.size
self.dat[i] = x
while i > 0:
i >>= 1
self.dat[i] = self.f(self.dat[i * 2], self.dat[i * 2 + 1])
def query(self, left, right):
left += self.size
right += self.size
left_res, right_res = self.default, self.default
while left < right:
if left & 1:
left_res = self.f(left_res, self.dat[left])
left += 1
if right & 1:
right -= 1
right_res = self.f(self.dat[right], right_res)
left >>= 1
right >>= 1
res = self.f(left_res, right_res)
return res
def solve():
input = sys.stdin.readline
mod = 10 ** 9 + 7
n, q = list(map(int, input().rstrip('\n').split()))
a = list(map(int, input().rstrip('\n').split()))
st = SegmentTree(n)
for i in range(n):
st.update(i, a[i])
for i in range(q):
t, x, y = list(map(int, input().rstrip('\n').split()))
if t == 1:
st.update(x - 1, st.query(x - 1, x) ^ y)
else:
print(st.query(x - 1, y))
if __name__ == '__main__':
solve()