光剑的后日谈 #3. 这次来聊聊什么是许愿机。

If bird born with no shackles

现在我们需要写一段代码,用矩阵快速幂求解一个递推数列:

$$ \begin{cases} f(0)=f(1)=f(2)=1 \\ f(n)=f(n-1)+2f(n-2)+5f(n-3), n \geq 3 \end{cases} $$

我们知道我们将数列中的相邻几项写成一个向量 $(f(i), f(i+1), f(i+2))$,然后找到一个矩阵乘上它转移到下一个向量 $(f(i+1), f(i+2), f(i+3))$,那么这个矩阵怎么算呢?

猜肯定是一个办法,但肯定不好。手动求解,待定系数也是一种方法。

然而,我们还有更加优雅的做法。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from z3 import *

def solve_transition_matrix():
    # 1. 创建求解器
    s = Solver()

    # 2. 定义我们需要求解的 3x3 转移矩阵 M
    # 变量 m_r_c 代表矩阵第 r 行第 c 列的元素 (全是整数)
    M = [[Int(f'm_{r}_{c}') for c in range(3)] for r in range(3)]

    # 3. 定义“任意”时刻的输入向量状态
    # 设 f0, f1, f2 分别代表 f(i), f(i+1), f(i+2)
    # 我们使用 Ints 创建符号变量,这些将作为全称量词的变量
    f0, f1, f2 = Ints('f0 f1 f2')

    # 构造当前向量 V_i
    v_current = [f0, f1, f2]

    # 4. 根据递推公式定义期望的输出向量 V_{i+1}
    # 题目递推式: f(n) = f(n-1) + 2f(n-2) + 5f(n-3)
    # 对应到我们的符号: 下一项 f3 = f2 + 2*f1 + 5*f0
    f3 = f2 + 2 * f1 + 5 * f0

    # 构造下一刻向量 V_{i+1} = [f(i+1), f(i+2), f(i+3)]
    v_next = [f1, f2, f3]

    # 5. 核心逻辑:构造约束
    # 约束条件:M * v_current == v_next
    # 这个等式必须对“所有”的 f0, f1, f2 都成立

    constraints = []
    for r in range(3):
        # 矩阵乘法:第 r 行 点乘 输入向量
        row_product = sum(M[r][c] * v_current[c] for c in range(3))
        # 结果必须等于输出向量的对应项
        constraints.append(row_product == v_next[r])

    # 使用 ForAll (全称量词)
    # 读作:对于任意整数 f0, f1, f2,上述约束(And(constraints))都必须成立
    s.add(ForAll([f0, f1, f2], And(constraints)))

    # 6. 求解并打印结果
    if s.check() == sat:
        m = s.model()
        print("成功找到转移矩阵 M:")
        print("-" * 15)
        # 将 z3 的解转换为 Python 整数并格式化输出
        for r in range(3):
            row_vals = [m.evaluate(M[r][c]).as_long() for c in range(3)]
            print(f"| {row_vals[0]:^3} {row_vals[1]:^3} {row_vals[2]:^3} |")
        print("-" * 15)

        # 验证一下文章开头的直觉
        # 第一行应该是 0 1 0 (输出 f1)
        # 第二行应该是 0 0 1 (输出 f2)
        # 第三行应该是 5 2 1 (输出 5f0 + 2f1 + 1f2)
    else:
        print("未找到满足条件的矩阵。")

if __name__ == "__main__":
    solve_transition_matrix()

用 Python 运行这段代码(当然你需要先使用 pip install z3-solver 来解决库的依赖),你会得到这样的输出:

1
2
3
4
5
6
成功找到转移矩阵 M:
---------------
|  0   1   0  |
|  0   0   1  |
|  5   2   1  |
---------------

看上去非常对。

当然你可能还会说:“这不就是一个自动解线性方程组的机器吗?没什么特别的。”不过 Z3 的威力不止于此。假如我们要求转移矩阵中不能出现素数,高斯消元等线性代数方法就爆炸了。而求解器依然能够工作。

或者看看下面这个场景:

你正在写一段底层高性能代码,需要计算一个 32 位整数的绝对值 abs(x)
但是,CPU 的分支预测(Branch Prediction)失败代价很高,你不希望使用 if (x < 0) 这种跳转指令。
你听说高手都用位运算骚操作(Bit Hacks)来实现无分支编程,但你不知道公式是什么。

我们给出一个计算模板,让 Z3 帮我们要找到对应的位移常数,自动“写”出这段黑客代码。
我们猜测公式大概长这样:(x + A) ^ A (这是一种常见的异或技巧)。
我们要让 Z3 帮我们确定 A 到底应该等于多少(A 可能是 x 移位后的结果)。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from z3 import *

def synthesize_abs_hack():
    # 1. 创建求解器
    s = Solver()

    # 2. 定义输入 x 是一个 32 位的位向量 (BitVector)
    x = BitVec('x', 32)

    # 3. 定义我们的“目标规范” (Specification)
    # 即我们希望代码实现的功能:如果是负数取反,否则不变
    # 注意:在位运算世界里,-x 等价于 (~x + 1)
    target = If(x < 0, -x, x)

    # 4. 定义我们的“实现模板” (Implementation Template)
    # 我们猜测无分支绝对值可以通过 (x + mask) ^ mask 来实现
    # 其中 mask 是 x 向右移动 k 位得到的
    # 我们不知道 k 是多少,让 Z3 去找这个 k
    k = BitVec('k', 32) # k 是我们要寻找的“魔法常数”

    # 算术右移 (x >> k) 生成掩码
    mask = x >> k

    # 我们的猜想公式
    implementation = (x + mask) ^ mask

    # 5. 核心约束:等价性验证
    # 我们要求:对于“任意”一个 32 位整数 x,实现必须等于目标
    # 注意:k 必须是一个固定的常数,不能随 x 变化,所以 k 在 ForAll 之外
    s.add(ForAll([x], implementation == target))

    # 6. 添加一些合理的范围约束,加速求解
    # 移位位数 k 应该在 0 到 31 之间
    s.add(k >= 0, k < 32)

    # 7. 求解
    if s.check() == sat:
        m = s.model()
        print("Z3 找到了魔法常数 k!")
        k_val = m[k].as_long()
        print(f"k = {k_val}")
        print("-" * 30)
        print("生成的无分支绝对值代码 (C/C++):")
        print(f"int abs_hack(int x) {{")
        print(f"    int mask = x >> {k_val};")
        print(f"    return (x + mask) ^ mask;")
        print(f"}}")
    else:
        print("也就是个猜想,看来这个模板行不通。")

if __name__ == "__main__":
    synthesize_abs_hack()

很快,你就会得到这样的输出:

1
2
3
4
5
6
7
8
Z3 找到了魔法常数 k!
k = 31
------------------------------
生成的无分支绝对值代码 (C/C++):
int abs_hack(int x) {
    int mask = x >> 31;
    return (x + mask) ^ mask;
}

于是我们就解决了这个问题。

What is meant by miracle?

聊了这么久,你可能会发现我还没有讲 Z3 是什么?

事实上,正如这篇文章的标题,Z3 是一个许愿机。

它是一个 SMT(模理论可满足性)求解器。核心是一个 SAT(布尔可满足性,就是那个著名的 NPC)求解器,外部加上了许多理论的扩展。

理论扩展使得 SMT 求解器远比 SAT 更强大。

SAT 中变量是原子,求解器看不到里面,就像 0 阶逻辑(命题逻辑)。
而 SMT 中,理论让求解器可以拆开一个公式子句,并获取不同子句间更深层的关系(比如 x>0x<0 是矛盾的)。也让我们能够用更丰富的语言写给求解器的公式。

我们用求解器支持的语言给它一个公式,它就自动判定这个公式是否是可满足的,如果可满足,就会给出一组解。而当不满足时,求解器会告诉我们“Unsat core”,即哪些约束矛盾了。在调试时,这是有用的。

就像我们用公式向许愿机许下一个愿望,它就自动显灵。

A word outside my days

强大的奇迹自然也有它的代价。众所周知,SAT 是一个 NPC 问题。也就是说,SMT 求解器在最坏情况下的复杂度是指数级(或至少超多项式)的。

而 Z3 中的一些理论甚至比 SAT 还要困难,达到了 Pspace-complete 甚至 undecidable 边缘的程度。比如说 NRA(非线性实数算术),求解复杂度是双指数($2^{2^n}$)。而 NIA(非线性整数算术)等一些理论,甚至是不可判定的(停机问题可以规约到 NIA)

SMT 求解器之所以有用,是因为它使用了强大的启发式搜索算法(最重要的是 CDCL,冲突驱动子句学习),利用工业实例的特殊结构化特性达到了非常优秀的平均复杂度。

So now, I’ll make a dream unchained~

上面已经展示了两个例子。不过这些还不够复杂。并且不是 racket 写的

接下来我们将用求解器解决一些更复杂的问题。这里需要提到 Rosette,一个 racket 项目,也是一个杀手级应用。你可以用 raco pkg install rosette 下载它。

2024 年提高组初赛有这么一道题:

1
2
3
int logic(int x, int y) {
    return (x & y) ^ ((x ^ y) | (~x & y));
}

以上函数的功能是什么?

硬分析当然可以,但是太吃操作。大部分人场上可能是代入具体值做的,然而这不能保证 100% 正确,并且场上的 D 选项还是“以上都不是”。

让我们来向许愿机许下这个愿望。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#lang rosette

; 1. 定义类型:模拟 C++ 的 int (32位有符号整数)
(define int32? (bitvector 32))

; ---------------------------------------------------------
; 2. 目标函数 (C++ 逻辑的精确翻译)
; ---------------------------------------------------------
; C++: (x & y) ^ ((x ^ y) | (~x & y))
;
; Rosette 映射表:
; &  -> bvand
; |  -> bvor
; ^  -> bvxor
; ~  -> bvnot
(define (target-logic x y)
  (bvxor (bvand x y)
         (bvor (bvxor x y)
               (bvand (bvnot x) y))))

; ---------------------------------------------------------
; 3. 用户待验证的函数 (在这里填入你的实现)
; ---------------------------------------------------------
(define (user-logic x y)
  ; 猜测:原逻辑推导后其实等价于 (x | y)
  ; 你可以试着把这里改成 (bvxor x y) 看看验证失败的效果
  (bvor x y)
)

; ---------------------------------------------------------
; 4. 验证过程
; ---------------------------------------------------------

; 定义两个符号变量,类型为 32位位向量
(define-symbolic x-sym int32?)
(define-symbolic y-sym int32?)

; 验证查询:是否存在一组输入 (x, y),使得两函数输出不相等?
; verify 会尝试寻找反例 (counterexample)
(define ce
  (verify (assert (equal? (target-logic x-sym y-sym)
                          (user-logic x-sym y-sym)))))

; ---------------------------------------------------------
; 5. 输出结果处理
; ---------------------------------------------------------

(if (unsat? ce)
    (printf "验证成功:两个函数在 32 位整数范围内是完全等价的!\n")
    (begin
      (printf "验证失败:函数不等价。\n")
      (printf "找到反例 (Counterexample):\n")

      ; 从模型中提取具体的位向量值
      (define ce-x (evaluate x-sym ce))
      (define ce-y (evaluate y-sym ce))

      ; 为了方便阅读,将位向量转换为普通整数显示
      (printf "x = ~a (Hex: ~a)\n" (bitvector->integer ce-x) ce-x)
      (printf "y = ~a (Hex: ~a)\n" (bitvector->integer ce-y) ce-y)

      (printf "------------------------\n")
      (printf "Target 输出: ~a\n" (target-logic ce-x ce-y))
      (printf "User   输出: ~a\n" (user-logic ce-x ce-y))
    ))

我们发现 bvor(即按位或)正是这个函数的等价物。而如果你改成 bvand/bvxor 则会得到 unsat 以及反例。这个问题就解决了。

这类东西在夺旗赛(CTF)中非常有用。你可以猜测一个被混淆了的表达式实际上是什么,然后利用 Rosette 的 verify 去验证是不是这么回事。

在形式化验证领域也非常有用。你可以写好一些规范,然后 verify 一下你的程序是否符合这些规范。一个特例是 verify 你的程序的输出是否和暴力相同,即对拍。不过形式化方法可以确保程序 100% 必定符合规范,而不像对拍可能会遗漏 corner。

到这里我们就展示了 Rosette 的 synthesize 和 verify 能力,它们一个给指定的规范寻找解(exist 量词),另一个判断指定规范是否永远正确(forall 量词)。Rosette 还有一个“天使执行”能力,可以造出让你的代码运行到某个指定 case 的数据,但是太复杂所以我们不讲(在入门中也用不到)。

racket 真是一个秘密武器。在某些领域威力简直无可匹敌。比如说形式化验证。又比如,写一个求解数独的程序。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#lang rosette

; 0. 初始时等待完善的数独。0 代表空位。
(define org
  #(#(0 0 0 0 0 0 0 0 0)
    #(0 0 0 0 0 0 0 0 0)
    #(0 0 0 0 0 0 0 0 0)
    #(0 0 0 0 0 0 0 0 0)
    #(0 0 0 0 0 0 0 0 0)
    #(0 0 0 0 0 0 0 0 0)
    #(0 0 0 0 0 0 0 0 0)
    #(0 0 0 0 0 0 0 0 0)
    #(0 0 0 0 0 0 0 0 0)))

; 1. 定义 9x9 的符号变量矩阵
(define a
  (for/vector ([i (in-range 9)])
    (for/vector ([j (in-range 9)])
      (let ([x (vector-ref (vector-ref org i) j)])
        (if (zero? x)
           (begin
             (define-symbolic* sym integer?)
             sym)
           x)))))

; 2. 定义约束并求解
(define solution
  (solve
   (begin
     ; 约束 A: 所有数字必须在 1 到 9 之间
     (for* ([row a]
            [x row])
       (assert (and (>= x 1) (<= x 9))))

     ; 约束 B: 每一行必须互不相同 (Row restriction)
     (for ([row a])
       ; distinct? 接受变长参数,所以我们要把 vector 转为 list 并 apply
       (assert (apply distinct? (vector->list row))))

     ; 约束 C: 每一列必须互不相同
     (for ([j (in-range 9)])
       (define col
         (for/list ([i (in-range 9)])
           (vector-ref (vector-ref a i) j)))
       (assert (apply distinct? col)))

     ; 约束 D: 每个宫格不能有相同数字
     (for* ([i '(0 3 6)]
            [j '(0 3 6)])
       (assert
        (apply distinct?
               (for*/list ([k (in-range 3)]
                           [l (in-range 3)])
                 (vector-ref (vector-ref a (+ i k)) (+ j l)))))))))

; 3. 输出结果
(if (sat? solution)
    (let ([concrete-a (evaluate a solution)])
      (displayln "Found a solution:")
      (displayln concrete-a))
    (displayln "No solution!"))

求解数独是(迄今为止)不能多项式的(数独是 NPC)。经典算法是 DLX。但假如说我增加一些限制,比如 9*9 网格的两条对角线上也不能有相同元素呢?或者比如说第一列有恰好两个 1,第 i 列有恰好两个 i 呢?DLX 就比较难以拓展了。

SMT 则能够方便地添加更多的约束以适应种种变体。你可以自己试一试。


我想肯定有人会问我怎么不像以前那样自己实现一个求解器……我只能说实在太难了搞不动。SMT 求解器这样的工程智慧,如果只是核心的 CDCL 和 SAT 求解其实还是能玩一玩的。但是外置的算术理论太过高深复杂了,写不动。