題目大意
如果一個 N × N 的矩陣滿足:
- 矩陣每行均為 [1, N] 的正整數(shù)的一個排列
- 矩陣內(nèi)所有元素與其上方的元素不同
那么這個矩陣便是美麗的疲吸。
現(xiàn)給定一個 N × N 的美麗矩陣,求有多少個 N × N 的美麗矩陣比它小欺殿。(矩陣從上到下按行比較)
題目保證 N 不超過 2000
分析
這個題的切入點在于美麗矩陣的定義粱腻。如果我們把當前行看作待排序的一個序列商膊,上面一行當成排序基準宾濒,則這個問題可以轉(zhuǎn)化成錯位排序問題。但是不同的是仅淑,在排了一部分數(shù)字以后称勋,剩下的部分的排序標準就不那么嚴苛了(即存在一些可行數(shù)沒有禁止位置)。
若 i 表示序列的長度涯竟, j 表示存在禁止位置的元素個數(shù)赡鲜,則由容斥原理易得:
這個表達式非常優(yōu)美,但是我們需要求 O(N2) 個 dp 值庐船,如果直接計算的話需要 O(N3) 银酬。不能承受】鹬樱考慮到組合遞推關(guān)系:
我們猜想 dp[i][j] 可以由 dp[i][j - 1] 和 dp[i - 1][j - 1] 推出揩瞪。果然,我們有:
現(xiàn)在我們來解決這個問題篓冲。根據(jù)題目的定義李破,兩個矩陣的比較與兩個字符串的比較方式類似,如果 A 矩陣小于 B 矩陣壹将,那么 A 矩陣的任意“前綴”小于等于 B 矩陣的對應(yīng)“前綴”嗤攻。如果兩個矩陣的第一個不相同元素的位置為 (i, j) ,那么對于給定的 B 矩陣瞭恰,這樣的 A 矩陣共有
其中 way0 和 way1 分別表示有多少種選法使得 A[i][j] < B[i][j] 且是否選取 A[i - 1] 中在 j 位置以前出現(xiàn)過的元素屯曹; cnt 表示 A[i - 1] 的前 j 個元素與 A[i] 的前 (j - 1) 個元素的相同個數(shù)。
如果我們用樹狀數(shù)組或名次樹來滑動地維護 way0 和 way1 惊畏,則均攤時間復(fù)雜度可降為每個位置 O(logN) 恶耽。剪枝以后可以接受。
代碼
總復(fù)雜度為 O(n2log(n))
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template <typename T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag,
tree_order_statistics_node_update>;
typedef long long ll;
typedef pair<int, int> pii;
#define FOR(i, a, b) for (int (i) = (a); (i) <= (b); (i)++)
#define ROF(i, a, b) for (int (i) = (a); (i) >= (b); (i)--)
#define REP(i, n) FOR(i, 0, (n)-1)
#define sqr(x) ((x) * (x))
#define all(x) (x).begin(), (x).end()
#define reset(x, y) memset(x, y, sizeof(x))
#define uni(x) (x).erase(unique(all(x)), (x).end());
#define BUG(x) cerr << #x << " = " << (x) << endl
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define _1 first
#define _2 second
const int maxn = 2123;
const ll MOD = 998244353;
ll fac[maxn], dp[maxn][maxn], ans, D[maxn];
int n, a[maxn][maxn];
pii way[maxn][maxn];
int main() {
scanf("%d", &n);
fac[0] = 1;
FOR(i, 1, n) fac[i] = fac[i - 1] * i % MOD;
dp[0][0] = 1;
FOR(i, 1, n) {
dp[i][0] = fac[i];
FOR(j, 1, i) {
dp[i][j] = (dp[i][j - 1] - dp[i - 1][j - 1]) % MOD;
if (dp[i][j] < 0) dp[i][j] += MOD;
}
}
D[0] = 1;
FOR(i, 1, n) D[i] = D[i - 1] * dp[n][n] % MOD;
FOR(i, 1, n) FOR(j, 1, n) scanf("%d", &a[i][j]);
FOR(i, 1, n) {
ordered_set<int> s[2];
FOR(j, 1, n) s[1].insert(j);
FOR(j, 1, n) {
way[i][j]._1 = s[0].order_of_key(a[i][j]);
if (a[i - 1][j] < a[i][j] && s[0].find(a[i - 1][j]) != s[0].end())
way[i][j]._1--;
way[i][j]._2 = s[1].order_of_key(a[i][j]);
if (a[i - 1][j] < a[i][j] && s[1].find(a[i - 1][j]) != s[1].end())
way[i][j]._2--;
s[0].erase(a[i][j]), s[1].erase(a[i][j]);
if (s[1].find(a[i - 1][j]) != s[1].end()) {
s[1].erase(a[i - 1][j]);
s[0].insert(a[i - 1][j]);
}
}
}
FOR(i, 1, n)
ans = (ans + way[1][i]._2 * fac[n - i] % MOD * D[n - 1]) % MOD;
FOR(i, 2, n) {
unordered_map<int, int> m;
FOR(j, 1, n) {
m[a[i - 1][j]]++;
int cnt = 2 * j - 1 - m.size();
ans = (ans + way[i][j]._1 * dp[n - j][n - 2 * j + cnt + 1]
% MOD * D[n - i]) % MOD;
ans = (ans + way[i][j]._2 * dp[n - j][n - 2 * j + cnt]
% MOD * D[n - i]) % MOD;
m[a[i][j]]++;
}
}
printf("%lld", ans);
}