假设你刚刚看过这篇文章:浅谈valarray

现在我们需要模仿 valarray 写一个科学计算库。其中,我们需要支持形如 a + b 的写法,ab 是两个向量,这个表达式的值是两个向量的和,即逐元素相加得到的新向量。

没错。这是运算符重载板子。你也许会写出类似这样的代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
template<typename T>
struct Vector {
    std::vector<T> data;

    // 构造函数,方便初始化
    Vector(size_t n = 0) : data(n) {}
    Vector(std::initializer_list<T> l) : data(l) {}

    T operator[](size_t i) const { return data[i]; }
    T &operator[](size_t i) { return data[i]; }
    size_t size() const { return data.size(); }

    Vector operator+(const Vector &rhs) const {
        Vector res(size()); // 初始化大小
        for (size_t i = 0; i < size(); ++i) {
            res[i] = data[i] + rhs[i];
        }
        return res; // 依赖 NRVO
    }
};

题外话:
注意加法中的 const& 参数类型。它是为了方便接受临时 Vector 对象(右值):普通左值引用会 CE。

而在 C++ 规范中,const 左值引用(const T&)是一个“万能引用”,它可以绑定到临时对象(右值)。

C++ 标准规定:常量左值引用可以延长临时对象的生命周期,直到该引用本身被销毁。这使得 const T& 成了 C++ 历史上最通用的参数传递方式:它既能接左值(变量),也能接右值(临时对象),且保证不会修改它。


看上去非常正确。然而,这个写法在嵌套的表达式(比如 a+b+c 三个向量相加)时,会多次循环遍历数组,并且生成和废弃大量的向量对象(res 的分配,以及右值在表达式中当加法调用完毕生命周期就结束,从而被释放)。

这是不好的。如果手写,我们只需要 res[i] = a[i] + b[i] + c[i]; 就行了,根本不需要这么多中间变量的分配,循环遍历和缓存未命中开销也会更小。

但是手写看上去丑陋和费事的多。能不能使用一些手法,使得我们能够兼顾优雅和高性能呢?

有的,兄弟!有的。

表达式模板

前人想到了一个非常天才的主意:利用模板系统的元编程能力,将表达式信息编码在类型中,自动展开为高效的计算过程。

这样说比较抽象。具体的,a+b 这个表达式不会被立即求值,而是会生成一棵 addexp<Vector, Vector> 类型的表达式树。这个表达式树可以被随机访问,访问它的第 i 个位置时,编译器会自动从它的数据源计算出第 i 个位置的值。

在上面的具体例子中,第 i 个位置的值就是 a[i]+b[i]

而这棵树可以赋值给普通的 Vector。赋值时,我们遍历 $i\in [0, \text{size})$,将目标 Vector 的第 $i$ 个位置赋值为表达式树的第 $i$ 个位置的值。

这就是表达式模板技术。

统一表示

在工程实现中,我们一般会实现一个基类 Exp,它有一个 operator[] 声明。而所有的表达式类,包括二元表达式以及 Vector 自己,都继承它,并实现 operator[] 作为随机访问的接口。

这样做的优势是可扩展性强,表示统一简洁。

奇异递归模板模式

你可能注意到了,这个思想其实要求我们实现多态。多态的一般写法是 virtual 声明虚函数,override 重写实现。然而,这里无法使用这种一般做法。

因为虚函数需要查表和维护 RTTI(运行时类型信息),有内存和时间的双重开销。对于科学计算,这是不可容忍的。

前人又有一种惊世骇俗的方法,专门用于实现零开销的静态多态。也就是“奇异递归模板模式”。

看起来像是这样:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
template<typename T>
struct Exp {
    auto operator[](size_t i) const {
        return static_cast<T>(*this)[i];
    }
};

template<typename T>
struct Vector : public Exp<Vector<T>> {
    std::vector<T> data;

    // 构造函数,方便初始化
    Vector(size_t n = 0) : data(n) {}
    Vector(std::initializer_list<T> l) : data(l) {}

    T operator[](size_t i) const {
        return data[i];
    }

    T &operator[](size_t i) {
        return data[i];
    }

    size_t size() const {
        return data.size();
    }

    // 修正后的加法实现
    Vector operator+(const Vector &rhs) const {
        Vector res(size()); // 初始化大小
        for (size_t i = 0; i < size(); ++i) {
            res[i] = data[i] + rhs[i];
        }
        return res; // 依赖 NRVO
    }
};

我想你应该能看懂。基类是一个模板类,子类则继承自基类,并将基类的模板参数填入自己。基类中的多态函数在被调用时,把基类强转为子类,并调用子类的实现。

本来,从超类型到子类型的转换其实是不安全的(需要 dynamic_cast)。不过,这里我们可以确保强转的那个子类一定是数据的真实类型,不会出锅。

某种意义上,其实是我们手动实现了一个 Union type,在模板参数中记录它的类型信息。

原力闪电

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
template<typename T>
struct Exp {
    auto operator[](size_t i) const {
        return static_cast<const T &>(*this)[i];
    }

    size_t size() const {
        return static_cast<const T &>(*this).size();
    }
};

template<typename L, typename R>
struct Add_exp : public Exp<Add_exp<L, R>> {
    const L &l;
    const R &r;

    size_t size() const {
        return l.size();
    }

    Add_exp(const L &l, const R &r) : l(l), r(r) {}

    auto operator[](size_t i) const {
        return l[i] + r[i];
    }
};

template<typename T>
struct Vector : public Exp<Vector<T>> {
    std::vector<T> data;

    // 构造函数,方便初始化
    Vector(size_t n = 0) : data(n) {}
    Vector(std::initializer_list<T> l) : data(l) {}

    T operator[](size_t i) const {
        return data[i];
    }

    T &operator[](size_t i) {
        return data[i];
    }

    size_t size() const {
        return data.size();
    }

    template<typename T2>
    void operator=(const Exp<T2> &rhs) {
        // clog << rhs.size() << "\n";
        data.resize(rhs.size());
        // clog << data.size() << "\n";
        for (size_t i = 0; i < size(); ++i) {
            data[i] = rhs[i];
        }
    }
};

template<typename L, typename R>
auto operator+(const Exp<L> &l, const Exp<R> &r) {
    return Add_exp<L, R>(static_cast<const L &>(l), static_cast<const R &>(r));
}

一个简陋的表达式模板。

测试:

1
2
3
4
5
6
7
8
int main() {
    Vector<int> res;
    res = (Vector{1, 2, 3} + Vector{4, 5, 6} + Vector{4, 2, 0});
    for (int i = 0; i < res.size(); ++i) {
        cout << res[i] << " ";
    }
    return 0;
}

输出 9 9 9

代码看上去非常简单,实际上也非常简单。但是你在复现的时候就会发现其中隐藏了大量的 C++ 语法和语义细节。这些偶发复杂性极其烦人。

建议自己复现一下。遇到疑难杂症可以询问 AI。

可扩展性的演示

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
template<typename T>
struct Exp {
    auto operator[](size_t i) const {
        return static_cast<const T &>(*this)[i];
    }

    size_t size() const {
        return static_cast<const T &>(*this).size();
    }
};

template<typename L, typename R, typename OP>
struct Bin_exp : public Exp<Bin_exp<L, R, OP>> {
    const L &l;
    const R &r;
    OP op;

    size_t size() const {
        return l.size();
    }

    Bin_exp(const L &l, const R &r) : l(l), r(r), op() {}

    auto operator[](size_t i) const {
        return op(l[i], r[i]);
    }
};

struct Add {
    auto operator()(auto a, auto b) const {
        return a + b;
    }
};

template<typename L, typename R>
using Add_exp = Bin_exp<L, R, Add>;

template<typename T>
struct Vector : public Exp<Vector<T>> {
    std::vector<T> data;

    // 构造函数,方便初始化
    Vector(size_t n = 0) : data(n) {}
    Vector(std::initializer_list<T> l) : data(l) {}

    T operator[](size_t i) const {
        return data[i];
    }

    T &operator[](size_t i) {
        return data[i];
    }

    size_t size() const {
        return data.size();
    }

    template<typename T2>
    void operator=(const Exp<T2> &rhs) {
        // clog << rhs.size() << "\n";
        data.resize(rhs.size());
        // clog << data.size() << "\n";
        for (size_t i = 0; i < size(); ++i) {
            data[i] = rhs[i];
        }
    }
};

template<typename L, typename R>
auto operator+(const Exp<L> &l, const Exp<R> &r) {
    return Add_exp<L, R>(static_cast<const L &>(l), static_cast<const R &>(r));
}

注意到仿函数的使用。C++ 的函数式特性支持奇差无比,这种时候我竟然不能使用 lambda(或者考虑 std::function 包装一层,但是堆分配和类型擦除都有开销)。

当然,想要用 lambda 也是可以的。可以考虑在构造函数中传入 lambda,或者使用更加高级的 TMP 黑魔法。

然后你会惊讶地注意到,如果你写了

1
2
3
4
5
6
7
8
9
int main() {
    Vector<int> res;
    auto tmp = (Vector{1, 2, 3} + Vector{4, 5, 6} + Vector{4, 2, 0});
    res = tmp;
    for (int i = 0; i < res.size(); ++i) {
        cout << res[i] << " ";
    }
    return 0;
}

程序会爆掉,并(可能,实际上此时是未定义的)抛出 std::bad_alloc 异常。

为什么?因为极其不幸的,tmp 的声明那一行结束之后,几个 Vector 临时对象的生命周期就结束了。于是 tmp 中存储的引用变成了悬垂引用。然后 res = tmp; 就寄了。


另一个坑:如果你以后增加了一些不是逐元素进行的操作,比如类似 bitset 的右移,v = (v << 1); 之类的写法可能就会出锅。

因为它会展开为一个 v[i] = v[i - 1] 的循环。这样,就会让原始数据被覆盖掉(类似于 01 背包的转移),于是得到错误的结果。


这些都需要大量复杂的手法来解决(比如你可以试着使用 SFINAE 等手法对于左值和右值进行逻辑分派,分别存储原值和常量引用)。所以我暂时按下。

These are your father’s parens… elegant weapons, belong to a more civilized era

考虑这段 racket 代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#lang racket

(define-syntax get-pos
  (lambda (stx)
    (syntax-case stx ()
      [(_ (op a b) i)
       #'(op (get-pos a i) (get-pos b i))]

      [(_ a i)
       #'(vector-ref a i)])))

(define-syntax get-len
  (lambda (stx)
    (syntax-case stx ()
      [(_ (op a b)) #'(get-len a)]
      [(_ a) #'(vector-length a)])))

(define-syntax vec-binexp
  (lambda (stx)
    (syntax-case stx ()
      [(_ (op a b))
       #'(for/vector ([i (in-range (get-len a))])
           (get-pos (op a b) i))]

      [(_ a i)
       #'a])))

(define a #(1 2 3))
(define b #(4 5 6))

(displayln (vec-binexp (* (+ #(3 4 5) #(2 1 0)) (+ a b))))

事实上,它做到了与上面的 C++ 黑魔法同样的事。

当然,racket 有一些弱点。

比如,如果你想要加入向量数乘,你就会惊讶地发现你不得不使用一个 if 做向量和数值的类型分派。racket 的宏难以利用类型信息,甚至在 typed/racket 中也一样:typed/racket 的实现是类型擦除的,它在普通 racket 上用宏构建了一层类型推导、检查与优化器。但真正运行时,typed/racket 已经被展开为了普通的 racket 代码。而宏展开阶段时 typed/racket 的类型推导和检查机制还没有开始介入,自己写的宏也就无法利用类型信息。

怎么办?一种方法是利用前人的工作,比如 #lang turnstile。它是专门为类型化语言构建的,甚至可以用来实现 haskell(比如著名的 haskett)。但是机制极端复杂,还有 Unicode 字符代码比较超标。

简单的 hack 就是手写类型标注。比如说数值 x 写成 (Number x)。然后宏中进行特判,消除类型分派。本质上是自己维护了类型系统。

说到这里,有没有大神愿意教我实现一个 typed/racket 并且让宏能够利用类型信息啊?求带飞 /bx