codeforces 894E Ralph and Mushrooms 強連通dp
E. Ralph and Mushrooms
time limit per test
2.5 seconds
memory limit per test
512 megabytes
input
standard input
output
standard output
Ralph is going to collect mushrooms in the Mushroom Forest.
There are m directed paths connecting n trees in the Mushroom Forest. On each path grow some mushrooms. When Ralph passes a path, he collects all the mushrooms on the path. The Mushroom Forest has a magical fertile ground where mushrooms grow at a fantastic speed. New mushrooms regrow as soon as Ralph finishes mushroom collection on a path. More specifically, after Ralph passes a path the i
For example, let there be 9 mushrooms on a path initially. The number of mushrooms that can be collected from the path is 9, 8, 6 and 3when Ralph passes by from first to fourth time. From the fifth time and later Ralph can't collect any mushrooms from the path (but still can pass it).
Ralph decided to start from the tree s
Input
The first line contains two integers n and m (1 ≤ n ≤ 106, 0 ≤ m ≤ 106), representing the number of trees and the number of directed paths in the Mushroom Forest, respectively.
Each of the following m lines contains three integers x, y and w (1 ≤ x, y ≤ n, 0 ≤ w ≤ 108), denoting a path that leads from tree x to tree ywith w mushrooms initially. There can be paths that lead from a tree to itself, and multiple paths between the same pair of trees.
The last line contains a single integer s (1 ≤ s ≤ n) — the starting position of Ralph.
Output
Print an integer denoting the maximum number of the mushrooms Ralph can collect during his route.
Examples
input
Copy
2 2
1 2 4
2 1 4
1
output
Copy
16
input
Copy
3 3
1 2 4
2 3 3
1 3 8
1
output
Copy
8
Note
In the first sample Ralph can pass three times on the circle and collect 4 + 4 + 3 + 3 + 1 + 1 = 16 mushrooms. After that there will be no mushrooms for Ralph to collect.
In the second sample, Ralph can go to tree 3 and collect 8 mushrooms on the path from tree 1 to tree 3.
題目大意:
有向圖有n個點m條邊,每條邊上都有蘑菇,第i個邊有a[i]個蘑菇,第一次經過這個邊能獲得a[i]個蘑菇,第二次能獲得a[i]-1個,第三次能獲得a[i]-1-2個,直到a[i]為負(採集不到蘑菇,但是還可以經過這條邊),求給定起點最多能獲得多少蘑菇。
一個強連通分量中的所有邊都可以無限經過,所以可以採完所有蘑菇,其他的邊只能採一次,所以強連通縮點後dp一下。
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#include <stack>
#include <utility>
using namespace std;
#define ll long long
const int maxn = 1e6 + 10;
vector<pair<int, int> > G[maxn], G2[maxn];
vector<pair<int, ll> >mp[maxn];
vector<int> S;
int vis[maxn], sccno[maxn], scc_cnt, sum[maxn];
ll dp[maxn], val[maxn];
int n, m;
ll cal_sum(ll x){
x++;
return x * (x - 1) * (x + 1) / (6ll);
}
ll calu(ll x) {
int l = 0, r = x;
ll t;
while (l <= r)
{
ll m = (l + r) / 2;
if (x - m * (m + 1) / 2 >= 0)
t = m, l = m + 1;
else
r = m - 1;
}
return x * (t + 1) - cal_sum(t);
}
void dfs1(int u)
{
if (vis[u]) return;
vis[u] = 1;
for (int i = 0; i < G[u].size(); i++) dfs1(G[u][i].first);
S.push_back(u);
}
void dfs2(int u)
{
if (sccno[u]) return;
sccno[u] = scc_cnt;
for (int i = 0; i < G2[u].size(); i++) dfs2(G2[u][i].first);
}
void find_scc(int n)
{
int i;
scc_cnt = 0;
S.clear();
memset(sccno, 0, sizeof(sccno));
memset(vis, 0, sizeof(vis));
for(i = 0; i < n; i++) dfs1(i);
for(i = n - 1; i >= 0; i--)
if(!sccno[S[i]]) { scc_cnt++; dfs2(S[i]); }
}
ll dfs(int u) {
if (dp[u]) return dp[u];
for (int i = 0; i < mp[u].size(); i++) {
int v = mp[u][i].first;
int w = mp[u][i].second;
dp[u] = max(dp[u], dfs(v) + w);
}
dp[u] += val[u];
return dp[u];
}
int main()
{
memset(dp, 0, sizeof(dp));
memset(val, 0, sizeof(val));
scanf("%d%d", &n, &m);
int start;
for (int i = 0; i < n; i++) {
G[i].clear();
G2[i].clear();
mp[i].clear();
}
for(int i = 1; i <= m; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
x--;y--;
G[x].push_back({y, w});
G2[y].push_back({x, w});
}
scanf("%d", &start);
find_scc(n);
for (int u = 0; u < n; u++) {
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i].first;
int w = G[u][i].second;
if (sccno[u] == sccno[v]) {
val[sccno[u]] += calu(w);
}
else {
mp[sccno[u]].push_back({sccno[v], w});
}
}
}
start--;
printf("%I64d\n", dfs(sccno[start]));
//int xx = dfs(sccno[start]);
//for (int i = 0; i < n; i++) printf("%d %d %I64d\n", i, sccno[i], dp[sccno[i]]);
return 0;
}