1. 程式人生 > 實用技巧 >線段樹那些事

線段樹那些事

線段樹是學不明白了……

部分指標用法

對於這段程式碼,

struct Node{
    int a, b, c;
}YJH[100], x;
Node *p = YJH, *q = &x;

以下程式碼在使用過程中是等價的:


cout << x.a << endl;
cout << q->a << endl;
cout << (*q).a << endl;

cout << YJH[2].b << endl;
cout << *(p + 2).b << endl;

用指標寫線段樹

相對於上一次線段樹的寫法,我獲得了用指標寫線段樹的本領。

相對於用下標來寫線段樹,指標的程式碼量、可擴充套件性都變好了。各種函式分別用下表和指標的不同寫法在下面會陳列出來。

初始化

陣列:

#define N 1000050//資料邊界
struct tree{
	long long l, r, sum, tag, tag_x;
    //l,r左右端點,sum為結點對應區間和,tag為加法標記,tag_x為乘法標記 
}t[N];//線段樹 
long long a[N];//輸入的數列(1~n) 
long long m, n, p, k;//如題意(k是操作種類) 
long long ls(long long rt){return rt << 1;}//左孩子 
long long rs(long long rt){return rt << 1 | 1;}//右孩子 lcez_yjh

指標:

struct Node{
  int l, r;
  ll v, tag;
  Node *ls, *rs;

  inline void () {} 
  inline void () {} 
  inline void () {} 
  inline void () {} 
    
}Mem[maxn << 1];

Node *pool = Mem;

相對來講,指標的實現方式程式碼更簡潔,而且在最後一層不會浪費過多的空間。線段樹開空間的時候需要對極端情況進行考慮,陣列往往開到maxn << 2,但是在指標這裡只需要開(maxn << 1) - 1

核心函式:FindRange

陣列:

void build(long long rt, long long l, long long r) {
  if (t[rt].l <= l && t[rt].r >= r) { 
	do sth.
	return;
  }
  long long mid = (t[rt].l + t[rt].r) >> 1;
  if (l <= mid)
    build(ls(rt), l, r);
  if (r > mid)
    build(rs(rt), l, r);
  pushup(rt);
}

指標:

void FindRange(L, R) {
  if (InRange(L, R)) {
    do sth.
    return;
  }
  else if (OutofRange(L, R)) return; 
  else {
    ls->FindRange(L, R);
    rs->FindRange(L, R);
  }
  pushup();
}

可以看到,直接用指標指向左右兒子的方法不受到陣列下表的限制,所以就不會有智障的四倍空間,只需要按照實際空間大小開就可以了。

記憶體池

用指標建立線段樹的方法大致有兩種,一種是直接在建樹的時候開一個新空間Node *u = new Node,但是new這個語句常數十分大,執行起來巨慢。所以我們可以提前申請好一個Node Mem[maxn << 1 - 1]的陣列,並用一個指標Node *Pool = Mem進行取出新點,從而實現代替new功能的函式:

Node* New() { return ++Pool; }

完整程式碼:

陣列:

#include <bits/stdc++.h>
#define N 1000050
using namespace std;

struct tree{
  long long l, r, sum, tag, tag_x;
  //l,r左右端點,sum為結點對應區間和,tag為加法標記,tag_x為乘法標記 
}t[N];//線段樹 
long long a[N];//輸入的數列(1~n) 
long long m, n, p = 9223372036854775807, k;//如題意(k是操作種類) 
long long ls(long long rt){return rt << 1;}//左孩子 
long long rs(long long rt){return rt << 1 | 1;}//右孩子 

void build(long long rt, long long l, long long r) {
  t[rt].tag_x = 1; t[rt].tag = 0;//初始化
  t[rt].l = l, t[rt].r = r;//建立一個結點,更新左右端點標記 
  if (l == r) { //如果到了葉子結點 
	t[rt].sum = a[l] % p; //不要忘記取模操作 
	return;
  }
  long long mid = (l + r) >> 1; //中間節點 
  build(ls(rt), l, mid);
  build(rs(rt), mid + 1, r); //如果不是葉子結點,就分別建立左右孩子 
  t[rt].sum = (t[ls(rt)].sum + t[rs(rt)].sum) % p; // 更新sum 
}

void push_down(long long rt) {
  t[ls(rt)].tag_x = (t[ls(rt)].tag_x * t[rt].tag_x) % p;
  t[rs(rt)].tag_x = (t[rs(rt)].tag_x * t[rt].tag_x) % p;//乘法懶標記更新後取模 

  t[ls(rt)].tag = (t[ls(rt)].tag * t[rt].tag_x) % p;
  t[rs(rt)].tag = (t[rs(rt)].tag * t[rt].tag_x) % p;//加法懶標記更新 

  t[ls(rt)].sum = (t[ls(rt)].sum * t[rt].tag_x) % p;
  t[rs(rt)].sum = (t[rs(rt)].sum * t[rt].tag_x) % p;//sum結點對應區間和更新

  t[rt].tag_x = 1; //父親的標記已經下傳,就歸零(因為是乘法,所以要調到1) 
	
  t[ls(rt)].tag = (t[ls(rt)].tag + t[rt].tag) % p;
  t[rs(rt)].tag = (t[rs(rt)].tag + t[rt].tag) % p;//加法懶標記更新 
	
  t[ls(rt)].sum += (t[ls(rt)].r - t[ls(rt)].l + 1) * t[rt].tag;
  t[rs(rt)].sum += (t[rs(rt)].r - t[rs(rt)].l + 1) * t[rt].tag;//sum結點對應區間和更新
    
  t[rt].tag = 0;//父親的標記已經下傳,就歸零 
}

void change(long long rt, long long x, long long y, long long z) {
  if (x <= t[rt].l && y >= t[rt].r) {
	t[rt].tag = (t[rt].tag + z) % p;
	t[rt].sum = (t[rt].sum + (t[rt].r - t[rt].l + 1) * z) % p; //如果修改區間覆蓋了這個節點的區間,就更新 
	return;
  }
  if(t[rt].tag || t[rt].tag_x != 1) push_down(rt);//訪問孩子結點的時候一定先把懶標記 傳下去 
  long long mid = (t[rt].l + t[rt].r) >> 1;
  if (x <= mid) {
	change(ls(rt),x,y,z);
  }
  if(y > mid){
	change(rs(rt),x,y,z);
  }
  //分別往左右兒子傳 

  t[rt].sum = (t[ls(rt)].sum + t[rs(rt)].sum) % p; //維護 
}

void change_x(long long rt, long long x, long long y, long long z) {
  if(x <= t[rt].l && y >= t[rt].r){
	t[rt].tag_x = (t[rt].tag_x * z) % p;
	t[rt].sum = (t[rt].sum * z) % p;
	t[rt].tag = (t[rt].tag * z) % p;//如果修改區間覆蓋了這個節點的區間,就更新 
	return;
  }
  if(t[rt].tag || t[rt].tag_x != 1) push_down(rt);//訪問孩子結點的時候一定先把懶標記 傳下去 
  long long mid = (t[rt].l + t[rt].r) >> 1;
  if (x <= mid) {
	change_x(ls(rt), x, y, z);
  }
  if (y > mid) {
	change_x(rs(rt), x, y, z);
  }
  //分別往左右兒子傳 
  t[rt].sum = (t[ls(rt)].sum + t[rs(rt)].sum) % p; //維護
}

long long getsum(long long rt, long long x, long long y) {
  long long res = 0;
  if (x <= t[rt].l && y >= t[rt].r) {
	return t[rt].sum % p;
  }
  if (t[rt].tag || t[rt].tag_x != 1) push_down(rt);
  long long mid = (t[rt].r + t[rt].l) >> 1;
  if (x <= mid) {
	res += getsum(ls(rt), x, y);
  }
  if (y > mid) {
	res += getsum(rs(rt), x, y);
  }
  return res % p;
}

int main() {
  long long i, j, x, y, z;
  scanf("%lld%lld%lld", &n, &m, &p);
  //scanf("%lld%lld", &n, &m); 
  for (i = 1; i <= n; i++) {
	scanf("%lld", &a[i]);
  }
  build(1, 1, n);
  for (i = 1; i <= m; i++) {
	scanf("%lld", &k);
		
    if (k == 1) {
	  scanf("%lld%lld%lld", &x, &y, &z);
	  change_x(1, x, y, z);
    } else if (k == 2) {
	  scanf("%lld%lld%lld", &x, &y, &z);
	  change(1, x, y, z);
  	} else if (k == 3) {
	  scanf("%lld%lld", &x, &y);
	  printf("%lld\n", getsum(1, x, y));
    }
//	if (k == 1) {
//		scanf("%lld%lld%lld",&x,&y,&z);
//		change(1,x,y,z);
//	}else if(k == 2){
//		scanf("%lld%lld",&x,&y);
//		printf("%lld\n",getsum(1,x,y));
//	}
  }
  return 0;
} 

指標:

#include <cstdio>
using namespace std;
typedef long long int ll;

const int maxn = 100005;
int n,q,p;

ll a[maxn];
struct Node{
  ll tag_a, v, tag_b;
  int l, r;
  Node *ls, *rs;
	
  Node(const int L, const int R) {
    l = L, r = R;
	tag_a = 0, tag_b = 1;
	if (l == r)	{
	  v = a[l];
	  ls = rs = NULL;
	} else {
	  int mid = (L + R) >> 1;
	  ls = new Node(L, mid);
	  rs = new Node(mid + 1, R);
	  push_up(); 
	}
  }
  inline void make_tag_1(ll w) {//+
	(v += (r - l + 1) * w) %= p;
	(tag_a += w) %= p;
  }
  inline void make_tag_2(ll w) {//*
	(v *= w) %= p;
	(tag_a *= w) %= p;
	(tag_b *= w) %= p;
  }
  inline void push_up() {
	v = ls->v + rs->v;
  }
  inline void push_down() {
	if (!tag_a && tag_b == 1) return;
//	if (tag_a && tag_b != 1) {
//	  ls->make_tag_1(tag_a);
//	  rs->make_tag_1(tag_a);
//	  tag_a = 0;
//	}
//	if (!tag_a&&tag_b) {
//	  ls->make_tag_2(tag_b);
//	  rs->make_tag_2(tag_b);
//	  tag_b = 1;
//	}
	if (tag_b != 1) {
	  ls->make_tag_2(tag_b);
	  rs->make_tag_2(tag_b);
	  tag_b = 1;
	}
	if (tag_a) {
	  ls->make_tag_1(tag_a);
	  rs->make_tag_1(tag_a);
	  tag_a = 0;
	}
  }
  inline bool InRange(const int L, const int R) { return (L <= l && r <= R); }
  inline bool Outofrange(const int L, const int R) { return (L > r || l > R); 	}
	
  inline void update_a(const int L, const int R, ll w) {//+
	if (InRange(L, R)) make_tag_1(w);
	else if (!Outofrange(L, R)) {
	  push_down();
	  ls->update_a(L, R, w);	
	  rs->update_a(L, R, w);
	  push_up();
	}
  }

  inline void update_b(const int L,const int R,ll w) {//*
	if(InRange(L, R)) make_tag_2(w);
	else if(!Outofrange(L, R)) {
	  push_down();
	  ls->update_b(L, R, w);
	  rs->update_b(L, R, w);
	  push_up();
	}
  }
  ll query(const int L, const int R) {
	if (InRange(L, R)) { return v; }
	if (Outofrange(L, R)) { return 0; }
	  push_down();
	return ls->query(L, R) + rs->query(L, R);
  }	
};

int main() {
  scanf("%lld%lld%lld", &n, &q, &p);
  for(int i = 1; i <= n; i++)
	scanf("%d", a + i);
  Node *rot = new Node(1, n);
	
  for(ll o, x, y, z; q; q--) {
	scanf("%lld%lld%lld", &o, &x, &y);
	if (o == 1)	{	
  	  scanf("%lld", &z);
	  rot->update_b(x, y, z);
	}
	if (o == 2) {	
	  scanf("%lld", &z);
	  rot->update_a(x, y, z);
	}
	if (o == 3)	{
	  ll m = rot->query(x, y);
	  printf("%lld\n", m % p);
	}
  }
  return 0;
}