AtCoder Beginner Contest

【pythonでABC190を解説】F - Shift and Inversions

問題概要

問題ページ

問題文

\(0, 1, 2, \dots, N - 1\) を並び替えた数列 \(A = [a_0, a_1, a_2, \dots, a_{N-1}]\) が与えられます。
\(k = 0, 1, 2, \dots, N - 1\) のそれぞれについて、\(b_i = a_{i+k \bmod N}\) で定義される数列 \(B = [b_0, b_1, b_2, \dots, b_{N-1}]\) の転倒数を求めてください。

転倒数とは

数列 \(A = [a_0, a_1, a_2, \dots, a_{N-1}]\) の転倒数とは、\(i < j\) かつ \(a_i > a_j\) を満たす添字の組 \((i, j)\) の個数のことです。

制約

  • 入力は全て整数
  • \(2 ≤ N ≤ 3 \times 10^5\)
  • \(a_0, a_1, a_2, \dots, a_{N-1}\) は \(0, 1, 2, \dots, N - 1\) の並び替えである

問題の考察

数列の先頭の値を末尾に移動する操作を繰り返し、その時点の転倒数をカウントする問題

大きく分けて\(2\)つのポイントがある。

問題のポイント

  • 初期状態の転倒数の数え方
  • 移動後の転倒数の数え方

転倒数のカウント方法

愚直にカウントすると、\(O(N^2)\)かかってしまう。

制約が\(2 ≤ N ≤ 3 \times 10^5\)なのでこれでは間に合わないため、計算量を落とす必要がある。

セグメントツリーを使えばこれを\(O(log N \times N)\)まで落とすことが可能。

転倒数を数え方

  • 数列の先頭から順番に処理(\(O(N)\))
  • 区間\(a_i+1\)から\(N\)までの合計を取得(\(O(log N)\))
  • セグメントツリーの\(a_i\)を1に更新

このように処理することで、数列で\(a_i\)より前に出現した\(a_i\)より大きい値の個数を\(O(log N)\)で取得することができる。

これをすべての\(a_i\)で行うことで転倒数の合計数が\(O(log N \times N)\)でカウントできる。

移動後の転倒数

転倒数の個数が\(O(log N \times N)\)でカウントできました。

しかし、先頭数値の末尾への異動は\(N-1\)回行われるため、全ての\(a_i\)で同じようにカウントすると\(O(log N \times N \times N)\)かかり間に合いません。

ポイントは\(a_i\)の値はランダム値ではなく、「\(a_0, a_1, a_2, \dots, a_{N-1}\) は \(0, 1, 2, \dots, N - 1\) の並び替えである」という点です。

移動後の転倒数の数え方

  • 減少する転倒数
    • \(a_i\)個が先頭からなくなる
    • \(a_i\)より後ろにある\(a_i\)より小さい値の数分転倒数が減少する
    • \(a_i\)個だけ転倒数が増加する
  • 増加する転倒数
    • \(a_i\)が末尾に追加される
    • \(a_i\)より前にある\(a_i\)より大きい値の数分転倒数が増加する
    • \(n - a_i - 1\)個だけ転倒数が増加する

例えば\((5,0,1,2,3,4,6,7,8,)\)という数列を考える。

\(5\)を移動することによって減少する転倒数は「\((5,0),(5,1),(5,2),(5,3),(5,4)\)」、増加する転倒数は「\((6,5),(7,5),(8,5)\)」というようになる。

この例では\(5\)を先頭から末尾に移動することにより、\(5\)個だけ転倒数が減少し、\(9-5-1=3\)個だけ転倒数が増加している。

たびすけ
移動後の転倒数は初期状態の転倒数の個数を基準にして答えられるよ!!

ACコード

import sys


class SegmentTree:
    def __init__(self, size, f=lambda x, y: x + y, default=0):
        self.size = 2 ** (size - 1).bit_length()
        self.default = default
        self.dat = [default] * (self.size * 2)
        self.f = f

    def update(self, i, x):
        i += self.size
        self.dat[i] = x
        while i > 0:
            i >>= 1
            self.dat[i] = self.f(self.dat[i * 2], self.dat[i * 2 + 1])

    def query(self, left, right):
        left += self.size
        right += self.size
        left_res, right_res = self.default, self.default
        while left < right:
            if left & 1:
                left_res = self.f(left_res, self.dat[left])
                left += 1

            if right & 1:
                right -= 1
                right_res = self.f(self.dat[right], right_res)
            left >>= 1
            right >>= 1
        res = self.f(left_res, right_res)
        return res


def solve():
    input = sys.stdin.readline
    mod = 10 ** 9 + 7
    n = int(input().rstrip('\n'))
    a = list(map(int, input().rstrip('\n').split()))
    st = SegmentTree(n)
    cnt = 0
    for i in range(n):
        st.update(a[i], 1)
        cnt += st.query(a[i] + 1, n)
    for i in range(n):
        print(cnt)
        cnt -= a[i]
        cnt += n - a[i] - 1


if __name__ == '__main__':
    solve()

プログラミング

-AtCoder Beginner Contest
-,