题目链接

题外话:

数竞党的防御性证明会用严谨的逻辑说明结论的正确性,而对“解题思路是如何想到的”一类的问题没有任何启发;oier/acmer 的题解则大多会提示我们从哪里入手,却不加严谨的证明。而本题解采取折中的方案:既对解题思路没有任何启发,也不加严谨的证明。


考虑 gcd 的动机在于,对于这样的题,$k$ 巨大。
于是考虑循环节。发现循环节是 $\mathrm{lcm}(n, m)$。

这时,我们考虑完整的循环节。
发现在 $n$ 与 $m$ 不互素时这是麻烦的,而 $n$ 与 $m$ 互素时一个循环节内每对 $(a_i, b_j)$ 刚好出现且仅出现一次。

这时考虑 $g = \gcd(n, m)$。
发现,对于任意 $r$,一个模 $g$ 余 $r$ 的 $i$ 只会和模 $g$ 余 $r$ 的 $j$ 配对。
这样就能够把非互素情形转化为互素情形。

对于互素情形,循环节是容易处理的。
考虑多余的一部分。这时,考虑枚举 $i$,寻找能与 $i$ 配对的 $j$。

记 $n' = n / g, m' = m / g$。
这时,对于一个下标 $i$,能够与它配对的任意 $j$ 与 $kn'+i, k \in \N$ 在模 $m'$ 意义下同余。

这样,我们对 $b$ 重排,将 $b_j$ 挪到 $b_{j \times \mathrm{inv}(n') \bmod m'}$。
发现所有 $i$ 对应的 $b$ 中元素是一个完整的区间。那么找出这个区间并计算它的贡献。

区间是容易找到的。
第一个索引是 $i \times \mathrm{inv}(n') \bmod m'$。
长度则是 $\lfloor c/n' \rfloor + [c \bmod n' \geq i]$。
其中方括号是艾弗森括号,$c$ 是循环节之外的“剩余部分”的操作总数。

下一步是利用数据结构维护区间的贡献。
于是转化为给定若干询问 $(\mathrm{range}, \mathrm{lim})$,查询 $\mathrm{range}$ 中小于 $\mathrm{lim}$ 的元素的和与数量。

这是一个二维数点问题。离线使用 BIT 维护值域上前缀计数与前缀和,扫描线即可。

AC 代码:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <bits/stdc++.h>

using namespace std;
constexpr int mod = 998244353;
int n, m;
long long k;
array<int, 200005> a, b;

pair<long long, long long> exgcd(long long a, long long b) {
    if (b == 0) {
        return {1, 0};
    } else {
        auto [x, y] = exgcd(b, a % b);
        return {y, x - (a / b) * y};
    }
}

struct question {
    int pos, lim;
    bool op;

    question(int pos, bool op, int lim)
        : pos(pos),
          op(op),
          lim(lim) {}

    friend bool operator<(const question &lhs, const question &rhs) {
        return lhs.pos < rhs.pos;
    }
};

int64_t solve(vector<int> &a, vector<int> &b, const int n, const int m, long long cnt) {
    auto copy = b;
    ranges::sort(copy);
    auto prefix_sum = copy;
    for (int i = 0; i + 1 < m; ++i) {
        prefix_sum[i + 1] = (prefix_sum[i + 1] + prefix_sum[i]) % mod;
    }

    long long res = 0;
    const int64_t num = cnt / ((long long) n * m), rem = cnt - num * n * m;
    for (int v: a) {
        auto it = ranges::upper_bound(copy, v);
        int c = distance(copy.begin(), it);

        res = (res + (c > 0 ? prefix_sum[c - 1] : 0) + (long long) (m - c) * v) % mod;
    }
    res = (res * (num % mod)) % mod;
    // cerr << res << "\n";

    const int64_t inv = (exgcd(n, m).first % m + m) % m;
    // cerr << inv << "\n";
    for (int i = 0; i < m; ++i) {
        copy[i * inv % m] = b[i];
    }

    vector<question> questions;
    for (int i = 0; i < n; ++i) {
        const int start = i * inv % m, len = rem / n + (rem % n > i);
        // cerr << i << " " << start << " " << len << "\n";

        if (start + len <= m) {
            questions.emplace_back(start - 1, false, a[i]);
            questions.emplace_back(start + len - 1, true, a[i]);
        } else {
            questions.emplace_back(start - 1, false, a[i]);
            questions.emplace_back(m - 1, true, a[i]);
            questions.emplace_back(len - (m - start) - 1, true, a[i]);
        }
    }

    vector<int> mapping;
    for (int i = 0; i < n; ++i) {
        mapping.push_back(a[i]);
    }

    for (int i = 0; i < m; ++i) {
        mapping.emplace_back(b[i]);
    }

    ranges::sort(mapping);
    mapping.erase(ranges::unique(mapping).begin(), mapping.end());

    vector<int> bit_cnt(mapping.size() + 2, 0), bit_sum(mapping.size() + 2, 0);
    auto get_id = [&mapping](int x) {
        return distance(mapping.begin(), ranges::upper_bound(mapping, x));
    };

    auto low_bit = [](int x) {
        return x & (-x);
    };

    auto query = [low_bit](int pos, auto &arr) {
        int res = 0;
        while (pos) {
            res = (res + arr[pos]) % mod;
            pos -= low_bit(pos);
        }
        return res;
    };

    auto add = [low_bit](int pos, int val, auto &arr) {
        while (pos < arr.size()) {
            arr[pos] = (arr[pos] + val) % mod;
            pos += low_bit(pos);
        }
    };

    sort(questions.begin(), questions.end());
    int ptr = 0;
    for (const auto [pos, lim, op]: questions) {
        if (pos >= 0) {
            // cerr << pos << " " << lim << " " << op << "\n";

            const int f = (op ? 1 : -1);

            while (ptr <= pos) {
                add(get_id(copy[ptr]), 1, bit_cnt);
                add(get_id(copy[ptr]), copy[ptr], bit_sum);
                ptr++;
            }

            const int rank = get_id(lim);
            const int term1 = query(rank, bit_cnt), term2 = query(rank, bit_sum);
            const int64_t term3 = (term2 + (long long) (pos + 1 - term1) * lim) % mod;
            res += f * term3;
            res %= mod;
            if (res < 0) {
                res += mod;
            }
        }
    }
    return res;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m >> k;
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
    }

    for (int i = 0; i < m; ++i) {
        cin >> b[i];
    }

    const int g = gcd(m, n);
    long long ans = 0;
    for (int i = 0; i < g; ++i) {
        vector<int> tmp1, tmp2;
        for (int j = i; j < n; j += g) {
            tmp1.push_back(a[j]);
        }

        for (int j = i; j < m; j += g) {
            tmp2.push_back(b[j]);
        }

        ans = (ans + solve(tmp1, tmp2, n / g, m / g, k / g + (k % g > i))) % mod;
    }
    cout << ans;
    return 0;
}