光剑的后日谈 #3. 这次来聊聊什么是许愿机。
If bird born with no shackles#
现在我们需要写一段代码,用矩阵快速幂求解一个递推数列:
$$
\begin{cases}
f(0)=f(1)=f(2)=1 \\
f(n)=f(n-1)+2f(n-2)+5f(n-3), n \geq 3
\end{cases}
$$
我们知道我们将数列中的相邻几项写成一个向量 $(f(i), f(i+1), f(i+2))$,然后找到一个矩阵乘上它转移到下一个向量 $(f(i+1), f(i+2), f(i+3))$,那么这个矩阵怎么算呢?
猜肯定是一个办法,但肯定不好。手动求解,待定系数也是一种方法。
然而,我们还有更加优雅的做法。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
| from z3 import *
def solve_transition_matrix():
# 1. 创建求解器
s = Solver()
# 2. 定义我们需要求解的 3x3 转移矩阵 M
# 变量 m_r_c 代表矩阵第 r 行第 c 列的元素 (全是整数)
M = [[Int(f'm_{r}_{c}') for c in range(3)] for r in range(3)]
# 3. 定义“任意”时刻的输入向量状态
# 设 f0, f1, f2 分别代表 f(i), f(i+1), f(i+2)
# 我们使用 Ints 创建符号变量,这些将作为全称量词的变量
f0, f1, f2 = Ints('f0 f1 f2')
# 构造当前向量 V_i
v_current = [f0, f1, f2]
# 4. 根据递推公式定义期望的输出向量 V_{i+1}
# 题目递推式: f(n) = f(n-1) + 2f(n-2) + 5f(n-3)
# 对应到我们的符号: 下一项 f3 = f2 + 2*f1 + 5*f0
f3 = f2 + 2 * f1 + 5 * f0
# 构造下一刻向量 V_{i+1} = [f(i+1), f(i+2), f(i+3)]
v_next = [f1, f2, f3]
# 5. 核心逻辑:构造约束
# 约束条件:M * v_current == v_next
# 这个等式必须对“所有”的 f0, f1, f2 都成立
constraints = []
for r in range(3):
# 矩阵乘法:第 r 行 点乘 输入向量
row_product = sum(M[r][c] * v_current[c] for c in range(3))
# 结果必须等于输出向量的对应项
constraints.append(row_product == v_next[r])
# 使用 ForAll (全称量词)
# 读作:对于任意整数 f0, f1, f2,上述约束(And(constraints))都必须成立
s.add(ForAll([f0, f1, f2], And(constraints)))
# 6. 求解并打印结果
if s.check() == sat:
m = s.model()
print("成功找到转移矩阵 M:")
print("-" * 15)
# 将 z3 的解转换为 Python 整数并格式化输出
for r in range(3):
row_vals = [m.evaluate(M[r][c]).as_long() for c in range(3)]
print(f"| {row_vals[0]:^3} {row_vals[1]:^3} {row_vals[2]:^3} |")
print("-" * 15)
# 验证一下文章开头的直觉
# 第一行应该是 0 1 0 (输出 f1)
# 第二行应该是 0 0 1 (输出 f2)
# 第三行应该是 5 2 1 (输出 5f0 + 2f1 + 1f2)
else:
print("未找到满足条件的矩阵。")
if __name__ == "__main__":
solve_transition_matrix()
|
用 Python 运行这段代码(当然你需要先使用 pip install z3-solver 来解决库的依赖),你会得到这样的输出:
1
2
3
4
5
6
| 成功找到转移矩阵 M:
---------------
| 0 1 0 |
| 0 0 1 |
| 5 2 1 |
---------------
|
看上去非常对。
当然你可能还会说:“这不就是一个自动解线性方程组的机器吗?没什么特别的。”不过 Z3 的威力不止于此。假如我们要求转移矩阵中不能出现素数,高斯消元等线性代数方法就爆炸了。而求解器依然能够工作。
或者看看下面这个场景:
你正在写一段底层高性能代码,需要计算一个 32 位整数的绝对值 abs(x)。
但是,CPU 的分支预测(Branch Prediction)失败代价很高,你不希望使用 if (x < 0) 这种跳转指令。
你听说高手都用位运算骚操作(Bit Hacks)来实现无分支编程,但你不知道公式是什么。
我们给出一个计算模板,让 Z3 帮我们要找到对应的位移常数,自动“写”出这段黑客代码。
我们猜测公式大概长这样:(x + A) ^ A (这是一种常见的异或技巧)。
我们要让 Z3 帮我们确定 A 到底应该等于多少(A 可能是 x 移位后的结果)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
| from z3 import *
def synthesize_abs_hack():
# 1. 创建求解器
s = Solver()
# 2. 定义输入 x 是一个 32 位的位向量 (BitVector)
x = BitVec('x', 32)
# 3. 定义我们的“目标规范” (Specification)
# 即我们希望代码实现的功能:如果是负数取反,否则不变
# 注意:在位运算世界里,-x 等价于 (~x + 1)
target = If(x < 0, -x, x)
# 4. 定义我们的“实现模板” (Implementation Template)
# 我们猜测无分支绝对值可以通过 (x + mask) ^ mask 来实现
# 其中 mask 是 x 向右移动 k 位得到的
# 我们不知道 k 是多少,让 Z3 去找这个 k
k = BitVec('k', 32) # k 是我们要寻找的“魔法常数”
# 算术右移 (x >> k) 生成掩码
mask = x >> k
# 我们的猜想公式
implementation = (x + mask) ^ mask
# 5. 核心约束:等价性验证
# 我们要求:对于“任意”一个 32 位整数 x,实现必须等于目标
# 注意:k 必须是一个固定的常数,不能随 x 变化,所以 k 在 ForAll 之外
s.add(ForAll([x], implementation == target))
# 6. 添加一些合理的范围约束,加速求解
# 移位位数 k 应该在 0 到 31 之间
s.add(k >= 0, k < 32)
# 7. 求解
if s.check() == sat:
m = s.model()
print("Z3 找到了魔法常数 k!")
k_val = m[k].as_long()
print(f"k = {k_val}")
print("-" * 30)
print("生成的无分支绝对值代码 (C/C++):")
print(f"int abs_hack(int x) {{")
print(f" int mask = x >> {k_val};")
print(f" return (x + mask) ^ mask;")
print(f"}}")
else:
print("也就是个猜想,看来这个模板行不通。")
if __name__ == "__main__":
synthesize_abs_hack()
|
很快,你就会得到这样的输出:
1
2
3
4
5
6
7
8
| Z3 找到了魔法常数 k!
k = 31
------------------------------
生成的无分支绝对值代码 (C/C++):
int abs_hack(int x) {
int mask = x >> 31;
return (x + mask) ^ mask;
}
|
于是我们就解决了这个问题。
What is meant by miracle?#
聊了这么久,你可能会发现我还没有讲 Z3 是什么?
事实上,正如这篇文章的标题,Z3 是一个许愿机。
它是一个 SMT(模理论可满足性)求解器。核心是一个 SAT(布尔可满足性,就是那个著名的 NPC)求解器,外部加上了许多理论的扩展。
理论扩展使得 SMT 求解器远比 SAT 更强大。
SAT 中变量是原子,求解器看不到里面,就像 0 阶逻辑(命题逻辑)。
而 SMT 中,理论让求解器可以拆开一个公式子句,并获取不同子句间更深层的关系(比如 x>0 和 x<0 是矛盾的)。也让我们能够用更丰富的语言写给求解器的公式。
我们用求解器支持的语言给它一个公式,它就自动判定这个公式是否是可满足的,如果可满足,就会给出一组解。而当不满足时,求解器会告诉我们“Unsat core”,即哪些约束矛盾了。在调试时,这是有用的。
就像我们用公式向许愿机许下一个愿望,它就自动显灵。
A word outside my days#
强大的奇迹自然也有它的代价。众所周知,SAT 是一个 NPC 问题。也就是说,SMT 求解器在最坏情况下的复杂度是指数级(或至少超多项式)的。
而 Z3 中的一些理论甚至比 SAT 还要困难,达到了 Pspace-complete 甚至 undecidable 边缘的程度。比如说 NRA(非线性实数算术),求解复杂度是双指数($2^{2^n}$)。而 NIA(非线性整数算术)等一些理论,甚至是不可判定的(停机问题可以规约到 NIA)
SMT 求解器之所以有用,是因为它使用了强大的启发式搜索算法(最重要的是 CDCL,冲突驱动子句学习),利用工业实例的特殊结构化特性达到了非常优秀的平均复杂度。
So now, I’ll make a dream unchained~#
上面已经展示了两个例子。不过这些还不够复杂。并且不是 racket 写的
接下来我们将用求解器解决一些更复杂的问题。这里需要提到 Rosette,一个 racket 项目,也是一个杀手级应用。你可以用 raco pkg install rosette 下载它。
2024 年提高组初赛有这么一道题:
1
2
3
| int logic(int x, int y) {
return (x & y) ^ ((x ^ y) | (~x & y));
}
|
以上函数的功能是什么?
硬分析当然可以,但是太吃操作。大部分人场上可能是代入具体值做的,然而这不能保证 100% 正确,并且场上的 D 选项还是“以上都不是”。
让我们来向许愿机许下这个愿望。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
| #lang rosette
; 1. 定义类型:模拟 C++ 的 int (32位有符号整数)
(define int32? (bitvector 32))
; ---------------------------------------------------------
; 2. 目标函数 (C++ 逻辑的精确翻译)
; ---------------------------------------------------------
; C++: (x & y) ^ ((x ^ y) | (~x & y))
;
; Rosette 映射表:
; & -> bvand
; | -> bvor
; ^ -> bvxor
; ~ -> bvnot
(define (target-logic x y)
(bvxor (bvand x y)
(bvor (bvxor x y)
(bvand (bvnot x) y))))
; ---------------------------------------------------------
; 3. 用户待验证的函数 (在这里填入你的实现)
; ---------------------------------------------------------
(define (user-logic x y)
; 猜测:原逻辑推导后其实等价于 (x | y)
; 你可以试着把这里改成 (bvxor x y) 看看验证失败的效果
(bvor x y)
)
; ---------------------------------------------------------
; 4. 验证过程
; ---------------------------------------------------------
; 定义两个符号变量,类型为 32位位向量
(define-symbolic x-sym int32?)
(define-symbolic y-sym int32?)
; 验证查询:是否存在一组输入 (x, y),使得两函数输出不相等?
; verify 会尝试寻找反例 (counterexample)
(define ce
(verify (assert (equal? (target-logic x-sym y-sym)
(user-logic x-sym y-sym)))))
; ---------------------------------------------------------
; 5. 输出结果处理
; ---------------------------------------------------------
(if (unsat? ce)
(printf "验证成功:两个函数在 32 位整数范围内是完全等价的!\n")
(begin
(printf "验证失败:函数不等价。\n")
(printf "找到反例 (Counterexample):\n")
; 从模型中提取具体的位向量值
(define ce-x (evaluate x-sym ce))
(define ce-y (evaluate y-sym ce))
; 为了方便阅读,将位向量转换为普通整数显示
(printf "x = ~a (Hex: ~a)\n" (bitvector->integer ce-x) ce-x)
(printf "y = ~a (Hex: ~a)\n" (bitvector->integer ce-y) ce-y)
(printf "------------------------\n")
(printf "Target 输出: ~a\n" (target-logic ce-x ce-y))
(printf "User 输出: ~a\n" (user-logic ce-x ce-y))
))
|
我们发现 bvor(即按位或)正是这个函数的等价物。而如果你改成 bvand/bvxor 则会得到 unsat 以及反例。这个问题就解决了。
这类东西在夺旗赛(CTF)中非常有用。你可以猜测一个被混淆了的表达式实际上是什么,然后利用 Rosette 的 verify 去验证是不是这么回事。
在形式化验证领域也非常有用。你可以写好一些规范,然后 verify 一下你的程序是否符合这些规范。一个特例是 verify 你的程序的输出是否和暴力相同,即对拍。不过形式化方法可以确保程序 100% 必定符合规范,而不像对拍可能会遗漏 corner。
到这里我们就展示了 Rosette 的 synthesize 和 verify 能力,它们一个给指定的规范寻找解(exist 量词),另一个判断指定规范是否永远正确(forall 量词)。Rosette 还有一个“天使执行”能力,可以造出让你的代码运行到某个指定 case 的数据,但是太复杂所以我们不讲(在入门中也用不到)。
racket 真是一个秘密武器。在某些领域威力简直无可匹敌。比如说形式化验证。又比如,写一个求解数独的程序。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
| #lang rosette
; 0. 初始时等待完善的数独。0 代表空位。
(define org
#(#(0 0 0 0 0 0 0 0 0)
#(0 0 0 0 0 0 0 0 0)
#(0 0 0 0 0 0 0 0 0)
#(0 0 0 0 0 0 0 0 0)
#(0 0 0 0 0 0 0 0 0)
#(0 0 0 0 0 0 0 0 0)
#(0 0 0 0 0 0 0 0 0)
#(0 0 0 0 0 0 0 0 0)
#(0 0 0 0 0 0 0 0 0)))
; 1. 定义 9x9 的符号变量矩阵
(define a
(for/vector ([i (in-range 9)])
(for/vector ([j (in-range 9)])
(let ([x (vector-ref (vector-ref org i) j)])
(if (zero? x)
(begin
(define-symbolic* sym integer?)
sym)
x)))))
; 2. 定义约束并求解
(define solution
(solve
(begin
; 约束 A: 所有数字必须在 1 到 9 之间
(for* ([row a]
[x row])
(assert (and (>= x 1) (<= x 9))))
; 约束 B: 每一行必须互不相同 (Row restriction)
(for ([row a])
; distinct? 接受变长参数,所以我们要把 vector 转为 list 并 apply
(assert (apply distinct? (vector->list row))))
; 约束 C: 每一列必须互不相同
(for ([j (in-range 9)])
(define col
(for/list ([i (in-range 9)])
(vector-ref (vector-ref a i) j)))
(assert (apply distinct? col)))
; 约束 D: 每个宫格不能有相同数字
(for* ([i '(0 3 6)]
[j '(0 3 6)])
(assert
(apply distinct?
(for*/list ([k (in-range 3)]
[l (in-range 3)])
(vector-ref (vector-ref a (+ i k)) (+ j l)))))))))
; 3. 输出结果
(if (sat? solution)
(let ([concrete-a (evaluate a solution)])
(displayln "Found a solution:")
(displayln concrete-a))
(displayln "No solution!"))
|
求解数独是(迄今为止)不能多项式的(数独是 NPC)。经典算法是 DLX。但假如说我增加一些限制,比如 9*9 网格的两条对角线上也不能有相同元素呢?或者比如说第一列有恰好两个 1,第 i 列有恰好两个 i 呢?DLX 就比较难以拓展了。
SMT 则能够方便地添加更多的约束以适应种种变体。你可以自己试一试。
我想肯定有人会问我怎么不像以前那样自己实现一个求解器……我只能说实在太难了搞不动。SMT 求解器这样的工程智慧,如果只是核心的 CDCL 和 SAT 求解其实还是能玩一玩的。但是外置的算术理论太过高深复杂了,写不动。