1. 程式人生 > >Problem 3. Party Invitations 2018 Goldman Sachs Women's CodeSprint

Problem 3. Party Invitations 2018 Goldman Sachs Women's CodeSprint

題意:boss和employee的關係可以構成一棵樹,樹上每個節點是其子樹的boss,和其ascent的employee。現在傳送invitation,要求boss必須比其employee先收到,問有多少種傳送invitaion的方式。

結果要對1e7取模,一看就是排列組合推公式。

推公式也是從樹的遞推關係入手,假設三個子樹a,b,c有同一個parent,那麼a,b,c之間是不會影響的。令其孩子節點個數分別是n1,n2,n3,如果總共有n個位置去放這三個子樹(n=n1+n2+n3),那麼就是n中選n1個位置先放a,剩下的n-n1個位置中選n2個放b,剩下的n-n1-n2個位子中放c。如果子樹a是葉子,返回1。

遞推公式是f(parent)=f(a)*f(b)*f(c)*C(n,n1)*C(n-n1,n2)*C(n-n1-n2,n3)。

如果輸入是dis-connected,相當於加一個virtual parent,將輸入中的幾棵樹作為並列的子樹處理。

這一題必須O(NlogN)才能過。所以需要快速求組合數。按照楊氏三角形預處理會超時,正解是把階乘先存起來,最後用逆元和快速冪求除法。

#include <bits/stdc++.h>

using namespace std;

string ltrim(const string &);
string rtrim(const string &);
vector<string> split(const string &);

/*
 * Complete the 'invitations' function below.
 *
 * The function is expected to return an INTEGER.
 * The function accepts following parameters:
 *  1. INTEGER n
 *  2. 2D_INTEGER_ARRAY pairs
 */
const int maxn=200010;
const int mod=1e9+7;
vector<vector<int> >mp;
long long save[maxn];
//map<pair<int,int>,long long >combmp;
map<pair<int,int>,int >indexmp;
int parent[maxn];//# of parents
int children[maxn];//# of children


long long Pow(long long a,long long b)
{
    long long s=1;
    long long t=1;
    while(b)
    {
        if(b&t)
        {
            s=(s*a)%mod;
        }
        a=(a*a)%mod;
        b=b>>1;
    }
    return s;
}
long long fac[maxn];
void getfac()
{
    fac[0]=1;
    for (int i=1; i<maxn; i++)
        fac[i]=(fac[i-1]*i)%mod;
}

long long inv(long long a) {
    return Pow(a, mod - 2);
}
long long Comb(int n,int m)
{
    if(n==m)
    {
        return 1;
    }
    if(m==0)
    {
        return 1;
    }
    if (n<m)
    {
        return 0;
    }
    return ((fac[n]*inv(fac[m]))%mod*inv(fac[n-m])%mod)%mod;
}
void init()
{
    for(int i=0;i<maxn;i++)
    {
        mp.push_back(vector<int>());
    }
    getfac();
//    memset(arr,false,sizeof(arr));
//    prim=produce_prim_number();
//    combmp.clear();
//    indexmp.clear();
//    memset(comb,0,sizeof(comb));
}

int calc_child(int root)
{
    if(mp[root].size()==0)
    {
        children[root]=0;
        return children[root];
    }
    for(int i=0;i<mp[root].size();i++)
    {
        children[root]+=1+calc_child(mp[root][i]);
        // cout<<"root "<<root<<" child "<<mp[root][i]<<" num "<<children[mp[root][i]]<<endl;
    }
    return children[root];
}
// long long save[maxn][maxn];
long long dfs(int root)
{
    if(save[root]!=0)
    {
        return save[root];
    }
    if(mp[root].size()==0)
    {
        save[root]=1;
        return 1;
    }
    long long ret=0;
    long long tmp=1;
    vector<int>presum=vector<int>();
    for(int j=mp[root].size()-1;j>=0;j--)
    {
        if(j==mp[root].size()-1)
        {
            presum.push_back(children[mp[root][j]]+1);
        }
        else
        {
//            cout<<"add "<<children[mp[root][j]]<<" "<<presum[mp[root].size()-1-j-1]<<endl;
            presum.push_back(children[mp[root][j]]+1+presum[mp[root].size()-1-j-1]);
        }
        
    }
//    for(int j=0;j<presum.size();j++)
//    {
//        cout<<"root "<<root<<" presum "<<presum[j]<<endl;
//    }
    for(int j=0;j<mp[root].size();j++)
    {
//        cout<<"root "<<root<<" "<<mp[root][j]<<" "<<presum[mp[root].size()-1-j]<<endl;
//        tmp*=Combination(presum[mp[root].size()-1-j],1+children[mp[root][j]])%mod;
        tmp*=Comb(presum[mp[root].size()-1-j],1+children[mp[root][j]])%mod;
        tmp%=mod;
    }
    for(int j=0;j<mp[root].size();j++)
    {
        tmp*=dfs(mp[root][j])%mod;
        tmp%=mod;
        // cout<<"root "<<root<<" remain slot "<<i<<" child "<<mp[root][j]<<" prod "<<<<endl;
    }
    ret+=tmp;
    ret%=mod;
    save[root]=ret;
//    cout<<"root "<<root<<" ret "<<ret<<endl;
    return ret;
}
int invitations(int n, vector<vector<int>> pairs) {
    long long ans=1;
    memset(parent,0,sizeof(parent));
    memset(children,0,sizeof(children));
    memset(save,0,sizeof(save));
    for(int i=0;i<mp.size();i++)
    {
        mp[i].clear();
    }
    // mp.clear();
    // for(int i=0;i<=n;i++)
    // {
    //     mp.push_back(vector<int>());
    // }
    
    for(int i=0;i<pairs.size();i++)
    {
        parent[pairs[i][1]]=1;
        // children[vector[i][0]]++;
        mp[pairs[i][0]].push_back(pairs[i][1]);
    }
    for(int i=1;i<=n;i++)
    {
        if(parent[i]==1)
        {
            continue;
        }
        calc_child(i);
    }
    int tot=n;
    for(int i=1;i<=n;i++)
    {
//         cout<<"children "<<i<<" "<<children[i]<<endl;
        if(parent[i]==1)//may be multiple roots
        {
            continue;
        }
        //connected_num++;
        ans=((ans%mod)*dfs(i))%mod;
//        cout<<"add comb "<<tot<<" "<<children[i]+1<<endl;
        ans*=Comb(tot,children[i]+1)%mod;
        ans%=mod;
        tot-=children[i]+1;
    }
    ans%=mod;
    // cout<<"ans "<<ans<<endl;
    return ans;
    
}



int main()
{
    init();
    ofstream fout(getenv("OUTPUT_PATH"));

    string tc_temp;
    getline(cin, tc_temp);

    int tc = stoi(ltrim(rtrim(tc_temp)));

    for (int tc_itr = 0; tc_itr < tc; tc_itr++) {
        string first_multiple_input_temp;
        getline(cin, first_multiple_input_temp);

        vector<string> first_multiple_input = split(rtrim(first_multiple_input_temp));

        int n = stoi(first_multiple_input[0]);

        int m = stoi(first_multiple_input[1]);

        vector<vector<int>> pairs(m);

        for (int i = 0; i < m; i++) {
            pairs[i].resize(2);

            string pairs_row_temp_temp;
            getline(cin, pairs_row_temp_temp);

            vector<string> pairs_row_temp = split(rtrim(pairs_row_temp_temp));

            for (int j = 0; j < 2; j++) {
                int pairs_row_item = stoi(pairs_row_temp[j]);

                pairs[i][j] = pairs_row_item;
            }
        }

        int result = invitations(n, pairs);

        fout << result << "\n";
    }

    fout.close();

    return 0;
}

string ltrim(const string &str) {
    string s(str);

    s.erase(
        s.begin(),
        find_if(s.begin(), s.end(), not1(ptr_fun<int, int>(isspace)))
    );

    return s;
}

string rtrim(const string &str) {
    string s(str);

    s.erase(
        find_if(s.rbegin(), s.rend(), not1(ptr_fun<int, int>(isspace))).base(),
        s.end()
    );

    return s;
}

vector<string> split(const string &str) {
    vector<string> tokens;

    string::size_type start = 0;
    string::size_type end = 0;

    while ((end = str.find(" ", start)) != string::npos) {
        tokens.push_back(str.substr(start, end - start));

        start = end + 1;
    }

    tokens.push_back(str.substr(start));

    return tokens;
}