問題概要
問題ページ
-
F - Shift and Inversions
問題ページへ移動する
問題文
\(0, 1, 2, \dots, N - 1\) を並び替えた数列 \(A = [a_0, a_1, a_2, \dots, a_{N-1}]\) が与えられます。
\(k = 0, 1, 2, \dots, N - 1\) のそれぞれについて、\(b_i = a_{i+k \bmod N}\) で定義される数列 \(B = [b_0, b_1, b_2, \dots, b_{N-1}]\) の転倒数を求めてください。
転倒数とは
数列 \(A = [a_0, a_1, a_2, \dots, a_{N-1}]\) の転倒数とは、\(i < j\) かつ \(a_i > a_j\) を満たす添字の組 \((i, j)\) の個数のことです。
制約
- 入力は全て整数
- \(2 ≤ N ≤ 3 \times 10^5\)
- \(a_0, a_1, a_2, \dots, a_{N-1}\) は \(0, 1, 2, \dots, N - 1\) の並び替えである
問題の考察
数列の先頭の値を末尾に移動する操作を繰り返し、その時点の転倒数をカウントする問題
大きく分けて\(2\)つのポイントがある。
問題のポイント
- 初期状態の転倒数の数え方
- 移動後の転倒数の数え方
転倒数のカウント方法
愚直にカウントすると、\(O(N^2)\)かかってしまう。
制約が\(2 ≤ N ≤ 3 \times 10^5\)なのでこれでは間に合わないため、計算量を落とす必要がある。
セグメントツリーを使えばこれを\(O(log N \times N)\)まで落とすことが可能。
転倒数を数え方
- 数列の先頭から順番に処理(\(O(N)\))
- 区間\(a_i+1\)から\(N\)までの合計を取得(\(O(log N)\))
- セグメントツリーの\(a_i\)を1に更新
このように処理することで、数列で\(a_i\)より前に出現した\(a_i\)より大きい値の個数を\(O(log N)\)で取得することができる。
これをすべての\(a_i\)で行うことで転倒数の合計数が\(O(log N \times N)\)でカウントできる。
移動後の転倒数
転倒数の個数が\(O(log N \times N)\)でカウントできました。
しかし、先頭数値の末尾への異動は\(N-1\)回行われるため、全ての\(a_i\)で同じようにカウントすると\(O(log N \times N \times N)\)かかり間に合いません。
ポイントは\(a_i\)の値はランダム値ではなく、「\(a_0, a_1, a_2, \dots, a_{N-1}\) は \(0, 1, 2, \dots, N - 1\) の並び替えである」という点です。
移動後の転倒数の数え方
- 減少する転倒数
- \(a_i\)個が先頭からなくなる
- \(a_i\)より後ろにある\(a_i\)より小さい値の数分転倒数が減少する
- \(a_i\)個だけ転倒数が増加する
- 増加する転倒数
- \(a_i\)が末尾に追加される
- \(a_i\)より前にある\(a_i\)より大きい値の数分転倒数が増加する
- \(n - a_i - 1\)個だけ転倒数が増加する
例えば\((5,0,1,2,3,4,6,7,8,)\)という数列を考える。
\(5\)を移動することによって減少する転倒数は「\((5,0),(5,1),(5,2),(5,3),(5,4)\)」、増加する転倒数は「\((6,5),(7,5),(8,5)\)」というようになる。
この例では\(5\)を先頭から末尾に移動することにより、\(5\)個だけ転倒数が減少し、\(9-5-1=3\)個だけ転倒数が増加している。
ACコード
import sys
class SegmentTree:
def __init__(self, size, f=lambda x, y: x + y, default=0):
self.size = 2 ** (size - 1).bit_length()
self.default = default
self.dat = [default] * (self.size * 2)
self.f = f
def update(self, i, x):
i += self.size
self.dat[i] = x
while i > 0:
i >>= 1
self.dat[i] = self.f(self.dat[i * 2], self.dat[i * 2 + 1])
def query(self, left, right):
left += self.size
right += self.size
left_res, right_res = self.default, self.default
while left < right:
if left & 1:
left_res = self.f(left_res, self.dat[left])
left += 1
if right & 1:
right -= 1
right_res = self.f(self.dat[right], right_res)
left >>= 1
right >>= 1
res = self.f(left_res, right_res)
return res
def solve():
input = sys.stdin.readline
mod = 10 ** 9 + 7
n = int(input().rstrip('\n'))
a = list(map(int, input().rstrip('\n').split()))
st = SegmentTree(n)
cnt = 0
for i in range(n):
st.update(a[i], 1)
cnt += st.query(a[i] + 1, n)
for i in range(n):
print(cnt)
cnt -= a[i]
cnt += n - a[i] - 1
if __name__ == '__main__':
solve()