1. 程式人生 > >hdu 5877 線段樹+離散化+DFS

hdu 5877 線段樹+離散化+DFS

連結:戳這裡

Weak Pair Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)
Problem Description
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
  (1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
  (2) au×av≤k.

Can you find the number of weak pairs in the tree?
 
Input
There are multiple cases in the data set.
  The first line of input contains an integer T denoting number of test cases.
  For each case, the first line contains two space-separated integers, N and k, respectively.
  The second line contains N space-separated integers, denoting a1 to aN.
  Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.

  Constrains: 
  
  1≤N≤105 
  
  0≤ai≤109 
  
  0≤k≤1018
 
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
 
Sample Input
1
2 3
1 2
1 2
 
Sample Output
1
 

題意:

n個節點的樹,節點的點權為ai,要求找出有多少個二元組(u,v)滿足

1:u是v的祖先且u!=v

2:a[u]*a[v]<=K

思路:

跑DFS的過程的時候,其實就是祖先到兒子的過程,但是會有兄弟的干擾,想象一下DFS序,我們要把兄弟節點刪掉

然後DFS序裡面的節點都是當前v的祖先,只需要快速找出有多少個祖先滿足條件

那麼我們考慮a[u]*a[v]<=K,則要快速找出前面的有多少個祖先a[u]滿足a[u]<=K/a[v],線段樹維護就可以了

這裡資料很大所以要離散化

程式碼:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<vector>
#include <ctime>
#include<queue>
#include<set>
#include<map>
#include<list>
#include<stack>
#include<iomanip>
#include<cmath>
#include<bitset>
#define mst(ss,b) memset((ss),(b),sizeof(ss))
///#pragma comment(linker, "/STACK:102400000,102400000")
typedef long long ll;
typedef long double ld;
#define INF (1ll<<60)-1
#define Max 1e9
using namespace std;
int T;
int n,m;
ll a[200100],b[200100],K;
int deep[100100];
int sum[800100];
void build(int root,int l,int r){
    if(l==r) {
        sum[root]=0;
        return ;
    }
    int mid=(l+r)/2;
    build(root*2,l,mid);
    build(root*2+1,mid+1,r);
    sum[root]=sum[root*2]+sum[root*2+1];
}
int query(int root,int l,int r,int x,int y){
    if(x<=l && y>=r) return sum[root];
    int mid=(l+r)/2;
    if(y<=mid) return query(root*2,l,mid,x,y);
    else if(x>mid) return query(root*2+1,mid+1,r,x,y);
    else return query(root*2,l,mid,x,mid)+query(root*2+1,mid+1,r,mid+1,y);
}
void update(int root,int l,int r,int x,ll v){
    if(l==r) {
        sum[root]+=v;
        return ;
    }
    int mid=(l+r)/2;
    if(x<=mid) update(root*2,l,mid,x,v);
    else update(root*2+1,mid+1,r,x,v);
    sum[root]=sum[root*2]+sum[root*2+1];
}
struct edge{
    int v,next;
}e[200100];
int head[100100],tot=0;
void Add(int u,int v){
    e[tot].v=v;
    e[tot].next=head[u];
    head[u]=tot++;
}
ll ans;
void DFS(int u){
    int l=lower_bound(b+1,b+m+1,K/a[u])-b;
    int pos=lower_bound(b+1,b+m+1,a[u])-b;
    ans+=1LL*query(1,1,m,1,l);
    update(1,1,m,pos,1);
    for(int i=head[u];i!=-1;i=e[i].next) DFS(e[i].v);
    update(1,1,m,pos,-1);
}
int main(){
    scanf("%d",&T);
    for(int cas=1;cas<=T;cas++){
        mst(deep,0);
        mst(head,-1);
        mst(sum,0);
        ans=tot=0;
        scanf("%d%I64d",&n,&K);
        for(int i=1;i<=n;i++){
            scanf("%I64d",&a[i]);
            b[i]=a[i];
        }
        m=n;
        for(int i=1;i<=n;i++) b[++m]=K/a[i];
        sort(b+1,b+m+1);
        m=unique(b+1,b+m+1)-(b+1);
        build(1,1,m);
        for(int i=1;i<n;i++){
            int u,v;
            scanf("%d%d",&u,&v);
            Add(u,v);
            deep[v]++;
        }
        for(int i=1;i<=n;i++){
            if(deep[i]==0) DFS(i);
        }
        printf("%I64d\n",ans);
    }
    return 0;
}