[Luogu] P7077 函式呼叫
Description
某資料庫應用程式提供了若干函式用以維護資料。已知這些函式的功能可分為三類:
\(1.\)將資料中的指定元素加上一個值;
\(2.\)將資料中的每一個元素乘以一個相同值;
\(3.\)依次執行若干次函式呼叫,保證不會出現遞迴(即不會直接或間接地呼叫本身)。
在使用該資料庫應用時,使用者可一次性輸入要呼叫的函式序列(一個函式可能被呼叫多次),在依次執行完序列中的函式後,系統中的資料被加以更新。為了計算出正確資料,小\(A\)查閱了軟體的文件,瞭解到每個函式的具體功能資訊,現在他想請你根據這些資訊幫他計算出更新後的資料應該是多少。
Solution
可以考慮依次執行完操作序列後,所有數先一起被乘上一個數,有一些位置再被加上所加上的數乘上這個數的貢獻。
注意到依次執行,不會出現遞迴等關鍵字眼,可以想到拓撲排序。(一定要注意這裡的\(m\)才是原來拓撲排序的\(n\),因為把\(m\)個函式當作點)
可以新建一個點\(0\),作為主函式,同時將它和輸入的所有呼叫函式連邊。那麼拓撲排序迴圈就要從\(0\sim{m}\)。
前一步是很好做的,可以考慮記憶化搜尋,或者先建反圖,跑一遍拓撲排序。維護一個乘法標記\(mul\),對於\(1\)類函式和\(3\)類函式,它們的\(mul=1\),對於\(2\)類函式,它的\(mul=v_i\)。然後從下到上,累乘\(mul\)即可。然後所有的\(a_i\)就要乘上\(mul[0]\)。
而如何處理加上的數究竟被加上了多少次呢?我們會發現,當操作序列不斷執行\(\times,+,\times,+...\)
對於\(u\rightarrow{v}\),設\(now=\prod\limits_{u\rightarrow{p},t_p>t_v}mul[p]\)(\(t_i\)表示\(i\)被執行的順序先後),那麼\(add[v]=add[v]+add[u]*now\)(\(u\)先執行\(add[u]\)次,每次都使\(add[v]\)乘上\(now\))。算完\(add\)後,處理加法即可。
(這個的確不是很難,但是誰有時間想啊,T1出題人1582.10.5~1582.10.14)
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod = 998244353;
queue < int > q;
int n, m, t, tot, hd[100005], nxt[1100005], to[1100005], rd[1100005], cnt[1100005];
ll a[100005];
struct node
{
int tp, pos, sz;
ll v, mul, add;
vector < int > w;
}f[100005];
int read()
{
int x = 0, fl = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0'; ch = getchar();}
return x * fl;
}
void add(int x, int y)
{
tot ++ ;
to[tot] = y;
nxt[tot] = hd[x];
hd[x] = tot;
return;
}
void topo1()
{
for (int i = 0; i <= m; i ++ ) cnt[i] = f[i].sz;
for (int i = 0; i <= m; i ++ ) if (!cnt[i]) q.push(i);
while (q.size())
{
int x = q.front(); q.pop();
for (int i = hd[x]; i; i = nxt[i])
{
int y = to[i];
f[y].mul = f[y].mul * f[x].mul % mod;
cnt[y] -- ;
if (!cnt[y]) q.push(y);
}
}
return;
}
void topo2()
{
for (int i = 0; i <= m; i ++ ) cnt[i] = rd[i];
for (int i = 0; i <= m; i ++ ) if (!cnt[i]) q.push(i);
while (q.size())
{
int x = q.front(); q.pop();
ll now = 1ll;
for (int i = f[x].sz - 1; i >= 0; i -- )
{
int y = f[x].w[i];
f[y].add = (f[y].add + f[x].add * now)% mod;
now = now * f[y].mul % mod;
cnt[y] -- ;
if (!cnt[y]) q.push(y);
}
}
return;
}
int main()
{
n = read();
for (int i = 1; i <= n; i ++ )
a[i] = (ll)read();
m = read();
f[0].mul = 1ll;
for (int i = 1; i <= m; i ++ )
{
f[i].tp = read();
if (f[i].tp == 1)
{
f[i].pos = read();
f[i].v = (ll)read();
f[i].mul = 1ll;
}
else if (f[i].tp == 2)
{
f[i].v = (ll)read();
f[i].mul = f[i].v;
}
else
{
f[i].sz = read();
f[i].mul = 1ll;
for (int j = 1; j <= f[i].sz; j ++ )
{
int x = read();
f[i].w.push_back(x);
add(x, i);
rd[x] ++ ;
}
}
}
f[0].add = 1ll;
t = read();
while (t -- )
{
int x = read();
add(x, 0);
rd[x] ++ ;
f[0].sz ++ ;
f[0].w.push_back(x);
}
topo1(); topo2();
for (int i = 1; i <= n; i ++ )
a[i] = a[i] * f[0].mul % mod;
for (int i = 1; i <= m; i ++ )
if (f[i].tp == 1)
a[f[i].pos] = (a[f[i].pos] + f[i].add * f[i].v % mod) % mod;
for (int i = 1; i <= n; i ++ )
printf("%lld ", a[i]);
puts("");
return 0;
}