代码模板 (Python)
Tip
本文记录 Python 语言的算法竞赛代码模板,全部使用 built-in 模块,有不严谨或错误的地方还望评论区斧正🤗。更好的阅读体验参见 专题解析 部分。
常用数据结构 API¶
列表¶
int 转 list
list(int) 转 int
nums = [1, 2, 3]
num = int(''.join(map(str, nums)))
def lst_int(nums):
return int(''.join(map(str, nums)))
列表特性
比较大小的时候,不管长度如何,依次比较到第一个元素不相等的位置
比如 [1, 2, 3] < [2, 3] 因为在比较 1 < 2 的时候就终止。
嵌套列表推导:展平二维数组
字符串¶
s1.startswith(s2, beg = 0, end = len(s2)): 用于检查字符串 s1 是否以字符串 s2 开头。是则返回 True。如果指定 beg 和 end,则在 s1 [beg: end] 范围内查找。- 使用
ascii_lowercase遍历 26 个字母:
队列¶
from collections import deque
list1 = [0, 1, 2, 3]
q = deque(list1)
q.append(4) # 向右侧加
q.appendleft(-1) # 向左侧加
q.extend(可迭代元素) # 向右侧添加可迭代元素
q.extendleft(可迭代元素)
q = q.pop() # 移除最右端并返回元素值
l = q.popleft() # 移除最左端
q.count(1) # 统计元素个数 1
# 返回 string 指定范围中 str 首次出现的位置
string.index(str, beg=0, end=len(string))
string.index(" ")
list(map(s.index,s)) # 返回字符索引数组,如 "abcba"-> [0,1,2,1,0]
字典¶
d.pop(key) # 返回 key 对应的 value,并在字典中删除这个键值对
d.get(key, default_value=None) # 返回 key 对应的 value,不存在则返回 default_value
d.keys() # 键构成的可迭代对象
d.values() # 值构成的可迭代对象
d.items() # 键值对构成的可迭代对象
d = defaultdict(list) # 指定了具有默认值空列表的字典
d[key] = value # 创建一个键值对
字典推导器:字母表对应下标
也可以使用 zip 初始化 dict
2606. 找到最大开销的子字符串 - 力扣(LeetCode)
Counter¶
from collections import Counter
list1 = ["a", "a", "a", "b", "c", "c", "f", "g", "g", "g", "f"]
dic = Counter(list1)
print(dic)
# Counter({'a': 3, 'g': 3, 'c': 2, 'f': 2, 'b': 1})
list1 = ["a", "a", "a", "b", "c", "f", "g", "g", "c", "11", "g", "f", "10", "2"]
print(Counter(list1).most_common(3))
# 结果:[('a', 3), ('g', 3), ('c', 2)]
list1 = ["a", "a", "a", "b", "c", "f", "g", "g", "c", "11", "g", "f", "10", "2"]
print(Counter(list1).most_common(1))
# 结果:[('a', 3)]
most_common(k) 时间复杂度 \(O(n \log k)\)。
map 映射函数¶
用法:
# 计算平方数
def square(x):
return x ** 2
# 计算列表各个元素的平方
map(square, [1,2,3,4,5])
# [1, 4, 9, 16, 25]
# 使用 lambda 匿名函数
map(lambda x: x ** 2, [1, 2, 3, 4, 5])
# [1, 4, 9, 16, 25]
# 提供了两个列表,对相同位置的列表数据进行相加
map(lambda x, y: x + y, [1, 3, 5, 7, 9], [2, 4, 6, 8, 10])
# [3, 7, 11, 15, 19]
自定义 set 规则¶
class MySet(set):
def add(self, element):
sorted_element = tuple(sorted(element))
if not any(sorted_element == e for e in self):
super().add(sorted_element)
s = MySet()
s.add((2, 1, 1))
s.add((1, 2, 1))
print(s) # 输出:{(1, 1, 2)}
I/O¶
快读快写:
import sys
sys.setrecursionlimit(1000000)
input=lambda:sys.stdin.readline().strip()
write=lambda x:sys.stdout.write(str(x)+'\n')
读到文件结尾:
import sys
n = 10 ** 7
primes = []
is_prime = [1] * (n + 1)
is_prime[0] = is_prime[1] = 0
for i in range(2, n + 1):
if is_prime[i]: primes.append(i)
for p in primes:
if i * p > n: break
is_prime[i * p] = 0
if i % p == 0: break
a = [0] * (n + 1)
for i in range(2, n + 1):
a[i] = a[i - 1] + (1 if not is_prime[i] else 0)
for line in sys.stdin:
input = line.strip()
l, r = map(int, input.split())
print(a[r] - a[l - 1])
有序列表 / 有序集合¶
SortedList 相当于 C++ 中的 multiset。常见用法如下:
- 添加元素:
s.add(val),\(O(\log n)\); - 添加一组可迭代元素:
s.upadte(iterable),\(O(k \log n)\); - 查找元素:
s.count(val)返回元素的个数,\(O(\log n)\); - 删除元素:
s.remove(val),\(O(\log n)\); - 删除指定下标元素:
s.pop(index = -1),\(O(\log n)\)。
源码参见 python-sortedcontainers | grantjenks - (github.com)。
贪心¶
排序贪心¶
语言整理 有一群人排好队, 每个人身高为 \(h_i\), 前面恰好 \(k_i\) 个身高不小于他的人数; 将 $ (h_1, k_1), (h_2, k_2), ..., (h_i, k_i), ...$ 打乱,请你恢复原来的顺序
思考
- 恢复顺序的依据: 通过 \((h_i, k_i)\) 能唯一确定此人的位置;
- 由于被打乱, 恢复顺序一定需要排序;
- 由于 \(k_i\) 为前面身高不小于他的人数, 可想到按照身高降序排序, 同身高内部按照 \(k_i\) 升序排序;
- 对排序后的数组遍历, 按照 "插入排序" 的思想找到合适的位置
def reconstructQueue(self, people: List[List[int]]) -> List[List[int]]:
# [7, 0] [7, 1] [6, 1] [5, 0] [5, 2] [4, 4]
people.sort(key = lambda x: -x[0] * 10 ** 5 + x[1])
res = []
for i, p in enumerate(people):
h, k = p[0], p[1]
if k == i:
res.append(p)
elif k < i:
res.insert(k, p)
return res
857. 雇佣 K 名工人的最低成本 - 力扣(LeetCode)
排序贪心 + 堆维护 \(k\) 个最小和。对任意一个员工子集,按照比例支付费用、且任意一个员工不少于阈值,等价于支付费用 = \(\max(\frac{w_i}{q_i}) \times \sum_{i=1}^{k}q_i\)。按照 \(\frac{w_i}{q_i}\) 从大到小排序,维护对应的 \(k\) 个最小 \(q\) 的和。
def mincostToHireWorkers(self, quality: List[int], wage: List[int], k: int) -> float:
n = len(quality)
nums = sorted([(w / q, q) for w, q in zip(wage, quality)])
hq, s = [], 0
for i in range(k):
heappush(hq, -nums[i][1])
s += nums[i][1]
res = nums[k - 1][0] * s
for i in range(k, n):
mxw = -heappop(hq)
s -= mxw
mxw = min(mxw, nums[i][1])
heappush(hq, -mxw)
s += mxw
res = min(res, nums[i][0] * s)
return res
2589. 完成所有任务的最少时间 - 力扣(LeetCode)
区间选点问题:选点数量可能超过 1 + 右端点排序贪心
为什么不能按照左端点排序?
如果按照左端点排序:
- 当下一个区间比当前区间先结束时,选点会在当前区间中间;
- 当下一个区间前缀和当前区间后缀相交时,选点在当前区间的尾部。
两者无法统一。
考虑按照 右端点排序 / 结束时间排序:
- 当下一个区间比当前区间先开始,选点可以在当前区间的尾部。
- 当下一个区间前缀和当前区间后缀相交时,选点也可以在区间的尾部。
因此二者是统一的。
从前向后考虑区间,当前区间,我们希望当前区间的后缀去匹配更多的后续区间的前缀,因此选点应该越靠后越好,即在当前区间的尾部。当相邻区间不相交时,选点是当前区间独占的。
因此,使用 \(sel\) 维护选择的点,每次进入新区间,首先考察范围内已经选择的点的个数;剩余的点需要贪心的安排在尾部。
时间复杂度:\(O(n \log n + n U)\), \(U\) 为最大结束时间。
def findMinimumTime(self, nums: List[List[int]]) -> int:
nums.sort(key = lambda x: x[1])
n, m = len(nums), nums[-1][1]
sel = [0] * (m + 1)
for l, r, c in nums:
c -= sum(sel[l: r + 1])
if c > 0:
for i in range(r, l - 1, -1):
if sel[i]: continue
sel[i] = 1
c -= 1
if c == 0: break
return sum(sel)
2024_CA_省 C.训练士兵
P10387 蓝桥杯 2024 省 A 训练士兵 - 洛谷 (luogu.com.cn)
题意:
共 n 人,每人需要 c_i 次训练,每人单独训练每次花费 p_i 元,团购训练花费 S 元,求所有人完成训练的最小花费。
思路:
- 团购价不变,有些人训练次数 \(c_i\) 少,有些人多。一开始团购价 \(S\) 往往比所有人单独训练价格 \(tot\) 更小;随着人训练完成后退出,\(tot\) 应该动态更新维护;当团购不合适时 \(S>tot\),单独操作;
- 联想到贪心 + 排序,尽可能贪心的多团购,讨论时按照所有人需要的训练次数从小到大排序;
- \(tot\) 初始为 \(\sum p_i\),用 \(res\) 记录答案,初始为 \(0\);用 \(cnt\) 记录 已经团购的次数,按照训练次数升序遍历, 如果 \(tot \ge S\),团购合适;\(res ← res + (c_i - cnt) \times S\);\(cnt \leftarrow c_i\),否则团购不合适,\(res ← res + (c_i - cnt) \times p_i\);
- 每一次遍历完成,代表此人训练完成且退出,需要动态维护 \(tot\),即 \(tot ← tot- p_i\)。
import sys
input = lambda:sys.stdin.readline().strip()
n, S = map(int, input().split())
nums = [[0, 0]] * n # 用于排序
p, c = [0] * n, [0] * n
# 数据预处理
for i in range(n):
nums[i] = list(map(int, input().split()))
# 排序:根据 nums [i][1] 即次数排序,默认是由低到高
nums.sort(key = lambda x: x[1])
for i in range(n):
p[i], c[i] = nums[i][0], nums[i][1]
res = cnt = 0
tot = sum(p)
for i in range(n):
if tot >= S: # 团购合适
res += (c[i] - cnt) * S
cnt = c[i]
else: # 团购不合适,此人单独训练
res += (c[i] - cnt) * p[i]
tot -= p[i] # 第 i 人完成训练,减去他的单独训练成本
print(res)
反悔贪心¶
- 贪心:尽可能;
- 反悔:有更优的出现了。
按照截止日期排序,尽可能不跳过每一个课程。反悔条件(cur > y)满足时从反悔堆反悔用时最大的课程。
def scheduleCourse(self, courses: List[List[int]]) -> int:
# 按照截至日期排序
courses.sort(key = lambda x: x[1])
hq = []
res, cur = 0, 0
for x, y in courses:
cur += x # 贪心:尽可能不跳过每一个课程
heapq.heappush(hq, -x) # 反悔堆:存放所有课程耗时
if cur > y: # 反悔条件:超过截止日期
cur += heapq.heappop(hq)
else:
res += 1
return res
def magicTower(self, nums: List[int]) -> int:
if sum(nums) + 1<= 0:
return -1
hq = []
res, cur = 0, 1
for x in nums:
cur += x # 贪心:尽可能不使用移动
if x < 0: # 反悔堆
heapq.heappush(hq, x)
if cur <= 0: # 反悔条件:血量不是正值
res += 1
cur -= heapq.heappop(hq) # 从反悔堆中,贪心回复血量
return res
1642. 可以到达的最远建筑 - 力扣(LeetCode)
def furthestBuilding(self, heights: List[int], bricks: int, ladders: int) -> int:
n = len(heights)
d = [max(0, heights[i] - heights[i - 1]) for i in range(1, n)]
hq = []
for res, x in enumerate(d):
# ladders - len(hq) 代表剩余梯子数量
heapq.heappush(hq, x) # 贪心 + 反悔堆
if ladders - len(hq) < 0: # 反悔条件:梯子不够了
bricks -= heapq.heappop(hq)
if bricks < 0:
return res
return n - 1
循环反悔贪心 + 反悔堆后置(需要贪心完成后才能加入当前值)
def minRefuelStops(self, target: int, startFuel: int, stations: List[List[int]]) -> int:
stations.append([target, 0])
n = len(stations)
pre = 0
res, cur = 0, startFuel
hq = []
for x, y in stations:
cur -= x - pre # 贪心:尽可能耗油不加油
pre = x
while hq and cur < 0: # 反悔条件:剩余油不够了
res += 1
cur -= heapq.heappop(hq)
if cur < 0 and not hq:
return -1
heapq.heappush(hq, -y) # 反悔堆:保存没加的油
return res
也是一个二维贪心问题。尽可能优先考虑利润维度。通过栈实现反悔操作。
def findMaximumElegance(self, items: List[List[int]], k: int) -> int:
items.sort(reverse = True)
s = set() # 只出现一次的种类 c
stk = [] # 反悔栈:出现两次以上的利润 p
res = total_profit = 0
for i, (p, c) in enumerate(items):
if i < k:
total_profit += p
if c not in s: # 种类 c 首次出现, 对应 p 一定最大, 一定保留
s.add(c)
else:
stk.append(p) # 反悔栈:存放第二次及以后出现的更小的 p
elif stk and c not in s:
# 只有 c 没有出现在 s 中时,才尝试反悔一个出现两次及以上的 p
total_profit += p - stk.pop()
s.add(c)
# 贪心:s 的长度只增不减
res = max(res, total_profit + len(s) ** 2)
return res
消消乐贪心¶
配合哈希表 / 哈希集合,在 \(O(n)\) 复杂度内,通过对乱序枚举到的每一个 \(x\) ,贪心找出符合性质 / 限制的整组数据并且消除。
最长连续子序列
给定一个未排序的整数数组 nums ,找出数字连续的最长子序列。
对于任何一个数 \(x\) ,向两边贪心找到相邻的这一组数,将其消除。
def longestConsecutive(self, nums: List[int]) -> int:
s = set(nums)
res = 0
for x in nums:
if x not in s: continue
cur = 1
s.remove(x)
y = x + 1
while y in s:
s.remove(y)
cur, y = cur + 1, y + 1
y = x - 1
while y in s:
s.remove(y)
cur, y = cur + 1, y - 1
res = max(res, cur)
return res
2007. 从双倍数组中还原原数组 - 力扣(LeetCode)
对于任何一个数 \(x\) ,如果是奇数则是最小出发数;否则向下贪心折半,直到得到最小出发数(奇数或者最小可达的偶数)。从最小出发数,出发,贪心删除整组数据。时间复杂度:\(O(n)\)
def findOriginalArray(self, changed: List[int]) -> List[int]:
n = len(changed)
if n & 1: return []
res = []
cnt = Counter(changed)
for i, x in enumerate(changed):
if cnt[x] == 0: continue
if x == 0:
if cnt[0] & 1: return []
res.extend(cnt[0] //2 * [0])
cnt[0] = 0
continue
while x & 1 == 0 and cnt[x // 2] > 0: x //= 2
y = x
while cnt[y] > 0:
if cnt[y * 2] < cnt[y]: return []
res.extend(cnt[y] * [y])
cnt[y * 2] -= cnt[y]
cnt[y] = 0
if cnt[y * 2]: y = 2 * y
else: y = 4 * y
return res
贪心集合划分¶
划分集合和不超过 \(k\) 的最少划分数:排序+回溯贪心
1986. 完成任务的最少工作时间段 - 力扣(LeetCode)
一种做法是 \(O(3^n+n\cdot 2^n)\) 的状压 dp,实际上可以给出 最坏情况下 \(O(2^n)\) 做法。
首先按照从到大小排序。 对于 \(a[i]\),选择 1:分配到新的组中;选择 2:分配到已经存在的、不超载的组。最坏情况下不超过 \(O(2^n)\),实际上由于剪枝操作,接近于线性复杂度。
def minSessions(self, nums: List[int], k: int) -> int:
nums.sort(reverse = True)
n = len(nums)
cnt = [0] * n
res = inf
def dfs(i, cur):
nonlocal res
if cur >= res: return
if i == n:
res = cur
return
x = nums[i]
for j in range(cur):
if cnt[j] + x <= k:
cnt[j] += x
dfs(i + 1, cur)
cnt[j] -= x
cnt[cur] += x
dfs(i + 1, cur + 1)
cnt[cur] -= x
dfs(0, 0)
return res
划分集合和不超过 \(k\) 的最少划分数:集合长度不超过 2
排序 + 双指针贪心。
def numRescueBoats(self, nums: List[int], k: int) -> int:
nums.sort(reverse = True)
n = len(nums)
i, j = 0, n - 1
res = 0
while i <= j:
lft = k - nums[i]
i += 1
if lft >= nums[j]:
j -= 1
res += 1
return res
前缀和与差分¶
一维前缀和¶
问题定义
数据范围
$n \in [1, 10^5], q \in [1, 10^5] $
思路:暴力
- 每次查询显然可以对区间求和,单次最坏复杂度 \(O(n)\);
- 总复杂度为 \(O(q \times n)\),超时。
算法基础:前缀和
- 预处理前缀和,可前缀和之差,以 \(O(1)\) 完成单次区间求和;
- 总复杂度为 \(O(n)\) 预处理 + \(O(q)\) 询问,即 \(O(q + n)\);
\(\text{定义: }p[i] = \sum(a[: i])\)
即 \(p[n] = p[n - 1] + a[n-1],\)
即 \(p[n + 1] = p[n] + a[n],\)
模板
P8218 【深进 1.例 1】求区间和 - 洛谷 (luogu.com.cn)
import sys
input = lambda: sys.stdin.readline().strip()
n = int(input())
a = list(map(int, input().split()))
q = int(input())
# 前缀和模板, p [i] = sum(a [: i])
p = [0] * (n + 1)
for i in range(n):
p[i + 1] = p[i] + a[i]
for _ in range(q):
l, r = map(int, input().split())
# l, r 下标从 1 开始,即求 a [l - 1] + a [l] + ... + a [r - 1]
# 即 sum(a [l - 1: r])
# 即 p [r] - p [l - 1]
print(p[r] - p[l - 1])
一维差分¶
def maximumBeauty(self, nums: List[int], k: int) -> int:
n = len(nums)
d = k - min(nums)
for i in range(n): nums[i] += d
mx = max(nums) + k
a = [0] * (mx + 1)
d = [0] * (mx + 2)
for x in nums:
d[x - k] += 1
d[x + k + 1] -= 1
a[0] = d[0]
for i in range(1, mx + 1):
a[i] = a[i - 1] + d[i]
return max(a)
二维差分¶
d = [[0] * (n + 2) for _ in range(m + 2)]
# 对矩阵中执行操作,使得左上角为(i, j),右下角为(x, y)的矩阵都加 k,等价于如下操作
d[i + 1][j + 1] += k
d[x + 2][y + 2] += k
d[i + 1][y + 2] -= k
d[x + 2][j + 1] -= k
# 还原差分时,直接原地还原
for i in range(m):
for j in range(n):
d[i + 1][j + 1] += d[i][j + 1] + d[i + 1][j] - d[i][j]
二维前缀¶
3070. 元素和小于等于 k 的子矩阵的数目 - 力扣(LeetCode)
class PreSum2d:
# 二维前缀和(支持加法和异或),只能离线使用,用 n*m 时间预处理,用 O1 查询子矩阵的和;op = 0 是加法,op = 1 是异或
def __init__(self,g,op=0):
m,n = len(g),len(g[0])
self.op = op
self.p=p=[[0]*(n+1) for _ in range(m+1)]
if op == 0:
for i in range(m):
for j in range(n):
p[i+1][j+1] = p[i][j+1]+p[i+1][j]-p[i][j]+g[i][j]
elif op==1:
for i in range(m):
for j in range(n):
p[i+1][j+1] = p[i][j+1]^p[i+1][j]^p[i][j]^g[i][j]
# O(1)时间查询闭区间左上(a, b), 右下(c, d)矩形部分的数字和。
def sum_square(self,a,b,c,d):
if self.op == 0:
return self.p[c+1][d+1]+self.p[a][b]-self.p[a][d+1]-self.p[c+1][b]
elif self.op==1:
return self.p[c+1][d+1]^self.p[a][b]^self.p[a][d+1]^self.p[c+1][b]
class NumMatrix:
def __init__(self, mat: List[List[int]]):
self.pre = PreSum2d(mat)
def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
# pre = self.pre
return self.pre.sum_square(row1,col1,row2,col2)
class Solution:
def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
n = len(grid)
m = len(grid[0])
res = 0
p = NumMatrix(grid)
for i in range(n):
for j in range(m):
if p.sumRegion(0, 0, i, j) <= k:
res += 1
return res
pre[i + 1][j + 1] 是左上角为(0, 0) 右下角为 (i, j)的矩阵的元素和。
如果是前缀异或是:
p[i+1][j+1] = p[i][j+1]^p[i+1][j]^p[i][j]^g[i][j]
def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
m, n = len(grid), len(grid[0])
pre = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m):
for j in range(n):
pre[i + 1][j + 1] = pre[i][j + 1] + pre[i + 1][j] - pre[i][j] + grid[i][j]
res = 0
for i in range(m):
for j in range(n):
if pre[i + 1][j + 1] <= k:
res += 1
return res
前缀异或 / 自定义前缀操作
二分¶
注:无特殊说明,均为整数。
二分查找¶
基本模型:
- 给定一个单调不减的数组 \(a\), 返回恰好 严格大于 \(x\) 的下标位置;
- 给定一个单调不减的数组 \(a\), 返回第一个 严格大于 \(x\) 的下标位置;
- 给定一个单调不减的数组 \(a\), 返回恰好使得 \(a[i] > x\) 的下标位置。
记为 \(bisect(a, x)\)
a = [1, 9, 9, 9, 200, 500]
\(bisect(a, 3)\) = 1
\(bisect(a, 1)\) = 1
\(bisect(a, -99)\) = 0
\(bisect(a, 9)\) = 4
\(bisect(a, 7000)\) = 6
变形一:给定一个单调不减的数组 \(a\), 返回恰好 大于等于 \(x\) 的下标位置,等价为返回恰好 大于 \(x-1\) 的下标位置,即 \(bisect(a, x - 1)\)。
a = [1, 9, 9, 9, 200, 500]
恰好大于等于 9 的位置
\(bisect(a, 9 -1)\) = 1
恰好大于等于 200 的位置
\(bisect(a, 200 - 1)\) = 4
变形二:给定一个单调不减的数组 \(a\) ,返回恰好 小于等于 \(x\) 的下标位置,等价为返回恰好大于 \(x\) 的下标位置 \(-1\),即 \(bisect(a, x) - 1\)。
a = [1, 9, 9, 9, 200, 500]
恰好小于等于 9 的位置
\(bisect(a, 9) - 1\) = 3
恰好小于等于 500 的位置
\(bisect(a, 500 )\) = 5
变形三:给定一个 单调不增 的数组 \(a\), 返回恰好 小于 \(x\) 的下标位置。处理方法:\(a' = [-x \text{ for } x \text{ in }a]\),等价于 \(bisect(a', -x)\)。
也可以用逆序做,更推荐用相反数做。
\(a = [500, 200, 9, 9, 9, 1]\)
\(a' = [-500, -200, -9, -9, -9, -1]\)
原数组中恰好小于 9 的位置:\(bisect(a', -9)\) = 5。
利用内置库 bisect 进行二分¶
bisect(a, x, lo = 0, hi = len(nums))
- 给定一个单调不减的数组 \(a\), 在其 \([lo, hi]\) 区间中, 返回第一个严格大于 \(x\) 的下标位置
- 时间复杂度 \(O( \log n)\)
bisect.bisect 和 bisect.bisect_right 是完全相同且同时支持的函数,为了方便,我们不写 bisect_right;
同时为了防止混淆,我们也不提 bisect_left,大家需要灵活学习如果用一个 bisect 实现所有变形。
from bisect import *
# 示例数组
# 0 1 2 3 4 5
arr = [1, 9, 9, 9, 200, 500]
# 查找插入位置
print(bisect(arr, 3)) # 输出: 1 (第一个大于 3 的索引)
print(bisect(arr, 1)) # 输出: 1 (第一个大于 1 的索引)
print(bisect(arr, -99)) # 输出: 0 (第一个大于 -99 的索引)
print(bisect(arr, 9)) # 输出: 4 (第一个大于 9 的索引)
print(bisect(arr, 7000)) # 输出: 6 (第一个大于 7000 的索引,此时等于数组长度)
arr = [1, 9, 9, 9, 200, 500]
# 如果需要找第一个大于等于 x 的位置索引
# bisect(nums, x - 1) ?
print(bisect(arr, 9 - 1)) # 输出: 1 (第一个大于等于 9 的索引)
print(bisect(arr, 200 - 1)) # 输出: 4 (第一个大于等于 200 的索引)
# 逆序数组,找到第一个小于 x 的位置索引
# 0 1 2 3 4 5
arr = [500, 200, 9, 9, 9, 1]
arr = [-x for x in arr]
print(bisect(arr, -9)) # 输出: 5 (第一个小于 9 的位置索引)
P2249 【深基 13.例 1】查找 - 洛谷 (luogu.com.cn)
from bisect import *
import sys
input = lambda: sys.stdin.readline().strip()
n, m = map(int, input().split())
nums = list(map(int, input().split()))
Q = list(map(int, input().split()))
s = set(nums) # nums 构成的集合,如果待查询数组 q not in s,直接返回-1
for q in Q:
if q not in s: print(-1, end = " ")
else: # q 一定出现在 nums 中, 利用技巧将“大于等于 x”转化成“大于 x-1”
print(bisect(nums, q - 1) + 1, end = " ")
2563. 统计公平数对的数目 - 力扣(LeetCode) 同 Problem - 1538C - Codeforces
- 由于求符合条件的数对个数,与顺序无关,先排序
- 对 $ \text{lower} \le x + y \le \text{upper}$,变形为 $ \text{lower} - x \le y \le \text{upper} - x$
- 即对一个 \(x\), 区间 \([i + 1, n)\) 中有多少个数出现在 区间 \([\text{lower} - x, \text{upper} - x]\)
- 即求 \(L = bisect(a, \text[lower] - x - 1)\),\(R = bisect(a, \text{upper} - x) - 1\)
- 答案等于 \(R-L + 1\)
from bisect import *
class Solution:
def countFairPairs(self, a: List[int], lower: int, upper: int) -> int:
a.sort()
res = 0
for i, x in enumerate(a):
L = bisect(a, lower - x - 1, i + 1)
R = bisect(a, upper - x, i + 1) - 1
res += R - L + 1
return res
手写更通用的二分算法¶
内置 \(bisect\) 固然好用,当条件从「恰好 \(a[i] > x\) 的下标位置 \(i\)」变成更通用的:
- 给定一个单调不减的数组 \(a\), 和关于 \(a[i]\) 单调不减函数 \(check\), 返回恰好有 \(check(a[i]) > x\) 的下标位置 \(i\)
- 例如,\(check(a[i]) = {(a[i])} ^ 3 + 2 \cdot a[i] + 1\) 希望找到一个位置 \(i\),恰好 \(check(a[i]) > x\);
可以发现,前文提及的数学模型,对应的 \(check(a[i])\) 即 \(a[i]\),是通用表述下的一个特例;
在 python3.8 版本不支持 bisect 传递比较规则,即无法传递 \(check\) 函数,于是我们需要自己实现 bisect 函数。
基本模型实现思路:
- 对于区间 \([lo, hi]\) 上二分,将区间划分为左半部 \([lo, \frac{lo + hi}{2})\), 右半部 \([\frac{lo + hi}{2}, hi)\);
- 区间中点 \(i = \frac{lo + hi}{2}\), 考虑 \(a[i] > x\) 吗?
- 是,区间更新为左半部,\(hi ← i\)
- 否,则由于 \(a[i] \le x\),恰好大于的位置应该不包括 \(i\),故 \(lo ← i + 1\)
- 当 \(lo = hi\) 结束,故 while 条件为 \(lo < hi\)。不会出现 $lo > hi $ 的情况。
# 【朴素二分】实现 bisect
def bisect(a, x, lo=0, hi=None):
if hi is None: hi = len(a)
while lo < hi:
i = (lo + hi) >> 1
if a[i] > x:
hi = i
else:
lo = i + 1
return lo
# 示例用法
a = [1, 9, 9, 9, 200, 500]
print(bisect(a, 9)) # 输出: 4
print(bisect(a, 7000)) # 输出: 6
\(check\) 模型思路:
- 对于区间 \([lo, hi]\) 上二分,将区间划分为左半部 \([lo, \frac{lo + hi}{2})\), 右半部 \([\frac{lo + hi}{2}, hi)\);
- 区间中点 \(i = \frac{lo + hi}{2}\), 考虑 \(check(a[i]) > x\) 吗?
- 是,区间更新为左半部,\(hi ← i\)
- 否,则由于 \(a[i] \le x\),恰好大于的位置应该不包括 \(i\),故 \(lo ← i + 1\)
- 当 \(lo = hi\) 结束,故 while 条件为 \(lo < hi\)
# 【朴素二分】实现 bisect,支持传递 check 函数
def bisect(a, x, lo = 0, hi = None, check = lambda y: y):
if hi is None: hi = len(a)
while lo < hi:
i = (lo + hi) >> 1
if check(a[i]) > x:
hi = i
else:
lo = i + 1
return lo
# 示例用法
a = [1, 9, 9, 9, 200, 500]
# 找到 a [i] ** 3 + a [i] * 2 + 1 恰好大于 x 的位置
x = 1000
print(bisect(a, x, check = lambda y: y ** 3 + y * 2 + 1)) #4
二分答案¶
二分答案一般满足条件:
- 求最值 / 最优问题
- 答案 \(res\) 存在一个确定、连续区间 \([lo, hi]\)
- 对确定的 \(res = i\),能够求出 \(check(i) =False \text{还是 }True\), 即是否满足条件
基本模型:
- 构造 False \(\to\) True 模型;
- 答案具有单调增性,即 \(res\) 是 \(check(i)\) 条件进行 \(False/True\) 切换的临界点。
时间复杂度:\(O(check(n) \cdot \log(L))\),其中 \(check(n)\) 表示进行一次遍历检查的复杂度。
与二分查找区别:
- 二分查找:在一个已知的有序数据集上进行二分地查找;
- 二分答案:答案有一个连续区间,在这个区间上二分,直到找到最优答案。
举个例子,我们需要找到一个单调增函数 \(f(x) = x^3 + x + 1\) 在 区间 \([1, 10^{18}]\) 上恰好满足 $check(x)=f(x) > target $ 为 \(True\) 的临界点。可以通过利用单调性,对 \(x\) 在区间上二分,找出恰好满足 \(check = True\) 的临界点。
def f(x):
return x ** 3 + x + 1
def bisect(lo, hi, target, check):
while lo < hi:
i = (lo + hi) >> 1
if check(i, target):
hi = i
else:
lo = i + 1
return lo
target = 99999
res = bisect(1, 10 ** 18, target, lambda x, target: f(x) > target)
# 找到恰好 f(x) > target 的地方
print(res) # 47
print(f(res)) # 103871
print(f(res - 1)) # 97383
求“区间”问题
给定 \(n\) 组数对 \((a, b)\),要满足对所有数对,都有 $ a // v = b $;求 \(v\) 的可取值区间;
\(\text{对于 }100\%\text{ 的评测用例,}1\leq n\leq10^4\mathrm{,}1\leq b\leq a\leq10^9\);
思考
- 显然 \(v\) 有区间 \([1, 10^9]\);
- 对于给定的 \(v\),可以通过枚举一整遍的方式判断是否符合条件;
- 暴力做法可以达到 \(O(nL)\), \(L = 10^9\) ,会超时;
思路
-
拆开来看 \(v\) 的上下界
-
对于下界,\(check_1(v)\) 表示恰好不存在 $ a// v > b$,即全部满足 \(a // v \le b\),随着 \(v\) 增大,从 不满足 → 满足。
- 对于上界,\(check_2(v)\) 表示恰好存在 $ a// v < b$,随着 \(v\) 增大,从 不满足 → 满足。 二分得到的是 \(M+ 1\),记得减去 1。
import sys
input = lambda: sys.stdin.readline().strip()
n = int(input())
a, b = zip(*[map(int, input().split()) for _ in range(n)])
def bisect(lo, hi, check):
while lo < hi:
i = (lo + hi) // 2
if check(i):
hi = i
else:
lo = i + 1
return lo
m = bisect(1, 10**9, lambda v: all(A // v <= B for A, B in zip(a, b)))
M = bisect(1, 10**9, lambda v: any(A // v < B for A, B in zip(a, b))) - 1
print(m, M)
时间复杂度: \(O(n \log L)\)
本题还有非二分答案做法。
3048. 标记所有下标的最早秒数 I - 力扣(LeetCode)
求“至少”问题
n, m = len(nums), len(changeIndices)
def check(mx): # 给 mx 天是否能顺利考完试
last_day = [-1] * n
for i, x in enumerate(changeIndices[:mx]):
last_day[x - 1] = i + 1
#如果给 mx 不能完成,等价于有为 i 遍历到考试日期的考试
if -1 in last_day:
return False
less_day = 0
for i, x in enumerate(changeIndices[:mx]):
if last_day[x - 1] == i + 1: # 到了考试日期
if less_day >= nums[x - 1]:
less_day -= nums[x - 1]
less_day -= 1 #抵消当天不能复习
else:
return False #寄了
less_day += 1
return True
left = sum(nums) + n # 至少需要的天数, 也是二分的左边界
res = left + bisect.bisect_left(range(left, m + 1), True, key = check)
return -1 if res > m else res
求“最多”问题
2226. 每个小孩最多能分到多少糖果 - 力扣(LeetCode)
给你一个 下标从 0 开始 的整数数组 candies 。数组中的每个元素表示大小为 candies[i] 的一堆糖果。你可以将每堆糖果分成任意数量的 子堆 ,但 无法 再将两堆合并到一起。
另给你一个整数 k 。你需要将这些糖果分配给 k 个小孩,使每个小孩分到 相同 数量的糖果。每个小孩可以拿走 至多一堆 糖果,有些糖果可能会不被分配。
返回每个小孩可以拿走的 最大糖果数目 。
示例 1:
输入:candies = [5,8,6], k = 3
输出:5
解释:可以将 candies[1] 分成大小分别为 5 和 3 的两堆,然后把 candies[2] 分成大小分别为 5 和 1 的两堆。现在就有五堆大小分别为 5、5、3、5 和 1 的糖果。可以把 3 堆大小为 5 的糖果分给 3 个小孩。可以证明无法让每个小孩得到超过 5 颗糖果。
示例 2:
输入:candies = [2,5], k = 11
输出:0
解释:总共有 11 个小孩,但只有 7 颗糖果,但如果要分配糖果的话,必须保证每个小孩至少能得到 1 颗糖果。因此,最后每个小孩都没有得到糖果,答案是 0 。
提示:
1 <= candies.length <= 10 ** 51 <= candies[i] <= 10 ** 71 <= k <= 10 ** 12
语言整理
给定长度为 \(n\) 一组数 \(a\),和 \(k\) 个人。这组数任意拆分的小子堆,每个人一堆,求可以拿走的最大值 \(res\) 。
思路
- 显然,答案有界,界于区间 \([0, \max(a)]\)
- 对 \(res\) 上界,即 \(check\) 表示恰好
sum(x // res for x in range(1, max(a) + 1)) < k - 二分得到的结果 \(-1\) 是答案
写法 1
class Solution:
def maximumCandies(self, a: List[int], k: int) -> int:
if sum(a) < k: return 0
lo, hi = 1, 10 ** 12 + 10
def check(res):
return sum(x // res for x in a) < k
while lo < hi:
i = (lo + hi) >> 1
if check(i): hi = i
else: lo = i + 1
return lo - 1
写法 2
class Solution:
def maximumCandies(self, a: List[int], k: int) -> int:
if sum(a) < k: return 0
lo, hi = 1, max(a) + 1
def check(res):
return sum(x // res for x in a) < k
while lo < hi:
i = (lo + hi) >> 1
if check(i): hi = i
else: lo = i + 1
return lo - 1
1642. 可以到达的最远建筑 - 力扣(LeetCode)
- 记录一下高度差数组
- 对于每次 \(check\),首先对高度差数组的前 \(res\) 项降序排序,得到新数组 \(t\)
- 对 \(t\) 的前 \(ladder\) 项,贪心的用梯子,只需要考虑后
t[ladders:]部分的和 - 这部分的和,如果大于 \(bricks\),说明,不能够到达
- 找出恰好不能到达的位置 \(res\),-1 是答案
class Solution:
def furthestBuilding(self, heights: List[int], bricks: int, ladders: int) -> int:
n = len(heights)
lo, hi = 0, n
d = [max(0, heights[i + 1] - heights[i]) for i in range(n - 1)]
def check(res):
t = sorted(d[:res], reverse = True)
return bricks < sum(t[ladders: ])
while lo < hi:
i = (lo + hi) >> 1
if check(i):
hi = i
else:
lo = i + 1
return lo - 1
def furthestBuilding(self, heights: List[int], bricks: int, ladders: int) -> int:
n = len(heights)
d = [max(0, heights[i + 1] - heights[i]) for i in range(n - 1)]
def check(x):
t = d[:x]
t.sort(reverse = True)
return not (ladders >= x or sum(t[ladders: ]) <= bricks)
return bisect.bisect_left(range(n), True, key = check) - 1
中位数转化为第 \(k\) 小问题
对于一个长度为 \(n\) 的 由于数组,求中位数等价于求数组中第 \(\frac{n-1}{2}\) 小的数问题(广义中位数)。
3134. 找出唯一性数组的中位数 - 力扣(LeetCode)
一共有 \((n + 1) \times n / 2\) 个子数组,其对应的 \(f=len (set(sub))\) 的值按照升序排列后,求其中位数。即转换为求数组中 第 \(k\) 小问题。
转换为二分查找:给定一个 \(x\) ,能得出所有子数组中 \(f\) 值小于等于 \(x\) 的个数 \(res\);可以发现 \(x\) 越大,\(res\) 越大;因此找到恰好让 \(res >k\) 的位置即可。这里需要使用到求 “不同元素个数小于等于 \(k\) 的子数组个数” 问题,这是一共广义上不定长滑动窗口问题。
def get_set_subarrays_lower_k(nums, k):
l = res = 0
freq = Counter()
for r, x in enumerate(nums):
freq[x] += 1
while len(freq) > k:
freq[nums[l]] -= 1
if freq[nums[l]] == 0: freq.pop(nums[l])
l += 1
res += r - l + 1
return res
class Solution:
def medianOfUniquenessArray(self, nums: List[int]) -> int:
n = len(nums)
m = ((n + 1) * n // 2 - 1) // 2
lo, hi = 0, n // 2 + 10
while lo < hi:
mid = (lo + hi) // 2
if get_set_subarrays_lower_k(nums, mid) > m:
hi = mid
else:
lo = mid + 1
return lo
搜索 / DFS / BFS¶
枚举 DFS¶
思路
- 贪心,从左到右,尽可能构造 9。
- 对每一位数字,只会用一种操作。
- 记 \(dfs(i, n, a, b)\) 表示当前考虑到第 \(i\) 位,剩下 \(a\) 次操作 1 和 \(b\) 次操作 2 情况下,当前获得的最大数为 \(n\)
- 对于操作 1,考虑 \(d=min(9-x,a)\),即当前能够执行操作 1 的次数
- 则 \(n \leftarrow n \times 10 + (x + d)\),$a \leftarrow a - d $
- 对于操作 2,考虑 \(b\) 是否大于等于 \(x + 1\),是则可以得到 9,且 \(b \leftarrow b - (x + 1)\)
import sys
input = lambda: sys.stdin.readline().strip()
N, A, B = map(int, input().split())
sN = str(N)
res = 0
def dfs(i, n, a, b):
global res
if i >= len(sN):
res = max(res, n)
return
x = int(sN[i])
d = min(9 - x, a)
dfs(i + 1, n * 10 + (x + d), a - d, b)
if b >= x + 1:
dfs(i + 1, n * 10 + 9, a, b - (x + 1))
dfs(0, 0, A, B)
print(res)
图上 DFS¶
语言整理
一个有向图,每个节点 \(u\) 有且仅有一条出边 \(u→v\)。给定 \(g\) 数组,\(g[u]=v\) 表示这条出边。
求图中最长环的长度。
思路
- \(dfs(u)\) 表示当前访问节点 \(u\),集合 \(s\) 表示访问过的节点集
- 如果 \(u\) 在 \(s\) 中,说明找到了环
- 如果 \(u\) 不在 \(s\) 中,添加到集合中,并且访问后续节点 \(dfs(g[u])\)
- 遍历所有节点,确保考虑到所有连通分量。
类似代码:
s = set
def dfs(u):
if u in s: return # 找到了环
s.add(u)
dfs(g[u])
for u in range(1, n + 1):
dfs(u)
- 那么怎么求解环的长度?
思路
- 时间戳思想,额外记录每次访问节点 \(u\) 的序号 \(idx\),字典 \(d\) 存放 \(u:idx\) 键值对
- 如果 \(u\) 在 \(d\) 中,说明第二次访问到 \(u\) ,构成闭环。两次序号之差 \(idx - d[u]\) 即环的长度
- 如果 \(u\) 不在 \(d\) 中,添加到字典中,并且访问后续节点 \(dfs(g[u], idx + 1)\)
- 遍历所有节点,确保考虑所有连通分量
- 在外层用 \(res\) 记录最大环长。
正解代码:
import sys
input = lambda: sys.stdin.readline().strip()
sys.setrecursionlimit(10000) # 增加递归深度至少大于 n,因为 python 默认为 1000
n = int(input())
g = [0] + list(map(int, input().split()))
res = 0
d = {}
def dfs(u, idx):
global res
if d.get(u) is not None:
res = max(res, idx - d[u]) # 找到闭环,序号差就是环长
return
d[u] = idx
dfs(g[u], idx + 1)
for u in range(1, n + 1): # 确保访问所有连通分量
dfs(u, 1)
print(res)
模拟 BFS¶
网格图模拟 BFS
- 是给定一个二维网格,以及一些初始位置,并说明初始位置的蔓延条件。
- 通过队列 \(q\) 存储位置。初始值即为初始位置
- 每次考虑当前位置 \((x,y)\) 的四周,尝试蔓延
网格图 BFS 模板。
# 设置 q 的初始值,如 q.append(...)
q = deque([(3, 4)]) # 或者 q = deque() 之后,q.append((3, 4))
g[3][4] = 0 # 标记访问过
di = [(0, 1), (0, -1), (1, 0), (-1, 0)]
while q:
x, y = q.popleft() # 弹出队首
for dx, dy in di: # 遍历四个方向,尝试蔓延
nx, ny = x + dx, y + dy
if 0 <= nx < n and 0 <= ny < m and g[nx][ny] == 1: # 判断蔓延是否合法
q.append((nx, ny)) # 加到队尾,表示后续开始蔓延的位置
g[nx][ny] = 0 # 标记访问过
# 其他操作
思路
- 枚举每个连通的岛屿,通过将访问过的位置设置为 0,即 \(grid[i][j] = 0\) 进行记录
- 每个岛屿的“登陆点”即为 \(q\) 的初始内容
- 每次将 \(q\) 的队首弹出,考虑其上下左右是否有陆地,是则加入到队尾,并且标记访问过,更新岛屿面积
- 重复操作,直到 \(q\) 为空。
class Solution:
def maxAreaOfIsland(self, g: List[List[int]]) -> int:
n, m = len(g), len(g[0])
q = deque()
res = 0
di = [(0, 1), (0, -1), (1, 0), (-1, 0)]
def bfs(i, j): # 考虑登陆点为 (i, j)的岛屿
ans = 1
q = deque([(i, j)])
g[i][j] = 0 # 登陆点设置为 0,表示已经访问过
while q:
x, y = q.popleft() # 弹出队首
for dx, dy in di: # 遍历四个方向,考虑是否有陆地
nx, ny = x + dx, y + dy
if 0 <= nx < n and 0 <= ny < m and g[nx][ny]:
q.append((nx, ny)) # 有陆地,加到队尾,表示后续需要考虑的位置
ans += 1
g[nx][ny] = 0 # 标记访问
return ans
for i, row in enumerate(g):
for j, x in enumerate(row):
if x == 1: # 遍历所有连通分量
res = max(res, bfs(i, j))
return res
import sys
from collections import deque
input = lambda: sys.stdin.readline().strip()
n, m = map(int, input().split())
g = [[0] * m for _ in range(n)]
di = [(0, 1), (0, -1), (1, 0), (-1, 0)]
q = deque()
for i in range(n):
r = input()
for j, x in enumerate(r):
if x == 'g':
g[i][j] = 1
q.append((i, j))
k = int(input())
while q and k:
for _ in range(len(q)):
x, y = q.popleft()
for dx, dy in di:
nx, ny = x + dx, y + dy
if 0 <= nx < n and 0 <= ny < m and g[nx][ny] == 0:
g[nx][ny] = 1
q.append((nx, ny))
k -= 1
for row in g:
print(''.join('g' if x else '.' for x in row))
回溯 / 递归 / 分治¶
子集型回溯¶
枚举子集,\(O(n\cdot2^n)\)
回溯方法 1:选 / 不选
def subsets(self, nums: List[int]) -> List[List[int]]:
n = len(nums)
res, path = [], []
def dfs(i):
if i == n:
res.append(path.copy())
return
path.append(nums[i])
dfs(i + 1)
path.pop()
dfs(i + 1)
dfs(0)
return res
方回溯法 2:枚举选哪个数 + 记录可以选的范围
\(dfs(i)\) 表示当前已经有选择了 \(path\) 后,\(path\) 下一个元素可以从 \(i\) 及其往后选。每一个 \(dfs\) 状态都是合法状态,需要记录。
def subsets(self, nums: List[int]) -> List[List[int]]:
res, path = [], []
n = len(nums)
def dfs(i):
res.append(path.copy())
for j in range(i, n):
path.append(nums[j])
dfs(j + 1)
path.pop()
dfs(0)
return res
位运算写法:
def subsets(self, nums: List[int]) -> List[List[int]]:
n = len(nums)
s = (1 << n) - 1
res = [[]]
sub = s
while sub:
res.append([nums[j] for j in range(n) if ((sub >> j) & 1)])
sub = (sub - 1) & s
return res
组合型回溯¶
枚举所有长度为 \(k\) 的组合
回溯方法 1:选 / 不选方法
时间复杂度 :\(O(n \cdot2^n)\)
def combine(self, n: int, k: int) -> List[List[int]]:
nums = list(range(1, n + 1))
res, path = [], []
def dfs(i):
if i == n:
if len(path) == k: res.append(path.copy())
return
# 不选
dfs(i + 1)
# 选
path.append(nums[i])
dfs(i + 1)
path.pop()
dfs(0)
return res
回溯方法 2:枚举当前选哪个数,以及记录可以选择的范围,每一个状态的合法情况需要记录。
剪枝操作 (1):确保所有子集长度不会超过 \(k\)。剪枝操作 (2) :要确保枚举当前选择的数的位置,不会使得最终整个子集长度达不到 \(k\) 。通过两个剪枝操作,确保只会得到长度恰好为 \(k\) 的子集。这里倒序 / 正序枚举在当前选择的数的范围影响下界 / 上界。
倒序枚举时, \(dfs(i)\) 表示当前选择范围为 \(nums[0] \sim nums[i]\),含有 \(i+1\) 个数。由于下一个状态是 \(j-1\),含有 \(j\) 个数,根据 \(j \ge k - len(path)-1\) 计算下界。
时间复杂度:\(O(k \cdot C(n,k))\)。因为总共组合状态个数有 \(C(n,k)\) 个,每个状态记录的长度不超过 \(k\) 。
# 倒叙枚举
def combine(self, n: int, k: int) -> List[List[int]]:
nums = list(range(1, n + 1))
res, path = [], []
def dfs(i):
if k == len(path):
res.append(path.copy())
return
for j in range(i, k - len(path) - 2, -1):
path.append(nums[j])
dfs(j - 1)
path.pop()
dfs(n - 1)
return res
位运算写法 + Gosper's Hack:
时间复杂度:\(O(n \cdot C(n,k))\)
def combine(self, n: int, k: int) -> List[List[int]]:
nums = list(range(1, n + 1))
s = (1 << n) - 1
sub = (1 << k) - 1
res = []
def next_sub(x):
lb = x & -x
left = x + lb
right = ((left ^ x) >> 2) // lb
return left | right
while sub <= s:
res.append([nums[i] for i in range(n) if (sub >> i) & 1])
sub = next_sub(sub)
return res
完全背包型组合
每个元素可以无限重复选择,需要找出目标值等于 \(target\) (或小于等于 \(target\) )的所有可行组合。
先排序,利于提前剪枝优化跳出循环。枚举当前选哪个 + 记录可以选择的范围型回溯,记录当前的和。由于可以重复选择,所以当前选择 \(j\) 以后,下一次的可以选择范围仍然是 \(j\)。
def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:
candidates.sort()
n, res, path = len(candidates), [], []
# 枚举当前选哪个,以及记录可以选择的范围;以及当前的和
def dfs(i, s):
if s == target:
res.append(path.copy())
return
for j in range(i, n):
x = candidates[j]
if x + s <= target:
path.append(x)
dfs(j, s + x) # 体现可重复选择
path.pop()
else: break
dfs(0, 0)
return res
括号生成问题:带限制组合型回溯
选 / 不选型回溯:枚举当前左 / 右括号,记录当前左括号的个数。记 \(lc\) 表示左括号个数, \(rc\) 表示右括号个数。
限制 1:\(lc\) 不能小于 \(rc\)。
限制 2:$lc $ 不能超过 \(n/2\)。
限制 3:当 \(lc=rc\),只能回溯左括号。
class Solution:
def generateParenthesis(self, n: int) -> List[str]:
# 枚举当前左 / 右括号,记录当前左括号的个数
n <<= 1
path = [None] * n
res = []
def dfs(i, lc):
rc = i - lc
if rc > lc or lc > n // 2: return
if i == n:
res.append(''.join(path))
return
path[i] = '('
dfs(i + 1, lc + 1)
if rc < lc: # 可以选右
path[i] = ')'
dfs(i + 1, lc)
dfs(0, 0)
return res
时间复杂度:由于状态个数是卡特兰数列,即 \(O(C_n) \sim O(\frac{4^n}{n^{\frac{3}{2}} \cdot\sqrt \pi})\)
排列型回溯¶
全排列:排列元素无重复
写法 1:\(dfs(i, S)\) 表示枚举到 第 \(i\) 位,没有枚举过的集合为 \(S\)。外层 \(path\) 表示当前回溯的路径。
其中 \(path\) 可以使用 \(path[i] = j\) 的写法,覆盖当前走到哪一步;也可以使用 \(append / pop\) 写法,覆盖和恢复现场。但是不可以在外层增加哈希集合维护没有枚举过的,这是因为集合添加操作的乱序性,外层的集合无法正确恢复现场(恢复后遍历顺序不正确)。
时间复杂度:当有 \(N\) 个数时,所有状态个数 \(M = A_N^N+A_N^{N-1}+~\cdots~+A_N^0=\sum_{k=0}^{N}\frac{N!}{k!} =N! \cdot \sum_{k=0}^{N}\frac{1}{k!}= e\cdot N!\)。(麦克劳林展开)。每个状态时间复杂度,可以将集合的复制下方到下一个状态,所以是 \(O(n)\)。故总复杂度:\(O(N \cdot N!)\)
def permute(self, nums: List[int]) -> List[List[int]]:
n = len(nums)
path = [0] * n
res = []
# 当前枚举到 位置 i,没有枚举过的集合为 S
def dfs(i, S):
if i == n:
res.append(path.copy())
return
for j in S:
path[i] = j
dfs(i + 1, S - {j})
dfs(0, set(nums))
return res
写法 2:更偏向于回溯。外层 \(path\) 表示当前回溯的路径,外层 \(on\_path\) 维护节点是否已经出现在回溯的路径中。
def permute(self, nums: List[int]) -> List[List[int]]:
n = len(nums)
path = []
on_path = [False] * n
res = []
# 当前枚举到 位置 i,on_path 记录是否已经出现在回溯路径 path 中
def dfs(i):
if i == n:
res.append(path.copy())
return
for pj, on in enumerate(on_path):
if not on:
on_path[pj] = True
path.append(nums[pj])
dfs(i + 1)
on_path[pj] = False
path.pop()
dfs(0)
return res
全排列:排列元素有重复:只能用 \(on\_path\) 回溯 / 位运算压缩
相同元素,在 \(i\) 处视为一个,加一个集合维护已经出现过的数字。
def permuteUnique(self, nums: List[int]) -> List[List[int]]:
n, res = len(nums), []
path, on_path = [0] * n, [0] * n
def dfs(i):
if i == n:
res.append(path.copy())
return
S = set() # 相同元素,在 i 处视为一个
for j, on in enumerate(on_path):
if not on and nums[j] not in S:
S.add(nums[j])
path[i] = nums[j]
on_path[j] = 1
dfs(i + 1)
on_path[j] = 0
dfs(0)
return res
相同值的排列视为同一个:在枚举 \(i\) 位置放谁的时候加一个集合维护已经出现过的数字。
def numSquarefulPerms(self, nums: List[int]) -> int:
n = len(nums)
res = 0
def is_sqr(x):
return x == int(sqrt(x)) ** 2
def dfs(i, S, pre):
nonlocal res
if i == n:
res += 1
return
# i 位置放谁
s = set()
for j in range(n):
x = nums[j]
if (S >> j) & 1 == 1 or x in s: continue
s.add(x)
if pre == None or is_sqr(pre + x):
dfs(i + 1, S | (1 << j), x)
dfs(0, 0, None)
return res
2850. 将石头分散到网格图的最少移动次数 - 力扣(LeetCode)
暴力枚举可重复全排列匹配 + 位运算压缩。用石头个数大于 1 和 没有石头的位置,构造两个列表,进行全排列暴力匹配。
def minimumMoves(self, grid: List[List[int]]) -> int:
frm, to = [], []
for i, row in enumerate(grid):
for j, x in enumerate(grid[i]):
if x == 0: to.append((i, j))
elif x > 1: frm.extend((i, j) for _ in range(x - 1))
res = inf
n = len(frm)
path = [None] * n
def dfs(i, S):
nonlocal res
if i == n:
cst = sum(abs(x1 - x2) + abs(y1 - y2) for (x1, y1), (x2, y2) in zip(path, to))
res = min(res, cst)
return
for j in range(n):
if (S >> j) & 1:
path[i] = frm[j]
dfs(i + 1, S ^ (1 << j))
dfs(0, (1 << n ) - 1)
return res
N 皇后问题
皇后之间不同行,不同列,且不能在同一斜线。如果只满足不同行不同列,等价于每行每列恰好一个皇后。如果用 \(col\) 表示皇后的位置,\(col[i]\) 表示 第 \(i\) 行的皇后在第 \(col[i]\) 列,则 "每行每列恰好一个皇后" 等价于 枚举 \(col\) 的全排列。
加上斜线上不能有皇后的条件,如果从上往下枚举,则左上方向、右上方向不能有皇后。所以问题变成,当前枚举到 第 \(i\) 行,可以枚举的列号的集合 \(S\) 。枚举列 \(j \in S\) ,合法情况即在 \(\forall r \in [0 ,~ i-1]\) ,其列值 \(c = col[r]\) 都不满足 \(i+j=r+c\) 或者 \(i-j=r-c\)。
写法 1:$ dfs(i, S)$ 枚举当前到第 \(i\) 行(选第 \(i\) 个数),可以选择的列号的集合是 \(S\) (没选择过的数字集合 S)
def solveNQueens(self, n: int) -> List[List[str]]:
res = []
path = [0] * n
# 当前枚举到第 i 行,可以继续枚举的列号集合是 S
def valid(i, j):
for r in range(i):
c = path[r]
if r + c == i + j or r - c == i - j:
return False
return True
def dfs(i, S):
if i == n:
res.append(['.' * j + 'Q' + (n - j - 1) * '.' for j in path])
return
for j in S:
if valid(i, j):
path[i] = j
dfs(i + 1, S - {j})
dfs(0, set(range(n)))
return res
写法 2:回溯全排列 + 位运算 + 集合优化 \(O(1)\) 判断斜线方向
由于判断 \(i+j\) 和 \(i - j\) 是否在之前回溯中出现过需要 \(O(n)\) 的时间,实际上只需要用集合记录出现过的 \(i+j\) 和 \(i - j\) 即可。对于出现过 \(i+j\) 和 \(i-j\) 分别(防止相互干扰)放进集合 \(lu\) 和 \(ru\) 中(由于 位运算中 \(i-j\) 可能出现负值,所以存放的元素改成 \(i-j+10\))。
def solveNQueens(self, n: int) -> List[List[str]]:
res = []
path = [0] * n
lu = ru = 0
# 当前枚举到第 i 行,可以继续枚举的列号集合是 S
def dfs(i, S):
nonlocal lu, ru
if i == n:
res.append(['.' * j + 'Q' + (n - j - 1) * '.' for j in path])
return
for j in range(n):
if (S >> j) & 1 and (lu >> (i + j)) & 1 == 0 and (ru >> (i - j + 10)) & 1 == 0:
path[i] = j
lu, ru = lu | (1 << (i + j)), ru | (1 << (i - j + 10))
dfs(i + 1, S & ~(1 << j))
lu, ru = lu ^ (1 << (i + j)), ru ^ (1 << (i - j + 10))
dfs(0, (1 << n) - 1)
return res
回溯分割字符串¶
记录当前切割到的位置,枚举下一个切割位置,判断切割合法性。
def partition(self, s: str) -> List[List[str]]:
n, path, res = len(s), [], []
# 当前分割的位置,枚举下次分割位置
def dfs(i):
if i == n:
res.append(path.copy())
return
for j in range(i + 1, n + 1):
t = s[i: j]
if t == t[::-1]:
path.append(t)
dfs(j)
path.pop()
dfs(0)
return res
增加了字符串段数限制:恰好等于 4。时间复杂度:\(O(n \times C(n,3))\)
def restoreIpAddresses(self, s: str) -> List[str]:
n, path, res = len(s), [], []
# 记录当前分割位置,枚举下一个分割位置
def dfs(i):
if len(path) == 4:
if i == n:
res.append('.'.join(path))
return
for j in range(i + 1, n + 1):
t = s[i: j]
if t == '0' or '0' not in t[0] and int(t) <= 255:
path.append(t)
dfs(j)
path.pop()
dfs(0)
return res
2698. 求一个整数的惩罚数 - 力扣(LeetCode)
判断一个数,其平方是否可能划分成若干字符串,其各段对应数字之和等于本身。例如 \(36\times36=1296,1+29+6=36\)
def check(x):
sx = str(x * x)
n = len(sx)
def dfs(i, s):
if i == n: return s == x
t = 0
for j in range(i + 1, n + 1):
t = t * 10 + int(sx[j - 1])
if t + s <= x and dfs(j, s + t):
return True
return False
return dfs(0, 0)
分治¶
395. 至少有 K 个重复字符的最长子串 - 力扣(LeetCode)
\(f(s)\) 表示字符串 s 中所有字符不少于 \(k\) 个的子串中的最大长度。考察所有在当前 \(s\) 串中出现次数少于 \(k\) 的字符(记作分割字符),最终最大串肯定不包含之。因此,每一层递归,以找到的第一个分割字符作为隔板,将 \(s\) 分成 若干个小的子串,取其中最大值即可。
时间复杂度:\(O(26N)\),这是由于每一层递归必然完全删除一个小写字母,且每一层需要遍历整个字符串,时间复杂度是 \(O(N)\);所以总复杂度是 \(O(26N)\)。
def longestSubstring(self, s1: str, k: int) -> int:
# s1 中所有字符数量 >= k 个最长子串
def f(s1):
cnt = Counter(s1)
for ch, c in cnt.items():
if c < k:
return max(f(sub) for sub in s1.split(ch))
return len(s1)
return f(s1)
1763. 最长的美好子字符串 - 力扣(LeetCode)
\(f(s)\) 表示字符串 \(s\) 中所有字符出现大小写的最长子串。以 \(s1\) 中只出现大写 / 小写的字母作为分割点,将问题分治,返回最大长度中出现最早的字符串。时间复杂度:\(O(26 \times n)\),因为每一层需要 \(O(n)\) 的复杂度,每一层递归至少减少一个字符。
from collections import *
from string import ascii_lowercase, ascii_uppercase
L, U = ascii_lowercase, ascii_uppercase
def f(s1):
s = set(s1)
for l, u in zip(L, U):
if (l in s) != (u in s):
ss = s1.split(l if l in s else u)
res = ''
for sub in ss:
cur = f(sub)
if len(cur) > len(res): res = cur
return res
return s1
class Solution:
def longestNiceSubstring(self, s: str) -> str:
return f(s)
排序¶
自定义排序¶
假设一个数据类型有身高 height、分数 score 和年龄 age 三个字段,现在的排序需求是:分数越高越靠前、若分数相同则年龄越小越靠前。
from typing import Self
class Item:
def __init__(self, height: int, score: int, age: int):
self.height = height
self.score = score
self.age = age
def __lt__(self, other: Self) -> bool:
if self.score == other.score:
return self.age < other.age
return self.score > other.score
a = [Item(180, 90, 21), Item(175, 92, 24), Item(185, 90, 22)]
a.sort()
for item in a:
print(f"height: {item.height}, "
f"score: {item.score}, "
f"age: {item.age}")
""" 输出
height: 175, score: 92, age: 24
height: 180, score: 90, age: 21
height: 185, score: 90, age: 22
"""
from functools import cmp_to_key
class Item:
def __init__(self, height: int, score: int, age: int):
self.height = height
self.score = score
self.age = age
def my_cmp(x: Item, y: Item) -> int:
"""返回负数或 0 表示 (x, y),返回正数表示 (y, x)"""
if x.score == y.score:
return x.age - y.age
return y.score - x.score
a = [Item(180, 90, 21), Item(175, 92, 24), Item(185, 90, 22)]
a.sort(key=cmp_to_key(my_cmp))
for item in a:
print(f"height: {item.height}, "
f"score: {item.score}, "
f"age: {item.age}")
""" 输出
height: 175, score: 92, age: 24
height: 180, score: 90, age: 21
height: 185, score: 90, age: 22
"""
class Item:
def __init__(self, height: int, score: int, age: int):
self.height = height
self.score = score
self.age = age
a = [Item(180, 90, 21), Item(175, 92, 24), Item(185, 90, 22)]
a.sort(key=lambda x: (-x.score, x.age))
for item in a:
print(f"height: {item.height}, "
f"score: {item.score}, "
f"age: {item.age}")
""" 输出
height: 175, score: 92, age: 24
height: 180, score: 90, age: 21
height: 185, score: 90, age: 22
"""
次最值问题¶
1289. 下降路径最小和 II - 力扣(LeetCode)
其中需要维护上一层的最小值、次小值及其对应的坐标。
mn = mn_2 = (inf, -1)
if y < mn[0]:
mn_2, mn = mn, (y, j)
elif y == mn[0]:
mn_2 = (y, j)
elif y < mn_2[0]:
mn_2 = (y, j)
计数排序¶
带修求第 \(k\) 小 / 大的数
使用哈希表维护每个数值出现次数,适用于数值的值域较小的情况。
例如:当 \(nums[i] \in [a,~b]\),可以在 \(O(b-a)\) 的时间复杂度内,找出第 \(k\) 大的数(有时表现会比 SortedList 更好)。
def get_min_k(cnt, k):
cur = 0
for x in range(a, b):
if cnt[x] == 0: continue
cur += cnt[x]
if cur >= k: return x
return b
2653. 滑动子数组的美丽值 - 力扣(LeetCode)
定长滑动窗口 + 哈希维护计数 + 计数排序。时间复杂度:\(O(nU),U\) 是值域极差。
def getSubarrayBeauty(self, nums: List[int], k: int, x: int) -> List[int]:
n = len(nums)
cnt = Counter(nums[:k])
def get_min_k():
cur = 0
for y in range(-50, 0):
cur += cnt[y]
if cur >= x: return y
return 0
res = [get_min_k()]
for r in range(k, n):
nl, nr = nums[r - k], nums[r]
cnt[nr] += 1
cnt[nl] -= 1
if cnt[nl] == 0: cnt.pop(nl)
res.append(get_min_k())
return res
位运算¶
位运算与集合论¶
集合 \(A, B\),最大二进制长度为 \(N\)
| 操作 | 位运算 |
|---|---|
| 全集 | $(1 $ << \(N)-1\) |
| 补集 | \(\sim A\) |
| 添加元素 | \(A \mid (1\) << \(i)\) |
| 删除元素 | \(A \& \sim(1\) << \(i)\) |
| 删除元素(一定在集合中) | \(A \oplus(1\) << \(i)\) |
| 属于 / 不属于 | \((A\) >> $i) \& 1 = 1 / 0 $ |
| 删除最小元素 | \(A\&(A-1)\) |
| 差集 | \(A\& \sim B\) |
| 差集(子集) / 对称差 | \(A\oplus B\) |
| \(A\) 包含于 \(B\) | \(A\&B=A\) |
(1). 把 b 位置为 1
通过 或 实现
(2). 把 b 位置清零
通过 与非 实现
(3). 获得一个数从高到低的每一位的值
1261. 在受污染的二叉树中查找元素 - 力扣(LeetCode)
class FindElements:
def __init__(self, root: Optional[TreeNode]):
self.root = root
def find(self, target: int) -> bool:
target += 1
node = self.root
for b in range(target.bit_length() - 2, -1, -1):
x = (target >> b) & 1
node = node.right if x else node.left
if not node: return False
return True
二维矩阵压缩为一维二进制串
满足 \(num >> x == s[i]\)
s = ["#", ".", ".", "#", ".", "#"]
num = sum((ch == '.') << i for i, ch in enumerate(s)) # 010110
print(bin(num)) # 0b 010110
预处理所有子集的和
时间复杂度:\(O(n\cdot 2^n)\)
sum_ = defaultdict(int)
for i, x in enumerate(nums):
for s in range(1 << i):
sum_[(1 << i) | s] = sum_[s] + x
从大到小枚举一个 \(s\) 的所有非空子集
暴力做法是从 \(s\) 出发,不断减 1。但是中途需要规避不是 \(s\) 子集的情况,相当于做 ”压缩版“ 的二进制减法:普通的二进制减法会把最低位的 1 变成 0,同时 1 右边的 0 变成 1(例如 \(101000 \rightarrow1 00111\));”压缩版“ 的二进制减法只保留原集合中的的、右边的 1,其余仍然是 0。(例如 \(101000 \rightarrow 100101,假设 s = 111101)\)。保留的方法,就是 \(\& s\)。
Gosper's Hack:枚举大小恰好为 \(k\) 的子集
例如当前为 \(0100110\),下一个大小仍然为 \(3\) 的集合是 \(0101001\)(\(left = 0101000\),即 \(sub + lowbit(sub)\);\(right =000001\),即 \(left \oplus sub=0001111\),$right =left \oplus sub $ >> \(2 / lowbit(sub)\)。
时间复杂度:\(O(n \cdot C(n,k))\),实际上优化不大
s = (1 << n) - 1
sub = (1 << k) - 1
def next_sub(x):
lb = x & -x
left = x + lb
right = ((left ^ x) >> 2) // lb
return left | right
while sub <= s:
# 处理 sub 逻辑
sub = next_sub(sub)
判断是否有两个连续(相邻)的 1
十进制长度
二进制长度
二进制中 1 的数量
十进制 int 转换 对应二进制的 int
十进制转 \(-2\) 进制
def baseNeg2(self, n: int) -> str:
if n == 0: return '0'
res = []
while n:
if n & 1: x = 1
else: x = 0
n = (n - x) // -2
res.append(str(x))
return ''.join(res[::-1])
最大异或
def findMaximumXOR(self, nums: List[int]) -> int:
n = max(nums).bit_length()
res = mask = 0
for i in range(n - 1, -1, -1):
mask |= 1 << i
s, tmp = set(), res | (1 << i)
for x in nums: # x ^ a = tmp -> a = tmp ^ x
x &= mask
if tmp ^ x in s:
res = tmp
break
s.add(x)
return res
拆位试填法¶
当发现题目要求所有元素按位运算得到的 最值 问题时,从高位开始考虑是否能为 1/0 。
考虑过的状态记录在 res 中,不考虑的位用 mask 置为 0 表示。
mask = res = 0
for b in range(n, -1, -1):
mask |= 1 << b # 蒙版
for x in nums:
x &= mask
# 最大值 ...
res |= 1 << b # 得到最大值
mask &= ~(1 << b) # 该位自由,不用考虑
mask = res = 0
for b in range(n, -1, -1):
mask |= 1 << b
ans_res = -1 # 初始值全是 1
cnt = 0
for x in nums:
ans_res &= x & mask
if ans_res > 0:
cnt += 1
else:
ans_res = -1 # 重置初始值
if cnt > k: # 说明这一位必然是 1
# mask 这位蒙版就应置为 0,表示后续都不考虑这位
mask &= ~(1 << b)
res |= 1 << b
return res
滑动窗口¶
定长滑动窗口¶
维护定长滑动窗口和
1343. 大小为 K 且平均值大于等于阈值的子数组数目 - 力扣(LeetCode)
def numOfSubarrays(self, nums: List[int], k: int, t: int) -> int:
n = len(nums)
s = sum(nums[ :k])
res = 1 if s / k >= t else 0
for r in range(k, n):
s = s + nums[r] - nums[r - k]
if s / k >= t: res += 1
return res
2134. 最少交换次数来组合所有的 1 II - 力扣(LeetCode)
将环形数组中所有 1 聚集到一起的最小交换位置次数,即长度为 \(k\) (字符串中 1 的个数)的滑动窗口中,0 个数最少的位置处。
def minSwaps(self, nums: List[int]) -> int:
k = nums.count(1)
if k == 0: return 0
nums = nums + nums
n = len(nums)
s = sum(nums[:k])
res = k - s
for r in range(k, n):
s = s + nums[r] - nums[r - k]
res = min(res, k - s)
return res
维护定长滑动窗口 + 字典计数
判断字符串 \(s1\) 中是否包含 \(s2\) 的一个排列:转化为判断是否存在两个字典计数情况相等。
def checkInclusion(self, s1: str, s2: str) -> bool:
k, n = len(s1), len(s2)
target = Counter(s1)
cur = Counter(s2[ :k])
if target == cur: return True
for r in range(k, n):
rch, lch = s2[r], s2[r - k]
cur[rch] += 1
cur[lch] -= 1
if target == cur: return True
return False
438. 找到字符串中所有字母异位词 - 力扣(LeetCode)
找出字符串 \(s1\) 中所有 \(s2\) 的排列的下标。
def findAnagrams(self, s: str, p: str) -> List[int]:
res = []
k, n = len(p), len(s)
target = Counter(p)
cur = Counter(s[ :k])
if cur == target: res.append(0)
for r in range(k, n):
lch, rch = s[r - k], s[r]
cur[rch] += 1
cur[lch] -= 1
if cur == target:
res.append(r - k + 1)
return res
2841. 几乎唯一子数组的最大和 - 力扣(LeetCode)
通过 \(Counter()\) 维护滑动窗口中,不同元素的个数。(即 \(len(set(win))\)
def maxSum(self, nums: List[int], m: int, k: int) -> int:
s = sum(nums[ :k])
cnt = Counter(nums[ :k])
res = s if len(cnt) >= m else 0
n = len(nums)
for r in range(k, n):
nl, nr = nums[r - k], nums[r]
s = s + nr - nl
cnt[nr] += 1
cnt[nl] -= 1
if cnt[nl] == 0: cnt.pop(nl)
if len(cnt) >= m and s > res: res = s
return res
2009. 使数组连续的最少操作数 - 力扣(LeetCode)
定长滑动窗口 + 正难则反:需要操作最少次数 = n - 能够不操作的最多的数字。这些数字显然是不重复的,所以首先去重。对于去重完的元素,每一个左边界 $ nums[left]\(, 在去重数组中 ,\)[nums[left] ,~ nums[left] + n - 1]$ 区间在数组中出现的次数即为当前可以保留的数字的个数。
def minOperations(self, nums: List[int]) -> int:
n = len(nums)
nums = sorted(set(nums))
res = left = 0
for i, x in enumerate(nums):
while x > nums[left] + n - 1:
left += 1
res = max(res, i - left + 1)
return n - res
定长滑动窗口 + 正难则反:要求前 + 后 的个数为定值 \(k\) 个,转换为中间为 \(n-k\) 个。
def maxScore(self, nums: List[int], k: int) -> int:
n, tot = len(nums), sum(nums)
if n == k: return tot
k = n - k
s = sum(nums[:k])
res = s
for r in range(k, n):
s = s + nums[r] - nums[r - k]
res = min(res, s)
return tot - res
不定长滑动窗口¶
和大于等于 \(k\) 的最短数组(最短长度)
def minSubArrayLen(self, k: int, nums: List[int]) -> int:
l = s = 0
res = inf
for r, x in enumerate(nums):
s += x
while s >= k:
res = min(res, r - l + 1)
s, l = s - nums[l], l + 1
return res if res < inf else 0
2904. 最短且字典序最小的美丽子字符串 - 力扣(LeetCode)
求包含恰好 \(k\) 个 1 的所有子串中,最短、字典序最小的:转换为 和大于等于 \(k\) 的最短数组问题。
def shortestBeautifulSubstring(self, s1: str, k: int) -> str:
if s1.count('1') < k: return ''
n = len(s1)
l = s = 0
resl, resr = 0, n
for r, ch in enumerate(s1):
s += int(ch)
while s >= k:
width = r - l + 1
if width < resr - resl + 1 or (width == resr - resl + 1 and s1[l: r + 1] < s1[resl: resr + 1]):
resl, resr = l, r
s, l = s - int(s1[l]), l + 1
return s1[resl: resr + 1]
和小于等于 \(k\) 的最长子数组(最大长度)
1493. 删掉一个元素以后全为 1 的最长子数组 - 力扣(LeetCode)
\(s\) 统计 0 的个数,转换为 $s \le 1 $ 的最长子数组长度。
def longestSubarray(self, nums: List[int]) -> int:
res = 0
l = s = 0
# s <= 1 的最长子数组
for r, x in enumerate(nums):
s += 1 if x == 0 else 0
while s > 1:
s -= 1 if nums[l] == 0 else 0
l += 1
if s <= 1:
res = max(res, r - l)
return res
2730. 找到最长的半重复子字符串 - 力扣(LeetCode)
预处理相邻字符相等情况,转换为 \(s \le 1\) 的最长字串。
def longestSemiRepetitiveSubstring(self, s1: str) -> int:
l = s = 0
res = 0
nums = [1 if s1[i] == s1[i - 1] else 0 for i in range(1, len(s1))]
# s <= 1 的最长子串
for r, x in enumerate(nums):
s += x
while s > 1:
s -= nums[l]
l += 1
res = max(res, r - l + 1)
return res + 1
不包含重复元素的子数组(最长长度 / 最大和)
包含重复元素的条件是 \(len(win) < r - l+1\)。
方法一:使用 \(d\) 字典维护元素出现的最新位置。当 \(d[ch]\) 出现在 \([l,~r]\) 区间内表示需要更新 \(l=d[ch]+1\)。
def lengthOfLongestSubstring(self, s: str) -> int:
d = defaultdict(lambda: -inf)
l = res = 0
for r, ch in enumerate(s):
if d[ch] >= l:
l = d.pop(ch) + 1
d[ch] = r
res = max(res, r - l + 1)
return res
方法二:
不包含重复元素的条件是 len(win) = r - l + 1。
def lengthOfLongestSubstring(self, s: str) -> int:
res = 0
l = 0
cnt = Counter()
for r, ch in enumerate(s):
cnt[ch] += 1
while len(cnt) < r - l + 1:
lch = s[l]
cnt[lch] -= 1
if cnt[lch] == 0: cnt.pop(lch)
l += 1
res = max(res, r - l + 1)
return res
1695. 删除子数组的最大得分 - 力扣(LeetCode)
不包含重复元素的子数组的最大和。不包含重复元素的条件是 \(len(win) = r - l+1\)。
def maximumUniqueSubarray(self, nums: List[int]) -> int:
res = 0
l = s = 0
cnt = Counter()
for r, x in enumerate(nums):
cnt[x] += 1
s += x
while len(cnt) < r - l + 1:
nl = nums[l]
s -= nl
cnt[nl] -= 1
if cnt[nl] == 0: cnt.pop(nl)
l += 1
res = max(res, s)
return res
不定长滑窗 + 位运算。
子数组中所有数两两 \(AND\) 都是 \(0\),要求子数组所有比特位上是 \(1\) 的个数最多是 \(1\) 。用 \(ors\) 维护滑窗中的 \(OR\) 值, \(r\) 能加入子数组需要满足 \(r~ \& ~ors = 0\)。(因为 $a | (b ~\&~ c) = (a \&~b) | (a \&~ c)=0 $,即与子数组中所有数相与都是 0)。而从子数组中删去 \(left\) 只需要做异或操作,因为已经确保了 \(ors\) 比特位上面每个 \(1\) 都是独属的。
def longestNiceSubarray(self, nums: List[int]) -> int:
# nums [i] <= 1e9,30 个二进制 bit 位上
ors = 0
l = 0
res = 0
for r, x in enumerate(nums):
while x & ors > 0:
ors ^= nums[l]
l += 1
ors |= x
res = max(res, r - l + 1)
return res
子数组合法方案数问题
先更新滑动窗口状态(广义上),检查、剔除不合法的元素(如窗口左边界右移、计数器减一等),累积上 \(res\) 的方案数(等于窗口长度)。
乘积小于 \(k\) 的子数组的个数(方案数)
713. 乘积小于 K 的子数组 - 力扣(LeetCode)
def numSubarrayProductLessThanK(self, nums: List[int], k: int) -> int:
if k <= 1: return 0
res = 0
l, prod = 0, 1
for r, x in enumerate(nums):
prod *= x
while prod >= k:
prod, l = prod / nums[l], l + 1
res += r - l + 1
return res
不同值个数小于等于 \(k\) 的子数组个数(方案数)
使用 \(freq\) 计数器维护元素出现个数、左边界 \(l\) 。当
3134. 找出唯一性数组的中位数 - 力扣(LeetCode)
# 计算不同值元素个数小于对于 k 的子数组个数
def get_set_subarrays_lower_k(nums, k):
l = res = 0
freq = Counter()
for r, x in enumerate(nums):
freq[x] += 1
while len(freq) > k:
freq[nums[l]] -= 1
if freq[nums[l]] == 0: freq.pop(nums[l])
l += 1
res += r - l + 1
return res
不定长滑窗 + 哈希计数¶
不定长滑窗哈希表:所有 \(freq[x] \le k\) 的最长子数组,由于单次判断的时间复杂度是 \(O(n)\),总复杂度是 $O(n^2) $。 通过 \(cnt\) 维护 \(freq[x] > k\) 的个数,在边界处增减 1。转换为 \(cnt = 0\) 最长子数组 ,复杂度 $O(n) $
2958. 最多 K 个重复元素的最长子数组 - 力扣(LeetCode)
def maxSubarrayLength(self, nums: List[int], k: int) -> int:
# 所有 freq [x] <= k 的最长子数组
# 转换为 cnt == 0 最长子数组 (cnt 为 freq [x] > k 的个数)
freq = Counter()
l = res = cnt = 0
for r, x in enumerate(nums):
cnt += int(freq[x] == k)
freq[x] += 1
while cnt > 0:
nl = nums[l]
cnt -= int(freq[nl] == k + 1)
freq[nl] -= 1
l += 1
res = max(res, r - l + 1)
return res
不定长滑窗 + 哈希计数。首先对 \(nums\) 数组排序 + 去重,对于 \(m=1\) 的情况直接返回 \(0\) ;否则要求 $2 \sim m $ 中的每一个数,都在滑动窗口中的所有因子构成的集合中出现一次。因为需要维护滑窗集合,所以采用滑窗 + 哈希, \(freq\) 维护滑窗中所有数的因子的出现个数。由于需要 $2 \sim m $ 中均有 \(freq[x] \ge 1\) ,每次判断需要 \(O(n)\) 的复杂度,总体是 \(O(mn)\) ,无法接受。因此使用 \(cnt\) 统计 \(freq[x] \ge 1\) 的个数,$cnt $ 在边界情况时增减 1。
from collections import *
from math import *
import sys
input = lambda: sys.stdin.readline().strip()
# 预处理 2 ~ mx 中所有数的约数的列表
mx = 100001
factors = [[] for _ in range(mx)]
for i in range(2, mx):
for j in range(i, mx, i):
factors[j].append(i)
t = int(input())
def solve():
n, m = map(int, input().split())
nums = list(set(map(int, input().split())))
nums.sort()
if m == 1:
return 0
# 2, 3, ... , m
# 求所有滑窗中,freq [2...m] >= 1 的最小极差。
# 用 cnt 统计 freq [x] >= 1 的个数,求 cnt == m - 1 的滑窗的最小极差
freq = Counter()
l = cnt = 0
res = inf
for r, x in enumerate(nums):
for y in factors[x]:
if y > m: break
if freq[y] == 0: cnt += 1
freq[y] += 1
while cnt == m - 1:
nl = nums[l]
res = min(res, x - nl)
for y in factors[nl]:
if y > m: break
if freq[y] == 1: cnt -= 1
freq[y] -= 1
l += 1
return res if res < inf else -1
for _ in range(t):
print(solve())
枚举型滑窗 / 分组滑窗¶
枚举出现种类个数
枚举 + 不定长滑窗 + 哈希计数
显然 \(k=1\) 时,原串满足条件直接返回其长度。
考虑 \(k>1\),要求子串中所有字符个数 \(\ge k\),假设 \(s1[l\sim r]\) 满足条件,考虑 \(s1[r+1]\),如果是在 \(s1\) 中出现过的字符,一定满足;否则一定不满足条件,此时不确定滑窗的左边界 \(l\) 是否需要右移,例如 \('bbaaacb'\),当 \(r=4\) 时,由于后续的 \(c\) 的出现会导致 无法凑出 \(bbaaab\),因此需要左移。这实际上比较难以确定。
然而,如果限制 / 约束了子串中不同字母的种类个数 \(c\),并依次枚举 \(c \in [1, 26]\),对于每个 \(c\),采用不定长滑窗 + 哈希计数的方式,可以唯一确定是否需要左移左边界。实现时,需要维护子串中每个字符出现的频次 \(freq\)、出现且次数不够 \(k\) 的字符的个数 \(cnt\)、字符的种类个数 \(tcnt\)。
def longestSubstring(self, s1: str, k: int) -> int:
if k == 1: return len(s1)
n = len(s1)
res = 0
# 枚举滑窗
for c in range(1, len(set(s1)) + 1):
# 滑窗中字母种类个数恰好为 c
freq = Counter()
cnt = 0 # 哈希计数
tcnt = 0 # 种类计数
l = 0
for r, ch in enumerate(s1):
if freq[ch] == 0:
cnt += 1
tcnt += 1
if freq[ch] == k - 1:
cnt -= 1
freq[ch] += 1
while tcnt > c:
lch = s1[l]
if freq[lch] == k:
cnt += 1
if freq[lch] == 1:
tcnt -= 1
cnt -= 1
freq[lch] -= 1
l += 1
if tcnt == c and cnt == 0:
res = max(res, r - l + 1)
return res
时间复杂度:\(O(26N)\)
定长滑窗 + 枚举
def countCompleteSubstrings(self, s1: str, k: int) -> int:
res = 0
# s 中每个字符恰好出现 k 次的子串个数
def f(s):
#
# 枚举字符种类个数
m = len(set(s))
n = len(s)
res = 0
# 由于是恰好 k 个,所以滑窗长度为 c * k
for c in range(1, m + 1):
K = c * k
if K > n: break
freq = Counter()
cnt = 0 # 等于 k 的个数
for i in range(K):
ch = s[i]
if freq[ch] == k - 1: cnt += 1
elif freq[ch] == k: cnt -= 1
freq[ch] += 1
res += int(cnt == c)
for r in range(K, n):
ch, lch = s[r], s[r - K]
if lch != ch:
if freq[ch] == k - 1: cnt += 1
elif freq[ch] == k: cnt -= 1
if freq[lch] == k: cnt -= 1
elif freq[lch] == k + 1: cnt += 1
freq[ch] += 1
freq[lch] -= 1
if freq[lch] == 0: freq.pop(lch)
res += int(cnt == c)
return res
sub = ''
for ch in s1:
if sub and abs(ord(ch) - ord(sub[-1])) > 2:
res += f(sub)
sub = ch
else:
sub += ch
res += f(sub)
return res
时间复杂度:\(O(26N)\)
枚举多起点
分组滑窗
2831. 找出最长等值子数组 - 力扣(LeetCode)
按元素对下标分组 + 不定长滑窗
预处理每个元素所有下标构成的数组,对某一元素 \(x\),其数组 \(a\) 上做不定长滑窗,\(a[r] - a[l] + 1\) 为滑窗长度,\(r-l+1\) 为 \(x\) 个数。因为最多能删除 \(k\) 个,所以要求 $ a[r] - a[l] + 1 - (r - l + 1) \le k$。
def longestEqualSubarray(self, nums: List[int], k: int) -> int:
d = defaultdict(list)
for i, x in enumerate(nums):
d[x].append(i)
res = 1
for a in d.values():
m = len(a)
l = 0
for r in range(m):
while a[r] - a[l] + 1 - (r - l + 1) > k:
l += 1
res = max(res, r - l + 1)
return res
双指针¶
相向双指针
def minimumRefill(self, nums: List[int], A: int, B: int) -> int:
n = len(nums)
i, j = 0, n - 1
a, b = A, B
res = 0
while i <= j:
l, r = nums[i], nums[j]
if i < j:
if a < l:
a = A - l
res += 1
else:
a -= l
if b < r:
b = B - r
res += 1
else:
b -= r
else:
x = max(a, b)
if x < l:
res += 1
i, j = i + 1, j - 1
return res
同向双指针
1574. 删除最短的子数组使剩余数组有序 - 力扣(LeetCode)
同向双指针 + 贪心。
def findLengthOfShortestSubarray(self, nums: List[int]) -> int:
n = len(nums)
l, r = 0, n - 1
while l + 1 < n and nums[l] <= nums[l + 1]:
l += 1
if l == n - 1: return 0
while r - 1 >= 0 and nums[r - 1] <= nums[r]:
r -= 1
res, mxl = r, l
for l in range(mxl + 1):
while r < n and nums[r] < nums[l]:
r += 1
res = min(res, r - l - 1)
return res
1163. 按字典序排在最后的子串 - 力扣(LeetCode)
转换问题:子串中字典序最大的串
性质 1: 后缀 \(s[i:]\) 是以 \(s[i]\) 为开头的子串中,字典序最大的串。
性质 2:考察字典序最大的串,记为 \(sub\),它的每个字符都小于等于 \(sub[0]\)。
例如,'ycyba' 是一个可能的最大串,而 'yczba'就不是,因为 'zba' > 'yczba'。
对于两个位置 \(i,j\),用偏移 \(k\) 考察 $s[i:] $ 和 \(s[j:]\) 中各个字符的大小关系:
-
\(s[i+k] = s[j+k]\),则继续往后考察,\(k+1\);
-
\(s[i+k] < s[j+k]\),应让 \(i\) 跳到 \(\max (i+k+1, j)\)。
- 首先可以肯定 \([i+1, i+k]\) 都不需要考虑,这是因为由于性质 1,2,所以以 \([i+1, i+k]\) 开头的子串一定比 \(s[i \sim i+k]\) 小;
- 其次,如果 \(i+k+1 \le j\),则 \(j\) 应该曾经已经考虑过 \([i+k+1, j - 1]\) 区间了,他们都没能更新 \(i\) ,一定比 \(s[i:i+k]\) 小
- 最后,如果 \(i+k+1>j\) ,例如 'yybbyybbyyc' 这种情况,\(s[i:i+k] = yybbyyb\),\(s[j:j+k] = yybbyyc\),有公共前缀 \(yybbyy\) ,以及后缀 \(yyc\),即 \([j, i + k]\) 区间一定回比后缀小。
所以, \(i = \max(i+k+1, j)\), $j = i + 1, k = 0 $。
-
\(s[i+k] > s[j+k]\),让 \(j\) 跳到 \(j+k+1, k = 0\)。
class Solution:
def lastSubstring(self, s: str) -> str:
i, j, k = 0, 1, 0
n = len(s)
while j + k < n:
if s[i + k] == s[j + k]:
k += 1
else:
if s[i + k] < s[j + k]:
i = j if j > i + k + 1 else i + k + 1
j = i + 1
else: # s [i + k] > s [j + k]:
j += k + 1
k = 0
return s[i: ]
分组循环¶
适用于:数组会被分割成若干组,且每一组的判断/处理逻辑是一样的。
核心思想:
- 外层循环负责遍历组之前的准备工作(记录开始位置),和遍历组之后的统计工作(更新答案最大值)。
- 内层循环负责遍历组,找出这一组最远在哪结束。
模板:
n = len(nums)
i = 0
while i < n:
start = i
while i < n and ...:
i += 1
# 从 start 到 i-1 是一组
# 下一组从 i 开始,无需 i += 1
ans = max(ans, i - start)
时间复杂度:\(O(n)\)
def longestAlternatingSubarray(self, nums: List[int], k: int) -> int:
n = len(nums)
i = res = 0
while i < n:
x = nums[i]
if x > k or x & 1:
i += 1
continue
start = i
i += 1
while i < n and nums[i] <= k and nums[i - 1] & 1 != nums[i] & 1:
i += 1
res = max(res, i - start)
return res
Tip
基础算法到此结束,接下来开始数据结构与高级算法的学习。
数据结构¶
双链表¶
class myList:
# 初始化链表,地址从 0 开始赋
def __init__(self, n: int) -> None:
self.val = [0] * (n + 10)
self.left = [0] * (n + 10)
self.right = [0] * (n + 10)
self.idx = 2
self.right[0] = 1
self.left[1] = 0
# 头插入
def push_front(self, x: int):
self.insert_right(0, x)
# 尾插入
def push_back(self, x: int):
self.insert_left(1, x)
# 在第 k 个插入的数左侧插入一个数
def insert_left(self, k: int, x: int):
self.insert_right(self.left[k], x)
# 在第 k 个插入的数右侧插入一个数
def insert_right(self, k: int, x: int):
self.val[self.idx] = x
self.right[self.idx] = self.right[k]
self.left[self.right[k]] = self.idx
self.left[self.idx] = k
self.right[k] = self.idx
self.idx += 1
# 删除第 k 个插入的数
def erase(self, k: int):
self.left[self.right[k]] = self.left[k]
self.right[self.left[k]] = self.right[k]
# 输出整个链表
def output(self) -> None:
i = self.right[0]
while i != 1:
print(self.val[i], end=' ')
i = self.right[i]
单调栈¶
def trap(self, height: List[int]) -> int:
# 单调栈:递减栈
stk, n, res = deque(), len(height), 0
for i in range(n):
# 1.单调栈不为空、且违反单调性
while stk and height[i] > height[stk[-1]]:
# 2.出栈
top = stk.pop()
# 3.特判
if not stk:
break
# 4.获得左边界、宽度
left = stk[-1]
width = i - left - 1
# 5.计算
res += (min(height[left], height[i]) - height[top]) * width
# 6.入栈
stk.append(i)
return res
矩形面积求解:维护单调增栈,同时首尾插入哨兵节点。
def largestRectangleArea(self, heights: List[int]) -> int:
heights.append(-1)
stk = [-1]
res = 0
for i, h in enumerate(heights):
while len(stk) > 1 and h < heights[stk[-1]]:
cur = stk.pop()
l = stk[-1]
width = i - l - 1
s = width * heights[cur]
res = max(res, s)
stk.append(i)
return res
1793. 好子数组的最大分数 - 力扣(LeetCode)
矩形面积求解问题变形:求 \(min(nums[i], ~ \cdots~, nums[j]) \times (j -i+1)\) ,并对 $i, ~j $ 做了范围约束。
def maximumScore(self, nums: List[int], k: int) -> int:
stk = [-1]
nums.append(-1)
res = 0
for i, h in enumerate(nums):
while len(stk) > 1 and h < nums[stk[-1]]:
cur = stk.pop()
l = stk[-1]
if not(l + 1 <= k and i - 1 >= k): continue # 约束范围
width = i - l - 1
res = max(res, width * nums[cur])
stk.append(i)
return res
单调栈维护元素的左右山形边界
对于 \(a[i]=x\) ,希望找到在左侧中的最小 \(l\) 满足 \(a[l + 1] \sim a[i-1] \ge 或者 \le x\);
对于 \(a[i]=x\) ,希望找到在右侧中的最大 \(r\) 满足 \(a[i + 1] \sim a[r-1] \ge 或者 \le x\);
stk, left = [], [-1] * n
for i in range(n):
x = nums[i]
while stk and x <= nums[stk[-1]]: stk.pop()
if stk: left[i] = stk[-1]
stk.append(i)
stk, right = [], [n] * n
for i in range(n - 1, -1, -1):
x = nums[i]
while stk and x <= nums[stk[-1]]: stk.pop()
if stk: right[i] = stk[-1]
stk.append(i)
2334. 元素值大于变化阈值的子数组 - 力扣(LeetCode)
在山形边界开区间所夹的区间内,满足所有元素大于等于山形边界元素 \(x\) ,即 $x = \min(a[l+1] \sim a[r-1]) $。
def validSubarraySize(self, nums: List[int], threshold: int) -> int:
n = len(nums)
# 单调栈解法
stk, left = [], [-1] * n
for i in range(n):
x = nums[i]
while stk and x <= nums[stk[-1]]: stk.pop()
if stk: left[i] = stk[-1]
stk.append(i)
stk, right = [], [n] * n
for i in range(n - 1, -1, -1):
x = nums[i]
while stk and x <= nums[stk[-1]]: stk.pop()
if stk: right[i] = stk[-1]
stk.append(i)
for i, x in enumerate(nums):
l, r = left[i], right[i]
k = r - l - 1
if x > (threshold / k): return k
return -1
单调栈优化 DP¶
2617. 网格图中最少访问的格子数 - 力扣(LeetCode)
暴力 dp 转移做法
class Solution:
def minimumVisitedCells(self, grid: List[List[int]]) -> int:
m, n = len(grid), len(grid[0])
f = [[inf] * n for _ in range(m)]
f[-1][-1] = 0
for i in range(m - 1, -1, -1):
for j in range(n - 1, -1, -1):
g = grid[i][j]
for k in range(1, min(g + 1, m - i)):
f[i][j] = min(f[i][j], f[i + k][j] + 1)
for k in range(1, min(g + 1, n - j)):
f[i][j] = min(f[i][j], f[i][j + k] + 1)
return f[0][0] + 1 if f[0][0] != inf else -1
单调栈 + 二分优化 DP
倒序枚举 \(i,~j\)
可以发现左边界 \(i\) 是递减的,右边界 $ j +g$ 是不确定的。联想到滑动窗口最值问题,维护一个向左增长的栈,栈元素自左向右递减。
由于栈中元素有序,每次查找只需要二分即可找出最值。
def minimumVisitedCells(self, grid: List[List[int]]) -> int:
m, n = len(grid), len(grid[0])
stkyy = [deque() for _ in range(n)] # 列上单调栈
f = 0 # 行上单调栈
for i in range(m - 1, -1, -1):
stkx = deque()
for j in range(n - 1, -1, -1):
g, stky = grid[i][j], stkyy[j]
f = 1 if i == m - 1 and j == n - 1 else inf
if g > 0:
if stkx and j + g >= stkx[0][1]:
mnj = bisect_left(stkx, j + g + 1, key = lambda x: x[1]) - 1
f = stkx[mnj][0] + 1
if stky and i + g >= stky[0][1]:
mni = bisect_left(stky, i + g + 1, key = lambda x: x[1]) - 1
f = min(f, stky[mni][0] + 1)
if f < inf:
while stkx and f <= stkx[0][0]:
stkx.popleft()
stkx.appendleft((f, j))
while stky and f <= stky[0][0]:
stky.popleft()
stky.appendleft((f, i))
return f if f != inf else -1
二分单调栈
class Solution:
def processTasks(self, tasks: List[List[int]]) -> int:
stk = [(-1, -1, 0)]
tasks.sort(key = lambda x: x[1])
for l, r, t in tasks:
p = bisect_left(stk, (l, )) - 1
blue = stk[-1][2] - stk[p][2]
red = max(0, stk[p][1] - l + 1)
t -= blue + red
if t <= 0: continue
nl, nr, nt = r - t + 1, r, stk[-1][2] + t
while stk:
ll, rr, _ = stk[-1]
if nl > rr: break
nl = ll - (rr - nl + 1)
stk.pop()
stk.append((nl, nr, nt))
return stk[-1][2]
单调队列¶
class MonotonicQueue:
def __init__(self, min_queue: bool = True):
self.q = deque()
if min_queue:
# 队头为最小值:不严格单调递增队列
self.compare = lambda a, b: a < b
else:
# 队头为最大值:不严格单调递减队列
self.compare = lambda a, b: a > b
def push_back(self, x):
while len(self.q) and self.compare(x, self.q[-1]):
self.q.pop()
self.q.append(x)
def pop_left(self, x):
if len(self.q) and x == self.q[0]:
self.q.popleft()
def get_extreme_value(self):
return self.q[0]
- 滑窗最大值:维护递减队列;
- 滑窗最小值:维护递增队列。
def maxSlidingWindow(self, nums: List[int], k: int) -> List[int]:
n = len(nums)
res = []
q = deque()
for i, x in enumerate(nums):
# 1.入,需要维护单调减队列的有序性
while q and x >= nums[q[-1]]:
q.pop()
q.append(i)
# 2.出,当滑动窗口区间长度大于 k 的时候,弹出去左端的
if i - q[0] + 1 > k:
q.popleft()
# 记录元素
if i >= k - 1:
res.append(nums[q[0]])
return res
2398. 预算内的最多机器人数目 - 力扣(LeetCode)
def maximumRobots(self, chargeTimes: List[int], runningCosts: List[int], budget: int) -> int:
n = len(chargeTimes)
res = 0
s = l = 0 # 滑窗的和 / 窗口左边界
q = deque() # 单调队列维护最大值
# 滑动窗口
for i, x in enumerate(chargeTimes):
while q and x >= chargeTimes[q[-1]]:
q.pop()
q.append(i)
s += runningCosts[i]
while i - l + 1 > 0 and s * (i - l + 1) + chargeTimes[q[0]] > budget:
s -= runningCosts[l]
l += 1
if l > q[0]:
q.popleft()
res = max(res, i - l + 1)
return res
单调队列优化 DP¶
2944. 购买水果需要的最少金币数 - 力扣(LeetCode)
暴力做法:\(O(n^2)\)
def minimumCoins(self, prices: List[int]) -> int:
n = len(prices)
# f [i] 表示获得 i 及其以后的所有水果的最少开销
f = [inf] * (n + 1)
for i in range(n, 0, -1):
# [i + 1, 2 * i] 免费
if 2 * i >= n:
f[i] = prices[i - 1]
else:
for j in range(i + 1, 2 * i + 2):
f[i] = min(f[i], f[j] + prices[i - 1])
return f[1]
注意到 i 递减,区间 \([i + 1, 2 \times i + 1]\) 是一个长度为为 i + 1 的滑动窗口,转移成滑动窗口最值问题。
def minimumCoins(self, prices: List[int]) -> int:
n = len(prices)
# f [i] 表示获得 i 及其以后的所有水果的最少开销
f = [inf] * (n + 1)
q = deque()
for i in range(n, 0, -1):
# i 递减,区间[i + 1, 2 * i + 1]是一个定长为 i + 1 的滑动窗口
while q and q[-1][1] - (i + 1) + 1 > i + 1:
q.pop()
if 2 * i >= n:
f[i] = prices[i - 1]
else:
f[i] = q[-1][0] + prices[i - 1]
while q and f[i] <= q[0][0]:
q.popleft()
q.appendleft((f[i], i))
return f[1]
倍增 LCA¶
\(f[u][i] 表示 u 节点 向上跳2^i\space 的节点\),\(dep[u] \space 表示深度\)
MX = int(n.bit_length())
f = [[0] * (MX + 1) for _ in range(n)]
dep = [0] * n
def dfs(u, fa):
# father [u] = fa
dep[u] = dep[fa] + 1 # 递归节点深度
f[u][0] = fa
for i in range(1, MX + 1): # 倍增计算向上跳的位置
f[u][i] = f[f[u][i - 1]][i - 1]
for v in g[u]:
if v != fa:
dfs(v, u)
# 假定 0 节点是树根
dep[0] = 1
for v in g[0]:
dfs(v, 0)
def lca(u, v):
if dep[u] < dep[v]:
u, v = v, u
# u 跳到和 v 同一层
for i in range(MX, -1, -1):
if dep[f[u][i]] >= dep[v]:
u = f[u][i]
if u == v:
return u
# 跳到 lca 的下一层
for i in range(MX, -1, -1):
if f[u][i] != f[v][i]:
u, v = f[u][i], f[v][i]
return f[u][0]
P3379 【模板】最近公共祖先(LCA) - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
from math import log
import sys
input = lambda: sys.stdin.readline().strip()
n, m, s = map(int, input().split())
# f [n][mx]
mx = int(log(n, 2))
f = [[0] * (mx + 1) for _ in range(n + 10)]
e = [[] for _ in range(n + 10)]
dep = [0] * (n + 10)
dep[s] = 1
for _ in range(n - 1):
u, v = map(int, input().split())
e[u].append(v)
e[v].append(u)
def dfs(u, fa):
dep[u] = dep[fa] + 1
f[u][0] = fa
for i in range(1, mx + 1):
f[u][i] = f[f[u][i - 1]][i - 1]
for v in e[u]:
if v != fa:
dfs(v, u)
for v in e[s]:
dfs(v, s)
def lca(u, v):
# 让 u 往上跳
if dep[u] < dep[v]: u, v = v, u
for i in range(mx, -1, -1):
if dep[f[u][i]] >= dep[v]:
u = f[u][i]
if u == v: return u
# 一定是在 lca 的下一层
# 一起跳
for i in range(mx, -1, -1):
if f[u][i] != f[v][i]:
u, v = f[u][i], f[v][i]
return f[u][0]
for _ in range(m):
a, b = map(int, input().split())
print(lca(a, b))
树上差分¶
点差分:解决多路径节点计数问题。
\(u \rightarrow v 的路径转化为 u \rightarrow lca左孩子 + lca \rightarrow v\)
# 差分时左闭右开,无需考虑啊 u = a 的情况
for u, v in query:
a = lca(u, v)
diff[u] += 1
diff[a] -= 1
diff[v] += 1
if father[a] != -1:
diff[father[a]] -= 1
树上异或¶
性质 1:对树上一条路径 \(u \rightarrow x_0 \rightarrow x_1 \rightarrow \cdots \rightarrow v\) 进行相邻节点两两异或运算,等价于只对路径起始节点和终止节点异或。
因而树上相邻异或 等价于 树上任意两点进行异或
性质 2:在树上任意相邻异或,总是有 偶数 个节点被异或。
class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
res = sum(nums)
delta = sorted([(x ^ k) - x for x in nums], reverse = True)
for du, dv in zip(delta[::2], delta[1::2]):
res = max(res, res + du + dv)
return res
树上直径¶
时间复杂度:\(O(n)\)
定义:树上任意两节点之间最长的简单路径即为树的「直径」。
定理:
- 对于无负边权的树,从树的任意节点出发寻找到距离最远的节点,一定是树直径的一个端点。 (反证)
方法一:两次 dfs
def treeDiameter(self, edges: List[List[int]]) -> int:
n = len(edges) + 1
e = [[] for _ in range(n + 1)]
for u, v in edges:
e[u].append(v)
e[v].append(u)
def dfs(u, fa):
res, mxv = 0, u
for v in e[u]:
if v == fa: continue
a, b = dfs(v, u)
if a + 1 > res:
res, mxv = a + 1, b
return res, mxv
_, s = dfs(0, -1)
res, _ = dfs(s, -1)
return res
方法二:树形 DP
返回每个节点 的最长路径 fst 和 与最长路径没有公共边的次长路径 sec,取 max(fst + sec)
def treeDiameter(self, edges: List[List[int]]) -> int:
n = len(edges) + 1
e = [[] for _ in range(n + 1)]
for u, v in edges:
e[u].append(v)
e[v].append(u)
res = 0
def dfs(u, fa):
nonlocal res
# 找出节点 u 为子树的最长 / 次长路径
fst = sec = -1
for v in e[u]:
if v == fa: continue
a, _ = dfs(v, u)
if a >= fst:
fst, sec = a, fst
else:
sec = max(a, sec)
res = max(fst + sec + 2, res)
return fst + 1, sec + 1
dfs(0, -1)
return res
树的直径问题,最小高度树的根一定在树的直径上。
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
e = [[] for _ in range(n)]
for u, v in edges:
e[u].append(v)
e[v].append(u)
# 确定以 x 为根
pa = [-1] * n
def dfs(u, fa):
pa[u] = fa
res, mxv = 0, u
for v in e[u]:
if v == fa:
continue
a, b = dfs(v, u)
if a + 1 > res:
res, mxv = a + 1, b
return res, mxv
_, x = dfs(0, -1)
dis, y = dfs(x, -1)
path = []
while y != -1:
path.append(y)
y = pa[y]
res = [path[dis // 2]]
if dis & 1:
res.append(path[dis // 2 + 1])
return res
堆 / 优先队列¶
库导入:from heapq import *
注意:Python 的 heap 默认小堆顶,即二叉堆堆顶元素小于左右孩子。常见操作如下:
heappush(heap, item):将item添加到heap中,同时保持堆的不变性。时间复杂度:\(O(\log n)\);heappop(heap):弹出并返回heap中的最小元素,同时保持堆的不变性。时间复杂度:\(O(\log n)\);heap[0]:取堆顶元素。时间复杂度 \(O(1)\)。
from heapq import *
hq = []
heappush(hq, 5)
heappush(hq, 9)
heappush(hq, 11)
heappush(hq, 12)
heappush(hq, 13)
heappush(hq, 15)
print(hq) # 输出: [5, 9, 11, 12, 13, 15]
# 获取堆顶元素(最小),O(1)
print(hq[0]) # 5
# 弹出堆顶元素(最小),O(logn)
heappop(hq)
print(hq) # 输出 [9, 12, 11, 15, 13]
# 注意:python 中堆默认且只能是小顶堆
# 大顶堆,通过取反实现
nums = [15, 13, 9, 5, 11, 12]
hq = []
for x in nums:
heappush(hq, -x)
print(hq) # [-15, -13, -12, -5, -11, -9]
# 获取堆顶元素(最大)
print(-hq[0]) # 15
# 弹出堆顶元素(最大),O(logn)
heappop(hq)
print(hq) # [-13, -11, -12, -5, -9]
print(-hq[0]) # 13
2530. 执行 K 次操作后的最大分数 - 力扣(LeetCode)
给你一个下标从 0 开始的整数数组 nums 和一个整数 k。你的 起始分数 为 0。
在一步 操作 中:
- 选出一个满足
0 <= i < nums.length的下标i, - 将你的 分数 增加
nums[i],并且 - 将
nums[i]替换为ceil(nums[i] / 3)。
返回在 恰好 执行 k 次操作后,你可能获得的最大分数。
向上取整函数 ceil(val) 的结果是大于或等于 val 的最小整数。
示例 2:
输入:nums = [1,10,3,3,3], k = 3
输出:17
解释:可以执行下述操作:
第 1 步操作:选中 i = 1 ,nums 变为 [1,4,3,3,3] 。分数增加 10 。
第 2 步操作:选中 i = 1 ,nums 变为 [1,2,3,3,3] 。分数增加 4 。
第 3 步操作:选中 i = 2 ,nums 变为 [1,2,1,3,3] 。分数增加 3 。
最后分数是 10 + 4 + 3 = 17 。
from heapq import *
from math import *
class Solution:
def maxKelements(self, nums: List[int], k: int) -> int:
res = 0
hq = []
for i, x in enumerate(nums):
nums[i] = -x
heappush(hq, -x)
for _ in range(k):
x = heappop(hq)
res += x
heappush(hq, -ceil(-x / 3))
return -res
思路
-
转换成从 \(n\) 个字母中,选出 \(n-m\) 个,使得字符串字典序最小
-
贪心:考虑当前选择字母的范围。假设上一处选择的字母下标为 \(L\),左边界 \(L+1\)。右边界 \(R\) 应该保证其右侧还有足够的备选项,在区间内,贪心的选择最小的字母。
-
最小堆,插入可以选择的元素,每次选择完成进行一次 \(pop\)
例如,
12, 5, 7, 6, 20, 1, 4, 15, 21 从 9 个数中,删除 3 个,即选择 6 个。
第一个可以选择的范围是 \([0, 3]\) ,即右边界的右侧有 5 个备选项。
第一处选择了 5,下标为 1,第二个可以选择的范围是 \([2, 4]\),即右边界还有 4 个备选项。
import sys
input = lambda: sys.stdin.readline().strip()
from heapq import *
n, m = map(int, input().split())
s = input()
hq = []
res = ''
L = 0
for i in range(m):
heappush(hq, (s[i], i))
for R in range(m, n):
heappush(hq, (s[R], R))
mn, mni = heappop(hq)
while mni < L:
mn, mni = heappop(hq)
res += mn
L = mni + 1
print(res)
注意:本题还有单调栈解法更为简单。
并查集¶
class DisjointSetUnion:
def __init__(self, n: int):
"""初始化一个含有 n 个元素的并查集,元素下标范围为 [0, n-1]"""
self.p = [i for i in range(n)] # p[i] 表示 i 号点的祖先结点编号
self.cnt = [1] * n # cnt[i] 表示 i 号点所在集合的元素个数
self.set_cnt = n # 集合的个数
def find(self, a: int) -> int:
"""返回 a 号点的祖先结点"""
if self.p[a] != a:
# 路径压缩
self.p[a] = self.find(self.p[a])
return self.p[a]
def merge(self, a: int, b: int) -> None:
"""合并结点 a 和结点 b 所在的集合"""
pa, pb = self.find(a), self.find(b)
if pa == pb:
return
# 按秩合并
self.set_cnt -= 1
if self.cnt[pa] < self.cnt[pb]:
self.p[pa] = pb
self.cnt[pb] += self.cnt[pa]
else:
self.p[pb] = pa
self.cnt[pa] += self.cnt[pb]
def same(self, a: int, b: int) -> bool:
"""判断结点 a 和 结点 b 是否在同一个集合"""
return self.find(a) == self.find(b)
def tree_size(self, a: int) -> int:
"""返回结点 a 所在集合的元素个数"""
return self.cnt[self.find(a)]
def forest_size(self) -> int:
"""返回集合的个数"""
return self.set_cnt
并查集是一种用于管理元素所属集合的数据结构,实现为一个森林,其中每棵树表示一个集合,树中的节点表示对应集合中的元素。集合内的元素可达且连通。
- \(union(u, v)\) 合并两个元素所属集合(合并对应的树)
-
\(find(x)\) 查询某个元素所属集合(查询对应的树的根节点),这可以用于判断两个元素是否属于同一集合
-
合并和查询的时间复杂度: 近似 \(O(1)\)
实现
-
初始化:每个元素都位于一个单独的集合,表示为一棵只有根节点的树。方便起见,我们将根节点的父亲设为自己
-
查询:当 \(fa[x] =x\) 表示 \(x\) 节点即是根节点;否则,通过递归调用 \(find(fa[x])\),沿着树向上移动,直至找到根节点。通常在判断是否可达、连通问题时进行查询,如需要判断 \(u,v\) 是否属于一个集合,通过 \(find(u),find(v)\) 是否相等判断。
-
合并:要合并两棵树,我们只需要将一棵树的根节点 \(find(v)\) 连到另一棵树的根节点 \(find(u)\),即 \(fa[find(v)] \leftarrow find(u)\)。通常在更新可达、连通关系进行合并。
路径压缩
查询过程中,经过的每个元素都属于该集合,我们可直接更新每个元素,让其父节点指向树根。即 \(fa[x] \leftarrow find(fa[x])\) 来减少树根,加快后续查询。最终,我们会将原树压缩成树高越发接近 2 的树。
并查集递归模板
# 并查集模板
fa = list(range(n + 1))
def find(x):
if fa[x] == x: return x
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
fa[find(v)] = find(u)
迭代模板
fa = list(range(n))
def find(x):
root = x
while fa[root] != root:
root = fa[root]
while fa[x] != x: # 路径压缩
x, fa[x] = fa[x], root
return root
def union(u, v):
root_u = find(u)
root_v = find(v)
if root_u != root_v:
fa[root_v] = root_u
常见问题
-
求连通块个数
-
记录每个集合大小:绑定到根节点
- 记录每个点到根节点的 距离:绑定到每一个节点上
语言整理
对一个无向图有 \(n\) 个点, \(m\) 条边, \((u,v)\) 表示一条无向边。给定 \(p\) 组查询,每组查询判断 \(p_u,p_v\) 是否可达。
思路
- 可达、连通的连通分量可以看作一个集合。
-
对给定的无向边 \((u,v )\) ,可通过 \(union(u,v)\) 合并,表示可达、连通关系。
-
对每组询问,\(find(p_u)\) 和 \(find(p_v)\) 的关系判断 \(u,v\) 是否可达。
import sys
n, m, p = map(int, input().split())
# 并查集模板
fa = list(range(n + 1))
def find(x):
if fa[x] == x: return x
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
fa[find(v)] = find(u)
for _ in range(m):
u, v = map(int, input().split())
# 通过合并,表示可达、连通关系
union(u, v)
for _ in range(p):
u, v = map(int, input().split())
print('Yes' if find(u) == find(v) else 'No')
并查集维护连通分量
语言整理
给定若干组数据,每组给 \(n,m\) 表示有 \(n\) 个节点的无向图,有 \(m\) 条边。求使得任意两个节点都可达还需要添加多少条边。
思路
- 可达、连通问题,想到并查集维护节点之间的可达性,用 \(union\) 操作更新并查集
-
考虑独立连通块的个数,即 \(cnt=len(set(fa))\) ,最少只需要添加 \(cnt - 1\) 条边,能够让任意连通块可达,即所有节点可达。
-
下标从 1 开始时,可定 \(fa = list(range(n+1))\),同时 \(cnt\) 要减去 1
import sys
input = lambda: sys.stdin.readline().strip()
while True:
s = input()
if s == '0': break
n, m = map(int, s.split())
fa = list(range(n + 1))
def find(x):
if fa[x] == x: return x
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
fa[find(v)] = find(u)
for _ in range(m):
u, v = map(int, input().split())
union(u, v)
# 压缩成严格菊花集
for x in range(1, n + 1):
fa[x] = find(x)
cnt = len(set(fa)) - 1 # 连通块数量,-1 是减去下标 0
print(cnt - 1) # cnt - 1 是需要修路数量
import sys
input = lambda: sys.stdin.readline().strip()
m, n = map(int, input().split())
k = int(input())
fa = list(range(m * n + 1))
def find(x):
if fa[x] == x: return x
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
fa[find(v)] = find(u)
for _ in range(k):
u, v = map(int, input().split())
union(u, v)
for x in range(1, n * m + 1):
fa[x] = find(x)
print(len(set(fa)) - 1)
1998. 数组的最大公因数排序 - 力扣(LeetCode)
质因子分解 + 并查集判断连通分量。
将所有数看成一个图中的节点。任意两个数 \(u, v\) , 如果不互质(gcd > 1) 说明存在一条边 $ u \sim v$。显然一种做法是用 \(O(n^2)\) 的时间维护所有节点对应的连通块。然而,实际上只需要对每个数 \(x\) 和它的所有质因子进行合并,这样可以保证有相同质因子的两个元素,他们可以在同一个连通分量。
记数组中最大值 $ m = max(nums)$, 可以看成一个 有 m 个节点的图。每次质因子分解的时间复杂度是 \(O(\sqrt x)\) ,所以从 \(O(n^2)\) 优化到 \(O(n \sqrt m)\)。最后,将排序好的数组和原数组对应位置上的元素进行对比。判断两个元素是否同属于一个连通分量即可。
时间复杂度:$O\bigg(n\big(\sqrt m \cdot \alpha(m) \big) +n\log n \bigg ) $
def gcdSort(self, nums: List[int]) -> bool:
n = len(nums)
fa = list(range(max(nums) + 1))
def find(x): # x 压缩到 fa [x] 中
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]
def union(u, v): # u 合并到 v 中
if find(u) != find(v):
fa[find(u)] = find(v)
for i, x in enumerate(nums):
xx = x
for j in range(2, int(sqrt(x)) + 1):
if x % j == 0:
union(j, xx)
while x % j == 0:
x //= j
if x > 1:
union(x, xx)
sorted_nums = sorted(nums)
for u, v in zip(nums, sorted_nums):
if u == v: continue
# 不在位元素,需要看是否在同一连通分量
if find(u) != find(v): return False
return True
952. 按公因数计算最大组件大小 - 力扣(LeetCode)
def largestComponentSize(self, nums: List[int]) -> int:
n = len(nums)
m = max(nums)
fa = list(range(m + 1))
def find(x):
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
fa[find(u)] = find(v)
for x in nums:
xx = x
for j in range(2, int(sqrt(x)) + 1):
if x % j == 0:
union(xx, j)
while x % j == 0:
x //= j
if x > 1:
union(xx, x)
for x in nums:
find(x)
cnt = Counter()
for x in nums: cnt[fa[x]] += 1
return max(cnt.values())
**并查集维护连通块大小 **
模板代码:
fa = list(range(n + 1))
siz = [1] * (n + 1)
def find(x):
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
siz[find(v)] += siz[find(u)]
fa[find(u)] = find(v)
诸如 \((0,1), (1, 2), (2, 3), (3,1)\) 应该视作一个连通块内,其交换次数为连通块大小 - 1。
class Solution:
def minSwapsCouples(self, row: List[int]) -> int:
n = len(row)
fa = list(range(n + 1))
siz = [1] * (n + 1)
def find(x):
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
siz[find(v)] += siz[find(u)]
fa[find(u)] = find(v)
s = set()
for i in range(0, n, 2):
p = i // 2
l, r = row[i] // 2, row[i + 1] // 2
if l == r: continue
union(l, r)
for x in row:
s.add(find(x // 2))
res = 0
for x in s:
res += siz[find(x)] - 1
return res
2867. 统计树中的合法路径数目 - 力扣(LeetCode)
并查集维护所有非质数子连通块的大小。
def countPaths(self, n: int, edges: List[List[int]]) -> int:
primes = []
N = n + 10
is_prime = [True] * N
is_prime[0] = is_prime[1] = False
for i in range(2, N):
if is_prime[i]:
primes.append(i)
for p in primes:
if i * p >= N:
break
is_prime[i * p] = False
if i % p == 0:
break
e = [[] for _ in range(n + 1)]
fa = list(range(n + 1))
siz = [1] * (n + 1)
def find(x):
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]
def union(u, v): # u 合并到 v
if find(u) != find(v):
siz[find(v)] += siz[find(u)]
fa[find(u)] = find(v)
for u, v in edges:
e[u].append(v)
e[v].append(u)
if not is_prime[u] and not is_prime[v]:
union(u, v)
res = 0
vis = [False] * (n + 1)
for u in range(1, n + 1):
if not vis[u] and is_prime[u]:
# 遍历 u 的所有非质数连通块
vis[u] = True
cur_siz = 0
for v in e[u]:
if not is_prime[v]:
sz = siz[find(v)]
res += sz + sz * cur_siz
cur_siz += sz
return res
并查集维护连通块按位与的值
100244. 带权图里旅途的最小代价 - 力扣(LeetCode)
def minimumCost(self, n: int, edges: List[List[int]], query: List[List[int]]) -> List[int]:
fa = list(range(n))
cc_and = [-1] * n
def find(x):
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]
def union(u, v, w): # v 合并到 u 中
if find(u) != find(v):
cc_and[find(u)] &= cc_and[find(v)]
fa[find(v)] = find(u)
for u, v, w in edges:
# 各自连通块内更新,只要更新其一即可
cc_and[find(u)] &= w
union(u, v, w)
return [0 if u == v else (-1 if find(u) != find(v) else cc_and[find(u)]) for u, v in query]
并查集维护链
把考虑过的元素串起来,链条的长度就是当前一段数的长度。关键需要高效串联两条链。
2334. 元素值大于变化阈值的子数组 - 力扣(LeetCode)
时间复杂度:\(O(nlogn)\)
对 \(nums\) 中每一个元素和对应的下标 按照降序排序,每次向右合并,当前的 \(x\) 一定是子数组中最小的。返回合并后并查集大小 \(k - 1\)( 减一使得不包含哨兵根节点)。不断向右侧合并直到出现符合的。
def validSubarraySize(self, nums: List[int], threshold: int) -> int:
n = len(nums)
fa = list(range(n + 1))
siz = [1] * (n + 1)
def find(x):
if fa[x] != x: fa[x] = find(fa[x])
return fa[x]
def union(u, v): # v 合并到 u 中
if find(u) != find(v):
siz[find(u)] += siz[find(v)]
fa[find(v)] = find(u)
for x, i in sorted(zip(nums, range(n)), reverse = True):
union(i, i + 1)
k = siz[find(i)] - 1
if x > (threshold / k): return k
return -1
import os
import sys
N = int(input())
a = list(map(int, input().split()))
fa = list(range(10 ** 6 + 1))
def find(x):
if x == fa[x]: return x
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
fa[find(v)] = find(u)
for i in range(N):
a[i] = find(a[i])
union(a[i] + 1, a[i])
print(*a)
树状数组¶
下标从 1 开始,单点修改 + 区间查询
# 下标从 1 开始
class FenwickTree:
def __init__(self, length: int):
self.length = length
self.tree = [0] * (length + 1)
def lowbit(self, x: int) -> int:
return x & (-x)
# 更新自底向上
def update(self, idx: int, val: int) -> None:
while idx <= self.length:
self.tree[idx] += val
idx += self.lowbit(idx)
# 查询自顶向下
def query(self, idx: int) -> int:
res = 0
while idx > 0:
res += self.tree[idx]
idx -= self.lowbit(idx)
return res
class NumArray:
def __init__(self, nums: List[int]):
n = len(nums)
self.nums = nums
self.tree = FenwickTree(n)
for i, x in enumerate(nums):
self.tree.update(i + 1, x)
def update(self, index: int, val: int) -> None:
# 因为这里是更新为 val, 所以节点增加的值应为 val - self.nums [index]
# 同时需要更新 nums [idx]
self.tree.update(index + 1, val - self.nums[index])
self.nums[index] = val
def sumRange(self, left: int, right: int) -> int:
r = self.tree.query(right + 1)
l = self.tree.query(left)
return r - l
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(index, val)
# param_2 = obj.sumRange(left, right)
离散化树状数组 + 还原:
class FenwickTree:
def __init__(self, length: int):
self.length = length
self.tree = [0] * (length + 1)
def lowbit(self, x: int) -> int:
return x & (-x)
# 更新自底向上
def update(self, idx: int, val: int) -> None:
while idx <= self.length:
self.tree[idx] += val
idx += self.lowbit(idx)
# 查询自顶向下
def query(self, idx: int) -> int:
res = 0
while idx > 0:
res += self.tree[idx]
idx -= self.lowbit(idx)
return res
class Solution:
def resultArray(self, nums: List[int]) -> List[int]:
# 离散化 nums
sorted_nums = sorted(nums)
tmp = nums.copy()
nums = [bisect.bisect_left(sorted_nums, x) + 1 for x in nums]
# 还原
mp_rev = {i: x for i, x in zip(nums, tmp)}
n = len(nums)
t1 = FenwickTree(n)
t2 = FenwickTree(n)
a = [nums[0]]
b = [nums[1]]
t1.update(nums[0], 1)
t2.update(nums[1], 1)
for i in range(2, len(nums)):
x = nums[i]
c = len(a) - t1.query(x)
d = len(b) - t2.query(x)
if c > d or c == d and len(a) <= len(b):
a.append(x)
t1.update(x, 1)
else:
b.append(x)
t2.update(x, 1)
# 还原为原始数据: i 为离散化秩,x 为还原值
return [mp_rev[i] for i in a] + [mp_rev[i] for i in b]
可离散化线段树:
class FenwickTree:
def __lowbit(self, x: int) -> int:
"""
返回x 的二进制中,最低为的1所构成的数。
:param x: 整数
:return: x的二进制中,最低为的1所构成的数
"""
return x & -x
def __init__(self, n: int, discretize: bool = False, nums: [List[int]] = None):
"""
初始化树状数组(Fenwick Tree)数据结构,下标从0开始
:param n: 值域范围
:param discretize: 是否对输入值进行离散化
:param nums: 离散化所需的输入数组
"""
self.__dic = None
self.__discretize = discretize
self.__nums = None
self.__n = n
if discretize:
unique_nums = sorted(set(nums))
self.__dic = {unique_nums[i]: i + 1 for i in range(len(unique_nums))}
self.__n = len(unique_nums)
self.__nums = [0] * (self.__n + 1)
def __query(self, x: int) -> int:
"""
查询小于等于x的个数
:param x: 查询的数
:return: 查询小于等于x的个数
"""
res = 0
while x > 0:
res += self.__nums[x]
x -= self.__lowbit(x)
return res
def update(self, x: int, val: int) -> None:
"""
x处对应的值增加val
:param x: 更新的数
:param val: 变化值
"""
if self.__discretize:
if x not in self.__dic:
raise ValueError(f"值{x} 不在离散化范围内")
x = self.__dic[x]
while x <= self.__n:
self.__nums[x] += val
x += self.__lowbit(x)
def query(self, lx: int, rx: int = None) -> int:
"""
如果只传入一个参数,则查询小于等于lx的个数
如果传入两个参数,则查询大于等于lx, 小于等于rx的个数
:param lx: 查询区间左端点
:param rx: 查询区间右端点
:return: 查询区间内的元素个数
"""
if self.__discretize:
if lx not in self.__dic:
raise ValueError(f"值{lx} 不在离散化范围内")
lx = self.__dic[lx]
if rx is not None:
if rx not in self.__dic:
raise ValueError(f"值{rx} 不在离散化范围内")
rx = self.__dic[rx]
if rx is not None:
if lx > rx:
raise ValueError(f"左边界{lx} 大于右边界{rx}")
return self.__query(rx) - self.__query(lx - 1)
return self.__query(lx)
线段树¶
lazy 线段树:
- 支持单点修改/区间修改/区间查询;
- 支持最小/最大/求和。
class SegmentTree:
__slots__ = ['node', 'lazy', 'n', 'nums', 'op', 'ini', 'ops']
def __init__(self, nums, ops = 'sum'):
n = len(nums)
if ops == 'sum' or ops == 'bin':
op, ini = lambda a, b: a + b, 0
elif ops == 'max':
op, ini = lambda a, b: max(a, b), -inf
elif ops == 'min':
op, ini = lambda a, b: min(a, b), inf
self.nums = nums
self.op = op
self.ini = ini
self.ops = ops
self.node = [ini] * (4 * n)
self.lazy = [None] * (4 * n)
self.n = n
def build(self, idx = 1, l = 1, r = None):
if r is None: r = self.n
if l == r:
self.node[idx] = self.nums[l - 1]
return
mid = (l + r) >> 1
self.build(idx << 1, l, mid)
self.build((idx << 1) + 1, mid + 1, r)
self.node[idx] = self.op(self.node[idx << 1], self.node[(idx << 1) + 1])
def do(self, idx, dl, dr, val = None):
if self.ops == 'bin':
self.node[idx] = dr - dl + 1
self.lazy[idx] = True
elif self.ops == 'sum':
self.node[idx] = self.op(self.node[idx], (dr - dl + 1) * val)
self.lazy[idx] = val
else:
self.node[idx] = self.op(val, self.node[idx])
self.lazy[idx] = val
def pushdown(self, idx, pl, pr):
val = self.lazy[idx]
mid = (pl + pr) >> 1
self.do(idx << 1, pl, mid, val)
self.do((idx << 1) + 1, mid + 1, pr, val)
self.lazy[idx] = None
def update(self, ul, ur, val, idx = 1, l = 1, r = None):
if r is None:r = self.n
if ul <= l and r <= ur:
self.do(idx, l, r, val)
return
if self.lazy[idx]:
self.pushdown(idx, l, r)
mid = (l + r) >> 1
if ul <= mid: self.update(ul, ur, val, idx << 1, l, mid)
if ur > mid: self.update(ul, ur, val, (idx << 1) + 1, mid + 1, r)
self.node[idx] = self.op(self.node[idx << 1], self.node[(idx << 1) + 1])
def query(self, ql, qr, idx = 1, l = 1, r = None):
if r is None: r = self.n
if ql <= l and r <= qr:
return self.node[idx]
if self.lazy[idx]:
self.pushdown(idx, l, r)
mid = (l + r) >> 1
ansl, ansr = self.ini, self.ini
if ql <= mid: ansl = self.query(ql, qr, idx << 1, l, mid)
if qr > mid:ansr = self.query(ql, qr, (idx << 1) + 1, mid + 1, r)
return self.op(ansl, ansr)
tr = SegmentTree([1, 2, 3, 4, 5], 'sum')
tr.build()
print(tr.query(1, 5)) # 15
print(tr.query(2, 5)) # 14
tr.update(2, 4, 2) # 1 4 5 6 5
print(tr.query(2, 5)) # 20
tr = SegmentTree([8, 4, 5, 7, 9], 'min')
tr.build()
print(tr.query(1, 4)) # 4
tr.update(1, 4, 5) # [5, 4, 5, 5, 9]
print(tr.query(4, 5)) # 5
tr.update(3, 5, -10) # [5, 4, -10, -10, -10]
print(tr.query(1, 3)) # -10
lazy 线段树(01 翻转):
class Solution:
def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
n = len(nums1)
node = [0] * (4 * n)
# 懒标记:True 表示该节点代表的区间被曾经被修改,但是其子节点尚未更新
lazy = [False] * (4 * n)
# 初始化线段树
def build(i = 1, l = 1, r = n):
if l == r:
node[i] = nums1[l - 1]
return
mid = (l + r) >> 1
build(i * 2, l, mid)
build(i * 2 + 1, mid + 1, r)
# 维护区间 [l, r] 的值
node[i] = node[i * 2] + node[i * 2 + 1]
# 更新节点值,并设置 lazy 标记
def do(i, l, r):
node[i] = r - l + 1 - node[i]
lazy[i] = not lazy[i]
# 区间更新:本题中更新区间 [l, r] 相当于做翻转
def update(L, R, i = 1, l = 1, r = n):
if L <= l and r <= R:
do(i, l, r)
return
mid = (l + r) >> 1
if lazy[i]:
# 根据标记信息更新 p 的两个左右子节点,同时为子节点增加标记
# 然后清除当前节点的标记
do(i * 2, l, mid)
do(i * 2 + 1, mid + 1, r)
lazy[i] = False
if L <= mid:
update(L, R, i * 2, l, mid)
if R > mid:
update(L, R, i * 2 + 1, mid + 1, r)
# 更新节点值
node[i] = node[i * 2] + node[i * 2 + 1]
build()
res, s = [], sum(nums2)
for op, L, R in queries:
if op == 1:
update(L + 1, R + 1)
elif op == 2:
s += node[1] * L
else:
res.append(s)
return res
lazy 线段树(带动态开点):
class Node:
__slots__ = ['l', 'r', 'lazy', 'val']
def __init__(self, val = 0):
self.l = None
self.r = None
self.lazy = None
self.val = val
class SegmentTree:
__slots__ = ['root', 'node', 'op', 'ini', 'ops', 'max_val']
def __init__(self, ops: str = 'sum', max_val: int = int(1e9)):
if ops == 'sum' or ops == 'bin':
op, ini = lambda a, b: a + b if a is not None else b, 0
elif ops == 'max':
op, ini = lambda a, b: max(a, b), -inf
elif ops == 'min':
op, ini = lambda a, b: min(a, b), inf
self.root = Node(ini)
self.op = op
self.ini = ini
self.ops = ops
self.max_val = max_val
def __do(self, node, dl, dr, val = None):
if self.ops == 'bin':
node.val = dr - dl + 1
node.lazy = True
elif self.ops == 'sum':
node.val = self.op(node.val, (dr - dl + 1) * val)
node.lazy = self.op(node.lazy, val)
else:
node.val = self.op(node.val, val)
node.lazy = val
# 下放 lazy 标记。如果是孩子为空,则动态开点
def __pushdown(self, node, pl, pr):
val = node.lazy
# 根据 lazy 标记信息,更新左右节点,然后将 lazy 信息清除
mid = (pl + pr) >> 1
self.__do(node.l, pl, mid, val)
self.__do(node.r, mid + 1, pr, val)
node.lazy = None
def update(self, ul, ur, val, node = None, l = 1, r = None):
# 查询默认从根节点开始
if node is None: node = self.root
if r is None: r = self.max_val
if ul <= l and r <= ur:
self.__do(node, l, r, val)
return
if node.l is None: node.l = Node(self.ini)
if node.r is None: node.r = Node(self.ini)
# 下放标记、根据标记信息更新左右节点,然后清除标记
if node.lazy is not None:
self.__pushdown(node, l, r)
mid = (l + r) >> 1
if ul <= mid: self.update(ul, ur, val, node.l, l, mid)
if ur > mid: self.update(ul, ur, val, node.r, mid + 1, r)
# node.val 为 True 表示这个节点所在区间,均被“跟踪”
node.val = self.op(node.l.val, node.r.val)
def query(self, ql, qr, node = None, l = 1, r = None):
# 查询默认从根节点开始
if node is None: node = self.root
if r is None: r = self.max_val
if ql <= l and r <= qr:
return node.val
if node.l is None: node.l = Node(self.ini)
if node.r is None: node.r = Node(self.ini)
# 下放标记、根据标记信息更新左右节点,然后清除标记
if node.lazy is not None:
self.__pushdown(node, l, r)
mid = (l + r) >> 1
ansl, ansr = self.ini, self.ini
if ql <= mid: ansl = self.query(ql, qr, node.l, l, mid)
if qr > mid: ansr = self.query(ql, qr, node.r, mid + 1, r)
return self.op(ansl, ansr)
tr = SegmentTree('sum')
tr.update(1, 10, 99)
tr.update(1, 4, 1)
tr.update(5, 10, 1)
print(tr.query(1, 1)) # 100
print(tr.query(1, 10)) # 1000
class Solution:
def fallingSquares(self, positions: List[List[int]]) -> List[int]:
tr = SegmentTree('max', int(1e8))
res = []
for l, sz in positions:
r = l + sz - 1
mxh = tr.query(l, r) + sz
tr.update(l, r, mxh)
res.append(tr.query(1, int(1e8)))
return res
https://leetcode.cn/problems/count-integers-in-intervals
递归动态开点(无 lazy)线段树
区间覆盖统计问题,区间覆盖不需要重复操作,不需要进行 lazy 传递
但是数据范围较大,需要动态开点
class CountIntervals:
__slots__ = 'left', 'right', 'l', 'r', 'val'
def __init__(self, l = 1, r = int(1e9)):
self.left = self.right = None
self.l, self.r, self.val = l, r, 0
def add(self, l: int, r: int) -> None:
# 覆盖区间操作,不需要重复覆盖,饱和区间无需任何操作
if self.val == self.r - self.l + 1:
return
if l <= self.l and self.r <= r: # self 已被区间 [l, r] 完整覆盖,不再继续递归
self.val = self.r - self.l + 1
return
mid = (self.l + self.r) >> 1
# 动态开点
if self.left is None:
self.left = CountIntervals(self.l, mid) # 动态开点
if self.right is None:
self.right = CountIntervals(mid + 1, self.r) # 动态开点
if l <= mid:
self.left.add(l, r)
if mid < r:
self.right.add(l, r)
# self.val 的值,表示区间 [self.l, self.r] 中被覆盖的点的个数
self.val = self.left.val + self.right.val
def count(self) -> int:
return self.val
https://leetcode.cn/problems/range-module/
动态开点 + lazy 线段树(旧版):
class Node:
__slots__ = ['l', 'r', 'lazy', 'val']
def __init__(self):
self.l = None
self.r = None
self.lazy = 0
self.val = False
class SegmentTree:
__slots__ = ['root']
def __init__(self):
self.root = Node()
def do(self, node, val):
node.val = val
node.lazy = 1
# 下放 lazy 标记。如果是孩子为空,则动态开点
def pushdown(self, node):
if node.l is None:
node.l = Node()
if node.r is None:
node.r = Node()
# 根据 lazy 标记信息,更新左右节点,然后将 lazy 信息清除
if node.lazy:
self.do(node.l, node.val)
self.do(node.r, node.val)
node.lazy = 0
def query(self, L, R, node = None, l = 1, r = int(1e9)):
# 查询默认从根节点开始
if node is None:
node = self.root
if L <= l and r <= R:
return node.val
# 下放标记、根据标记信息更新左右节点,然后清除标记
self.pushdown(node)
mid = (l + r) >> 1
vl = vr = True
if L <= mid:
vl = self.query(L, R, node.l, l, mid)
if R > mid:
vr = self.query(L, R, node.r, mid + 1, r)
return vl and vr
def update(self, L, R, val, node = None, l = 1, r = int(1e9)):
# 查询默认从根节点开始
if node is None:
node = self.root
if L <= l and r <= R:
self.do(node, val)
return
mid = (l + r) >> 1
# 下放标记、根据标记信息更新左右节点,然后清除标记
self.pushdown(node)
if L <= mid:
self.update(L, R, val, node.l, l, mid)
if R > mid:
self.update(L, R, val, node.r, mid + 1, r)
# node.val 为 True 表示这个节点所在区间,均被“跟踪”
node.val = bool(node.l and node.l.val and node.r and node.r.val)
class RangeModule:
def __init__(self):
self.tree = SegmentTree()
def addRange(self, left: int, right: int) -> None:
self.tree.update(left, right - 1, True)
def queryRange(self, left: int, right: int) -> bool:
return self.tree.query(left, right - 1)
def removeRange(self, left: int, right: int) -> None:
self.tree.update(left, right - 1, False)
# Your RangeModule object will be instantiated and called as such:
# obj = RangeModule()
# obj.addRange(left, right)
# param_2 = obj.queryRange(left, right)
# obj.removeRange(left, right)
class SegmentTree:
__slots__ = ['node', 'lazy', 'n', 'nums', 'op', 'ini', 'ops']
def __init__(self, nums, ops='sum'):
if ops == 'sum' or ops == 'bin':
self.op, self.ini = lambda a, b: a + b, 0
elif ops == 'max':
self.op, self.ini = lambda a, b: max(a, b), -float('inf')
elif ops == 'min':
self.op, self.ini = lambda a, b: min(a, b), float('inf')
self.nums = nums
self.ops = ops
self.n = len(nums)
self.node = {}
self.lazy = {}
def build(self, idx=1, l=1, r=None):
if r is None:
r = self.n
if l == r:
self.node[idx] = self.nums[l - 1]
return
mid = (l + r) >> 1
self.build(idx << 1, l, mid)
self.build((idx << 1) + 1, mid + 1, r)
self.node[idx] = self.op(self.node.get(idx << 1, self.ini), self.node.get((idx << 1) + 1, self.ini))
def do(self, idx, dl, dr, val=None):
if self.ops == 'bin':
self.node[idx] = dr - dl + 1
self.lazy[idx] = True
else:
self.node[idx] = self.op(val, self.node.get(idx, self.ini))
self.lazy[idx] = val
def pushdown(self, idx, pl, pr):
if idx not in self.lazy:
return
val = self.lazy[idx]
mid = (pl + pr) >> 1
if idx << 1 not in self.node:
self.node[idx << 1] = self.ini
if (idx << 1) + 1 not in self.node:
self.node[(idx << 1) + 1] = self.ini
self.do(idx << 1, pl, mid, val)
self.do((idx << 1) + 1, mid + 1, pr, val)
del self.lazy[idx]
def update(self, ul, ur, val, idx=1, l=1, r=None):
if r is None:
r = self.n
if ul <= l and r <= ur:
self.do(idx, l, r, val)
return
self.pushdown(idx, l, r)
mid = (l + r) >> 1
if ul <= mid:
self.update(ul, ur, val, idx << 1, l, mid)
if ur > mid:
self.update(ul, ur, val, (idx << 1) + 1, mid + 1, r)
self.node[idx] = self.op(self.node.get(idx << 1, self.ini), self.node.get((idx << 1) + 1, self.ini))
def query(self, ql, qr, idx=1, l=1, r=None):
if r is None:
r = self.n
if ql <= l and r <= qr:
return self.node.get(idx, self.ini)
self.pushdown(idx, l, r)
mid = (l + r) >> 1
ansl, ansr = self.ini, self.ini
if ql <= mid:
ansl = self.query(ql, qr, idx << 1, l, mid)
if qr > mid:
ansr = self.query(ql, qr, (idx << 1) + 1, mid + 1, r)
return self.op(ansl, ansr)
def update_suffix(self, ul, ur, val, idx=1, l=1, r=None):
if r is None:
r = self.n
siz = r - l + 1
if ul <= l and r <= ur and siz - self.node.get(idx, 0) <= val:
ans = siz - self.node.get(idx, 0)
self.do(idx, l, r)
return ans
mid = (l + r) >> 1
self.pushdown(idx, l, r)
ans = 0
if ur > mid:
ans = self.update_suffix(ul, ur, val, (idx << 1) + 1, mid + 1, r)
val -= ans
if val and ul <= mid:
ans += self.update_suffix(ul, ur, val, idx << 1, l, mid)
self.node[idx] = self.op(self.node.get(idx << 1, self.ini), self.node.get((idx << 1) + 1, self.ini))
return ans
class Solution:
def processTasks(self, nums):
nums.sort(key=lambda x: x[1])
n, m = len(nums), nums[-1][1]
tr = SegmentTree([0] * (m + 1), 'bin')
for l, r, c in nums:
l, r = l + 1, r + 1
c -= tr.query(l, r)
if c > 0:
tr.update_suffix(l, r, c)
return tr.query(0, m + 1)
线段树优化问题:
2617. 网格图中最少访问的格子数 - 力扣(LeetCode)
单点修改 + 区间最小值查询
class SegmentTree:
def __init__(self, n: int):
self.n = n
self.tree = [inf] * (4 * n)
def op(self, a, b):
return min(a, b)
def update(self, ul, ur, val, idx = 1, l = 1, r = None):
if r is None: r = self.n
if ul <= l and r <= ur:
self.tree[idx] = val
return
mid = (l + r) >> 1
if ul <= mid:self.update(ul, ur, val, idx * 2, l, mid)
if ur > mid: self.update(ul, ur, val, idx * 2 + 1, mid + 1, r)
self.tree[idx] = self.op(self.tree[idx * 2], self.tree[idx * 2 + 1]) # 更新当前节点的值
def query(self, ql, qr, idx = 1, l = 1, r = None):
if r is None: r = self.n
if ql <= l and r <= qr:
return self.tree[idx]
mid = (l + r) >> 1
ansl, ansr = inf, inf
if ql <= mid:ansl = self.query(ql, qr, idx * 2, l, mid)
if qr > mid: ansr = self.query(ql, qr, idx * 2 + 1, mid + 1, r)
return self.op(ansl, ansr)
class Solution:
def minimumVisitedCells(self, grid: List[List[int]]) -> int:
m, n = len(grid), len(grid[0])
treey = [SegmentTree(m) for _ in range(n)]
# treey [j] 是第 j 列的线段树
for i in range(m - 1, -1, -1):
treex = SegmentTree(n)
for j in range(n - 1, -1, -1):
if i == m - 1 and j == n - 1:
treex.update(j + 1, j + 1, 1)
treey[j].update(i + 1, i + 1, 1)
continue
g = grid[i][j]
if g == 0: continue
mnx = treex.query(j + 1 + 1, min(g + j, n - 1) + 1) if j < n - 1 else inf
mny = treey[j].query(i + 1 + 1, min(g + i, m - 1) + 1) if i < m - 1 else inf
mn = min(mnx, mny) + 1
treex.update(j + 1, j + 1, mn)
treey[j].update(i + 1, i + 1, mn)
res = treey[0].query(1, 1)
return res if res != inf else -1
最值查询朴素无更新线段树:
class Solution:
def subArrayRanges(self, nums: List[int]) -> int:
class SegmentTree:
def __init__(self, n, flag):
self.n = n
self.tree = [inf * flag] * (4 * n)
self.flag = flag
def op(self, a, b):
if self.flag == 1: return min(a, b)
elif self.flag == -1: return max(a, b)
def build(self, idx = 1, l = 1, r = None):
if not r: r = self.n
if l == r:
self.tree[idx] = nums[l - 1]
return
mid = (l + r) >> 1
self.build(idx * 2, l, mid)
self.build(idx * 2 + 1, mid + 1, r)
self.tree[idx] = self.op(self.tree[idx * 2], self.tree[idx * 2 + 1])
def query(self, ql, qr, idx = 1, l = 1, r = None):
if not r: r = self.n
if ql <= l and r <= qr:
return self.tree[idx]
ansl, ansr = inf * self.flag, inf * self.flag
mid = (l + r) >> 1
if ql <= mid: ansl = self.query(ql, qr, idx * 2, l, mid)
if qr > mid: ansr = self.query(ql, qr, idx * 2 + 1, mid + 1, r)
return self.op(ansl, ansr)
n = len(nums)
mxtr, mntr = SegmentTree(n, -1), SegmentTree(n, 1)
res = 0
mxtr.build()
mntr.build()
for i in range(n):
for j in range(i + 1, n):
res += mxtr.query(i + 1, j + 1) - mntr.query(i + 1, j + 1)
return res
排序 + 贪心 + lazy 线段树二分优化
class SegmentTree:
__slots__ = ['node', 'lazy', 'n', 'nums', 'op', 'ini', 'ops']
def __init__(self, nums, ops = 'sum'):
n = len(nums)
if ops == 'sum' or ops == 'bin':
op, ini = lambda a, b: a + b, 0
elif ops == 'max':
op, ini = lambda a, b: max(a, b), -inf
elif ops == 'min':
op, ini = lambda a, b: min(a, b), inf
self.nums = nums
self.op = op
self.ini = ini
self.ops = ops
self.node = [ini] * (4 * n)
self.lazy = [None] * (4 * n)
self.n = n
def build(self, idx = 1, l = 1, r = None):
if r is None: r = self.n
if l == r:
self.node[idx] = self.nums[l - 1]
return
mid = (l + r) >> 1
self.build(idx << 1, l, mid)
self.build((idx << 1) + 1, mid + 1, r)
self.node[idx] = self.op(self.node[idx << 1], self.node[(idx << 1) + 1])
def do(self, idx, dl, dr, val = None):
if self.ops == 'bin':
self.node[idx] = dr - dl + 1
self.lazy[idx] = True
else:
self.node[idx] = self.op(val, self.node[idx])
self.lazy[idx] = val
def pushdown(self, idx, pl, pr):
val = self.lazy[idx]
mid = (pl + pr) >> 1
self.do(idx << 1, pl, mid, val)
self.do((idx << 1) + 1, mid + 1, pr, val)
self.lazy[idx] = None
def update(self, ul, ur, val, idx = 1, l = 1, r = None):
if r is None:r = self.n
if ul <= l and r <= ur:
self.do(idx, l, r, val)
return
if self.lazy[idx]:
self.pushdown(idx, l, r)
mid = (l + r) >> 1
if ul <= mid: self.update(ul, ur, val, idx << 1, l, mid)
if ur > mid: self.update(ul, ur, val, (idx << 1) + 1, mid + 1, r)
self.node[idx] = self.op(self.node[idx << 1], self.node[(idx << 1) + 1])
def query(self, ql, qr, idx = 1, l = 1, r = None):
if r is None: r = self.n
if ql <= l and r <= qr:
return self.node[idx]
if self.lazy[idx]:
self.pushdown(idx, l, r)
mid = (l + r) >> 1
ansl, ansr = self.ini, self.ini
if ql <= mid: ansl = self.query(ql, qr, idx << 1, l, mid)
if qr > mid:ansr = self.query(ql, qr, (idx << 1) + 1, mid + 1, r)
return self.op(ansl, ansr)
def update_suffix(self, ul, ur, val, idx = 1, l = 1, r = None):
if r is None: r = self.n
siz = r - l + 1
if ul <= l and r <= ur and siz - self.node[idx] <= val:
ans = siz - self.node[idx]
self.do(idx, l, r)
return ans
mid = (l + r) >> 1
if self.lazy[idx]:
self.pushdown(idx, l, r)
ans = 0
if ur > mid: ans += self.update_suffix(ul, ur, val, (idx << 1) + 1, mid + 1, r)
val -= ans
if val and ul <= mid: ans += self.update_suffix(ul, ur, val, idx << 1, l, mid)
self.node[idx] = self.op(self.node[idx << 1], self.node[(idx << 1) + 1])
return ans
class Solution:
def findMinimumTime(self, nums: List[List[int]]) -> int:
nums.sort(key = lambda x: x[1])
n, m = len(nums), nums[-1][1]
tr = SegmentTree([0] * m, 'bin')
for l, r, c in nums:
c -= tr.query(l, r)
if c > 0:
tr.update_suffix(l, r, c)
return tr.query(0, m)
ST 表 / 可重复贡献问题¶
可重复贡献问题:指对于运算 \(opt\), 满足 $ x \space opt \space x = x$。例如区间最值问题,区间 GCD 问题。
ST 表思想基于倍增,不支持修改操作。
预处理:\(O(nlogn)\)
\(f(i, j)\) 表示 \([i, i + 2^j - 1]\) 区间的最值,则将其分为两半:
则:
初始化时:
对于 \(j\) 的上界需要满足 \(i + 2 ^ j - 1\) 能够取到 \(n - 1\),即 \(2 ^ j\) 能够取到 \(n\)。所以外层循环条件:
对于 \(i\) 的上界需要满足:
即:
例如,对于 \(f(4, 3) = opt(f(4, 2), f(8, 2))\)
lenj = math.ceil(math.__log(n, 2)) + 1
f = [[0] * lenj for _ in range(n)]
for i in range(n):
f[i][0] = a[i]
for j in range(1, lenj):
# i + 2 ^ j < n + 1
for i in range(n + 1 - (1 << j)):
f[i][j] = opt(f[i][j - 1], f[i + (1 << (j - 1))][j - 1])
单次询问:\(O(1)\)
例如, 对于 \(qry(5, 10)\),区间长度为 \(6\),\(int(log_2^6) = 2\),只需要 \(k = 2^2\) 的两个区间一定可以覆盖整个区间。
即 \(opt(5, 10) = opt(opt(5, 8), opt(7, 10))\),即分别是 $(l, l + 2^k-1) $ 和 \((r - 2^k+1,r)\)
可以提取预处理一个对数数组。例如 \(int(log(7)) = int(log(3)) + 1 = int(log(1)) + 1 + 1\)
模板
def opt(a, b):
return max(a, b)
lenj = math.ceil(math.__log(n, 2)) + 1
f = [[0] * lenj for _ in range(n)]
log = [0] * (n + 1)
for i in range(2, n + 1):
log[i] = log[i >> 1] + 1
for i in range(n): f[i][0] = a[i]
for j in range(1, lenj):
for i in range(n + 1 - (1 << j)):
f[i][j] = opt(f[i][j - 1], f[i + (1 << (j - 1))][j - 1])
def qry(l, r):
k = log[r - l + 1]
return opt(f[l][k], f[r - (1 << k) + 1][k])
类模板
class ST:
def opt(self, a, b):
return a & b
def __init__(self, nums):
n = len(nums)
log = [0] * (n + 1)
for i in range(2, n + 1):
log[i] = log[i >> 1] + 1
lenj = ceil(math.__log(n, 2)) + 1
f = [[0] * lenj for _ in range(n)]
for i in range(n): f[i][0] = nums[i]
for j in range(1, lenj):
for i in range(n + 1 - (1 << j)):
f[i][j] = self.opt(f[i][j - 1], f[i + (1 << (j - 1))][j - 1])
self.f = f
self.log = log
def qry(self, L, R):
k = self.log[R - L + 1]
return self.opt(self.f[L][k], self.f[R - (1 << k) + 1][k])
def subArrayRanges(self, nums: List[int]) -> int:
# f [i][j] 表示 [i, i + 2^j - 1] 的最值
n = len(nums)
lenj = ceil(math.__log(n, 2)) + 1
log = [0] * (n + 1)
for i in range(2, n + 1):
log[i] = log[i // 2] + 1
class ST:
def __init__(self, n, flag):
self.flag = flag
f = [[inf * flag] * lenj for _ in range(n)]
for i in range(n):
f[i][0] = nums[i]
for j in range(1, lenj):
for i in range(n + 1 - (1 << j)):
f[i][j] = self.op(f[i][j - 1], f[i + (1 << (j - 1))][j - 1])
self.f = f
def op(self, a, b):
if self.flag == 1: return min(a, b)
return max(a, b)
def query(self, l, r):
k = log[(r - l + 1)]
return self.op(self.f[l][k], self.f[r - (1 << k) + 1][k])
n = len(nums)
mxtr, mntr = ST(n, -1), ST(n, 1)
res = 0
for i in range(n):
for j in range(i + 1, n):
res += mxtr.query(i, j) - mntr.query(i, j)
return res
动态规划¶
入门 DP¶
递归写法
class Solution:
def climbStairs(self, n: int) -> int:
def dfs(n):
return 1 if n == 1 or n == 0 else dfs(n - 1) + dfs(n - 2)
return dfs(n)
记忆化搜索:lru_cache
from functools import *
class Solution:
def climbStairs(self, n: int) -> int:
@lru_cache(maxsize = None)
def dfs(n):
return 1 if n == 1 or n == 0 else dfs(n - 1) + dfs(n - 2)
return dfs(n)
s = Solution()
print(s.climbStairs(3)) # 3
选或者不选 \(f[i][1]\) 表示考虑完前 \(i\) 个屋子, 且第 \(i\) 个屋子选的最大金额 \(f[i][0]\) 表示考虑完前 \(i\) 个屋子,且第 \(i\) 个屋子不选的最大金额 不选,\(f[i][0] = \text{max(f[i - 1][1], f[i - 1][0])}\) 选, \(f[i][1] = \text{f[i - 1][0] + nums[i - 1]}\)
class Solution:
def rob(self, nums: List[int]) -> int:
n = len(nums)
f = [[0] * 2 for _ in range(n + 1)]
for i in range(1, n + 1):
f[i][0] = max(f[i - 1][1], f[i - 1][0])
f[i][1] = f[i - 1][0] + nums[i - 1]
return max(f[n][1], f[n][0])
线性 DP¶
最长上升子序列:
\(O(n^2)\) 做法,\(f[i]\) 表示以 \(nums[i]\) 结尾的所有上升子序列中最长的长度。
\(O(nlogn)\) 做法,\(f[i]\) 表示长度为 \(i\) 的所有上升子序列中,子序列末尾的最小值。
正序遍历 \(nums\) 中每一个数 \(x\), 二分找出 \(x\) 在 \(f\) 中的插入位置(恰好大于 $ x$ 的位置)。
# f [i] 表示长度为 i 的子序列的末尾元素的最小值
f = []
# 找到恰好大于 x 的位置
def check(x, mid):
return f[mid] >= x
for x in nums:
lo, hi = 0, len(f)
while lo < hi:
mid = (lo + hi) >> 1
if check(x, mid):
hi = mid
else:
lo = mid + 1
if lo >= len(f):
f.append(x)
else:
f[lo] = x
最长公共子序列:
$ f[i][j] 表示从s[0: i] 和 s2[0: j] 中的最长公共子序列$
时间复杂度:\(O(mn)\)
可以证明:\(f(i-1, j -1)+ 1 \ge \max(f(i-1,j), ~f(i,~j-1))\)
# f [n][m]
f = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
for j in range(1, m + 1):
if s1[i - 1] == s2[j - 1]:
f[i][j] = f[i - 1][j - 1] + 1
else:
f[i][j] = max(f[i - 1][j], f[i][j - 1])
编辑距离:
def getEditDist(s1, s2):
m, n = len(s1), len(s2)
f = [[inf] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1): f[i][0] = i
for i in range(1, n + 1): f[0][i] = i
f[0][0] = 0
for i in range(1, m + 1):
for j in range(1, n + 1):
a = f[i - 1][j] + 1
b = f[i][j - 1] + 1
c = f[i - 1][j - 1] + (1 if s1[i - 1] != s2[j - 1] else 0)
f[i][j] = min(a, b, c)
return f[m][n]
背包 DP¶
\(N\) 个物品,价值为 \(v_i\),重量为 \(w_i\),背包容量为 \(W\)。挑选物品不超过背包容量下,总价值最大是多少。
- 0 - 1 背包:每个物品用 0 或者 1 次。
- 完全背包:每个物品可以用 0 到 \(+\infty\) 次。
- 多重背包:每个物品最多 \(s_i\) 次。
- 分组背包:物品分为若干组,每一组里面选 0 或者 1 次。
01 背包¶
状态表示:\(f(i, j)\)
- 集合:
- 所有拿物品的选法
- 条件:1. 只从前 \(i\) 个物品中选;2. 总重量 \(\le j\)
- 表示的属性:所有选法的总价值的最大值(\(\max\))
- 最终求解的问题 \(f(N, W)\) 。
状态计算:
集合的划分问题:如何将集合 \(f(i,j)\) 划分成更小的可计算子集。
# f [i][j] 表示用前 i 个物品,在总重量不超过 j 的情况下,所有物品选法构成的集合中,总价值的最大值
# 考虑 f [i][j] 对应集合的完备划分: 选 i ,其子集的最大值是 f [i - 1][j - w[i]] + v [i],需要在 j - w [i] >= 0 满足
# 不选 i, 其子集的最大值是 f [i - 1][j]。一定可以满足
f = [[0] * (W + 1) for _ in range(N + 1)]
for i in range(1, N + 1):
for j in range(1, W + 1):
if j - w[i] >= 0: # 可以选 i,也可以不选 i
f[i][j] = max(f[i - 1][j], f[i - 1][j - w[i]] + v[i])
else:
f[i][j] = f[i - 1][j] # 只能不选 i
print(f[N][W])
滚动数组优化为一维:逆序遍历
由于 \(f(i, j)\) 只和 \(f(i-1, j)\) 有关。如果使用滚动数组 \(f(j)\) 优化,去掉第一维度,在同一个 \(i\) 下,如果正序遍历 \(j\) ,在恰好更新 \(f(j)\) 前所有 \(f(j'< j)\) 存放的是新值 \(f(i,j')\),所有 \(f(j''\geq j)\) 存放的是老值 \(f(i-1,j'')\)。
由于我们希望能够得到 \(f(i-1, j - w[i])\) ,所以我们必须逆序遍历 \(j\) :在恰好更新 \(f(j)\) 前,\(f(j'\leq j)\) 都是老值,表示 \(f(i-1, j')\)。
所以 \(j\) 的枚举为 \(range(W, w[i]-1, -1)\)
f = [0] * (W + 1)
for i in range(1, N + 1):
for j in range(W, w[i] - 1, -1):
f[j] = max(f[j], f[j - w[i]] + v[i])
# 此时 f [j] 就代表 f [i - 1][j], f [j - w[i] 代表 f [i - 1][j - w[i]]
return f[W]
import sys
input = lambda: sys.stdin.readline().strip()
W, N = map(int, input().split())
w, v = [0] * (N + 1), [0] * (N + 1)
for i in range(1, N + 1):
w[i], v[i] = map(int, input().split())
# f [i][j] 表示在前 i 个物品中,重量不超过 j 的所有选择方案的集合中,获得的最大价值
# 最终要求 f [N][W]
f = [[0] * (W + 1) for _ in range(N + 1)]
for i in range(1, N + 1):
for j in range(1, W + 1):
if j - w[i] >= 0: # 可以选 i,也可以不选 i
f[i][j] = max(f[i - 1][j], f[i - 1][j - w[i]] + v[i])
else:
f[i][j] = f[i - 1][j] # 只能不选 i
print(f[N][W])
增加限制条件:不超过 k 次使用折扣券。注意,k 的遍历方向也是逆序。
def maxTastiness(self, price: List[int], tastiness: List[int], maxAmount: int, maxCoupons: int) -> int:
# f [i][j] [k] 从前 i 个物品,不超过容量 j 的情况下,不超过 k 张券的最大价值
# f [i][j] [k] = max(f [i - 1][j] [k], f [i - 1][j - w] [k] + v, f [i - 1][j - w // 2] [k - 1] + v)
f = [[0] * (maxCoupons + 1) for _ in range(maxAmount + 1)]
for w, v in zip(price, tastiness):
for j in range(maxAmount, w // 2 - 1, -1):
for k in range(maxCoupons, -1, -1):
if j - w >= 0:
f[j][k] = max(f[j][k], f[j - w][k] + v)
if k >= 1:
f[j][k] = max(f[j][k], f[j - w // 2][k - 1] + v)
return f[maxAmount][maxCoupons]
恰好装满型 01 背包:
2915. 和为目标值的最长子序列的长度 - 力扣(LeetCode)
第二个转移的条件是:
可以通过初始值修改,将不合法的 \(f[i][j]\) 置为 \(-\infty\),合法的 \(f[i][j] \ge 0\)。则初始值 \(f[0][0] =0\) 得到二维版本:
class Solution:
def lengthOfLongestSubsequence(self, nums: List[int], t: int) -> int:
a = nums
n = len(nums)
dp = [[-inf] * (t + 1) for _ in range(n + 1)]
for i in range(n + 1): dp[i][0] = 0
for i in range(1, n + 1):
w = a[i - 1]
for j in range(t + 1):
if j >= w:
dp[i][j] = max(dp[i - 1][j], dp[i - 1][j - w] + 1)
else:
dp[i][j] = dp[i - 1][j]
if dp[n][t] >= 0:
return dp[n][t]
else:
return -1
优化:\(j\) 的上界可以优化为 \(\min(\text{重量前缀}, target)\)
def lengthOfLongestSubsequence(self, nums: List[int], target: int) -> int:
f = [0] + [-inf] * target
pre = 0
for w in nums:
pre += w
for j in range(min(pre, target), w - 1, -1):
f[j] = max(f[j], f[j - w] + 1)
return f[target] if f[target] >= 0 else -1
定义 \(mx =\max(B)\),由于所有项一定在 \([0, mx]\) 内开始执行,按照 \(B\) 从小到大排序遍历。
定义 \(d [i][j]\) 为考察前 \(i\) 项,恰好花 \(j\) 时间内获得最多的钱。
如果第 \(i\) 项执行,\(d [i][j] = max(d [i - 1][j - b] + c, d [-1i][j])\)。注意需要判断 \(j\) 的界限是在 \([a, b]\) 内。
d 的维度:\((n + 1) \cdot (mx + 1)\)
时间复杂度:\(O(n\times \max(B))\)
import sys
input = lambda: sys.stdin.readline().strip()
n = int(input())
nums = []
for _ in range(n):
a, b, c = map(int, input().split())
nums.append((a, b, c))
# 按照 B 排序
nums.sort(key = lambda x: x[1])
# 所有项一定在 [0, mx] 内开始执行
mx = nums[-1][1]
d = [[0] * (mx + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
a, b, c = nums[i - 1]
for j in range(1, mx + 1):
if a <= j <= b: # “选”
d[i][j] = max(d[i - 1][j - a] + c, d[i - 1][j])
else:
d[i][j] = d[i - 1][j]
print(max(d[n]))
滚动数组优化空间
import sys
input = lambda: sys.stdin.readline().strip()
n = int(input())
nums = []
for _ in range(n):
a, b, c = map(int, input().split())
nums.append((a, b, c))
nums.sort(key=lambda x: x[1])
mx = nums[-1][1]
d = [0] * (mx + 1)
for i in range(1, n + 1):
a, b, c = nums[i - 1]
for j in range(b, a - 1, -1):
d[j] = max(d[j], d[j - a] + c)
print(max(d))
分割等和子集问题:
给定一组数,判断是否可以分割成两个等和子集。
def canPartition(self, nums: List[int]) -> bool:
s, n = sum(nums), len(nums)
if s & 1: return False
# f [i][j] 表示从前 i 个数中,分割成和为 j 是否可能。
# f [n][s // 2]
f = [1] + [0] * (s // 2)
for x in nums:
for j in range(s // 2, x - 1, -1):
f[j] |= f[j - x]
return f[s // 2] == 1
完全背包¶
**状态表示:\(f(i, j)\) ** 同 0 - 1 背包。
状态计算: 对于集合的划分,按照第 \(i\) 个物品选几个($0, 1, ... , $) 划分。
朴素做法:\(O(N\cdot W^2)\)
for i in range(1, N + 1):
for j in range(W + 1):
for k in range(j // w[i] + 1):
f[i][j] = max(f[i][j], f[i - 1][j - k * w[i]] + k * v[i])
return f[N][W]
冗余优化:\(O(N \cdot W)\)
可以发现后面一坨的最大值等价于 \(f(i, j - w)\)
所以 \(f(i, j) = \max \big(f(i - 1, j), f(i, j - w[i]) + v[i] \big)\),
for i in range(1, N + 1):
for j in range(W + 1):
f[i][j] = f[i - 1][j]
if j - w[i] >= 0:
f[i][j] = max(f[i][j], f[i][j - w[i]] + v[i])
# f [i][j - w[i]] 包含了 f [i - 1][j - k * w[i]] 的部分 (k >= 1)
return f[N][W]
优化为一维
完全背包求组合方案数
def change(self, amount: int, coins: List[int]) -> int:
# f [i][j] 表示 前 i 个硬币凑出 j 的方案数
# 状态表示:从前 i 个硬币中组合出 j 的所有方案的集合
# 属性:个数
# 转移:对集合进行划分。
# f [i][j] = f [i - 1][j] + f [i][j - c]
n = len(coins)
f = [[0] * (amount + 1) for _ in range(n + 1)]
# f [i][0] = 1
for i in range(n + 1): f[i][0] = 1
for i in range(1, n + 1):
for j in range(1, amount + 1):
c = coins[i - 1]
f[i][j] = f[i - 1][j]
if j - c >= 0:
f[i][j] += f[i][j - c]
return f[n][amount]
优化成一维:
def change(self, amount: int, coins: List[int]) -> int:
# f [i][j] = f [i - 1][j] + f [i][j - c]
n = len(coins)
# 从前 i 个中构成 j 的方案数
f = [0] * (amount + 1)
f[0] = 1
for c in coins:
for j in range(c, amount + 1):
f[j] += f[j - c]
return f[amount]
求排列方案数:伪完全背包
\(f(i)\) 表示找出总和为 \(i\) 的排列方案数,\(f(i)=\sum f(i-w)\)
def combinationSum4(self, nums: List[int], target: int) -> int:
n = len(nums)
f = [0] * (target + 1)
f[0] = 1
for i in range(1, target + 1):
for j in range(n):
w = nums[j]
if i - w >= 0:
f[i] += f[i - w]
return f[target]
1449. 数位成本和为目标值的最大数字 - 力扣(LeetCode)
每个数字有一个重量,可以无限选,问恰好重量为 target 的最大数字。(类似题目:长度最大的字典序最小串等)
先用完全背包模型求出最长长度,然后贪心的从 9~1 倒序遍历逆序构造。构造的条件是 \(f[target-w]+1 = f[target]\),即通过长度判断是否可以转移。
def largestNumber(self, cost: List[int], target: int) -> str:
# 先求出能构成的最长数串
# 每个物品重量 W, 价值为 1,
# f [i][j] 表示从前 i 个物品中选法中,能够构成的最大价值
# f [i][j] = max(f [i][j], f [i][j - w])
f = [0] + [-inf] * target
for w in cost:
for j in range(w, target + 1):
f[j] = max(f[j], f[j - w] + 1)
mxl = f[target]
if mxl <= 0: return '0'
res = ''
# 贪心的构造,从高位到低位尽可能构造
for x in range(9, -1, -1):
w = cost[x - 1]
while target - w >= 0 and f[target] == f[target - w] + 1:
res += str(x)
target -= w
return res
多重背包¶
在完全背包的基础上,增加每个物品最多选择选择的次数限制 \(s[i]\)。
暴力做法:\(O(N \cdot W ^2)\)
for i in range(1, n + 1):
for j in range(W + 1):
for k in range(min(c[i] + 1, j // w[i] + 1)):
f[i][j] = max(f[i][j], f[i - 1][j - k * w[i]] + k * v[i])
可以发现无法借助完全背包的方法进行优化。
二进制拆分重量为新的包裹:\(O(N \cdot W\cdot log(\sum W) \cdot )\)
思路:将每一件最多能选 \(c\) 个的物品拆分成若干个包裹,大小分别是 $ 1, 2, \cdots, 2^k, c' $ ,例如 \(c=500\), 拆分成 $1, 2, \cdots, 128,245 $,可以证明这些数字可以枚举出 $ 0 \sim 500$ 之间的所有数。将这些包裹看出是新的物品,有其对应的新的 重量 和 价值。
可以估算,总包裹的个数不超过 $ N \cdot log_2{(\sum W)}$ 。
W, V = [], []
for _ in range(N):
ow, ov, oc = map(int, input().split())
k = 1
while oc >= k: # 例如 10, 拆分成 1,2,4 和 3
W, V = W + [ow * k], V + [ov * k]
oc -= k
k <<= 1
if oc > 0:
W, V = W + [ow * oc], V + [ov * oc]
f = [0] * (mxW + 1)
for w, v in zip(W, V):
for j in range(mxW, w - 1, -1):
f[j] = max(f[j], f[j - w] + v)
print(f[mxW])
分组背包¶
有 \(N\) 组物品,容量为 \(mxW\) 的背包,每组物品最多只能选其中一个。 例如,水果(苹果,香蕉,橘子)只能选一个或者不选。
\(f(i, j)\) 从前 \(i\) 组选,总重量不超过 \(j\) 的所有选法方案的价值和的最大值。
状态转移:第 \(i\) 组物品一个都不选 \(f(i-1,j)\),第 \(i\) 组物品选第 \(k\) 个 \(f(i-1,j-w[i][k]) + v[i][k]\)
W, V = [[0] for _ in range(N + 1)], [[0] for _ in range(N + 1)]
for i in range(1, N + 1):
K = int(input())
for k in range(K):
w, v = map(int, input().split())
W[i], V[i] = W[i] + [w], V[i] + [v]
f = [0] * (mxW + 1)
for i in range(1, N + 1):
for j in range(mxW, -1, -1):
for k in range(len(W[i])):
if j - W[i][k] >= 0:
f[j] = max(f[j], f[j - W[i][k]] + V[i][k])
网格图 DP¶
最大 / 最小单趟路径和
LCR 166. 珠宝的最高价值 - 力扣(LeetCode)
求从左上角 \((0,0)\) 到右下角 \((m-1,~n-1)\) 能够获得的最大价值和。转移方程:\(f(i,j)=g(i,j)+\max(f(i-1,j),~f(i,j-1))\)。
def jewelleryValue(self, grid: List[List[int]]) -> int:
# f(i, j) 表示到达 (i, j) 网格的最高价值
m, n = len(grid), len(grid[0])
f = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
x = grid[i - 1][j - 1]
f[i][j] = x + max(f[i - 1][j], f[i][j - 1])
return f[m][n]
巧妙设计初值:\(f(1,0)=f(0,1)=0,其余为 inf\)
def minPathSum(self, grid: List[List[int]]) -> int:
m, n = len(grid), len(grid[0])
f = [[inf] * (n + 1) for _ in range(m + 1)]
f[0][1] = f[1][0] = 0
for i in range(1, m + 1):
for j in range(1, n + 1):
x = grid[i - 1][j - 1]
f[i][j] = x + min(f[i - 1][j], f[i][j - 1])
return f[m][n]
将三角形转换成网格图:约束 \(j \in [1,i]\)。
def minimumTotal(self, triangle: List[List[int]]) -> int:
m = n = len(triangle)
f = [[inf] * (n + 1) for _ in range(m + 1)]
f[0][1] = 0
for i in range(1, m + 1):
for j in range(1, i + 1):
x = triangle[i - 1][j - 1]
f[i][j] = x + min(f[i - 1][j], f[i - 1][j - 1])
return min(f[m][1: ])
1289. 下降路径最小和 II - 力扣(LeetCode)
求从上一层列下标不同位置处转移的路径最小和 。维护上一层的最小值、次小值。时间复杂度从 \(O(n^3)\) 优化到 \(O(n^2)\)。
def minFallingPathSum(self, grid: List[List[int]]) -> int:
m = n = len(grid)
if m == 1: return min(grid[0])
f = [[inf] * (n + 1) for _ in range(m + 1)]
mn = mn_2 = (0, -1)
for i in range(1, m + 1):
pmn, pmn_2 = mn, mn_2
mn = mn_2 = (inf, -1)
for j in range(1, n + 1):
x = grid[i - 1][j - 1]
y = f[i][j] = x + pmn[0] if pmn[1] != j else x + pmn_2[0]
if y < mn[0]: mn_2, mn = mn, (y, j)
elif y == mn[0]: mn_2 = (y, j)
elif y < mn_2[0]: mn_2 = (y, j)
return min(f[m])
求路径方案数
巧妙设计初值,避免特殊边界讨论。
def uniquePathsWithObstacles(self, grid: List[List[int]]) -> int:
m, n = len(grid), len(grid[0])
f = [[0] * (n + 1) for _ in range(m + 1)]
f[0][1] = 1
for i in range(1, m + 1):
for j in range(1, n + 1):
x = grid[i - 1][j - 1]
f[i][j] = 0 if x else f[i - 1][j] + f[i][j - 1]
return f[m][n]
多路线问题¶
双起点 + 双路线问题
\(f(t,i,j)\) 表示共同来到第 \(t\) 行,分别在 \((t,i)\) 和 \((t,j)\)。为了表示出两个初始位置分别在 \((0,0)\) 和 \((0, n- 1)\),所有初始值设置为 \(-inf\),同时将 $ f[0][0][n + 1] = 0$。
def cherryPickup(self, grid: List[List[int]]) -> int:
m, n = len(grid), len(grid[0])
# f(t, i, j)
res = 0
f = [[[-inf] * (n + 2) for _ in range(n + 2)] for _ in range(m + 1)]
f[0][0][n + 1] = 0
for t in range(1, m + 1):
for i in range(1, n + 1):
x = grid[t - 1][i - 1]
for j in range(1, n + 1):
y = 0 if i == j else grid[t - 1][j - 1]
mx = max(f[t - 1][i][j], f[t - 1][i - 1][j], f[t - 1][i + 1][j],\
f[t - 1][i][j - 1], f[t - 1][i - 1][j - 1], f[t - 1][i + 1][j - 1],\
f[t - 1][i][j + 1], f[t - 1][i - 1][j + 1], f[t - 1][i + 1][j + 1])
f[t][i][j] = x + y + mx
if t == m: res = max(res, f[t][i][j])
return res
同起点 + 双路线问题 $ f(t, i, j) $ 表示 \(A, B\) 共同走了 \(t\) 步,分别到达 $(i, t - i), ~ (j, t - j) $。
为了区分没有采摘、不可达,初始值设定为 \(-inf\);当 \(i =j\) 时,表示当前 \(A,B\) 在同一个位置,只能加一份价值 \(grid[i][t-i]\),否则当前价值为 \(grid[j][t - j] + grid[i][t-i]\) 。转移:\(\max \{f(t-1,i,j),~f(t-1,i-1,j), ~f(t-1,i,j-1),~f(t-1,i-1,j-1)\}\)。对于不可达的某个具体位置,检测到价值中有 \(-1\) 时直接返回。
同时需要注意 \(t\) 会对 \(i\) 和 \(j\) 的取值范围约束。最终答案:\(f(2n-2,~n-1,~n-1)\)
def cherryPickup(self, grid: List[List[int]]) -> int:
# f(t, i, j) 表示走了 t 步, 分别到达 (i, t - i) 和 (j, t - j) 获得的最大得分
n = len(grid)
f = [[[-inf] * (n) for _ in range(n)] for _ in range(2 * n - 1)]
f[0][0][0] = grid[0][0]
for t in range(1, 2 * n - 1):
for i in range(max(0, t - n + 1), min(n, t + 1)):
x = grid[i][t - i]
if x < 0: continue
for j in range(max(0, t - n + 1), min(n, t + 1)):
y = grid[j][t - j] if i != j else 0
if y < 0: continue
mx = max(f[t - 1][i][j], f[t - 1][i - 1][j], f[t - 1][i][j - 1], f[t - 1][i - 1][j - 1])
f[t][i][j] = mx + x + y
res = f[2 * n - 2][n - 1][n - 1]
return res if res != -inf else 0
树形 DP¶
题目详情 - Problem 4E. 最大社交深度和 - HydroOJ
1,指定某个节点为根节点。
2,第一次搜索完成预处理(如子树大小等),同时得到该节点的解。
3,第二次搜索进行换根的动态规划,由已知解的节点推出相连节点的解。
def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
g = [[] for _ in range(n)]
dep = [0] * n
siz = [1] * n
res = [0] * n
for u, v in edges:
g[u].append(v)
g[v].append(u)
def dfs1(u, fa): # 预处理深度
dep[u] = dep[fa] + 1 if fa != -1 else 0
for v in g[u]:
if v != fa:
dfs1(v, u)
siz[u] += siz[v]
def dfs2(u, fa):
for v in g[u]:
if v != fa:
res[v] = res[u] - siz[v] + (n - siz[v])
dfs2(v, u)
dfs1(0, -1)
res[0] = sum(dep)
dfs2(0, -1)
return res
\(u\) 剔除 v 子树部分下降 1,深度和增加 \(n - siz[v]\)
\(v\) 子树部分上升 1,深度和减少 \(siz[v]\)
则状态转移方程 \(res[v] = res[u] - siz[v] + (n - siz[v])\)
区间 DP¶
石子合并
s = [0] * (n + 1)
f = [[0] * n for _ in range(n)]
for i in range(n):
s[i + 1] = s[i] + nums[i]
for l in range(2, n + 1):
for i in range(n + 1 - l):
j = i + l - 1
f[i][j] = inf
for k in range(i, j):
f[i][j] = min(f[i][j], f[i][k] + f[k + 1][j] + s[j + 1] - s[i])
长度统一处理:对于 length = 1, \(f[i][i-1]\) 是 0,\(f[j + 1][j]\) 也是 0。等价于没有
对于 length = 2, \(f[i][i+1] 其中一项 [i][i-1] + f[i+1][i+1]+...\) ,因此和长度大于等于 3 统一。
def maxCoins(self, nums: List[int]) -> int:
nums = [1] + nums + [1]
n = len(nums)
f = [[0] * n for _ in range(n)]
for l in range(1, n - 1):
for i in range(1, n - l):
j = i + l - 1
for k in range(i, j + 1):
f[i][j] = max(f[i][j], f[i][k - 1] + f[k + 1][j] + nums[k] * nums[i - 1] * nums[j + 1])
return f[1][n - 2]
$ f[a, b] 表示从[a : b] 一定能获胜的最小金额$。一定制胜的策略是当前位置一定答错,同时选择左右两边较大区间
复杂度:\(O(n^3)\)
def getMoneyAmount(self, n: int) -> int:
# f [a, b] 表示从 [a : b] 一定能获胜的最小金额
# 最多取到 f [n + 1][n]
f = [[0] * (n + 1) for _ in range(n + 2)]
for l in range(2, n + 1):
for i in range(1, n + 2 - l):
j = i + l - 1
f[i][j] = inf
for k in range(i, j + 1):
f[i][j] = min(f[i][j], k + max(f[i][k - 1], f[k + 1][j]))
return f[1][n]
1039. 多边形三角剖分的最低得分 - 力扣(LeetCode)
def minScoreTriangulation(self, values: List[int]) -> int:
# f [i: j] 表示从 [i: j] 的最小得分
# f [0: n - 1]
n = len(values)
f = [[0] * (n + 1) for _ in range(n + 1)]
for l in range(3, n + 1):
for i in range(n + 1 - l):
j = i + l - 1
f[i][j] = inf
for k in range(i + 1, j):
f[i][j] = min(f[i][j], f[i][k] + f[k][j] + values[i] * values[k] * values[j])
return f[0][n - 1]
95. 不同的二叉搜索树 II - 力扣(LeetCode)
卡特兰数 + 区间 dp,\(f[i, j]\) 表示从 \(i,i+1,~\cdots~, j\) 序列中构成的所有二叉搜索树的根节点(对应的列表)。
最终问题:\(f(1,n)\),对于每个区间,枚举中间节点 \(k \in [i,j]\),分别从左右子树对应的列表中(\(f(i,k-1)\) 和 \(f(k+1,j)\)),利用乘法原理进行构造。
def generateTrees(self, n: int) -> List[Optional[TreeNode]]:
# f [i, j] 表示用 1 .. j 构建的二叉搜索树的所有根节点列表
# 枚举树根节点 k in range(i, j + 1)
# f [i, k - 1] 为所有左子树可能的根节点列表
# f [k + 1, j] 为所有右子树可能的根节点列表
f = [[[None] for _ in range(n + 2)] for _ in range(n + 2)]
for l in range(1, n + 1):
for i in range(1, n + 2 - l):
j = i + l - 1
f[i][j] = []
for k in range(i, j + 1):
for left in f[i][k - 1]:
for right in f[k + 1][j]:
f[i][j].append(TreeNode(k, left, right))
return f[1][n]
最长回文子序列¶
求最长回文子序列长度问题
\(f[i: j]~ 表示s[i] \sim s[j] 中的最长回文子序列的长度\)
def mx_pal_subseq(self, s: str) -> int:
# f(i, j) 表示 s [i] ~ s [j] 的最长回文子序列的长度
n = len(s)
f = [[0] * n for _ in range(n)]
for i in range(n): f[i][i] = 1
for l in range(2, n + 1):
# i + l < n + 1
for i in range(n + 1 - l):
j = i + l - 1
if s[i] == s[j]:
f[i][j] = f[i + 1][j - 1] + 2
else:
f[i][j] = max(f[i + 1][j], f[i][j - 1])
return f[0][n - 1]
推论:对于长度为 \(n\) 的字符串,其最长回文子序列长度为 \(L\), 则最少添加 \(n - L\) 个字符可以使原串变成回文串。
1312. 让字符串成为回文串的最少插入次数 - 力扣(LeetCode)
[P1435 IOI2000] 回文字串 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
def minInsertions(self, s: str) -> int:
# f [i: j] 表示从 s [i] ~ s [j] 的 最长回文子序列
n = len(s)
f = [[0] * (n + 1) for _ in range(n + 1)]
for i in range(n):
f[i][i] = 1
for l in range(2, n + 1):
for i in range(n + 1 - l):
j = i + l - 1
if s[i] == s[j]:
f[i][j] = f[i + 1][j - 1] + 2
else:
f[i][j] = max(f[i + 1][j], f[i][j - 1])
return n - f[0][n - 1]
2002. 两个回文子序列长度的最大乘积 - 力扣(LeetCode)
二进制枚举,将集合划分成互不相交的两部分。求各自最长回文子序列长度的乘积。\(O(2^n\cdot n^2)\)。
def maxProduct(self, s1: str) -> int:
n, res = len(s1), 0
s = (1 << n) - 2
sub = s
def mx_pal_subseq(ss):
m = len(ss)
f = [[0] * m for _ in range(m)]
for i in range(m): f[i][i] = 1
for l in range(2, m + 1):
for i in range(m + 1 - l):
j = i + l - 1
if ss[i] == ss[j]:
f[i][j] = f[i + 1][j - 1] + 2
else:
f[i][j] = max(f[i + 1][j], f[i][j - 1])
return f[0][m - 1]
while sub:
s2 = ''.join([s1[j] for j in range(n) if (sub >> j) & 1])
s3 = ''.join([s1[j] for j in range(n) if (sub >> j) & 1 == 0])
cur = mx_pal_subseq(s2) * mx_pal_subseq(s3)
if cur > res: res = cur
sub = (sub - 1) & s
return res
最长回文子串¶
def longestPalindrome(self, s: str) -> str:
# 定义 f [i][j] 表示从 s [i] ~ s [j] 是否是回文字符串
left = right = 0
n = len(s)
f = [[True] * (n + 1) for _ in range(n + 1)]
for l in range(2, n + 1):
for i in range(n + 1 - l):
j = i + l - 1
f[i][j] = s[i] == s[j] and f[i + 1][j - 1]
if f[i][j]:
left, right = i, j
return s[left: right + 1]
数位 DP¶
模板 1:统计各位数字出现次数
统计在 \([a, b]\) 区间各个数字出现的次数。
需要实现 \(count(n, x)\) 函数统计 $[1, n] $ 区间中数字 \(x\) 出现的次数
def count(n, x):
# 在 1 ~ n 中 x 数字出现的次数
# 上界 abcdefg
# yyyizzz , 考虑 i 位上 x 的出现次数
#
# 1.1 如果 x 不为 0 yyy 为 000 ~ abc - 1, zzz 为 000 ~ 999
# 1.2x 为 0,yyy 为 001 ~ abc - 1, zzz 为 000 ~ 999
#
# 2. yyy 为 abc,
# 2.1 d < x 时,0
# 2.2 d = x 时,zzz 为 000 ~ efg
# 2.3 d > x 时,zzz 为 000 ~ 999
s = str(n)
res = 0
n = len(s)
for i in range(n):
pre = 0 if i == 0 else int(s[:i])
suf = s[i + 1:]
if x == 0: res += (pre - 1) * pow(10, len(suf))
else: res += pre * pow(10, len(suf))
d = int(s[i])
if d == x: res += (int(suf) if suf else 0) + 1
elif d > x: res += pow(10, len(suf))
return res
def get(a, b):
for i in range(10):
print(count(b, i) - count(a - 1, i), end = ' ')
print()
简化版:
def count(n, x): # 统计 1 ~ n 中 数字 x 的出现次数
res = 0
s = str(n)
m = len(s)
for i in range(m):
pre = 0 if i == 0 else int(s[: i])
d = int(s[i])
sufs = s[i + 1: ]
if x == 0: pre -= 1
if d > x: pre += 1
if d == x: res += (int(sufs) if sufs else 0) + 1
res += pre * pow(10, len(sufs))
return res
def get(a, b):
for i in range(10):
print(count(b, i) - count(a - 1, i), end = ' ')
print()
def count(n, x): # 统计 1 ~ n 中 数字 x 的出现次数
res = 0
s = str(n)
m = len(s)
for i in range(m):
pre = 0 if i == 0 else int(s[: i])
d = int(s[i])
sufs = s[i + 1: ]
if x == 0: pre -= 1
if d > x: pre += 1
if d == x: res += (int(sufs) if sufs else 0) + 1
res += pre * pow(10, len(sufs))
return res
return count(n, 1)
模板 2:带限制数位 dp 统计问题
通用模板 v1.0:统计 \([1, ~n]\) 区间中,符合限制条件的数字个数。
\(f(i,~mask,~is\_limit,~is\_num)\) 表示 前导数字集合为 \(mask\),从第 \(i\) 位开始往后填,能满足限制条件的数字个数。
其中, \(is\_limit\) 表示前导是否恰好全都取到上界。为 \(True\) 时,\(i\) 的上界 \(hi = int(s[i])\) 否则 为 9;
\(is\_num\) 表示前导是否有数字。为 \(True\) 时,\(i\) 的下界从 \(0\) 开始;否则可以继续不填数字,或者下界从 \(1\) 开始。
@lru_cache(maxsize = None)
def f(i: int, mask: int, is_limit: bool, is_num: bool):
if i == m:
if is_num: return 1
return 0
res = 0
lo, hi = 0, 9
if not is_num:
lo = 1
res += f(i + 1, mask, False, False)
if is_limit:
hi = int(s[i])
for j in range(lo, hi + 1):
# j 没有在 mask 的集合中出现过
if (mask >> j) & 1 == 0:
res += f(i + 1, mask | (1 << j), is_limit and j == hi, True)
return res
return f(0, 0, True, False)
简化版本:
@lru_cache(None)
def f(i, mask, is_limit, is_num):
if i == len(s): return int(is_num)
res = 0 if is_num else f(i + 1, mask, False, False)
lo, hi = 0 if is_num else 1, int(s[i]) if is_limit else 9
for j in range(lo, hi + 1):
if (mask >> j) & 1 == 0:
res += f(i + 1, mask | (1 << j), is_limit and j == hi, True)
return res
时间复杂度:记 \(D = 10\),由于每个状态只会被计算一次,每个状态的复杂度是 \(O(D)\);每一个 \((i, mask)\) 能够唯一确定 \((i,~mask,~is\_limit,~is\_num)\) 四元组(因此在记忆化的时候只需要 \((i, mask)\) 维度),所以状态个数为 \(m\cdot2^D\),其中 \(m\) 表示 \(n\) 的二进制长度。所以复杂度为: \(O(D\cdot m \cdot 2^D)\)
实际上某些问题中, \(is\_num\) 可以 被简化掉,因为 \(not(mask ==0)\) 和 \(is\_num\) 是 等价的。
统计 \(1 \sim n\) 中各个数位都不相同的数字的个数。限制条件:\(mask\) 前导中出现过的数字是不可以填的。
def countSpecialNumbers(self, n: int) -> int:
s = str(n)
@lru_cache(None)
def f(i, mask, is_limit, is_num):
if i == len(s): return int(is_num)
res = 0 if is_num else f(i + 1, mask, False, False)
lo, hi = 0 if is_num else 1, int(s[i]) if is_limit else 9
for j in range(lo, hi + 1):
if (mask >> j) & 1 == 0:
res += f(i + 1, mask | (1 << j), is_limit and j == hi, True)
return res
return f(0, 0, True, False)
统计区间:\(1 \sim N\) 中的所有数字,每个数位都被旋转。
限制条件:旋转后不等于自身,且合法的数字。只需要在数字中包含至少 一个 \([2, 5, 6, 9]\) 且不包含 \([3, 4, 7]\)。
def rotatedDigits(self, n: int) -> int:
s = str(n)
m = len(s)
# 合法情况:包含至少一个 [2, 5, 6, 9] 且 不包含 [3, 4, 7]
nums = [0, 0, 1, -1, -1, 1, 1, -1, 0, 1]
@lru_cache(None)
def f(i, has_mir, is_limit, is_num):
if i == m: return int(has_mir and is_num)
res = 0 if is_num else f(i + 1, has_mir, False, False)
lo, hi = 0 if is_num else 1, int(s[i]) if is_limit else 9
for j in range(lo, hi + 1):
if nums[j] != -1:
res += f(i + 1, has_mir or nums[j] == 1, is_limit and j == hi, True)
return res
return f(0, False, True, False)
902. 最大为 N 的数字组合 - 力扣(LeetCode)
def atMostNGivenDigitSet(self, digits: List[str], n: int) -> int:
s = str(n)
ss = set([int(ch) for ch in digits])
m = len(s)
@lru_cache(None)
def f(i, is_limit, is_num):
if i == m: return int(is_num)
res = 0 if is_num else f(i + 1, False, False)
lo, hi = 0 if is_num else 1, int(s[i]) if is_limit else 9
for j in range(lo, hi + 1):
if j in ss:
res += f(i + 1, is_limit and j == hi, True)
return res
return f(0, True, False)
2827. 范围中美丽整数的数目 - 力扣(LeetCode)
运用模运算的性质:整个数字 模 \(k\) 的结果,比如 \(1234 \bmod 17\) ,可以看成 \((1000 \bmod 17) + (200 \bmod 17)+(30 \bmod 17 )+ (4 \bmod 17)\),所以最后模数 的结果只需要等价成不断 \(mod\_res \times 10 + j\) 即可。
def numberOfBeautifulIntegers(self, low: int, high: int, k: int) -> int:
def cal(x):
s = str(x)
m = len(s)
@lru_cache(None)
def f(i, mod_res, odd_even_delta, is_limit, is_num):
if i == m: return int(odd_even_delta == 0 and mod_res == 0 and is_num)
res = 0 if is_num else f(i + 1, mod_res, odd_even_delta, False, False)
lo, hi = 0 if is_num else 1, int(s[i]) if is_limit else 9
for j in range(lo, hi + 1):
res += f(i + 1, (mod_res * 10 + j) % k, odd_even_delta + (1 if j & 1 else -1), is_limit and j == hi, True)
return res
return f(0, 0, 0, True, False)
return cal(high) - cal(low - 1)
600. 不含连续 1 的非负整数 - 力扣(LeetCode)
二进制数位 dp。上界改为 1
def findIntegers(self, n: int) -> int:
def get_bin(x):
res = i = 0
while x:
res = res + pow(10, i) * (x % 2)
i += 1
x >>= 1
return res
n = get_bin(n)
s = str(n)
m = len(s)
@lru_cache(None)
def f(i, pre, is_limit, is_num):
if i == m: return int(is_num)
res = 0 if is_num else f(i + 1, None, False, False)
lo, hi = 0 if is_num else 1, int(s[i]) if is_limit else 1
for j in range(lo, hi + 1):
if pre == None or (j == 1 and pre != 1) or j == 0:
res += f(i + 1, j, is_limit and j == hi, True)
return res
return f(0, None, True, False) + 1
状态机 DP¶
0 表示当前异或偶数个 k,1 表示当前异或奇数个 k
\(0 \rightarrow 0 或者 1 \rightarrow 1\):\(加上x\)
\(0 \rightarrow 1 或者 1 \rightarrow 0\): $ 加上x \oplus k$
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
n = len(nums)
dp = [[0] * 2 for _ in range(n + 1)]
dp[n][1] = -inf
for i, x in enumerate(nums):
dp[i][0] = max(dp[i - 1][0] + x, dp[i - 1][1] + (x ^ k))
dp[i][1] = max(dp[i - 1][1] + x, dp[i - 1][0] + (x ^ k))
return dp[n - 1][0]
状压 DP¶
竖方块摆放确定时,横方块摆放一定确定(合法或者恰好填充),所以只需要看竖方块的摆放情况。对于 $N\times M $ 的方格,\(f(i, j)\) 表示第 \(i\) 行形态为 \(j\) 时,前 \(i\) 行合法切割方案数。 \(j\) 是用十进制记录 \(M\) 位二进制数,其每位数字为 \(1\) 表示放竖方块上半部分,为 \(0\) 表示其他情况。(竖方块下半部分 / 横方块)
\(f(i,j)\) 能由 \(f(i-1, k)\) 状态转移的充要条件:1. \(j ~\&~ k =0\) ,保证同列上下两行不会同时放竖方块的上半部分。2. \(j ~|~ k\) 的所有连续的 0 的个数必须是偶数。\(j ~|~ k\) 为 0 当且仅当上下两行都是横方块,所以必须是偶数个。
初始状态对于 \(f(0)\) ,不能对第一行产生影响,所以只有 \(f(0,~0) = 1\) ,其余为 0。最终目标:\(f(N, 0)\) ,状态转移方程:\(f(i,j) = \sum valid(f(i-1,k))\)
对于所有 \(M\) 位二进制数,预处理其是否满足所有连续 0 的个数是否是偶数。
N = M = 11
f = [[0] * (1 << M + 1) for _ in range(N + 1)]
def solve(n, m):
# f [n][1 << m]
# 预处理,判断 i 是否含有连续的奇数个 0
s = set()
for i in range(1 << m):
c = 0
for j in range(m):
if i >> j & 1:
if c & 1: break
else: c += 1
if c & 1: s.add(i)
f[0][0] = 1
for i in range(1, n + 1):
for j in range(1 << m):
f[i][j] = 0
for k in range(1 << m):
if (j & k == 0 and (j | k not in s)):
f[i][j] += f[i - 1][k]
return f[n][0]
最短哈密顿回路 / 旅行商问题
哈密顿回路:无向带权图中经过所有顶点的回路。朴素做法对于 \(N\) 个顶点,时间复杂度为 \(O(n!)\) ,是 \(NP-hard\) 问题(无法在多项式时间复杂度内求解)。
实际上,设已经访问过的点集 \(S\),当前节点 \(j\) ,设 \(f(S,j)\) 表示路径已经访问过点集 \(S\) 中的点且当前访问的 \(j\) 时 的最短路径。有状态转移:\(f(S,j) = \min \{f(S-j, k) + w(k,~j) ~, \forall~k \in S-j \}\)。可以用二进制来压缩已经访问的点集 \(S\)。最终问题 \(f(2^N -1, N - 1 )\),初始值 \(f(0, 0) = 0\)。
def solve():
n = int(input())
f = [[inf] * n for _ in range(1 << n)]
w = []
for _ in range(n):
w.append(list(map(int, input().split())))
f[1][0] = 0
for S in range(1, 1 << n):
for j in range(n):
if (S >> j) & 1: # j 在 S 中,
for k in range(n):
if ((S ^ (1 << j)) >> k) & 1: # 且 k 在 S - j 中
f[S][j] = min(f[S ^ (1 << j)][k] + w[k][j], f[S][j])
return f[(1 << n) - 1][n - 1]
全排列型状压¶
朴素 - 全排列状压
1879. 两个数组最小的异或值之和 - 力扣(LeetCode)
\(O(n^2 \times 2^n )\) 做法:\(f(i, s)\) 表示考虑完成 \(nums[0: i]\) ,状态为 \(s\) 之下,最小异或值之和。
class Solution:
def minimumXORSum(self, nums1: List[int], nums2: List[int]) -> int:
n = len(nums2)
f = [[inf] * (1 << n) for _ in range(n + 1)]
for i in range(n + 1): f[i][0] = 0
for i in range(1, n + 1):
x = nums1[i - 1]
for s in range(1, 1 << n):
for j in range(n):
if (s >> j) & 1 == 0: continue
f[i][s] = min(f[i][s], f[i - 1][s ^ (1 << j)] + (x ^ nums2[j]))
return f[n][(1 << n) - 1]
优化:省略前一维度,这是因为 \(i\) 的信息隐含在 \(s\) 所含 1 的个数之中。时间复杂度 \(O(n\times 2^n)\)
class Solution:
def minimumXORSum(self, nums1: List[int], nums2: List[int]) -> int:
n = len(nums2)
f = [inf] * (1 << n)
f[0] = 0
for s in range(1, 1 << n):
x = nums1[s.bit_count() - 1]
for j in range(n):
if (s >> j) & 1 == 0: continue
f[s] = min(f[s], f[s ^ (1 << j)] + (x ^ nums2[j]))
return f[(1 << n) - 1]
约束型 - 全排列状压
对于带有约束的全排列问题,\(f[i][s]\) 表示考虑完全排列 \(p[0: i]\) ,已经选择集合状态为 \(s\) 情况下的合法方案数。一般转移方程:\(f[i][s]=\sum f[i-1][s-\{j\}],~ \forall ~valid(j)\),初始化 \(f[0][0]=1\)。时间复杂度为 \(O(n^2 \times 2^n)\)。
优化思路:由于 \(s\) 中包含了 \(i\) 的信息,即 \(bin(s).count('1')\) ,所以第一维度可以省略。时间复杂度 \(O(n\times 2^n)\)
时间复杂度为 \(O(n^2 \times 2^n)\)。
class Solution:
def countArrangement(self, n: int) -> int:
res = 0
# f [i][s] 考虑完 perm [1] ~ perm [i],已选择状态为 s
m = (1 << n) - 1
f = [[0] * (m + 1) for _ in range(n + 1)]
f[0][0] = 1
for i in range(1, n + 1):
for s in range(m + 1):
for j in range(n):
if (s >> j) & 1 and ((j + 1) % i == 0 or i % (j + 1) == 0):
f[i][s] += f[i - 1][s ^ (1 << j)]
return f[n][m]
优化:省略前一维度。时间复杂度 \(O(n\times 2^n)\)
class Solution:
def countArrangement(self, n: int) -> int:
res = 0
# f [i][s] 考虑完 perm [1] ~ perm [i],已选择状态为 s
m = (1 << n) - 1
f = [1] + [0] * m
for s in range(m + 1):
i = bin(s).count('1')
for j in range(n):
if (s >> j) & 1 and ((j + 1) % i == 0 or i % (j + 1) == 0):
f[s] += f[s ^ (1 << j)]
return f[m]
\(f(s,i)\) 表示当前选择的状态为 \(s\) ,最后一个位置选择的元素为 \(nums[i]\) 。对所有在 \(s\) 中的 \(i\) ,考虑其所有可能的前一个位置的值 \(nums[j]\)。\(f(s,i)=\sum f(s\oplus j,~j),~\forall \text{valid}(j)\)。复杂度:\(O(n^2\cdot 2^n)\)。
moder = 10 ** 9 + 7
class Solution:
def specialPerm(self, nums: List[int]) -> int:
n = len(nums)
f = [[0] * n for _ in range(1 << n)]
f[0][0] = 1
for s in range(1 << n):
for i in range(n):
if (s >> i) & 1 == 0: continue
if (s ^ (1 << i)) == 0:
f[s][i] = 1
continue
for j in range(n):
if i == j or (s >> j) & 1 == 0: continue
x, y = nums[i], nums[j]
if x % y == 0 or y % x == 0:
f[s][i] = (f[s][i] + f[s ^ (1 << i)][j]) % moder
res = 0
for i in range(n):
res = (res + f[s][i]) % moder
return res
划分成 \(k\) 个子集的问题¶
\(f(i,s)\) 表示划分到第 \(i\) 个子集,划分的状态为 \(s\) 情况下的某个值。\(f(i,s)=F((f(i-1,s-sub),~G(sub) ))\)。
时间复杂度: \(O(n\cdot 3^n)\)。由于元素个数为 \(i\) 的集合个数有 \(C(n,i)\) 个,其子集个数为 \(2^i\) ,根据二项式定理 \((a + b) ^n = \sum_{i=0}^n C_n^ia^ib^{n-i}\),所以 \(\sum_{i=0}^{n} C(n,i)\cdot 2^i = (2+1)^n=3^n\),每次需要 \(O(n)\) 时间计算 \(G\) 的情况下,时间复杂度为 \(O(n\cdot 3^n)\)
最小化 \(k\) 个子集和中的最大值问题。\(f(i,s)\) 表示划分到第 \(i\) 个子集,划分的状态为 \(s\) 情况下的 \(i\) 个子集中和最大值的最小值。
考虑 \(s\) 的所有子集 \(sub\) ,\(f(i,s)=\min \{ \max(f(i-1,s-sub),~\sum sub)\}\)。最终答案为 \(f(k,1\) << \(n-1)\),初始值 \(f(0,0)=0\)。
def distributeCookies(self, cookies: List[int], k: int) -> int:
# f [i][s] 表示当前划分状态为 s, s 为 1 表示已经分配
# 划分完第 i 个集合,所有集合的最大值 的最小值
# f [i][s] = min(max(f [i - 1][s ^ sub], sum(sub)))
# f [k][1 << n - 1]
n = len(cookies)
f = [[inf] * (1 << n) for _ in range(k + 1)]
f[0][0] = 0
for i in range(1, k + 1):
for s in range(1, 1 << n):
sub = s
while sub:
tot = sum(cookies[j] for j in range(n) if (sub >> j) & 1)
f[i][s] = min(f[i][s], max(f[i - 1][s ^ sub], tot))
sub = (sub - 1) & s
return f[k][(1 << n) - 1]
1723. 完成所有工作的最短时间 - 力扣(LeetCode)
此题是上一题的数据增强版,优化方法:预处理所有子集的和 + 一维滚动状压 dp。复杂度:\(O(3^n+ n\cdot 2^n)\)
def minimumTimeRequired(self, nums: List[int], k: int) -> int:
n = len(nums)
f = [inf] * (1 << n)
f[0] = 0
sum_ = defaultdict(int)
for i, x in enumerate(nums):
for s in range(1 << i):
sum_[(1 << i) | s] = sum_[s] + x
for _ in range(1, k + 1):
for s in range((1 << n) - 1, 0, -1):
sub = s
while sub:
tot = sum_[sub]
if f[s ^ sub] > tot: tot = f[s ^ sub]
if f[s] > tot: f[s] = tot
sub = (sub - 1) & s
return f[(1 << n) - 1]
划分集合每个和不超过 \(k\) 的最少划分数
1986. 完成任务的最少工作时间段 - 力扣(LeetCode)
\(f(s)\) 表示到达这个状态需要的最少划分段数。枚举每个子集 \(sub\) ,\(f(s)=\min \{f(s-sub)+1\},~\forall \sum sub\le k\)。时间复杂度:\(O(3^n+n\cdot 2^n)\)
def minSessions(self, nums: List[int], k: int) -> int:
n = len(nums)
f = [inf] * (1 << n)
f[0] = 0
# 预处理所有子集的和
sum_ = defaultdict(int)
for i, x in enumerate(nums):
for sub in range(1 << i):
sum_[sub | (1 << i)] = sum_[sub] + x
for s in range(1, 1 << n):
sub = s
while sub:
if sum_[sub] <= k:
f[s] = min(f[s], f[s ^ sub] + 1)
sub = (sub - 1) & s
return f[(1 << n) - 1]
集合是否能划分成 k 个相等子集
698. 划分为 k 个相等的子集 - 力扣(LeetCode)
\(f[s]\) 为在压缩状态 \(s\) 下的余数。考察每一个在集合中的元素 \(nums[j]\),对于删去其的集合 \(ls=f[s \oplus(1\) << \(j)]\),当且仅当 \(ls+nums[j] \le siz\) 的时候可以更新 \(f[s]\),相当于枚举所有删去一个元素的子集向 \(f[s]\) 转移,能否构造出整数倍的集合。
这种方法会有一定的重复,不妨反过来,对于 \(f[s]\) ,考察其没有出现的每一个元素 \(nums[j]\),更新 \(f[s | nums[j]]\)。这样可以大大减少重复。
时间复杂度:不超过 \(O(n \cdot 2^n)\)
class Solution:
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
siz = sum(nums) // k
if sum(nums) % k != 0 or any(x > siz for x in nums): return False
# f [s] 表示 在选择状态为 s 的情况下,余数是多少
n = len(nums)
m = (1 << n) - 1
f = [0] + [-inf] * m
for s in range(m):
if f[s] == -inf: continue
for j in range(n):
if (s >> j) & 1 == 0:
nx = s | (1 << j)
if f[nx] == 0: continue
if f[s] + nums[j] <= siz:
f[nx] = (f[s] + nums[j]) % siz
return f[m] == 0
多重状压:记忆化搜索¶
当某些字符、数字可以使用 若干次时,传统的状压不方便表示使用状况。因此可以转用 \(dfs\) 的记忆化搜索方式,配合 \(Counter\) 计数器实现状压的代替品。
其中剪枝部分,要求 \(s[0]\) 一定在 \(word\) 中出现。这是因为如果当前 \(word\) 中存在能消除 \(s[0]\) 的,那么最终解一定至少包含这些部分。否则,没有能消除 \(s[0]\) 的,表示当前 \(dfs(s)\) 的结果不合法,返回 \(inf\)。这样实际上约束了每次一定转移到最优的方案。
例如对于 \(s=the\), 首先寻找所有包含 \(t\) 的方案并向其转移。时间复杂度接近 \(O(n\cdot m)\) ,\(n\) 为原始字符串大小,\(m\) 为可以考虑的字符串数量。
def minStickers(self, words: List[str], target: str) -> int:
words = [Counter(word) for word in words]
# dfs(s) 表示得到 s 的最少数量
@lru_cache(None)
def dfs(s):
if s == '': return 0
cs = Counter(s)
res = inf
for word in words:
# 如果 word 压根无法消除 s [0] 可以直接跳过
# 因为再怎么使用也无法完全消除 s
# 应该首先考虑将 s [0] 能消除的方案
if s[0] not in word: continue
ns = s
for k, v in word.items():
ns = ns.replace(k, '', v)
res = min(res, dfs(ns) + 1)
return res
res = dfs(target)
return res if res < inf else -1
\(k\) 进制状压¶
\(k\) 种颜色染色 \(n \times m\) 网格(不允许出现空着的格子)问题
1931. 用三种不同颜色为网格涂色 - 力扣(LeetCode)
\(k\) 进制预处理 + 合法状态预处理 + 枚举状压
每一行使用长度为 \(m\) 的 \(k\) 进制的串来表示。通过预处理的方式记录在 \(color\) 中,键为 \(k\) 进制的串对应的十进制数,值为对应的 \(k\) 进制串的列表。相邻两行的约束,通过枚举来预处理。转移方程:\(f(i,s)= \sum f(i-1, ~e[s])\)
时间复杂度:\(O(k^{2m}\times n)\)
moder = 10 ** 9 + 7
class Solution:
def colorTheGrid(self, m: int, n: int) -> int:
# 三进制表示每一行的颜色
colors = {}
for b in range(3 ** m):
color = []
x = b
while x:
color.append(x % 3)
x //= 3
color.extend([0] * (m - len(color)))
if any(color[i] == color[i + 1] for i in range(len(color) - 1)):
continue
colors[b] = color[::-1]
e = defaultdict(list)
# 预处理每一种状态可以邻接的状态
for i, u in colors.items():
for j, v in colors.items():
flag = True
for b in range(m):
if u[b] == v[b]:
flag = False
break
if flag: e[i].append(j)
# f [i][s] 表示 i 行为 s 的方案数
f = [[0] * (3 ** m) for _ in range(n)]
for b in colors.keys():
f[0][b] = 1
for i in range(1, n):
for s in colors.keys():
for ps in e[s]:
f[i][s] = (f[i - 1][ps] + f[i][s]) % moder
return (sum(f[n - 1])) % moder
划分 DP¶
约束划分个数¶
将数组分成 (恰好/至多) \(k\) 个连续子数组,计算与这些子数组有关的最优值。
类型 1: \(f[i][j]\) 当前考虑完前缀 \(a[:i]\),且 \(a[:i]\) 恰好划分成 \(j\) 个连续子数组所得到的最优解。 枚举最后一个子数组的左端点 \(L\), 从 \(f[L][j-1]\) 转移到 \(f[i][j]\),并考虑 \(a[L:i]\) 对最优解的影响。 \(f(i,j)=\min(f(L,j-1))\)
类型 2:\(f(i,j, pre)\) 表示当前考虑到 \(a[i]\), 且 \(a[:i]\) 的前缀中包含 \(j\) 个连续子数组所得的最优解,其中 \(pre\) 表示当前待划分的这段的状态。 考虑是否在 \(i\) 处划分,并考虑前一段状态 \(pre\) 是否允许划分。\(f(i,j,pre)=\min \{~f(i+1,j,pre),~f(i+1,j+1,pre')~\}\)
3117. 划分数组得到最小的值之和 - 力扣(LeetCode)
\(f(i, j, pre\_and):\) 表示当前考虑到 \(nums[i]\),且前缀中包含 \(j\) 段,\(pre\_and\) 表示当前待划分的这段的 AND。
def minimumValueSum(self, nums: List[int], andValues: List[int]) -> int:
n, m = len(nums), len(andValues)
@lru_cache(None)
def f(i, j, pre_and):
# 表示当前考虑到 nums [i],且前缀中包含 j 段,pre_and 表示当前待划分的这段的 AND
if i == n and j == m: return 0
if i < n and j == m: return inf
if i == n and j < m: return inf
pre_and &= nums[i]
# 在 i 处不划分,
res = f(i + 1, j, pre_and)
# 在 i 处划分,条件是这一段 pre_and == andValues [j]
if pre_and == andValues[j]:
res = min(res, f(i + 1, j + 1, -1) + nums[i])
return res
res = f(0, 0, -1)
return res if res < inf else -1
时间复杂度:\(O(mn\log U)\),由于 $ pre_and$ 表示当前待划分这段的按位与。记 \(\log U\) 表示最大数对应的二进制位数。对于一个确定的 \(i\),向前 AND 每次不变或者减少比特 1 的个数。所以不同的 \(pre\_and\) 数不超过 \(logU\)。总共有 \(mn \log U\) 个状态,每个状态是 \(O(1)\) 。
不相交区间¶
如果报酬等于区间长度(时间差),是一个贪心 + 排序问题:按照结束时间排序,结束早的优先考虑。本题的报酬与区间长度无关,因此需要 \(dp\) 。
原始问题是考虑 $ 0 \sim \max (endTime)$ 时间段内,能够获得的最大报酬。考虑子问题 \(0 \sim endTime[i]\) 时间段内的最大报酬。
\(f[x]\) 表示 $ 0 \sim endTime[x] $ 时间段内的最多报酬,一种转移是 \(f[x-1]\),表示当前区间 \(x\) 不考虑;另一种转移,考虑结束时间 $ \le startTime[x]$ 的最后一个位置 \(idx\) ,有 \(endTime[idx]\le startTime[x]\)(可以发现 \(f[x]\) 关于 \(x\) 递增,最后一个位置一定是最大值),不妨同样按照结束时间从小到大排序,因此可以使用二分查找。
实现时,在数组前面增加一个 \((0,0,0)\) 用于统一查找。
时间复杂度:\(O(n \log n)\)
def jobScheduling(self, startTime: List[int], endTime: List[int], profit: List[int]) -> int:
# f [x] 表示 0 ~ endTime [x] 时间段内的最多报酬
n = len(startTime)
nums = [(0, 0, 0)] + sorted([(s, e, p) for s, e, p in zip(startTime, endTime, profit)], key = lambda x: x[1])
f = [0] * (n + 1)
def bisect_left(lo, hi, k):
while lo < hi:
mid = (lo + hi) >> 1
if nums[mid][1] > k:
hi = mid
else:
lo = mid + 1
return lo - 1
for i in range(1, n + 1):
idx = bisect_left(0, i, nums[i][0])
f[i] = max(f[i - 1], f[idx] + nums[i][2])
return f[n]
计算几何¶
旋转与向量¶
将点 \((x, ~y)\) 顺时针旋转 \(\alpha\) 后,新的点坐标为 \((x \cos \alpha+y\sin\alpha,~~~ y \cos \alpha~ - x\sin\alpha )\)
证明:
点 \(P(x, y)\) 表示为 半径为 \(r\),极角为 \(\theta\) 的坐标系下。
顺时针旋转 \(\alpha\) 后:
距离¶
\(A(x_1, ~y_1),~ B(x_2, ~y_2)\)
曼哈顿距离 $ = |x_1 - x_2| + |y_1 - y_2|$
切比雪夫距离 $ = \max(|x_1 - x_2| ,~ |y_1 - y_2|)$
曼哈顿距离转切比雪夫:
即将所有点顺时针旋转 45° 后再乘 \(\sqrt{2}\)。将 \(P(x,y)\) 映射到 \(P'(x+y,x-y)\) 坐标系下,\(d_{M} = d'_Q\)。对于三维点 \(P(x, y, z)\) 映射到 \(P''(x+y+z, -x+y+z, x-y+z, x + y -z)\) 坐标系下,\(d_M = d''_Q\)。当需要求若干点之间的最大 \(d_M\) 时,可以转换为:
from sortedcontainers import SortedList
class Solution:
def minimumDistance(self, points: List[List[int]]) -> int:
msx, msy = SortedList(), SortedList()
for x, y in points:
msx.add(x + y)
msy.add(x - y)
res = inf
for x, y in points:
msx.remove(x + y)
msy.remove(x - y)
xmx = msx[-1] - msx[0]
ymx = msy[-1] - msy[0]
res = min(res, max(xmx, ymx))
msx.add(x + y)
msy.add(x - y)
return res
切比雪夫转曼哈顿距离:
将 \(P(x, y)\) 映射到 \(P'(\frac{x + y}{2}, \frac{x-y}{2})\) 坐标系下,\(d_Q = d'_M\)。切比雪夫距离在计算的时候需要取 max,往往不是很好优化,对于一个点,计算其他点到该的距离的复杂度为 O(n)。而曼哈顿距离只有求和以及取绝对值两种运算,我们把坐标排序后可以去掉绝对值的影响,进而用前缀和优化,可以把复杂度降为 O(1)。
[P3964 TJOI2013] 松鼠聚会 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
转换成切比雪夫距离。将 x, y 分离,前缀和维护到各个 xi 和 yi 的距离和,再相加:
def solve():
n = int(input())
points = []
res = inf
for _ in range(n):
x, y = map(int, input().split())
points.append(((x + y) / 2, (x - y) / 2))
numsx = [p[0] for p in points]
numsy = [p[1] for p in points]
def g(nums):
nums.sort()
curx = nums[0]
curd = sum(nums[i] - curx for i in range(1, n))
dic = {nums[0]: curd}
for i in range(1, n):
x = nums[i]
d = x - curx
curd = curd + i * d - (n - i) * d
dic[x] = curd
curx = x
return dic
dicx, dicy = g(numsx), g(numsy)
for x, y in points:
ans = dicx[x] + dicy[y]
res = min(res, ans)
print(int(res))
字符串¶
KMP¶
暴力匹配所有起始位置
时间复杂度:\(O(mn)\)
for i in range(len_s - len_p + 1):
ii, j = i, 0
while j < len_p:
if s[ii] == p[j]: ii, j = ii + 1, j + 1
else: break
if j == len_p: res.append(i)
前缀函数 / next 数组
时间复杂度:\(O(n)\),在线算法
对于一个长度为 \(n\) 的字符串,其前缀函数是一个长度为 \(n\) 的数组 \(\pi\),其中 \(\pi(i)\) 定义:子串 \(s[0] \sim s[i]\) 中存在的、相等的最长真前缀和真后缀的长度。如果不存在则为 0。规定:\(\pi[0]=0\),因为其不存在真前后缀。
例如:'aabaaab' 的 \(\pi\) 数组为 [0, 1, 0, 1, 2, 2, 3]
求解前缀函数:
- 相邻的前缀函数值,至多 + 1。\(\pi(i-1)\) 表示着前一个状态匹配的最长真前后缀,也是下一个待匹配真前缀的最右元素下标。当且仅当 \(s[i]=s[\pi(i-1)]\),有 \(\pi(i) = \pi(i-1)+1\) 。
- 考虑 \(s[i] \ne s[\pi(i-1)]\),失配时,希望找到 \(s[0] \sim s[i-1]\) 中,仅次于 \(\pi[i-1]\) 的第二长度 \(j\),使得在位置 \(i-1\) 的前后缀性质仍然保持,即 \(s[0] \sim s[j-1] = s[i-j] \sim s[i-1]\)。
实际上,第二长真后缀也完整存在于 当前真前缀 \(s[0] \sim s[j-1]\) 中,即有转移方程:\(j^{(n-1)}=\pi(j^n-1)\)。所以如此往复,要么直到 $s[i]=s[j'] $ ,然后转移到第一种情况;要么直到 \(j'=0\) 。两种情况,通过判断 $s[i] $ 是否 \(s[j']\) 来确定要不要让 \(j'+1\) 统一,最后 \(s[i]=s[j']\)。
def get_pi(s):
n = len(s)
pi = [0] * n
for i in range(1, n):
j = pi[i - 1]
while j > 0 and s[i] != s[j]:
j = pi[j - 1]
if s[i] == s[j]: j += 1
pi[i] = j
return pi
KMP 算法:找出 \(p\) 在 \(s\) 中的所有出现
时间复杂度:\(O(n+m)\),其中 \(m=len(p),~n=len(s)\)
构造字符串 \(t=p\#s\),计算其前缀函数 \(\pi\)。考虑前缀函数 \(\pi[m+1]\sim \pi[n+m]\),其中 $\pi(i)=m $ 的地方,一定完成对模式串 \(p\) 的匹配。此时,\(i\) 位于 $t $ 中 \(s\) 的最后位置,所以原始位置为 \(i-m+1-m-1 = i - 2 * m\)。
def kmp(p, s):
res = []
m, n = len(p), len(s)
pi = get_pi(p + '#' + s)
for i in range(m + 1, len(pi)):
if pi[i] == m: res.append(i - 2 * m)
return res
字符串排序¶
sorted(str) #返回按照字典序排序后的列表,如 "eda"-> ['a','d','e']
s_sorted=''.join(sorted(str)) #把字符串列表组合成一个完整的字符串
Z 函数 (扩展 KMP)¶
对于字符串 s,函数 \(z[i]\) 表示 \(s\) 和 \(s[i:]\) 的最长公共前缀 \((LCP)\) 的长度。特别的,定义 \(z[0] = 0\)。即:
例如, \(z(abacaba) = [0, 0, 1, 0, 3, 0, 1]\)
可视化:Z Algorithm (JavaScript Demo) (utdallas.edu)
# s = 'aabcaabxaaaz'
n = len(s)
z = [0] * n
l = r = 0
for i in range(1, n):
if i <= r: # 在 Z-box 范围内
z[i] = min(z[i - l], r - i + 1)
while i + z[i] < n and s[z[i]] == s[i + z[i]]:
l, r = i, i + z[i]
z[i] += 1
# print(z) # [0, 1, 0, 0, 3, 1, 0, 0, 2, 2, 1, 0]
判断子序列¶
判断 p 在删除 ss 中下标元素后,是否仍然满足 s 是 p 的子序列。
ss = set(removable[:x])
i = j = 0
n, m = len(s), len(p)
while i < n and j < m:
if i not in ss and s[i] == p[j]:
j += 1
i += 1
return j == m
字符串哈希¶
2430. 对字母串可执行的最大删除数 - 力扣(LeetCode)
Trie 树 / 字典树¶
26 叉字典树:
class Trie:
def __init__(self):
self.is_end = False
self.next = [None] * 26
def insert(self, word: str) -> None:
node = self
for ch in word:
idx = ord(ch) - ord('a')
if not node.next[idx]:
node.next[idx] = Trie()
node = node.next[idx]
node.is_end = True
def search(self, word: str) -> bool:
node = self
for ch in word:
idx = ord(ch) - ord('a')
if not node.next[idx]:
return False
node = node.next[idx]
return node.is_end
def startsWith(self, prefix: str) -> bool:
node = self
for ch in prefix:
idx = ord(ch) - ord('a')
if not node.next[idx]:
return False
node = node.next[idx]
return True
哈希字典树:
def countPrefixSuffixPairs(self, words: List[str]) -> int:
class Node:
__slots__ = 'children', 'cnt'
def __init__(self):
self.children = {} # 用字典的字典树
self.cnt = 0
res = 0
root = Node() # 树根
for word in words:
cur = root
for p in zip(word, word[::-1]): # (p [i], p [n - i - 1])
if p not in cur.children:
cur.children[p] = Node()
cur = cur.children[p]
res += cur.cnt
cur.cnt += 1
return res
class Trie:
def __init__(self):
self.end = False
self.next = {}
def insert(self, word: str) -> None:
p = self
for ch in word:
if ch not in p.next:
p.next[ch] = Trie()
p = p.next[ch]
p.end = True
def search(self, word: str) -> bool:
p = self
for ch in word:
if ch not in p.next:
return False
p = p.next[ch]
return p.end
def startsWith(self, prefix: str) -> bool:
p = self
for ch in prefix:
if ch not in p.next:
return False
p = p.next[ch]
return True
# Your Trie object will be instantiated and called as such:
# obj = Trie()
# obj.insert(word)
# param_2 = obj.search(word)
# param_3 = obj.startsWith(prefix)
图论¶
建图¶
给定 \(n\) 个节点,\(m\) 条边。我们希望用 \(g(u,v)\) 表示 \(u,v\) 的信息。(如连通性、边权)
邻接矩阵
带权无向图
- 初始值设置为 \(inf\) 表示不连通,即距离无穷大
- \(g(x,x)\) 应为 \(0\)
from math import inf
n, m = map(int, input().split())
g = [[inf] * n for _ in range(n)]
for _ in range(m):
u, v, w = map(int, input().split())
g[u][v] = g[v][u] = w
g[u][u] = g[v][v] = 0 # 原地不动,距离是 0
样例输出
[0, 2, inf, 3, inf]
[2, 0, 5, 4, inf]
[inf, 5, 0, 10, inf]
[3, 4, 10, 0, 7]
[inf, inf, inf, 7, 0]
邻接表
1)带权无向图
n, m = map(int, input().split())
e = [[] for _ in range(n)]
for _ in range(m):
u, v, w = map(int, input().split())
e[u].append((v, w))
e[v].append((u, w))
[(1, 2), (3, 3)]
[(0, 2), (2, 5), (3, 4)]
[(1, 5), (3, 10)]
[(0, 3), (1, 4), (4, 7), (2, 10)]
[(3, 7)]
2)带权有向图
n, m = map(int, input().split())
e = [[] for _ in range(n)]
for _ in range(m):
u, v, w = map(int, input().split())
e[u].append((v, w))
3)不带权有向图
图遍历¶
DFS 序(邻接表版)
n, m = map(int, input().split())
e = [[] for _ in range(n)]
for _ in range(m):
u, v, w = map(int, input().split())
e[u].append((v, w))
e[v].append((u, w))
s = set() # 已经访问的
def dfs(u):
# 遍历当前节点的操作,如输出节点信息等
print(u, end = " ")
s.add(u)
# 遍历邻居
for v, _ in e[u]:
if v not in s:
dfs(v)
dfs(0) # 0 1 2 3 4
print()
s.clear() #已经访问的
dfs(4) # 4 3 0 1 2
去重边建图
100244. 带权图里旅途的最小代价 - 力扣(LeetCode)
这道题需要在建图的时候取 AND 运算的最小值。
e = [defaultdict(lambda: -1) for _ in range(n)]
for u, v, w in edges:
e[v][u] = e[u][v] = e[u][v] & w
Floyd¶
求解带权图上 多源最短路。
给定 \(n\) 个节点,\(m\) 条边的带权无向图,和 \(q\) 组询问,每次需要回答 \(u, v\) 的最短路径长度。
- 考虑 \(u,v\) 路径上的中间节点 \(k\) ,拆分成 \(u \rightarrow k\) 的最短路径长度 + \(k \rightarrow v\) 的最短路径长度
- \(k\) 代表当前允许使用的中间节点,\(k\) 遍历 \([0, n - 1]\) ,对于每个 \(k\) 对整个邻接矩阵进行更新
- \(\text{ g[u][v] = min(g[u][v], g[u][k] + g[k][v])}\)
- 为什么 \(k\) 不能放在其他位置?放在最外层,对于每个 \(k\) 从左上到右下进行一次更新,能够有效利用之前的值。放在最内层,每处 \((u,v)\) 局部更新,不能有效利用之前的值。
时间复杂度:\(O(n^3)\)
for _ in range(m):
u, v, w = map(int, input().split())
g[u][v] = g[v][u] = w
for k in range(n):
for u in range(n):
for v in range(n):
g[u][v] = min(g[u][v], g[u][k] + g[k][v])
from math import inf
n, m = map(int, input().split())
g = [[inf] * n for _ in range(n)]
for _ in range(m):
u, v, w = map(int, input().split())
g[u][v] = g[v][u] = w
for k in range(n):
for u in range(n):
for v in range(n):
g[u][v] = min(g[u][v], g[u][k] + g[k][v])
print(g)
print(g[4][2]) # 16
1334. 阈值距离内邻居最少的城市 - 力扣(LeetCode)
class Solution:
def findTheCity(self, n: int, edges: List[List[int]], distanceThreshold: int) -> int:
res, idx = inf, 0
g = [[inf] * n for _ in range(n)]
for u, v, w in edges:
g[u][v] = g[v][u] = w
for k in range(n):
for u in range(n):
for v in range(n):
g[u][v] = min(g[u][v], g[u][k] + g[k][v])
for u in range(n):
cnt = sum(g[u][v] <= distanceThreshold and u != v for v in range(n))
if cnt <= res:
res, idx = cnt, u
return idx
Dijkstra¶
求解 非负权图 上单源最短路径。
朴素 Dijkstra¶
适用于稠密图,时间复杂度:\(O(n^2)\)
d = [inf] * n
d[0] = 0
s = set() # S 集合为已经确定的节点集合
for _ in range(n - 1):
x = -1
# 从 U - S 中找出距离 S 最近的节点
for u in range(n):
if u not in s and (x < 0 or d[u] < d[x]):
x = u
s.add(x)
# 松弛,对每个节点判断以 x 作为中间节点时,是否距离原点更加
for u in range(n):
d[u] = min(d[u], d[x] + g[u][x])
from math import inf
n, m = map(int, input().split())
g = [[inf] * n for _ in range(n)]
for _ in range(m):
u, v, w = map(int, input().split())
g[u][v] = g[v][u] = w
g[u][u] = g[v][v] = 0
d = [inf] * n
d[0] = 0
s = set() # S 集合为已经确定的节点集合
for _ in range(n - 1):
x = -1
# 从 U - S 中找出距离 S 最近的节点
for u in range(n):
if u not in s and (x < 0 or d[u] < d[x]):
x = u
s.add(x)
# 松弛,对每个节点判断以 x 作为中间节点时,是否距离原点更加
for u in range(n):
d[u] = min(d[u], d[x] + g[u][x])
print(d) # [10, 11, 16, 7, 0]
class Solution:
def networkDelayTime(self, e: List[List[int]], n: int, k: int) -> int:
g = [[inf] * (n + 1) for _ in range(n + 1)]
m = len(e)
for i in range(m):
u, v, w = e[i]
g[u][v] = w
d = [inf] * (n + 1)
d[k] = 0
s = set() # S 集合为已经确定的节点集合
for _ in range(n - 1):
x = -1
# 从 U - S 中找出距离 S 最近的节点
for u in range(1, n + 1):
if u not in s and (x < 0 or d[u] < d[x]):
x = u
s.add(x)
# 松弛,对每个节点判断以 x 作为中间节点时,是否距离原点更加
for u in range(1, n + 1):
d[u] = min(d[u], d[x] + g[x][u])
res = max(d[1: ])
return res if res < inf else -1
1976. 到达目的地的方案数 - 力扣(LeetCode)
最短路 Dijkstra + 最短路 Dp:求源点 0 到任意节点 i 的最短路个数。
def countPaths(self, n: int, roads: List[List[int]]) -> int:
g = [[inf] * n for _ in range(n)]
moder = 10 ** 9 + 7
for u, v, w in roads:
g[u][v] = g[v][u] = w
g[u][u] = g[v][v] = 0
d = [inf] * n # dist 数组, d [i] 表示源点到 i 的最短路径长度
d[0] = 0
v = [False] * n # 节点访问标记
mn, res = inf, 0
f = [0] * n # f [i] 表示源点到 i 节点的最短路个数
f[0] = 1
for _ in range(n - 1):
x = -1
for u in range(n):
if not v[u] and (x < 0 or d[u] < d[x]):
x = u
v[x] = True
for u in range(n):
a = d[x] + g[x][u]
if a < d[u]: # 到 u 的最短路个数 = 经过 x 到 u 的个数 = 到 x 的最短路的个数
d[u], f[u] = a, f[x]
elif a == d[u] and u != x: # 路径一样短,追加
f[u] = (f[u] + f[x]) % moder
return f[n - 1]
2662. 前往目标的最小代价 - 力扣(LeetCode)
将 有向图路径 转换为 节点。不需要建图,但是需要首先对 d 数组进行预处理。
def minimumCost(self, start: List[int], target: List[int], specialRoads: List[List[int]]) -> int:
# 把路径(a, b) -> (c, d) 简化成 (c, d)
t, s = tuple(target), tuple(start)
d, v = defaultdict(lambda: inf), set()
d[s] = 0
def g(p, q):
return abs(p[0] - q[0]) + abs(p[1] - q[1])
# 补充 start 和 target 节点
specialRoads.append([s[0], s[1], t[0], t[1], g(s, t)])
specialRoads.append([s[0], s[1], s[0], s[1], 0])
while True:
x = None
# 找到距离 start 最近的 且 未计算过的节点
for x1, y1, x2, y2, w in specialRoads:
u = (x2, y2)
if u not in v and (not x or d[u] < d[x]):
x = u
v.add(x)
if x == t:
return d[t]
for x1, y1, x2, y2, w in specialRoads:
u0, u = (x1, y1), (x2, y2)
# 两种情况,1. start 经过 x 到达 u
# 2. start 经过 x 再到 u0 从路径到达 u
d1 = d[x] + g(x, u)
d2 = d[x] + g(x, u0) + w
d[u] = min(d[u], d1, d2)
堆优化 Dijkstra¶
适用于稀疏图(\(点个数的平方 远大于 边的个数\)),复杂度为 \(O(mlogm)\),\(m表示边的个数\)。
使用小根堆,存放未确定最短路点集对应的 (d [i], i)。对于同一个 i 可能存放多组不同 d [i] 的元组,因此堆中元素的个数最多是 \(m\) 个。
寻找最小值的过程可以用一个最小堆来快速完成。
e = [[] for _ in range(n)]
for u, v, w in roads:
e[u].append((v, w))
e[v].append((u, w))
d = [inf] * n
d[0] = 0
hq = [(0, 0)] # 小根堆,存放未确定最短路点集对应的 (d [i], i)
while hq:
dx, x = heapq.heappop(hq)
if dx > d[x]: continue # 跳过重复出堆,首次出堆一定是最短路
for u, w in e[x]:
a = d[x] + w
if a < d[u]:
d[u] = a # 同一个节点 u 的最短路 d [u] 在出堆前会被反复更新
heapq.heappush(hq, (a, u))
1976. 到达目的地的方案数 - 力扣(LeetCode)
def countPaths(self, n: int, roads: List[List[int]]) -> int:
e = [[] for _ in range(n)]
for u, v, w in roads:
e[u].append((v, w))
e[v].append((u, w))
moder = 10 ** 9 + 7
f = [0] * n
d = [inf] * n
f[0], d[0] = 1, 0
hq = [(0, 0)] # 小根堆,存放未确定最短路点集对应的 (d [i], i)
while hq:
dx, x = heapq.heappop(hq)
if dx > d[x]: continue # 之前出堆过
for u, w in e[x]:
a = d[x] + w
if a < d[u]:
d[u] = a
f[u] = f[x]
heapq.heappush(hq, (a, u))
elif a == d[u]:
f[u] = (f[u] + f[x]) % moder
return f[n - 1]
有向图 + 邻接矩阵最短路
def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
e = [[] * (n + 1) for _ in range(n + 1)]
for u, v, w in times:
e[u].append((v, w))
d = [inf] * (n + 1)
d[k] = 0
hq = [(0, k)]
while hq:
dx, x = heapq.heappop(hq)
if dx > d[x]: continue
for u, w in e[x]:
a = d[x] + w
if a < d[u]:
d[u] = a
heapq.heappush(hq, (a, u))
res = max(d[1: ])
return res if res < inf else -1
2045. 到达目的地的第二短时间 - 力扣(LeetCode)
使用双列表 d,存放最短和次短。将等红绿灯转换为松弛条件,通过 t 来判断红灯还是绿灯。
def secondMinimum(self, n: int, edges: List[List[int]],
time: int, change: int) -> int:
# 将 节点 (u, t) 即 (节点,时间) 作为新的节点
e = [[] for _ in range(n + 1)]
for u, v in edges:
e[u].append(v)
e[v].append(u)
hq = [(0, 1)]
# (t // change) & 1 == 0 绿色
# (x, t) -> (u, t + time)
# (t // change) & 1 == 1 红色
# 需要 change - t % change 时间进入下一个节点
d, dd = [inf] * (n + 1), [inf] * (n + 1)
d[1] = 0
while hq:
t, x = heapq.heappop(hq)
if d[x] < t and dd[x] < t: # 确认最小的和次小的
continue
for u in e[x]:
nt = inf
if (t // change) & 1 == 0:
nt = t + time
else:
nt = t + change - t % change + time
if nt < d[u]:
d[u] = nt
heapq.heappush(hq, (nt, u))
elif dd[u] > nt > d[u] :
dd[u] = nt
heapq.heappush(hq, (nt, u))
return dd[n]
堆优化 Dijkstra(字典写法)¶
转换建图 + 堆 Dijkstra (字典写法 )
def electricCarPlan(self, paths: List[List[int]], cnt: int, start: int, end: int, charge: List[int]) -> int:
# 将(节点, 电量) 即 (u, c) 看成新的节点
# 将充电等效转换成图
# 则将节点 i 充电消耗时间 charge [u] 看成从(u, c) 到 (u, c + 1) 有 w = 1
n = len(charge)
e = [[] for _ in range(n)]
for u, v, w in paths:
e[u].append((v, w))
e[v].append((u, w))
hq = [(0, start, 0)]
d = {}
while hq:
dx, x, c = heapq.heappop(hq)
if (x, c) in d: # 已经加入到寻找到最短路的集合中
continue
d[(x, c)] = dx
for u, w in e[x]:
if c >= w and (u, c - w) not in d:
heapq.heappush(hq, (w + dx, u, c - w))
if c < cnt:
heapq.heappush(hq, (charge[x] + dx, x, c + 1))
return d[(end, 0)]
def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
e = [[] * (n + 1) for _ in range(n + 1)]
for u, v, w in times:
e[u].append((v, w))
d = {}
hq = [(0, k)]
while hq:
dx, x = heapq.heappop(hq)
if x in d: continue # 跳过非首次出堆
d[x] = dx # 首次出堆一定是最短路
for u, w in e[x]:
a = d[x] + w
if u not in d: # 未确定最短路
heapq.heappush(hq, (a, u)) # 入堆,同一个节点可能用多组
for i in range(1, n + 1):
if i != k and i not in d:
return -1
return max(d.values())
2045. 到达目的地的第二短时间 - 力扣(LeetCode)
求解严格次短路问题:两个 d 字典,一个存放最短,一个存放严格次短
def secondMinimum(self, n: int, edges: List[List[int]], time: int, change: int) -> int:
# 将 节点 (u, t) 即 (节点,时间) 作为新的节点
# (t // change) & 1 == 0 绿色
# (x, t) -> (u, t + time)
# (t // change) & 1 == 1 红色
# 需要 change - t % change 时间进入下一个节点
# (x, t) -> (u, t + change - t % change + time)
e = [[] for _ in range(n + 1)]
for u, v in edges:
e[u].append(v)
e[v].append(u)
hq = [(0, 1)]
d, dd = {}, {} # dd 是确认次短的字典
while hq:
t, x = heapq.heappop(hq)
if x not in d:
d[x] = t
elif t > d[x] and x not in dd:
dd[x] = t
else:
continue
for u in e[x]:
if (t // change) & 1 == 0:
if u not in dd:
heapq.heappush(hq, (t + time, u))
else:
if u not in dd:
heapq.heappush(hq, (t + change - t % change + time, u))
return dd[n]
转换建图问题:可折返图 转换成 到达时间的奇偶问题
2577. 在网格图中访问一个格子的最少时间 - 力扣(LeetCode)
class Solution:
def minimumTime(self, grid: List[List[int]]) -> int:
# (w, x0, x1) 表示到达(x0, x1) 时刻至少为 w
if grid[0][1] > 1 and grid[1][0] > 1: return -1
m, n = len(grid), len(grid[0])
deltas = [(1, 0), (-1, 0), (0, 1), (0, -1)]
target = (m - 1, n - 1)
d = {}
hq = [(0, (0, 0))]
while hq:
dx, x = heappop(hq)
if x in d: continue
d[x] = dx
if x == target: return d[target]
x0, x1 = x[0], x[1]
for u0, u1 in [(x0 + dx, x1 + dy) for dx, dy in deltas]:
if not (0 <= u0 < m and 0 <= u1 < n) or (u0, u1) in d: continue
u, t = (u0, u1), grid[u0][u1]
if dx + 1 >= t:
heappush(hq, (dx + 1, u))
else:
# 例如 3 -> 6,折返一次变成 5 后 + 1 到达 6
du = (t - dx - 1) if (t - dx) & 1 else t - dx
heappush(hq, (dx + du + 1, u))
最短路与子序列「和/积」问题¶
求解一个数组的所有子序列的和 / 积中第 k 小 (大同理) 问题,其中子序列是原数组删去一些元素后剩余元素不改变相对位置的数组。
以和为例,可以转化为最短路问题:
将子序列看成节点 \((s, idx)\), \(s\) 表示序列的和,\(idx\) 表示下一个位置,则 \(idx - 1\) 表示序列最后一个元素的位置。
例如 \([1, 2, 4, 4, 5, 9]\) 的其中一个子序列 \([1,2]\),对应节点 \((3, 2)\)。如果从 \(idx-1\) 位置选或不选来看,可以转换为子序列 \([1, 2, 4]\) 和 \([1, 4]\),则定义节点之间的边权是序列和之差,由于有序数组,边权一定非负。
可以将原问题看成从 \([\space ]\) 为 源节点的,带正权的图。只需要不断求解到源节点的最短路节点,就可以得到所有子序列从小到大的和的值。
假设有 \(n\) 个节点,堆中元素个数不会超过 \(k\),时间复杂度是 \(O(klogk)\)。
注意,如果采用二分答案方式求解,即想求出恰好有 \(k\) 个元素小于等于对应子序列之和 \(s\) 的算法,时间复杂度为 \(O(klogU), U = \sum{a_i}\)
2386. 找出数组的第 K 大和 - 力扣(LeetCode)
def kSum(self, nums: List[int], k: int) -> int:
res = sum(x for x in nums if x > 0)
nums = sorted([abs(x) for x in nums])
# (s, idx) (子序列和, 当前下标)
hq = [(0, 0)]
while k > 1:
# 每一次会将最小的子序列的和 pop 出去
# pop k - 1 次,堆顶就是答案
s, idx = heappop(hq)
# 选 idx - 1
if idx < len(nums):
heappush(hq, (s + nums[idx], idx + 1))
# 不选 idx - 1
if idx:
heappush(hq, (s + (nums[idx] - nums[idx - 1]), idx + 1))
k -= 1
return res - hq[0][0]
动态修改边权¶
- 在邻接表数组中记录原矩阵中边的位置,方便修改
- 记 \(d_{signal, i}\) 表示第 \(signal\) 次得到的节点 \(i\) 到源点的最短路。跑两次 dijkstra 算法
第二次修改边权时,对于特殊边尝试修改条件:
解得:
当这个值大于 1 时,是一个合法的边权,进行修改。
def modifiedGraphEdges(self, n: int, edges: List[List[int]], source: int, destination: int, target: int) -> List[List[int]]:
e = [[] for _ in range(n)]
# 存放边的位置,方便在原矩阵直接修改
for pos, (u, v, w) in enumerate(edges):
e[u].append([v, pos])
e[v].append([u, pos])
total_d = [[inf] * n for _ in range(2)]
total_d[0][source] = total_d[1][source] = 0
def dijkstra(signal):
d = total_d[signal] # 第 signal 次的最短路数组
v = set()
for _ in range(n - 1):
x = -1
for u in range(n):
if u not in v and (x < 0 or d[u] < d[x]):
x = u
v.add(x)
for u, pos in e[x]:
w = edges[pos][2]
w = 1 if w == -1 else w
# d [x] + nw + total_d [0][destination] - total_d [0][u] = target
if signal == 1 and edges[pos][2] == -1:
nw = target - total_d[0][destination] + total_d[0][u] - d[x]
if nw > 1: # 合法修改
w = edges[pos][2] = nw
d[u] = min(d[u], d[x] + w)
return d[destination]
if dijkstra(0) > target: return [] # 全为 1 也会超过 target
if dijkstra(1) < target: return [] # 最短路无法变大
for e in edges:
if e[2] == -1:
e[2] = 1
return edges
最小生成树¶
P3366 【模板】最小生成树 - 洛谷 (luogu.com.cn)
Prim¶
from math import *
def solve():
n, m = map(int, input().split())
d = [inf] * n
g = [[] for _ in range(n)]
for _ in range(m):
u, v, w = map(int, input().split())
u, v = u - 1, v - 1
g[u].append((v, w))
g[v].append((u, w))
d[0] = 0
res = 0
s = set()
for _ in range(n):
dx, x = inf, -1
for i in range(n):
if i not in s and (x < 0 or d[i] < dx):
dx, x = d[i], i
s.add(x)
res += dx
for i, w in g[x]:
if i not in s:
d[i] = min(d[i], w)
if inf not in d:
print(res)
return
print('orz')
solve()
Kruskal¶
import sys
input = lambda: sys.stdin.readline().strip()
def solve():
n, m = map(int, input().split())
edges = []
for _ in range(m):
u, v, w = map(int, input().split())
edges.append((w, u, v))
# 按边的权重排序
edges.sort()
# 并查集初始化
fa = list(range(n + 1))
def find(x):
if fa[x] == x:
return x
fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
fa[find(v)] = find(u)
return True
return False
res = 0 # 最小生成树的权重和
cnt = 0 # 已选择的边数
for w, u, v in edges:
if union(u, v): # 如果边的两个端点不在同一集合中
res += w
cnt += 1
if cnt == n - 1: # 已经选择了 n-1 条边,最小生成树完成
break
if cnt == n - 1:
print(res)
else:
print('orz') # 无法形成最小生成树
solve()
二分图¶
简单来说,如果图中点可以被分为两组,并且使得所有边都跨越组的边界,则这就是一个二分图。
二分图的最大匹配、完美匹配和匈牙利算法_完美匹配图论-CSDN 博客
定义:无向图 \(G(U,V,E)\) 中节点可以划分成互斥集合 \(U\), \(V\),使得 \(\forall (u, v) \in E\) 的两个端点分属于两个集合。
-
两个互斥点集中的任意两点之间都不存在边
-
任何一条边的两个端点分别来互斥的两个点集 \(U, V\)
-
不存在奇数点的环(不存在奇数条边的环)
证明:因为走过一条边必然从一个集合走到另一个集合,要完成闭环必须走偶数条边(偶数个点)
- 可能存在孤点
二分图判定¶
DFS 染色:
def isBipartite(self, graph: List[List[int]]) -> bool:
n = len(graph)
color = [0] * n
flag = True
def dfs(u, c):
nonlocal flag
color[u] = c
for v in graph[u]:
if color[v] == 0:
dfs(v, -c)
elif color[v] == c:
flag = False
return
for i in range(n):
if color[i] == 0: dfs(i, 1)
if not flag: return False
return True
Bfs 染色:
def isBipartite(self, graph: List[List[int]]) -> bool:
n = len(graph)
q = collections.deque()
color = [0] * n
for i in range(n):
if not color[i]:
q.append(i)
color[i] = 1
while q:
u = q.popleft()
c = color[u]
for v in graph[u]:
if not color[v]:
color[v] = -c
q.append(v)
elif color[v] == c:
return False
return True
并查集做法:
维护两个并查集 \(U, V\) , 分别存储两个互斥点集。
对于每个节点 \(u\) 遍历其所有邻接节点 $v $ 。如果遇到 \(u\) , \(v\) 在同一个并查集,说明不满足二分图。(同一点集中出现连接的边)
否则将所有邻接节点加到另一个并查集中。
def isBipartite(self, graph: List[List[int]]) -> bool:
n = len(graph)
s = set()
pa = list(range(n))
def find(x):
if pa[x] != x:
pa[x] = find(pa[x])
return pa[x]
def union(u, v):
if find(u) != find(v):
pa[find(v)] = find(u)
for u in range(n):
if u not in s:
s.add(u)
p = None
for v in graph[u]:
if find(u) == find(v):
return False
if p: union(p, v)
p = v
return True
二分图最大匹配 / 匈牙利算法¶
二分图的匹配
给定一个二分图 G,在 G 的一个子图 M 中, M 的边集 \(\{E\}\) 中的 任意两条边都没有公共顶点,则称 M 是一 个匹配。
最大匹配:匹配边数最大的匹配。
完美匹配:如果一个图的某个匹配中,所有的顶点都是匹配点,那么它就是一个完美匹配。图 4 是一个完美匹配。
-
完美匹配一定是最大匹配(完美匹配的任何一个点都已经匹配,添加一条新的匹配边一定会冲突)
-
但并非每个图都存在完美匹配。
- 完美匹配的边数 = 左 / 右部的点数
二分图最大权完美匹配:二分图边权和最大的完美匹配。
交替路:从一个未匹配点出发(右),依次经过非匹配边、匹配边、非匹配边…形成的路径叫交替路。
增广路:从一个未匹配点出发(右),走交替路,如果途径另一个未匹配点(出发的点不算),则这条交替路称为增广路(agumenting path)。例如,图 5 中的一条增广路如图 6 所示(图中的匹配点均用红色标出)
- 特点:非匹配边比匹配边多一条。因此,研究增广路的意义是改进匹配。只要把增广路中的匹配边和非匹配边的身份交换即可。由于中间的匹配节点不存在其他相连的匹配边,所以这样做不会破坏匹配的性质。交换后,图中的匹配边数目比原来多了 1 条。
增广路定理
通过不停地找增广路来增加匹配中的匹配边和匹配点。找不到增广路时,达到最大匹配
n1, n2, m = map(int, input().split())
vis = set()
match = {}
e = defaultdict(list)
def dfs(u) -> bool:
for v in e[u]:
if v in vis: continue
vis.add(v)
if v not in match or dfs(match[v]):
match[v] = u
return True
return False
for _ in range(m):
u, v = map(int, input().split())
e[u].append(v)
for u in range(1, n1 + 1):
vis = set()
dfs(u)
print(len(match))
二分图最大权完美匹配 / KM 算法¶
暴力枚举时间复杂度:\(O(m!)\),实际上是二分图的最大权完美匹配问题。复杂度:\(O(m^3)\)
class KM:
def __init__(self, n):
self.maxn = 300 + 10
self.INF = float('inf')
self.wx = [0] * (self.maxn)
self.wy = [0] * (self.maxn)
self.cx = [-1] * (self.maxn)
self.cy = [-1] * (self.maxn)
self.visx = [0] * (self.maxn)
self.visy = [0] * (self.maxn)
self.cntx = n
self.cnty = n
self.Map = [[0] * (self.maxn) for _ in range(self.maxn)]
self.slack = [0] * (self.maxn)
def dfs(self, u):
self.visx[u] = 1
for v in range(1, self.cnty + 1):
if not self.visy[v] and self.Map[u][v] != self.INF:
t = self.wx[u] + self.wy[v] - self.Map[u][v]
if t == 0:
self.visy[v] = 1
if self.cy[v] == -1 or self.dfs(self.cy[v]):
self.cx[u] = v
self.cy[v] = u
return True
elif t > 0:
self.slack[v] = min(self.slack[v], t)
return False
def KM(self):
for i in range(1, self.cntx + 1):
for j in range(1, self.cnty + 1):
if self.Map[i][j] == self.INF:
continue
self.wx[i] = max(self.wx[i], self.Map[i][j])
for i in range(1, self.cntx + 1):
self.slack = [self.INF] * (self.maxn)
while True:
self.visx = [0] * (self.maxn)
self.visy = [0] * (self.maxn)
if self.dfs(i):
break
minz = self.INF
for j in range(1, self.cnty + 1):
if not self.visy[j] and minz > self.slack[j]:
minz = self.slack[j]
for j in range(1, self.cntx + 1):
if self.visx[j]:
self.wx[j] -= minz
for j in range(1, self.cnty + 1):
if self.visy[j]:
self.wy[j] += minz
else:
self.slack[j] -= minz
ans = 0
for i in range(1, self.cntx + 1):
if self.cx[i] != -1:
ans += self.Map[i][self.cx[i]]
return ans
def add_edge(self, u, v, w):
self.Map[u + 1][v + 1] = w
连通块问题¶
2867. 统计树中的合法路径数目 - 力扣(LeetCode)
DFS + 字典维护节点所在连通块大小
cc_siz 用来记录连通块的大小。vis 数组对质数节点进行记录,dfs 的起始节点一定是质数节点的非质数子节点。
使用 cc_node 记录一次连通分量 dfs 得到的节点列表,更新对应 cc_siz 的值。这样后续在遍历到已经遍历过的非质数连通块时,可以直接得到结果。
def countPaths(self, n: int, edges: List[List[int]]) -> int:
primes = []
N = n + 10
is_prime = [True] * N
is_prime[0] = is_prime[1] = False
for i in range(2, N):
if is_prime[i]:
primes.append(i)
for p in primes:
if i * p >= N:
break
is_prime[i * p] = False
if i % p == 0:
break
e = [[] for _ in range(n + 1)]
for u, v in edges:
e[u].append(v)
e[v].append(u)
vis = [False] * (n + 1)
cc_siz = {}
cc_node = []
def dfs(u, fa):
siz = 1
cc_node.append(u)
for v in e[u]:
if v != fa and not is_prime[v]:
siz += dfs(v, u)
return siz
res = 0
for u in range(1, n + 1):
if not vis[u] and is_prime[u]:
vis[u] = True
cur_siz = 0
for v in e[u]:
if is_prime[v]:
continue
# 对于每一个子连通分量
if v in cc_siz:
siz = cc_siz[v]
else:
cc_node.clear()
siz = dfs(v, u)
for node in cc_node:
cc_siz[node] = siz
res += siz + siz * cur_siz
cur_siz += siz
return res
DFS + 字典维护连通块的 AND 值 和 节点对应的连通块下标
100244. 带权图里旅途的最小代价 - 力扣(LeetCode)
通过字典中连通块下标,判断两个节点是否在同一连通块内。
def minimumCost(self, n: int, edges: List[List[int]], query: List[List[int]]) -> List[int]:
cc_and = {} # 键为节点,值为 (cc_cnt, and_ans),即对应的连通块编号 和 连通块的 and 值
cc_cnt = 0 # 计数,记录当前统计到第几个连通块
cc_node = []
e = [{} for _ in range(n)]
for u, v, w in edges:
if v not in e[u]:
e[v][u] = e[u][v] = w
else:
e[v][u] = e[u][v] = e[u][v] & w
vis = [False] * n
def dfs(u):
vis[u] = True
cc_node.append(u)
and_ans = -1
for v in e[u]:
w = e[u][v]
and_ans &= w
if not vis[v]:
and_ans &= dfs(v)
return and_ans
for u in range(n):
if not vis[u]:
and_ans = dfs(u)
for node in cc_node:
cc_and[node] = (cc_cnt, and_ans)
cc_node.clear()
cc_cnt += 1
return [0 if u == v else (cc_and[u][1] if cc_and[u][0] == cc_and[v][0] else -1)
for u, v in query]
并查集维护连通块属性
928. 尽量减少恶意软件的传播 II - 力扣(LeetCode)
题目问从 \(bad\) 选一个节点,删除其和其邻接的边,能得到感染后的最小数量。暴力做法枚举删除的 \(bad\) 点,需要 \(O(n^3)\) 。
逆向思维:枚举所有的 \(good\) 连通块,维护块大小 \(siz\) ,及其邻接的 \(bad\) 节点集合 \(cc\_bad\)。恰好只有一个邻接 \(bad\) 的连通块,由于删除 \(bad\) 后整个连通块不会被感染,所以其对该 \(bad\) 节点的贡献为连通块大小。最后,返回所有 \(bad\) 节点得到的贡献和(来自恰好仅邻接该节点的 \(good\) 连通块)最大、序最小的节点。
def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
n = len(graph[0])
fa = list(range(n))
siz = [1] * n
cc_bad = defaultdict(set)
def find(x):
if fa[x] != x: fa[x] = find(fa[x])
return fa[x]
def union(u, v):
if find(u) != find(v):
siz[find(u)] += siz[find(v)]
cc_bad[find(u)] |= cc_bad[find(v)]
fa[find(v)] = find(u)
bad = set(initial)
good = set(range(n)) - bad
for u in good:
for v, con in enumerate(graph[u]):
if not con: continue
if v in bad: cc_bad[find(u)].add(v)
else: union(u, v)
pa = set(find(u) for u in good)
bad_siz = Counter()
for p in pa:
if len(cc_bad[p]) == 1:
bad_siz[list(cc_bad[p])[0]] += siz[p]
mx, res = 0, min(bad)
for u, sz in bad_siz.items():
if sz > mx: mx, res = sz, u
if sz == mx: res = min(res, u)
return res
最小费用最大流¶
TODO
数学¶
数论¶
取整函数¶
上下取整转换
证明:由于有 \(\left\lceil \frac{n}{m} \right\rceil = \left\lfloor \frac{n}{m} \right\rfloor\) 当且仅当 \(n=k \cdot m\) 时成立,当 \(n=k\cdot m-r,r\in(0,m)\),一定有 \(k=\left\lceil \frac{k\cdot m}{m} \right\rceil =\left\lceil \frac{k\cdot m-r}{m} \right\rceil = \left\lceil \frac{k\cdot m-r+1}{m} \right\rceil = 1+ \left\lfloor \frac{k\cdot m-r}{m} \right\rfloor\),显然 \(k\cdot m-r+1 \in R\),令其等于 \(n\) 时得证 \(\left\lceil \frac{n}{m} \right\rceil =1+ \left\lfloor \frac{n-1}{m} \right\rfloor\)。
灵神恒等式*
1553. 吃掉 N 个橘子的最少天数 - 力扣(LeetCode)
实际上这个结论可以推广到任意个数,比如:
不等式
取余性质
幂等律
素数¶
素数计数函数近似值
小于等于 \(x\) 的素数个数记为 \(\pi(x)\),\(\pi (x) 近似于 \frac{x}{\ln x}\)。
(1). 埃氏筛¶
时间复杂度:\(O(nloglogn)\)
primes = []
is_prime = [True] * (n + 1) # MX 为最大可能遇到的质数 + 1
is_prime[1] = is_prime[0] = False
for i in range(2, int(math.sqrt(n)) + 1): # i * i <= n
if is_prime[i]:
for j in range(i * i, n + 1, i):
is_prime[j] = False
for i in range(2, n + 1):
if is_prime[i]: primes.append(i)
时间复杂度证明
对于 2,要在数组中筛大约 \(\frac{n}{2}\) 个数,同理对于素数 \(p\),约要筛去 \(\frac{p}{n}\) 个数。
切片优化
primes = []
is_prime = [True] * (n + 1)
is_prime[0] = is_prime[1] = False
for i in range(2, int(math.sqrt(n)) + 1):
if is_prime[i]:
is_prime[i * i::i] = [False] * ((n - i * i) // i + 1)
for i in range(2, n + 1):
if is_prime[i]: primes.append(i)
(2). 欧拉筛 / 线性筛¶
基本思想:每一个合数一定存在最小的质因子。确保每一个合数只被他的最小质因子筛去。
primes = []
is_prime = [True] * (n + 1)
is_prime[0] = is_prime[1] = False
for i in range(2, n + 1):
if is_prime[i]: primes.append(i)
for p in primes:
if i * p > n: break
is_prime[i * p] = False
if i % p == 0: break
正确性证明:
- 每个合数不会被筛超过一次:
枚举 \(i\) 从小到大的所有质数,在 \(i \% p = 0\) 出现之前,\(p\) 一定小于 \(i\) 的所有质因子,\(p \cdot i\) 的质因子的前缀与 \(i\) 的质因子前缀相同,故 \(p\) 一定是 \(i \cdot p\) 的最小质因子,筛去;在出现 \(i \% p =0\) 时,\(p\) 恰好是 \(i\) 的最小质因子,同理,然后 break。保证每个合数只会被最小的质因子筛去。
- 每个合数都会被筛最少一次:
每个合数 \(x\) 一定存在最小质因子 \(p\),和对应的 $ x / p$。在 \(i\) 枚举到 \(x / p\) 的时候,一定会筛去 \(x\)
由于保证每个合数一定被晒一次,所以是 \(O(n)\)
(3). 分解质因子¶
所谓质因子分解是将一个正整数 \(n\) 写成一个或多个质数的乘积形式。
试除法。复杂度不超过 \(O(\sqrt n )\),实际上是 \(O(logn) \sim O(\sqrt {n})\)
对于一个数 x,最多有一个大于等于 \(\sqrt n\) 的质因子。(可以用反证法,证明)
所以只需要进行特判,在遍历完 \([2, int(\sqrt n)]\) 区间后,如果 x 比 1 大,则 x 就等于那最后一个质因子。
from math import *
def solve(x):
for i in range(2, int(sqrt(x)) + 1): # i = 2; i * i <= x
if x % i == 0:
s = 0
while x % i == 0:
s += 1
x //= i
print(f'{i} {s}') # i 是质因子, s 表示幂次
if x > 1:
print(f'{x} 1')
print()
solve(2 ** 3 * 3 ** 4 * 5 ** 2 * 7 * 14)
Oi Wiki 风格:统计质因子及其出现次数
from math import *
# 统计质因子及其出现次数
def breakdown(x):
res = []
for i in range(2, int(sqrt(x)) + 1):
if x % i == 0:
cnt = 0
while x % i == 0:
x //= i
cnt += 1
res.append((i, cnt))
if x > 1: res.append((x, 1))
return res
print(breakdown(2 ** 3 * 3 ** 4 * 5 ** 2 * 7 * 14)) # [(2, 4), (3, 4), (5, 2), (7, 2)]
from math import *
# 统计质因子及其出现次数
def breakdown(x):
res = set()
for i in range(2, int(sqrt(x)) + 1):
if x % i == 0:
cnt = 0
while x % i == 0:
x //= i
cnt += 1
res.add(i)
if x > 1: res.add(x)
return res
print(len(breakdown(int(input()))))
(4). 乘分解质因子¶
求 \(n!\) 中 质因子 \(p\) 的出现次数问题
例如 \(9!\) 中 \(2\) 的出现次数,\(9!\) 中有 \(2,4,6,8\) 分别提供 \(1,2,1,3\) 个 \(2\) ,出现次数为 \(7\) 。
\(n!\) 中 \(p\) 出现次数,\([1,n]\) 中是 \(p\) 的倍数的数各自提供 1 个 \(p\),有 \(n // p\) 个数;这些数都除 \(p\) 表示各拿走一个 \(p\),之后最大的数是 \(n//p\),更新上界 \(n' \leftarrow n//p\) ,在 \([1,n'\)] 中 \(p\) 的倍数的数各自提供 1 个 \(p\) ,有 $n' // p $ 个数,重复上面操作,直到 \(n=0\) 。
# 统计 n! 中质因子 p 出现的次数
def fpf(n, p): # factorial_prime_factor
res = 0
while n:
res += n // p
n //= p
return res
求 \(n!\) 转换为 \(b\) 进制数字后的末尾 0 的个数。
例如:\((9!)_{10}=(720)_{10}=(880)_{9} = 8 \times 9^2 + 8 \times 9^1+0\times 1=3^4\times 2^3+2^3 \times 3^2=3^2\times 2^4\times 5\);
\(5!=(120)_{10}=1111000_{2}=2^3\times 3^1\times5^1\)。
\(x\) 在 \(b\) 进制下的表示是 \(x = \sum d_i b^{i}\),最低非 0 位 \(k\) 对应 \(d_k\times b^k\), 一定能将 \(x\) 表示成 \(b^k \times (\sum d_i\times b^{i-k})\),实际上由于 \(x\) 可以表示成标准分解式 \(x=p_0^{k_0} \times p_1^{k_1}\cdots\),将 \(b\) 分解成 \(p_0^{k'_0} \times \cdots\),将 \(b\) 代入 \(n!\) 的标准分解式,得:
即, \(p_0^{k_0} \times p_1^{k_1}\cdots = (p_0 ^{k^{'}_{0} \cdot k } \times p_1 ^{k^{'}_{1} \cdot k} \times...) \times (\sum d_i \times b^{i-k})\)
由于 \(k_i = k^{'}_{i} \cdot k\),由于 \(k\) 是最低位,应取 \(k\) 等于 \(\min \frac{k_i}{k_i'}\)。转换为:\(k_i\) 为 \(b\) 的质因子在 \(n!\) 中出现的次数。
import sys
input = lambda: sys.stdin.readline().strip()
from math import *
# 统计 n! 中质因子 p 出现的次数
def fpf(n, p): # factorial_prime_factor
res = 0
while n:
res += n // p
n //= p
return res
# 统计质因子及其出现次数
def breakdown(n):
res = []
for i in range(2, int(sqrt(n)) + 1):
if n % i == 0:
cnt = 0
while n % i == 0:
n //= i
cnt += 1
res.append((i, cnt))
if n > 1: res.append((n, 1))
return res
def solve():
n, b = map(int, input().split())
pf = breakdown(b) # 对 b 进行质因子分解
res = inf
for f, c in pf:
res = min(res, fpf(n, f) // c)
return res
print(solve())
阶乘合并
-
\(A_i!\) 的累加结果,\(A_1! + A_2! + ... + A_n!\),对排序后的 \(A_1 < A_2 < ... < A_n\),由公因式提取可知,则 \(A_1!\) 为最大公因数
-
使用计数器 \(m\) 存储 x! 出现次数。 考虑阶乘的合并情况:考察 \(x!\) 的个数 \(c = m[x]\),若有 \(c = k \times (x + 1)\),则能合并为 $k $ 个 $ (x + 1)!$ 。
-
即合并条件为 \(c % (x + 1) == 0\),合并操作为 \(m[x + 1] \leftarrow m[x + 1] + c // (x + 1)\)
-
从 \(A_1\) 开始合并,当无法合并时,相当于将 \(A_1! + A_2! + ... + A_n!\) 转换成 \(B_1! + B_2! + ... + B_m!\),返回 \(B_1\) 即可
# Ai! 的累加结果,A1! + A2! + ... + An!,
# 对排序后的 A1 < A2 < ... < An,由公因式提取可知,则 A1! 为最大公因数
# 使用计数器 m 存储 x! 出现次数。
# 考虑阶乘的合并情况:考察 x! 的个数 c = m [x],若有 c = k * (x + 1),则能合并为 k 个 (x + 1)!。
# 即合并条件为 c % (x + 1) == 0,合并操作为 m [x + 1] ← m [x + 1] + c // (x + 1)
# 从 A1 开始合并,当无法合并时,相当于将 A1! + A2! + ... + An! 转换成 B1! + B2! + ... + Bm!,返回 B1 即可
from collections import Counter
n = int(input())
a = list(map(int, input().split()))
a.sort()
m = Counter()
for x in a:
m[x] += 1
x = a[0]
while True:
c = m[x] # x! 的个数
if c % (x + 1) == 0: # x! 的个数是 x+1 的倍数
m[x + 1] += c // (x + 1) # c = k * (x + 1)
x += 1
else:
print(x)
break
约数¶
试除法求所有约数¶
复杂度为:\(O(\sqrt{n})\)
from math import *
def solve(x):
res = []
for i in range(2, int(sqrt(x)) + 1):
if x % i == 0:
res.append(i)
if i != x // i:
res.append(x // i)
return res
print(solve(24)) # [2, 12, 3, 8, 4, 6]
from math import *
def solve(x):
res = []
for i in range(2, int(sqrt(x)) + 1):
if x % i == 0:
res.append(i)
if i != x // i:
res.append(x // i)
return len(res) + 2 # 1 和 自身
print(solve(1200000)) # 96
乘积数的约数个数
对于一个以标准分解式给出的数 \(N = \prod_{i = 1}^k p_i^{\alpha_i}\), 其约数个数为 \(\prod_{i = 1} ^k (\alpha_i + 1)\)
例如 \(N = 2^5 \cdot 3^1, 约数个数为(5 + 1) \times (1 + 1) = 12\)
乘积数的所有约数之和
对于一个以标准分解式给出的数 \(N = \prod_{i = 1}^k p_i^{\alpha_i}\), 其约数之和为 \(\prod_{i = 1} ^k (\sum_{j = 0}^{\alpha_i} p_i^j)\)
例如 \(N = 2^5 \cdot 3^1, 约数个数为 (2^0 + 2^1 + \cdots + 2^5) \times (3^0 + 3^1)\)。展开结果实际上,各个互不相同,每一项都是一个约数,总个数就是约数个数。
from collections import Counter
from math import *
moder = 10 ** 9 + 7
res = 1
t = int(input())
cnt = Counter()
for _ in range(t):
x = int(input())
for i in range(2, int(sqrt(x)) + 1):
if x % i == 0:
c = 0
while x % i == 0:
c += 1
x //= i
cnt[i] += c
if x > 1: cnt[x] += 1
def S(a, n):
s0 = 1
for _ in range(n):
s0 = (a * s0 + 1) % moder
return s0
for a, n in cnt.items():
res = (res * S(a, n)) % moder
print(res % moder)
约数筛 / 约数预处理¶
求 \([1:mx]\) 范围内每个数的所有约数(包含 1 和 自身)列表(从小到大排列)。
时间复杂度: \(O(n \sqrt n)\)
# mx = 10 ** 5 + 1
factors = [[] for _ in range(mx)]
for f in range(1, mx):
for x in range(f, mx, f):
factors[x].append(f)
欧拉函数¶
定义:$\phi(n) $ 表示 $1 \sim n $ 中 与 \(n\) 互质(最大公约数为 1)的数的个数。
时间复杂度:\(O(\sqrt n)\) ,同质因数分解。
对于一个以标准分解式给出的数 \(N = \prod_{i = 1}^k p_i^{\alpha_i}\),满足:
证明方法:容斥原理。
减去 $p_1, p_2, \cdots, p_k $ 的所有倍数的个数,这一步会多筛一些数。例如 一个数既是 \(p_1\), 又是 \(p_2\) 的倍数,会删去两次。
加上所有 \(p_i \cdot p_j\) 的倍数
减去所有 \(p_i \cdot p_j \cdot p_u\) 的倍数,以此类推。
最后一步,可以通过观察系数的角度来证明。例如 $\frac{1}{p_i} $ 项的系数是 -1。
证明方法二:
性质:
- 积性函数:对于互质的 \(p, q\), \(\phi(p \times q) = \phi(p) \times \phi(q)\)。 特别的, 对于奇数 \(p\), \(\phi(2p) = \phi(p)\)
证明:互质的数,质因子分解的集合无交集。\(\phi(2) = 1\)
- 对于质数 \(p\) , \(\phi(p^k) = p^k - \frac{p^k}{p} = p^k - p^{k -1}\)
证明:减去是 \(p\) 的倍数的数,剩下不是 p 的倍数的数,一定和 \(p\) 互质。
def solve(n):
res = n
for i in range(2, int(sqrt(n)) + 1):
if n % i == 0:
res = res * (i - 1) // i
while n % i == 0:7
n //= i
if n > 1:
res = res * (n - 1) // n
return res
筛法求欧拉函数¶
对于 \(N\) 的最小质因子 \(p_1\), \(N' = \frac{N}{p_1}\),我们希望筛法中,\(N\) 通过 \(N' \cdot p_1\) 筛掉。
考虑两种情况:
- $ N' \bmod p_1 = 0 $,则 \(N'\) 包含了 \(N\) 的所有质因子。
- \(N' \bmod p_i \ne 0\) ,则 \(N'\) 与 \(p_1\) 互质(证明:质数是因子只有 1 和本身,因此最大公约数是 1,互质)。
由欧拉函数的积性性质,互质的数质因子分解无交集:
在筛质数的同时筛出欧拉函数。
primes = []
is_prime = [True] * (n + 1)
phi = [0] * (n + 1)
phi[1] = 1
for i in range(2, n + 1):
if is_prime[i]:
phi[i] = i - 1
primes.append(i)
for p in primes:
if p * i > n: break
is_prime[i * p] = False
if i % p == 0:
phi[i * p] = p * phi[i]
break
phi[i * p] = (p - 1) * phi[i]
欧拉定理¶
例如:\(5^ {\phi(6)} \bmod 6=5^2 \bmod 6=25 \bmod 6=1\)。
证明:考察 \(1\sim n\) 中与 \(n\) 互质的 \(\phi(n)\) 个数:\(p_1, ~p_2,~\cdots,~p_{\phi(n)}\)。将他们乘上 \(a\) ,再逐个对 \(n\) 取模,得到另一组数 \(ap_1 \bmod n, ~~ap_2 \bmod n,\cdots,~~ap_{\phi(n)} \bmod n\)。
可以证明这一组数两两不相同(反证法,若 \(ap_i \equiv ap_j(\bmod n)\),则 \(ap_i - ap_j \equiv 0(\bmod n)\),由于 \(a,~n\) 互质,则一定有 \(p_i = p_j\),矛盾),同时这一组每个数都和 \(n\) 互质(因为 \(a\) 和 \(p_i\) 都与 \(n\) 互质)。
则可以得到,新的这组数集 和 原先与 \(n\) 互质的数集完全相同。有:\(p_1 \cdot p_2 \cdots p_{\phi(n)} \bmod n = \prod a p_i \bmod n\),即:
费马小定理¶
若 \(a\) 与 素数 \(p\) 互质,则 \(a^{p-1} \equiv 1 (\bmod ~ p)\)。
裴蜀定理¶
对正整数 \(a,b\),记最大公约数 \(d=gcd(a,b)\),
- 对于任意 \(x,y\),\(ax+by=D,~D\) 是 \(d\) 的倍数,即 \(d|(ax+by)\)
-
一定存在 \(x,y\) 使得 \(ax+by=d\) 成立。(例如,一定存在 \(x,y\) 使得 \(12x+8y = 4\))
-
推论:$gcd(a,b )=d \iff $ 存在 \(x,y 使得 ax +by=d\);
\(a,b\) 互质 \(\iff\) 存在 \(x,y\) 使得 \(ax+by=1\)。
推广:对于任意 \(n\) 个数 \(a1,a2,\cdots , a_n\) ,最大公约数为 \(d=gcd(a1, \cdots, a_n)\)
- 一定存在 \(x1, \cdots, x_n\) 使得 $\sum a_i x_i = d $ 成立。(即对于任意 $\sum a_i x_i = k \(,\)k$ 一定是 \(d\) 的整数倍)
- $gcd(a1, \cdots, a_n) = d \iff $ 存在 \(x_1, c\dots, x_n\), 使得 \(\sum a_i x_i = d\)
判断是否能从原给定集合中,选出子集 \(A=\{a1, \cdots,a_n\}\),存在一组 \(X=\{x1, \cdots,x_n\}\),使得 \(AX=1\)。实际只需要整个原集合的 \(gcd\) 值为 1,则一定存在一个最小子集其 \(gcd\) 为 1,由裴蜀定理,能找到存在一组 \(X=\{x1, \cdots,x_n\}\),使得 \(AX=1\)。
def isGoodArray(self, nums: List[int]) -> bool:
res_gcd = nums[0]
for x in nums:
res_gcd = gcd(res_gcd, x)
return res_gcd == 1
1625. 执行操作后字典序最小的字符串 - 力扣(LeetCode)
暴力枚举需要 $O(nC+nC) $,其中当 $b $ 是偶数时,\(C=10\),否则 \(C=10^2\)。其中枚举轮转起点位置阶段可以使用裴蜀定理优化。
对于任意起点 \(i'=(i\times b) \bmod n = ib - kn\),由裴蜀定理 $\sum a_i x_i = k \(,\)k$ 一定是 \(d\) 的整数倍,所以 \(i'=K \times gcd(b,n)\)。则在一个字符串轮转中,只需要通过扩展字符串为两倍长度,枚举起点在 \(gcd(a,b)\) 的整倍数的位置即可。时间复杂度:\(O(nC + \frac{n}{gcd(a,b) }\times C)\),其中当 $b $ 是偶数时,\(C=10\),否则 \(C=10^2\)。
def findLexSmallestString(self, s: str, a: int, b: int) -> str:
res = s
n = len(s)
s = s + s
s_str = set()
e_lim = 10 if b & 1 else 1 # b 为奇数才可以对偶数位置增加
for o_cnt in range(10): # 对奇数位置增加 a 的次数
for e_cnt in range(e_lim):
tmp = list(map(int, s))
for i in range(1, 2 * n, 2): tmp[i] = (tmp[i] + a * o_cnt) % 10
for i in range(0, 2 * n, 2): tmp[i] = (tmp[i] + a * e_cnt) % 10
s_str.add(''.join(map(str, tmp)))
def _gcd(a, b):
return _gcd(b, a % b) if b else a
g = _gcd(n, b)
for i in range(0, n, g): # 裴蜀定理优化
for ss in s_str:
tmp = ss[i: i + n]
if tmp < res: res = tmp
return res
欧几里得算法¶
算法原理:\(gcd(a, b) = gcd(b,a\bmod b)\)
证明:
- 对于任意一个能整除 \(a\) 且 能整除 b 的数 \(d\), $a \bmod b $ 可以写成 \(a - k \cdot b\) ,其中 \(k = a // b\) ,所以 \(d\) 一定能够整除 \(b, a \bmod b\);
- 对于任意一个能整除 \(b\) 且能整除 \(a - k \cdot b\) 的数 \(d\), 一定能整除 \(a-k\cdot b + k\cdot b = a\),所以二者的公约数的集合是等价的。
- 所以二者的最大公约数等价
时间复杂度:\(O(\log (\max(a,~b)))\)
证明:
引理 1: \(a\bmod b \in[0,~ b-1]\)。例如,\(38 \bmod 13 = 12\)
引理 2:取模,余数至少折半。
如果 $ b > a//2,~a \bmod ~b = a - b < a//2$。例如,a = 9, b = 5, a mod b = 9 - 5 = 4
如果 \(b \le a//2, ~ a \bmod b \le b - 1 \le a//2 -1\)。
情况 1:当每次执行 gcd 时,如果 \(a < b\) ,则交换;情况 2:否则 \(a \ge b\),一定发生引理 2 的情况,即对 \(a\) 取模,一定会让 \(a\) 折半。最坏情况下,每两次让 \(a\) 折半,所以时间复杂度为 :
\(O(T) = O(T /2) + 2 = O(T/4) + 4 = O(\frac {T}{2^k}) + k\times2 = 2\log k\),即 \(O(\log(\max(a, b)))\)
扩展欧几里得¶
求解 \(ax+by=gcd(a,b)\) 的一组整数解。
当 \(b=0\),\((x,y)=(1,0)\);否则 \(gcd(a,b)=gcd(b, a \bmod b)\),向下递归一层,得到下一层的 \((x',y')\)。由裴蜀定理,\(bx'+(a \%b)y'=bx'+(a-a//b \times b)y'=ax+by\),所以 \(a(y')+b(x'-a//b \times y') = ax+by\),可得解为 \((y', x'-a//b \times y')\),再回代到上一层,最终可以得到 特解 \((x_0, y_0)\)。构造通解:\((x_0+k\times \frac{b}{gcd(a,b)},~y_0-k\times \frac{a}{gcd(a,b)})\)。
def exgcd(a, b):
if b == 0: return 1, 0, a
x, y, d = exgcd(b, a % b)
return y, x - a // b * y, d
线性同余方程¶
线性同余方程
-
给定 \(a,b, m(a\bmod m \ne0)\), 求一个特解 $ x$, 使得 \(ax \equiv b (\bmod m)\)。
-
求方程 \(ax+my=b\) 的特解。
无解条件:\(b\bmod gcd(a,m) \ne 0\)
例如 \(2x \equiv 3 (\bmod 6)\),无解;\(4x \equiv 3 (\bmod 5)\),解的形式是 \(5k+2\)。
转换为 存在整数 \(y\), 使得 $ax = m(-y) +b $,即 \(ax+my=b\)。由裴蜀定理,这个方程有解的条件是 \(b\) 是 \(gcd(a, m)\) 的倍数。对应特解通过扩展欧几里得构造,求得一组特解 \((x_0, y_0)\),满足 \(ax_0 + my_0 =gcd(a,m)\),将方程除 \(gcd(a,m)\) 再乘 \(b\) ,即最后结果方程一组特解为 \((x_0'=x_0 \times \frac {b}{gcd},~ y_0'= y_0 \times \frac{b}{gcd})\)。则其任意解为:\((x'_0+k\times \frac{m}{gcd},~y'_0-k\times \frac{a}{gcd})\),则通解 \(x' \bmod \frac{m}{gcd} = x'_0\),所以最小特解为 \(x_m = x_0 \times \frac{b}{gcd} \bmod \frac{m}{gcd},~~y_m=y_0\times \frac{b}{gcd} \bmod \frac{a}{gcd}\)。
def exgcd(a, b):
if b == 0: return 1, 0, a
x, y, d = exgcd(b, a % b)
return y, x - a // b * y, d
# 求 ax + by = c 的特解
def liEu(a, b, c):
x, y, d = exgcd(a, b)
if c % d != 0: return None
# x0 + k * (b // d)
# y0 - k * (a // d)
return x * (c // d) % (b // d), y * (c // d) % (a // d)
同余
两个整数 \(a\), \(b\), 若它们除以正整数 \(m\) 所得的余数相等,则称 \(a\), \(b\) 对于模 \(m\) 同余, 读作 \(a\) 同余于 \(b\) 模 \(m\), 或读作 \(a\) 与 \(b\) 关于模 \(m\) 同余。
同余性质
- 保持基本运算:
- 除法原理:若 \(ka\equiv kb\quad({\mathrm{mod}}\quad m)\) 且 \(k,m\) 互质,则 \(a\equiv b\quad({\mathrm{mod}}\quad m)\)
乘法逆元¶
定义:
对 \(\forall a \in \mathbb{Z}\),如果 \(a\) 与 \(p\) 互质,则 \(\exists x \in \mathbb{Z}\) 使得 $ a \times x \equiv 1 (\text{ mod } p)$,则 \(x\) 为 \(a\) 在模 \(p\) 意义下的乘法逆元,记作 \(a^{-1}\)。
乘法逆元满足:对于 \(b \div a \mod p\),如果有 \(a, p\) 互质,则 \(b \div a \mod p = b \times a^{-1} \mod p\)
例如,\(2 \times 5 \equiv 1 (\mod 9)\),则 2和5互相是对方在模9意义下的乘法逆元。所以对于 \(8 \div 5 \mod 9 = 8 \times 2 \mod 9 = 7\)
快速幂求逆元
条件: 1. 模数 \(b\) 是质数。2. \(a \bmod b \ne0\)
\(a\) 模 \(b\) 的乘法逆元是 \(a^{b-2} \bmod b\)
def qmi(a, n, p):
res = 1
while n:
if n & 1: res = res * a % p
a = a * a % p
n >>= 1
return res
def inv(a, b): # 确保 b 是素数且 a % b != 0
return qmi(a, b - 2, b)
扩展欧几里得求逆元
条件:1. \(gcd(a,b)=1\) 。 2. \(a \bmod b \ne0\)
- 求 \(a\) 在模 \(b\) 意义下的逆元
- 求线性同余方程 \(ax\equiv1 (\bmod b)\) 的解 \(a^{-1}\)
转化为方程 \(ax+by=1\) 的解 \(x\), 时间复杂度:$ O(log(min(a, b)))$
def exgcd(a, b):
if b == 0: return 1, 0, a
x, y, d = exgcd(b, a % b)
return y, x - a // b * y, d
# 求 ax + by = c 的特解
def liEu(a, b, c):
x, y, d = exgcd(a, b)
if c % d != 0: return None
return x * (c // d) % (b // d), y * (c // d) % (a // d)
def inv(a, b):
# ax mod b = 1
# ax + by = 1
x, y = liEu(a, b, 1)
return x
线性求逆元¶
P3811 【模板】模意义下的乘法逆元 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
求 \(1,2, \cdots, n\) 中每个数关于 \(p\) 的逆元。
已知 \(1\times 1\equiv1(\bmod b)\) 恒成立。对于 \(i\) 的逆元 \(i^{-1}\),考察模数 \(p\) 关于 \(i\) 的倍数 \(k = \lfloor \frac{p}{i} \rfloor\),\(j=p \bmod i\),有 \(p=ki+j\)。放到模 \(p\) 意义下,\(ki+j \equiv0 (\bmod p)\),同乘 \(i^{-1} \times j ^{-1}\),得 \(kj^{-1} + i^{-1} \equiv 0(\bmod p)\),移项 \(i^{-1} \equiv-kj^{-1} (\bmod p) \equiv - \lfloor \frac{p}{i} \rfloor (p \bmod i)^{-1} (\bmod p)\)
递推式:
逆元解决除法取模
如果 \(b\) 与 \(p\) 互质,对于 \(\forall ~a,\) 如果 \(a /b\) 是整数,则一定存在乘法逆元 \(x\),使得 \(\frac{a}{b} \equiv a \cdot x (mod~p)\)。 \(x\) 是 \(b\) 的乘法逆元,记为 \(b^{-1}\)(\(b\) 模 \(p\) 的逆元)。\(b\) 的乘法逆元存在的充要条件:\(b\) 和 \(p\) 互质。
逆元性质
- \(b \cdot b^{-1} \equiv1(mod~p)\) (证明:对定义式两边同乘 \(b\),得到 \(a\equiv a\cdot b ^{-1} \cdot b (mod ~p)\),由于 \(a\) 是 \(b\) 的倍数,且 \(b\) 与 \(p\) 互质,所以 \(a\) 与 \(p\) 互质,满足同余的除法原理。)
- 当模数 \(p\) 为素数时,\(b^{-1}_{\bmod p} =b^{p-2}\)。(证明:特殊情况下,对于质数 \(p\),由费马小定理得 \(b^{p-1}\equiv1(mod~p)\),则可知,要求与其互质的数 \(b\) 的逆元 \(x\) 满足 \(bx \equiv 1(mod~p)\),\(b^{-1} =x=b^{p-2}\),可以使用快速幂求。
中国剩余定理¶
条件:整数 \(m_1,m_2,\ldots,m_n\) 两两互质
引理:寻找整数 \(y_1\) 满足 \(y_1\) 除以 \(3\) 余 \(1\)、除以 \(5\) 余 \(0\)、除以 \(7\) 余 \(0\)。
\(y_1\) 一定是 \(5 \times 7 = 35\) 的倍数,设 \(y_1 = 35k\) ,则有 \(35k \equiv 1 (\bmod 3)\),此时 \(k\) 是 \(35\) 模 \(3\) 的逆元
对于任意的整数 \(a_1,a_2,\ldots,a_n\) , 方程组
的最小非负整数解:
记:
P1495 【模板】中国剩余定理(CRT)/ 曹冲养猪 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
def exgcd(a, b):
if b == 0: return 1, 0, a
x, y, d = exgcd(b, a % b)
return y, x - a // b * y, d
def liEu(a, b, c):
# ax + by = c 的解
x, y, d = exgcd(a, b)
a //= d
b //= d
c //= d
return x * c % b, y * c % a
def inv(a, b):
x, _ = liEu(a, b, 1)
return x
def CRT(a, m):
M = 1
res = 0
for mi in m: M *= mi
for ai, mi in zip(a, m):
ci = M // mi
res = (res + ai * ci * inv(ci, mi)) % M
return res
扩展中国剩余定理¶
P4777 【模板】扩展中国剩余定理(EXCRT) - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
删去条件 \(m_i\) 两两互质。
前两个方程:\(x\equiv a_1(\bmod m_1),x\equiv a_2(\bmod m_2)\),转换成不定方程 \(x=pm_1 +a_1=qm_2+a_2\),移项得到 \(pm_1-qm_2=a_2-a_1\),由裴蜀定理,当 \(a_2-a_1\) 不是 \(gcd(m_1,m_2)\) 的倍数时,整个方程无解;否则得到 \(gcd\) 方程的特解 \((p_0,q_0)\),由扩展欧几里得,原方程其通解为 $(P=p_0 \times\frac{a_2-a_1}{gcd(m_1, m_2)} \bmod \frac{m_2}{gcd(m_1, m_2)}+ k\frac{m_2}{gcd(m_1, m_2)},~~~~Q=q_0\times \frac{a_2-a_1}{gcd(m_1,m_2)} \bmod \frac{m_1}{gcd(m_1, m_2)} - k\frac{m_1}{gcd(m_1, m_2)}) $。
代入有 \(x=Pm_1+a_1=p'_0m_1+a_1+k\times \frac{m_1m_2}{gcd(m_1,m_2)}=p_0'm_1+a_1+k \cdot \text{lcm}(m_1, m_2)\),所以可以写成 \(x\equiv a_1' (\bmod m_1'),a_1'=p_0'm_1+a_1,m'_1=\text{lcm}(m_1,m_2)\)。每一次合并可以使得方程数量减少 1,只需要合并 \(n-1\) 次。
每一次将 \(m_1 \leftarrow abs(\frac{m_1m_2}{gcd(m_1,m_2)}),~a_1 \leftarrow p'_0m_1 + a1\)。
def exgcd(a, b):
if b == 0: return 1, 0, a
x, y, d = exgcd(b, a % b)
return y, x - a // b * y, d
def liEu(a, b, c):
x, y, g = exgcd(a, b)
a, b, c = a // g, b // g, c // g
return x * c % b, y * c % a, g
def ex_CRT(a, m):
a1, m1 = a[0], m[0]
for i in range(1, len(a)):
a2, m2 = a[i], m[i]
# 构造 p * m1 + a1 = q * m2 + a2
p, q, g = liEu(m1, -m2, a2 - a1)
if (a2 - a1) % g != 0: return -1
a1 = p * m1 + a1
m1 = abs(m1 * m2 // g)
return a1
[P8807 蓝桥杯 2022 国 C] 取模 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
给定 \(n,m\), 问是否存在两个不同的数 \(x,y\) 使得 \(1\leq x<y\leq m\) 且 \(n\bmod x=n\bmod y\)。
考虑反面情况,当且仅当对于任意 \([1,m]\) 的数 \(x\) ,\(n \bmod x\) 两两不相等,则不成立。由于对于任意的 \(n\),\(n \bmod 1=0\),所以 \(n \bmod 2\) 只能取 1,同理 $n \bmod 3 $ 只能取 2,可以得到 \(n \bmod k\) 必须取 \(k-1\)。所以当且仅当 \(\forall k \in [1, m] 有 n \bmod k =k-1\) 恒成立,才不存在。
def solve():
n, m = map(int, input().split())
for k in range(1, m + 1):
if n % k != k - 1:
return 'Yes'
return 'No'
由扩展中国剩余定理,\(m_1 = 1, a_1=0\),每一次将 \(m_1 \leftarrow abs(\frac{m_1m_2}{gcd(m_1,m_2)}),~a_1 \leftarrow p'_0m_1 + a1\)。最终 \(m_1 = lcm(1,2,\cdots,m)\),\(a_1 = -1\),有方程 \(x\equiv -1(\bmod m_1)\),所以必然有当 $ n \bmod L=-1 即L-1$ 时,不存在返回 'NO'。由于当 \(m\) 比较大的时候,其 \(L\) 值增长速度快,值已经超过 \(n\),此时 \(n \bmod L = n < L-1\) 一定返回 'Yes'。
def gcd(a, b):
return gcd(b, a % b) if b else a
def lcm(a, b):
return a * b // gcd(a, b)
Lcm = [1] * 21
for i in range(2, 21):
Lcm[i] = lcm(i, Lcm[i - 1])
def solve():
n, m = map(int, input().split())
if m > 20: return 'Yes'
L = Lcm[m]
if n % L == L - 1: return 'No'
return 'Yes'
离散数学¶
容斥¶
给你一个正整数 n ,请你计算在 [1,n] 范围内能被 3或者5或者7 整除的所有整数之和。
返回一个整数,用于表示给定范围内所有满足约束条件的数字之和。
利用等差数列求和公式:\(1 \sim n 中 能被 x 整除的数之和 = (x + 2\cdot x+ \cdots + n//x \cdot x) = x \cdot(1 + n // x) * (n // x) // 2\)
因而,
class Solution:
def sumOfMultiples(self, n: int) -> int:
# 定义 f(x) 为能被 x 整除的数字之和
def f(x):
return x * (1 + n // x) * (n // x) // 2
return f(3) + f(5) + f(7) - f(15) - f(21) - f(35) + f(105)
3116. 单面值组合的第 K 小金额 - 力扣(LeetCode)
容斥 + 预处理最小公倍数:给定无重复集合 \(coins'\),\(1 \sim x\) 中,能对任意一个 \(coins'\) 中元素整除的个数 为 \(check(x)\)。
将问题转换成,恰好能有 不少于 \(k\) 个数被任意一个 \(coins'\) 中元素整除的 \(x\) 的值,使用二分答案。回溯法枚举子集,预处理所有 \(coins'\) 子集的最小公倍数,所有相同长度、为 \(l\) 的子集的最小公倍数存放在 \(dic[l]\) 中。 对于任意一个数 \(y\) ,\(1 \sim x\) 中能被它整除的个数为 \(int(x/y)\)。
def findKthSmallest(self, coins: List[int], k: int) -> int:
coins.sort()
if coins[0] == 1: return k
c = set(coins)
n = len(coins)
for i in range(n):
for j in range(i + 1, n):
x, y = coins[i], coins[j]
if y % x == 0 and y in c:
n -= 1
c.remove(y)
coins = list(c)
# 预处理:dic [i] 表示 从 coins 中选出 i 个数的子集的最小公倍数
dic = defaultdict(list)
dic[1] = coins
# 回溯枚举子集
path = []
def dfs(i):
l = len(path)
if i == n:
if l >= 2:
lcm_ = path[0]
for j in range(1, l):
lcm_ = lcm(lcm_, path[j])
dic[l].append(lcm_)
return
dfs(i + 1)
path.append(coins[i])
dfs(i + 1)
path.pop()
dfs(0)
def check(x):
# 检查 1 ~ x 中,能被任意一个 c 整除的个数 res 和 k 的关系
res = 0
for l in range(1, n + 1):
plus = l & 1
for d in dic[l]:
res = res + (1 if plus else -1) * (x // d)
return res >= k
lo, hi = 0, 5 * 10 ** 10 + 10
while lo < hi:
mid = (lo + hi) >> 1
if check(mid):
hi = mid
else:
lo = mid + 1
return lo
鸽巢原理 / 抽屉原理¶
常用于求解最坏情况下的解,以及证明不存在解(连最坏情况下,都不存在解,则所有情况不存在解)。
鸽巢原理定理
有 \(n+1\) 只鸽子,飞入 \(n\) 个鸽子巢,则至少有一个巢里有不少于两只鸽子。(反证法:假设没有一个巢中有不少于两只鸽子,则鸽子总数不会超过 \(n\), 矛盾)。即将 \(n+1\) 个物体,划分成 \(n\) 组,至少有一组有不少于两个物体。
推广:将 \(n\) 个物体,划分成 \(k\) 组,至少有一组不少于 \(\left\lceil \frac{n}{k} \right\rceil\)。(证明:反证法,假设所有组少于 \(\left\lceil \frac{n}{k} \right\rceil\),则至多 \((\left\lceil \frac{n}{k} \right\rceil - 1) \times k < (\frac{n}{k}) \times k=n\) 个物体,矛盾)
例如,53 个物体,分成 6 组,最坏情况下是 9,9,9,9,9,8。
简单应用:
- 任意 11 个整数中,至少有 2 个整数之差是 10 的倍数。(证明,从余数角度来看,11 个数对 10 的余数有 11 个,一共有 10 种余数 0 ~ 9,至少有两个数对 10 同余,故其差也对 10 同余)
- 一个人骑车 10 小时内走完了 281 公里路程,已知他第一小时走了 30 公里,最后一小时走了 17 公里。证明:他一定在某相继的两小时中至少走完了 58 公里路程。(证明:8 小时走了 234 公里,234 个物品分到 8 组,最坏情况下,至少有一组是 \(\lceil \frac{234}{8} \rceil=29\) 公里,其余各组是 \(28\) 公里,那么第一小时和第二小时一定至少有 58 公里)
Ramsey 定理 / 拉姆齐定理
任意 \(n\) 个人,必然有 \(\lceil (n-1)/2 \rceil\) 个人相互认识 或者 相互不认识。(证明:考虑其中一个人的视角,剩下 n - 1 个人需要划分成两组,\(k1\) 表示与它认识,\(k2\) 表示与他不认识,其中一组至少为 \(\lceil (n-1)/2 \rceil\) 个人。)
隔板法
Gardon 有 \(1\leq K\leq10^6\) 种糖果,第 \(i\) 种糖果有 \(1\leq a_i\leq10^6\) 个,Gradon 不喜欢连续两次吃同样种类的糖果,问是否存在可行的吃糖方案。
即给定 \(K\) 组物品,各组物品个数有 \(a_i\) 个,每次从任意一组种拿出一个,要求连续两个物品不能属于同一组。是否存在一种方案能拿完所有物品?
找到最大个数的数量 \(N\), 其余物品数量之和为 \(S\), 通过隔板法,当且仅当形成的 \(N-1\) 个右侧隔间每个至少有一个其他物品时满足条件。所以当 \(S \ge N-1\) 成立,反之不成立。
数学公式¶
排序不等式¶
结论:对于两个有序数组的乘积和,顺序和 \(\ge\) 乱序和 \(\ge\) 倒序和。
对于 \(a_1 \le a_2 \le \cdots \le a_n,b_1 \le b_2 \le \cdots \le b_n\),并有 \(c1,c2,\cdots, c_n\) 是 \(b1, b2, \cdots , b_n\) 的乱序排列。有如下关系:
当且仅当 \(a_i = a_j\) 或者 \(b_i = b_j \space (1 \le i, j\le n)\) 时,等号成立。
区间递增 k 个数¶
结论:对于 \(i_0 = a\),每次递增 \(k\),在区间 \([a, b)\) 内的个数是:
平均数不等式¶
即:调和平均数 ,几何平均数,算术平均数,平方平均数 (调几算方)
应用:
例如当算术平均数为定值,\(x_i\) 分布越接近,平方平均数越小,因此可以进行贪心算法:
3081. 替换字符串中的问号使分数最小 - 力扣(LeetCode) 各个字母之间的出现次数的差异越小,越均衡,最终结果越小。可以基于贪心 + 堆进行维护,每次取出出现次数最小中字典序最小的字符。
def minimizeStringValue(self, s: str) -> str:
cnt = Counter(s)
hq = [(cnt[ch], ch) for ch in string.ascii_lowercase]
heapq.heapify(hq)
alp = []
res = list(s)
for i in range(s.count('?')):
v, k = heappop(hq)
v += 1
alp.append(k)
heappush(hq, (v, k))
alp.sort(reverse = True)
for i, x in enumerate(res):
if res[i] == '?':
res[i] = alp.pop()
return ''.join(res)
求和公式¶
取模性质¶
模运算与基本四则运算有些相似,但是除法例外。其规则如下: (a + b) % p = (a % p + b % p) % p (a - b) % p = (a % p - b % p) % p (a * b) % p = (a % p * b % p) % p a ^ b % p = ((a % p)^b) % p 结合律: ((a+b) % p + c) % p = (a + (b+c) % p) % p ((a b) % p * c)% p = (a * (b c) % p) % p 交换律: (a + b) % p = (b+a) % p (a * b) % p = (b * a) % p 分配律: (a+b) % p = ( a % p + b % p ) % p ((a +b)% p * c) % p = ((a * c) % p + (b * c) % p) % p
数列¶
等比数列求和公式
递推方法求等比数列求和(带模运算)
希望求:\(S(a,n) \bmod p=(a^0+a^1+\cdots+a^n) \bmod p\),不难发现 \(S(a,n)=a\cdot \big(S(a,n-1) \big) + 1\)。
时间复杂度:\(O(n)\)
组合数学¶
排列
\(A_m^n = \frac{m!}{n!}\)
\(\mathrm{A}_n^m=n\mathrm{A}_{n-1}^{m-1}\)
递推公式:可理解为“某特定位置”先安排,再安排其余位置。
组合数学
\(~ C_m^n = \frac{m!}{n!(m-n)!}\)
\(C_m^n = C_m^{m-n}\)
递推公式:
\(C_m^n = C_{m -1}^n + C_{m-1}^{n-1}\)
@lru_cache(None)
def C(n, m):
if m == 0 or n == m: return 1
return C(n - 1, m - 1) + C(n - 1, m)
\(C_n^0+C_n^1 + \cdots+ C_n^n = 2 ^ n\)
路径方案数 \(= C(n+m-2,m-1)\)
@lru_cache(None)
def C(n, m):
if m == 0 or n == m: return 1
return C(n - 1, m - 1) + C(n - 1, m)
class Solution:
def uniquePaths(self, m: int, n: int) -> int:
m, n = m - 1, n - 1
return C(m + n, m)
二项式定理¶
卡特兰数¶
5. 卡特兰数(Catalan)公式、证明、代码、典例._c n = n+11 ( n2n )-CSDN 博客
证明方法:
看成从从 \((0,~ 0)\) 到 右上角 \((n, ~n )\),每次只能向右或者向上,向上的次数不超过向右的次数的路径数。
对于不合法的情况,超过 \(y = x\) ,即一定触碰 \(y = x + 1\), 取路径与 \(y = x + 1\) 交点中,距离原点最近的点,将路径远离原点的部分关于 \(y = x + 1\) 翻转。由于原来的终点 \((n,n)\) 关于 \(y = x + 1\) 翻转的点是 $(n - 1, n + 1) $ ,所以不合法的路径数是 \(C_{2n}^{n-1}\)
递推公式 1:
证明方法:从 \((0, 0)\) 到 \((n +1, n+1)\) 的路径数可以看成分三步:
首先从 \((0,0)\) 走到 \((i,i)\) ,其方案数为 \(H(i)\);然后从 \((i,i)\) 走到 \((n,n)\) 方案数为 \(H(n-i)\);最后从 \((n,n)\) 走到 \((n + 1, n + 1)\) 其方案数为 \(H(1)\) = 1。
递推公式 2:
推论:
前几项: 1,1,2,5,14,42,132,429,1430
- \(n\) 个节点可以构造的不同的二叉树的个数。(证明:$F(n) $ 为有 n 个节点的二叉树的所有根节点个数。其左子树的可能情况为 $F(i), i \in [0,n], $ 对应右子树的情况为 \(F(n-i),\) 乘积求和形式即为卡特兰数列的递推式。
- 从 \((0,~ 0)\) 到 右上角 \((n, ~n )\),每次只能向右或者向上,向上的次数不超过向右的次数的路径数。(即不超过 \(y = x\) )
- 一个无穷大栈,进栈顺序为 \(1, 2, ... , n\) 的出栈顺序数
- \(n\) 个左括号和 \(n\) 个右括号构成的括号序列,能够构成的有效括号序列个数。
凸多边形划分问题
在一个 n 边形中,通过不相交于 n 边形内部的对角线,把 n 边形拆分为若干个三角形,问有多少种拆分方案?
以凸多边形的一边为基,设这条边的 2 个顶点为 A 和 B。从剩余顶点中选 1 个,可以将凸多边形分成三个部分,中间是一个三角形,左右两边分别是两个凸多边形,然后求解左右两个凸多边形。
2.设问题的解 f(n),其中 n 表示顶点数,那么 f(n)= f(2)f(n-1)+f(3) f(n-2)+……+f(n-2)f(3)+f(n-1) f(2)。 其中,f(2)f(n-1)表示:三个相邻的顶点构成一个三角形,另外两个部分的顶点数分别为 2(一条直线两个点)和 n-1。 其中,f(3)f(n-2)表示:将凸多边形分为三个部分,左右两边分别是一个有 3 个顶点的三角形和一个有 n-2 个顶点的多边形。
3.设 f(2) = 1,那么 f(3) = 1, f(4) = 2, f(5) = 5。结合递推式,不难发现 f(n) 等于 H(n-2)。
快速幂¶
欧拉降幂 / 快速幂
def pow(a, n, moder):
res = 1
while n:
if n & 1: res = (res * a) % moder
n >>= 1
a = (a * a) % moder
return res
矩阵乘法时间复杂度:\(O(M_1N_2N_1)\)
矩阵乘法
moder = 10**9 + 7
def mul(a, b):
m_a, n_a = len(a), len(a[0])
m_b, n_b = len(b), len(b[0])
c = n_a # 可以加一个 n_a 和 m_b 的判等
res = [[0]*n_b for _ in range(m_a)]
for i in range(m_a):
for j in range(n_b):
tmp = 0
for k in range(c):
# tmp = (tmp + (a [i][k] * b [k][j]) % moder) % moder # 如果需要取模
tmp += a[i][k] * b[k][j]
res[i][j] = tmp
return res
矩阵快速幂¶
moder = 10 ** 9 + 7
def mul(a, b):
ma, na = len(a), len(a[0])
mb, nb = len(b), len(b[0])
# ma * nb
c = na
res = [[0] * nb for _ in range(ma)]
for i in range(ma):
for j in range(nb):
tmp = 0
for k in range(na):
tmp = (tmp + a[i][k] * b[k][j] % moder) % moder
res[i][j] = tmp
return res
def mat_pow(a, n):
res = [[6, 6]]
while n:
if n & 1: res = mul(res, a)
a = mul(a, a)
n >>= 1
return res
1411. 给 N x 3 网格图涂色的方案数 - 力扣(LeetCode)
递推方程:\(f(0)=(6, 6),~f(i)=(f(i-1,0) \times 2+f(i-1,1)\times 2,~~f(i-1,0) \times 2 + f(i-1,1)\times3)\)。
即:
moder = 10 ** 9 + 7
def mul(a, b):
ma, na = len(a), len(a[0])
mb, nb = len(b), len(b[0])
# ma * nb
c = na
res = [[0] * nb for _ in range(ma)]
for i in range(ma):
for j in range(nb):
tmp = 0
for k in range(na):
tmp = (tmp + a[i][k] * b[k][j] % moder) % moder
res[i][j] = tmp
return res
def mat_pow(a, n):
res = [[6, 6]]
while n:
if n & 1: res = mul(res, a)
a = mul(a, a)
n >>= 1
return res
class Solution:
def numOfWays(self, n: int) -> int:
m = [[2, 2],
[2, 3]]
x = mat_pow(m, n - 1)
return sum(x[0]) % moder
高等数学¶
调和级数¶
经典应用:求一个数的约数的个数期望值
- 考虑 1~n 所有的数的约数个数。
- 从筛法的角度来看,拥有约数 2 的所有的数,是 1 ~ n 中所有 2 的倍数,大约是 n // 2 个。
- 所以 1~n 所有的数的约数个数和 可以看成 所有的倍数的个数 = \(n/1 + n / 2 + n /3 + \cdots + n / n = n \sum_{i=1}^n\frac{1}{i} = n \ln n。\)
- 所以 =,从期望角度来讲,一个数 \(n\) 的约束个数的期望约是 \(\ln n\)
泰勒展开式¶
麦克劳林公式:\(x_0 = 0\)
常用展开:
所以有:
Stirling 斯特林公式¶
描述阶乘的近似阶:
所以可以得到卡特兰数列的近似:
杂项¶
区间问题¶
区间选点问题 / 最大不相交区间数量:
- 射气球问题:给定 \(n\) 个气球,用区间表示,选出最少的弓箭将所有气球都能射爆;
- 区间选点问题:给定 \(n\) 个区间,选出最少的点使得每个区间至少包含 1 个点;
- 最大不相交区间问题:给定 \(n\) 个区间,选出最多的区间,使得区间两两不相交(含端点)。
这几个问题是等价的,最大不相交区间问题中,将区间集合 \(U\) 划分成两两不相交的最多的区间集合 \(S\) 和 \(U-S\) 。任意 \(U-S\) 中区间一定和 \(S\) 中的区间相交,它们不需要额外占用更多的选点,因此 \(|S|\) 就等于最少的弓箭数量、选点数量。
对于 \(mxr\) ,考察 \(l, r\) 一共有三种情况:
- \(mxr <l\),区间个数加一,\(mxr = r\);
- \(l \le mxr <r\),不变;
- \(r \le mxr\),\(mxr =r\)。
综上,只需要在 \(l > mxr\) 时,更新 \(res, mxr\),否则在 \(r < mxr\) 时,更新 \(mxr\) 。
复杂度:\(O(n)\)
452. 用最少数量的箭引爆气球 - 力扣(LeetCode)
def solve(nums):
nums.sort()
mxr = -inf
res = 0
for l, r in nums:
if l > mxr:
res, mxr = res + 1, r
elif r < mxr:
mxr = r
# 或者 mxr = min(mxr, r)
return res
区间分组:
给定 \(n\) 个区间,要求将其分成最少的组,每个组各个区间之间两两不相交(含端点)。
维护所有分组的 \(mxr\),对于 \(mxr_{\min}\) ,考察 \(l\) 与之关系。
- \(l > mxr_{\min}\),可以放在这个分组中
- \(l \le mxr_{\min}\),需要开辟一个新的分组,最右端点是 \(r\) 。
时间复杂度:\(O(n \log n)\)
def solve(nums):
nums.sort()
hq = []
for l, r in nums:
if hq and l > hq[0]:
heappop(hq)
heappush(hq, r)
return len(hq)
区间覆盖:
给定 \(n\) 个区间和 一个需要覆盖的线段区间 \([L, R]\),找出最少的区间使它们完全覆盖线段区间 \([L,R]\);如果无法完成输出 -1。
考察当前已经覆盖到的最远端 \(mxr\) ,考察所有 \(l \le mxr\),如果第一个 \(l > mxr\),无法覆盖线段区间 \([mxr+1,]\) 记录这些区间的最远右端点 \(nmxr\),然后更新 \(mxr\) 。一旦超过 \(R\) 表示完成覆盖。
def solve(L, R, nums):
n = len(nums)
nums.sort()
mxr = L
i = res = 0
while i < n:
l, r = nums[i]
if l > mxr: return -1
nmxr = r
while i + 1 < n and nums[i + 1][0] <= mxr:
nmxr = max(nmxr, nums[i + 1][1])
i += 1
res += 1
mxr = nmxr
if mxr >= R: return res
i += 1
return -1
合并区间:
先排序。
class Solution:
def merge(self, intervals: List[List[int]]) -> List[List[int]]:
intervals.sort()
res = []
l, r = intervals[0][0], intervals[0][1]
for interval in intervals:
il, ir = interval[0], interval[1]
if il > r:
res.append([l, r])
l = il
r = max(r, ir)
res.append([l, r])
return res
2580. 统计将重叠区间合并成组的方案数 - 力扣(LeetCode)
def countWays(self, ranges: List[List[int]]) -> int:
ranges.sort(key = lambda x: x[0])
l, r = ranges[0][0], ranges[0][1]
nranges = []
for il, ir in ranges:
if il > r:
nranges.append([l, r])
l = il
r = max(ir, r)
区间交集:
\(Lo,Hi\) 记录当前可变温度区间。每次来到新时刻,更新为 \([Lo-dt,~Hi+dt]\)。判断该区间是否和当前 \([lo, hi]\) 相交。是则求其交集。
def solve():
n, m = map(int, input().split())
tem = [(0, m, m)]
for _ in range(n):
at, lo, hi = map(int, input().split())
tem.append((at, lo, hi))
Lo = Hi = m
for i in range(1, n + 1):
at, lo, hi = tem[i]
dt = at - tem[i - 1][0]
Lo, Hi = Lo - dt, Hi + dt
if Lo > hi or Hi < lo: return 'NO'
Lo, Hi = max(Lo, lo), min(Hi, hi)
return 'YES'
贡献法¶
经典问题:子数组的最小值之和,子数组的最大值之和,子数组的极差之和。
- 套娃式定义,如子数组的子数组,子序列的子序列
- 求某些的和,可以考虑成子子问题对总问题的贡献
考虑每个值对子数组最大值,最小值的贡献情况,用单调栈维护。
最大值用减小栈维护,贡献是 \((i - t) \times (t - stk[-1]) \times nums[t]\)
def subArrayRanges(self, nums: List[int]) -> int:
res = 0
stk = [-1]
total_mx = 0 # 贡献
nums.append(inf)
for i, x in enumerate(nums):
# 单调减
while len(stk) > 1 and x >= nums[stk[-1]]:
t = stk.pop()
total_mx += (i - t) * (t - stk[-1]) * nums[t]
stk.append(i)
stk = [-1]
nums[-1] = -inf
total_mn = 0
for i, x in enumerate(nums):
# 单调增
while len(stk) > 1 and x <= nums[stk[-1]]:
t = stk.pop()
total_mn += (i - t) * (t - stk[-1]) * nums[t]
stk.append(i)
return total_mx - total_mn
IDLE 使用¶
【蓝桥杯】Python 自带编辑器 IDLE 的使用教程_python 蓝桥杯编译器-CSDN 博客
- 新建文件:Ctrl + N
- 运行:F5
- 打开文档:F1(蓝桥杯可用)
- 字体:Options -> Configure IDLE -> Font/Tabs,建议:Consolas, 18
- alt+3 注释代码
- alt+4 取消注释
- Tab 多行缩进,
Ctrl + [取消缩进
ceil 精度处理¶
同时存在除法和 \(ceil\) 运算时, \(ceil(a /b)\) 以及 \(ceil(a/b +x)\) 操作会由于精度问题,导致偏差。
方法 1:\(ceil(x) = math.ceil(x - eps)\),其中 $eps $ 是小常量,如 \(10^{-8}\)
方法 2:所有数乘 \(b\) ,其中 \(b \times ceil(a,b) = ((a-1)//b + 1) \times b\) ,例如 $ ceil(4,3) \times3=2\times 3=6$。
1883. 准时抵达会议现场的最小跳过休息次数 - 力扣(LeetCode)
\(f(i, j) =\min \{f(i - 1, j - 1) + d[i]/s,\ ceil(f(i - 1, j) + d[i]/s)\}\)
方法 1:
eps = 1e-8
def ceil(x):
return math.ceil(x - eps)
def minSkips(self, d: List[int], s: int, hoursBefore: int) -> int:
n = len(d)
if sum(d) > s * hoursBefore: return -1
if n == 1: return 0 if d[0] <= s * hoursBefore else -1
mx = sum(d) + n
f = [[mx] * (n + 1) for _ in range(n + 1)]
d = [D / s for D in d]
f[0][0] = ceil(d[0])
f[0][1] = d[0]
for i in range(1, n - 1):
for j in range(i + 2):
f[i][j] = ceil(f[i - 1][j] + d[i]) # 不休息
if j: f[i][j] = min(f[i][j], f[i - 1][j - 1] + d[i]) # 休息
for k in range(n):
if f[n - 2][k] + d[-1] <= hoursBefore:
return k
方法 2:
def ceil(a, b):
return ((a - 1) // b + 1) * b
def minSkips(self, d: List[int], s: int, hoursBefore: int) -> int:
n = len(d)
if sum(d) > s * hoursBefore: return -1
if n == 1: return 0 if d[0] <= s * hoursBefore else -1
mx = sum(d) + n
f = [[mx] * (n + 1) for _ in range(n + 1)]
f[0][0] = ceil(d[0], s)
f[0][1] = d[0]
for i in range(1, n - 1):
for j in range(i + 2):
f[i][j] = ceil(f[i - 1][j] + d[i], s)
if j: f[i][j] = min(f[i][j], f[i - 1][j - 1] + d[i])
for k in range(n):
if f[n - 2][k] + d[-1] <= hoursBefore * s:
return k
离散化¶
二分写法
字典写法
sorted_nums = sorted(set(nums))
mp = {x: i + 1 for i, x in enumerate(sorted_nums)}
nums = [mp[x] for x in nums]
二分 + 还原
tmp = nums.copy()
sorted_nums = sorted(set(nums))
nums = [bisect.bisect_left(sorted_nums, x) + 1 for x in nums]
mp_rev = {i: x for i, x in zip(nums, tmp)}
日期问题¶
蓝桥杯必备模块——datetime 轻松应对各类时间问题_python 时间差计算: 轻松应对日期和时间的差异-CSDN 博客
Python datetime 模块详解、示例-CSDN 博客
from datetime import *
# 计算日期之差
t1 = date(year = 2025, month = 3, day= 1)
t2 = date(year = 2025, month = 3, day = 17)
print(t2 - t1) # 16 days, 0:00:00
print((t2 - t1).days) # 16
# 参数也可简化
t1 = date(2025, 1, 1)
t2 = date(2025, 3, 17)
print((t2 - t1).days) # 75
# 获取当前日期
print(date.today()) # 2025-03-17
t1 = date(2025, 3, 17)
print(t1 + timedelta(days = 2)) # 2025-03-19
print(t1 + timedelta(1)) # 2025-03-18
print(t1.weekday()) # 0,weekday()从 0~6 对应星期 1~7
from datetime import *
t1 = date(1901, 1, 1)
t2 = date(2000, 12, 31)
res = 0
while t1 <= t2:
if t1.weekday() == 0:
res += 1
t1 += timedelta(days = 1)
print(res)
Python 的
datetime模块支持的日期范围是从0001-01-01到9999-12-31。如果你尝试处理超出这个范围的日期,就会引发这个错误。





















