1. 程式人生 > 實用技巧 >BZOJ-1798 [Ahoi2009]Seq 維護序列seq(雙標記線段樹)

BZOJ-1798 [Ahoi2009]Seq 維護序列seq(雙標記線段樹)

題目描述

  給定長為 \(n(1\leq n\leq 10^5)\) 的序列 \(a\)\(m(1\leq m\leq 10^5)\) 次操作,有三種操作:

  操作 \(1\)1 l r v,把區間 \([l,r]\) 中的 \(a_i\) 修改為 \(a_i\times v\)\(0\leq v\leq 10^9\))。

  操作 \(2\)2 l r v,把區間 \([l,r]\) 中的 \(a_i\) 修改為 \(a_i+v\)($0\leq v\leq 10^9 $)。

  操作 \(3\)3 l r,求 \(\displaystyle\sum_{i=l}^{r}a_i\),答案對 \(p(1\leq p\leq 10^9)\)

取模。

分析

區間加法

tree[p].add=tree[p].add+v;
tree[p].sum=tree[p].sum+v*(r-l+1);

區間乘法

  因為 \((ax+b)\times c+d=acx+bc+d\),這說明加法標記不會影響乘法標記,但是下放乘法標記必須把加法標記也乘一下,所以要先乘後加:在做乘法的時候把加法標記也乘上這個數,在後面做加法的時候直接加即可。

tree[p].add=(tree[p].add*v)%mod;
tree[p].mul=(tree[p].mul*v)%mod;
tree[p].sum=(tree[p].sum*v)%mod;

pushdown的維護

\(\text{mul}\):直接乘。

\(\text{add}\):因為 \(\text{add}\) 的值要包括乘之後的值,所以子節點需要先乘上 \(\text{mul}\)

void pushdown(long long p)
{
	long long mid=(tree[p].l+tree[p].r)>>1;
	if(tree[p].mul!=1)
	{
		tree[p*2].mul=tree[p*2].mul*tree[p].mul%mod;
		tree[p*2+1].mul=tree[p*2+1].mul*tree[p].mul%mod;
		tree[p*2].add=tree[p*2].add*tree[p].mul%mod;
		tree[p*2+1].add=tree[p*2+1].add*tree[p].mul%mod;
		tree[p*2].sum=tree[p*2].sum*tree[p].mul%mod;
		tree[p*2+1].sum=tree[p*2+1].sum*tree[p].mul%mod;
		tree[p].mul=1;
	}
	if(tree[p].add!=0)
	{
		tree[p*2].add=(tree[p*2].add+tree[p].add)%mod;
		tree[p*2+1].add=(tree[p*2+1].add+tree[p].add)%mod;
		tree[p*2].sum=(tree[p*2].sum+tree[p].add*(mid-tree[p].l+1)%mod)%mod;
		tree[p*2+1].sum=(tree[p*2+1].sum+tree[p].add*(tree[p].r-mid)%mod)%mod;
		tree[p].add=0;
	}
}

程式碼

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int n,m,mod;
long long a[N];
struct SegmentTree
{
    int l,r;
    long long add,mul,sum;
}tree[N<<2];
void build(int p,int l,int r)
{
    tree[p].l=l;tree[p].r=r;
    tree[p].sum=tree[p].add=0;
    tree[p].mul=1;
    if(l==r)
    {
        tree[p].sum=a[l];
        tree[p].mul=1;
        return ;
    }
    int mid=(l+r)/2;
    build(p*2,l,mid);
    build(p*2+1,mid+1,r);
    tree[p].sum=tree[p*2].sum+tree[p*2+1].sum;
}
void pushdown(long long p)
{
    if(tree[p].mul!=1)
    {
        tree[p*2].mul=tree[p*2].mul*tree[p].mul%mod;
        tree[p*2+1].mul=tree[p*2+1].mul*tree[p].mul%mod;
        tree[p*2].add=tree[p*2].add*tree[p].mul%mod;
        tree[p*2+1].add=tree[p*2+1].add*tree[p].mul%mod;
        tree[p*2].sum=tree[p*2].sum*tree[p].mul%mod;
        tree[p*2+1].sum=tree[p*2+1].sum*tree[p].mul%mod;
        tree[p].mul=1;
    }
    if(tree[p].add!=0)
    {
        tree[p*2].sum=(tree[p*2].sum+tree[p].add*(tree[p*2].r-tree[p*2].l+1))%mod;
        tree[p*2+1].sum=(tree[p*2+1].sum+tree[p].add*(tree[p*2+1].r-tree[p*2+1].l+1))%mod;
        tree[p*2].add=(tree[p*2].add+tree[p].add)%mod;
        tree[p*2+1].add=(tree[p*2+1].add+tree[p].add)%mod;
        tree[p].add=0;
    }
}
void update_mul(int p,int l,int r,long long v)
{
    if(l<=tree[p].l&&tree[p].r<=r)
    {
        tree[p].sum=tree[p].sum*v%mod;
        tree[p].add=tree[p].add*v%mod;
        tree[p].mul=tree[p].mul*v%mod;
        return ;
    }
    pushdown(p);
    int mid=(tree[p].l+tree[p].r)/2;
    if(l<=mid)
        update_mul(p*2,l,r,v);
    if(r>mid)
        update_mul(p*2+1,l,r,v);
    tree[p].sum=(tree[p*2].sum+tree[p*2+1].sum)%mod;
}
void update_add(int p,int l,int r,long long v)
{
    if(l<=tree[p].l&&tree[p].r<=r)
    {
        tree[p].sum=(tree[p].sum+(tree[p].r-tree[p].l+1)*v%mod)%mod;
        tree[p].add=(tree[p].add+v)%mod;
        return ;
    }
    pushdown(p);
    int mid=(tree[p].l+tree[p].r)/2;
    if(l<=mid)
        update_add(p*2,l,r,v);
    if(r>mid)
        update_add(p*2+1,l,r,v);
    tree[p].sum=(tree[p*2].sum+tree[p*2+1].sum)%mod;
}
long long query(int p,int l,int r)
{
    if(l<=tree[p].l&&tree[p].r<=r)
        return tree[p].sum;
    pushdown(p);
    int mid=(tree[p].l+tree[p].r)/2;
    long long ans=0;
    if(l<=mid)
        ans=(ans+query(p*2,l,r))%mod;
    if(r>mid)
        ans=(ans+query(p*2+1,l,r))%mod;
    return ans;
}
int main()
{
    cin>>n>>mod;
    for(int i=1;i<=n;i++)
        scanf("%lld",&a[i]);
    build(1,1,n);
    cin>>m;
    while(m--)
    {
        int op;
        scanf("%d",&op);
        int l,r;
        long long v;
        if(op==1)
        {
            scanf("%d %d %lld",&l,&r,&v);
            update_mul(1,l,r,v);
        }
        if(op==2)
        {
            scanf("%d %d %lld",&l,&r,&v);
            update_add(1,l,r,v);
        }
        if(op==3)
        {
            scanf("%d %d",&l,&r);
            printf("%lld\n",query(1,l,r));
        }
    }
    return 0;
}