共通
import sys
import bisect
import itertools
import collections
import fractions
import heapq
import math
from operator import mul
from functools import reduce
from functools import lru_cache
def solve():
input = sys.stdin.readline
mod = 10 ** 9 + 7
if __name__ == '__main__':
solve()
これ以降のライブラリはこの共通テンプレを前提のコードとなっています。
共通テンプレを使わない場合には、適宜import
等を記述する必要があります。
標準入力
整数
1行
= int(input().rstrip('\n'))
1行リスト
= list(map(int, input().rstrip('\n').split()))
複数行
= [int(input().rstrip('\n')) for _ in range()]
複数行リスト
= [list(map(int, input().rstrip('\n').split())) for _ in range()]
文字列
1行
= str(input().rstrip('\n'))
1行リスト
= list(map(str, str(input().rstrip('\n')).split()))
複数行
= [str(input().rstrip('\n')) for _ in range()]
複数行リスト
= [list(map(int, input().rstrip('\n').split())) for _ in range()]
幅優先探索(BFS)
グリッド
h, w = list(map(int, input().rstrip('\n').split()))
#スタートを設定
ql = [[0, ]]
ql = collections.deque(ql)
fq = collections.defaultdict(list)
#スタートを設定
fq[]
while True:
if len(ql) != 0:
cost, yv, xv = ql.popleft()
for yv, xv in [[yv + 1, xv], [yv - 1, xv], [yv, xv + 1], [yv, xv - 1]]:
if 0 <= yv < h and 0 <= xv < w:
if (yv, xv) not in fq:
ql.append([cost + 1, yv, xv])
fq[yv, xv]
else:
break
ツリー
mpl = collections.defaultdict(list)
#経路情報の数を設定
for i in range():
mpa, mpb = list(map(int, input().rstrip('\n').split()))
mpl[mpa-1] += [mpb-1]
mpl[mpb-1] += [mpa-1]
#スタートを設定
ql = [[0, ]]
ql = collections.deque(ql)
fq = collections.defaultdict(list)
#スタートを設定
fq[]
while True:
if len(ql) != 0:
cost, tmp = ql.popleft()
for tmn in mpl[tmp]:
if tmn not in fq:
ql.append([cost + 1, tmn])
fq[tmn]
else:
break
Union-Find
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
return {r: self.members(r) for r in self.roots()}
def __str__(self):
return '\n'.join('{}: {}'.format(r, self.members(r)) for r in self.roots())
セグメントツリー
基本
class SegmentTree:
def __init__(self, size, f=lambda x, y: x + y, default=0):
self.size = 2 ** (size - 1).bit_length()
self.default = default
self.dat = [default] * (self.size * 2)
self.f = f
def update(self, i, x):
i += self.size
self.dat[i] = x
while i > 0:
i >>= 1
self.dat[i] = self.f(self.dat[i * 2], self.dat[i * 2 + 1])
def query(self, left, right):
left += self.size
right += self.size
left_res, right_res = self.default, self.default
while left < right:
if left & 1:
left_res = self.f(left_res, self.dat[left])
left += 1
if right & 1:
right -= 1
right_res = self.f(self.dat[right], right_res)
left >>= 1
right >>= 1
res = self.f(left_res, right_res)
return res
遅延評価 範囲変更
class AlgSegmentTreeRangeUpdate:
def __init__(self, size, f=lambda x, y: min(x, y), default=2 ** 31 - 1):
self.size = (size - 1).bit_length()
self.no = 2 ** self.size
self.default = default
self.data = [default] * (self.no * 2)
self.lazy = [None] * (self.no * 2)
self.f = f
def get_index(self, left, right):
l_left = (left + self.no) >> 1
r_right = (right + self.no) >> 1
lc = 0 if left & 1 else (l_left & -l_left).bit_length()
rc = 0 if right & 1 else (r_right & -r_right).bit_length()
for i in range(self.size):
if rc <= i:
yield r_right
if l_left < r_right and lc <= i:
yield l_left
l_left >>= 1
r_right >>= 1
def propagates(self, *ids):
for i in reversed(ids):
v = self.lazy[i - 1]
if v is None:
continue
self.lazy[2 * i - 1] = v
self.data[2 * i - 1] = v
self.lazy[2 * i] = v
self.data[2 * i] = v
self.lazy[i - 1] = None
def update(self, left, right, x):
*ids, = self.get_index(left, right)
self.propagates(*ids)
l_left = self.no + left
r_right = self.no + right
while l_left < r_right:
if r_right & 1:
r_right -= 1
self.lazy[r_right - 1] = x
self.data[r_right - 1] = x
if l_left & 1:
self.lazy[l_left - 1] = x
self.data[l_left - 1] = x
l_left += 1
l_left >>= 1
r_right >>= 1
for i in ids:
self.data[i - 1] = self.f(self.data[2 * i - 1], self.data[2 * i])
def query(self, left, right):
self.propagates(*self.get_index(left, right))
l_left = self.no + left
r_right = self.no + r
res = self.default
while l_left < r_right:
if r_right & 1:
r_right -= 1
res = self.f(res, self.data[r_right - 1])
if l_left & 1:
res = self.f(res, self.data[l_left - 1])
l_left += 1
l_left >>= 1
r_right >>= 1
return res
遅延評価 範囲加算
class SegmentTreeRangeAdd:
def __init__(self, size, f=lambda x, y: min(x, y), default=2 ** 31 - 1):
self.size = (size - 1).bit_length()
self.no = 2 ** self.size
self.default = default
self.data = [default] * (self.no * 2)
self.lazy = [None] * (self.no * 2)
self.f = f
def get_index(self, left, right):
l_left = (left + self.no) >> 1
r_right = (right + self.no) >> 1
lc = 0 if left & 1 else (l_left & -l_left).bit_length()
rc = 0 if right & 1 else (r_right & -r_right).bit_length()
for i in range(self.size):
if rc <= i:
yield r_right
if l_left < r_right and lc <= i:
yield l_left
l_left >>= 1
r_right >>= 1
def propagates(self, *ids):
for i in reversed(ids):
v = self.lazy[i - 1]
if v is None:
continue
self.lazy[2 * i - 1] += v
self.data[2 * i - 1] += v
self.lazy[2 * i] += v
self.data[2 * i] += v
self.lazy[i - 1] = 0
def update(self, left, right, x):
*ids, = self.get_index(left, right)
self.propagates(*ids)
l_left = self.no + left
r_right = self.no + right
while l_left < r_right:
if r_right & 1:
r_right -= 1
self.lazy[r_right - 1] += x
self.data[r_right - 1] += x
if l_left & 1:
self.lazy[l_left - 1] += x
self.data[l_left - 1] += x
l_left += 1
l_left >>= 1
r_right >>= 1
for i in ids:
self.data[i - 1] = self.f(self.data[2 * i - 1], self.data[2 * i])
def query(self, left, right):
self.propagates(*self.get_index(left, right))
l_left = self.no + left
r_right = self.no + right
res = self.default
while l_left < r_right:
if r_right & 1:
r_right -= 1
res = self.f(res, self.data[r_right - 1])
if l_left & 1:
res = self.f(res, self.data[l_left - 1])
l_left += 1
l_left >>= 1
r_right >>= 1
return res
二分探索
cor_v = 10 ** 20
inc_v = -1
while cor_v - inc_v > 1:
bin_v = (cor_v + inc_v) // 2
cost = 0
#条件を満たすcostを全検索
#costが制約を満たすか
if cost <= bin_v:
cor_v = bin_v
else:
inc_v = bin_v
print(cor_v)
最小共通祖(LCA)
class Lca(object):
def __init__(self, graph, root=0):
self.graph = graph
self.root = root
self.n = len(graph)
self.bit_len = (self.n - 1).bit_length()
self.depth = [-1 if i != root else 0 for i in range(self.n)]
self.parent = [[-1] * self.n for _ in range(self.bit_len)]
self.bfs()
self.doubling()
def bfs(self):
ql = [[0, self.root]]
ql = collections.deque(ql)
fq = collections.defaultdict(list)
fq[self.root] = 0
while True:
if len(ql) != 0:
cost, tmp = ql.popleft()
for tmv in self.graph[tmp]:
if tmv not in fq:
ql.append([cost + 1, tmv])
fq[tmv] = cost + 1
else:
break
def doubling(self):
for i in range(1, self.bit_len):
for v in range(self.n):
if self.parent[i - 1][v] != -1:
self.parent[i][v] = self.parent[i - 1][self.parent[i - 1][v]]
def get(self, u, v):
if self.depth[v] < self.depth[u]:
u, v = v, u
du = self.depth[u]
dv = self.depth[v]
for i in range(self.bit_len):
if (dv - du) >> i & 1:
v = self.parent[i][v]
if u == v:
return u
for i in range(self.bit_len - 1, -1, -1):
pu, pv = self.parent[i][u], self.parent[i][v]
if pu != pv:
u, v = pu, pv
return self.parent[0][u]
メモ化再帰
@lru_cache(maxsize=None)
def memo_func():
ワーシャルフロイド
wn = list(map(int, input().rstrip('\n').split()))
inf = 10 ** 13
wl = [[inf] * wn for _ in range(wn)]
for i in range(wn):
wl[i][i] = 0
for i in range(wn):
wa, wb, wc = list(map(int, input().rstrip('\n').split()))
wl[wa - 1][wb - 1] = wc
wl[wb - 1][wa - 1] = wc
for i in range(wn):
for j in range(wn):
for k in range(wn):
wl[j][k] = min(wl[j][k], wl[j][i] + wl[i][k])
Zアルゴリズム
#対象文字列を指定
st =
st_len = len(st)
z_algorithm = [0] * st_len
z_i = 1
z_j = 0
while z_i < st_len:
while z_i + z_j < st_len and st[z_j] == st[z_i + z_j]:
z_j += 1
z_algorithm[z_i] = z_j
if z_j == 0:
z_i += 1
continue
z_k = 1
while z_i + z_k < st_len and z_k + z_algorithm[z_k] < z_j:
z_algorithm[z_i + z_k] = z_algorithm[z_k]
z_k += 1
z_i += z_k
z_j -= z_k
組合せ
基本
def combination(n, r):
r = min(n - r, r)
if r == 0:
return 1
else:
denominator = 1
for i in range(n, n - r, -1):
denominator = (denominator * i)
molecule = 1
for i in range(1, r + 1):
molecule = (molecule * i)
return denominator // molecule
MOD
def combination_mod(n, r, mod):
r = min(n - r, r)
if r == 0:
return 1
else:
denominator = 1
for i in range(n, n - r, -1):
denominator = (denominator * i) % mod
molecule = 1
for i in range(1, r + 1):
molecule = (molecule * i) % mod
return denominator * pow(molecule, mod - 2, mod) % mod
重複組合せ
基本
def duplicate_combination(n, r):
n = n + r - 1
r = min(n - r, r)
if r == 0:
return 1
else:
denominator = 1
for i in range(n, n - r, -1):
denominator = (denominator * i)
molecule = 1
for i in range(1, r + 1):
molecule = (molecule * i)
return denominator // molecule
MOD
def duplicate_combination_mod(n, r, mod):
n = n + r - 1
r = min(n - r, r)
if r == 0:
return 1
else:
denominator = 1
for i in range(n, n - r, -1):
denominator = (denominator * i) % mod
molecule = 1
for i in range(1, r + 1):
molecule = (molecule * i) % mod
return denominator * pow(molecule, mod - 2, mod) % mod
順列
基本
def permutation(n, r):
if r == 0:
return 1
else:
denominator = 1
for i in range(n, n - r, -1):
denominator = (denominator * i)
return denominator // 1
MOD
def permutation_mod(n, r, mod):
if r == 0:
return 1
else:
denominator = 1
for i in range(n, n - r, -1):
denominator = (denominator * i) % mod
return denominator // 1
階乗
基本
def factorial(denominator_no):
return math.factorial(denominator_no)
MOD
def factorial_mod(denominator_no, molecule_list, mod):
denominator = 1
for i in range(1, denominator_no + 1):
denominator = (denominator * i) % mod
molecule = 1
for molecule_no in molecule_list:
for i in range(1, molecule_no + 1):
molecule = (molecule * i) % mod
return denominator * pow(molecule, mod - 2, mod) % mod
たびすけ
頻出ライブラリはスニペットに登録したりして、いつでも使えるようにしておきましょう!