1. 程式人生 > >[BZOJ3533][Sdoi2014]向量集(凸包+線段樹+二分)

[BZOJ3533][Sdoi2014]向量集(凸包+線段樹+二分)

Address

洛谷P3309
BZOJ3533
LOJ#2197

Solution

先假設詢問物件是所有的向量,並且已經全部加入集合。
發現向量 ( x , y ) (x,y)

和向量 ( a , b ) (a,b) 的點積,就等於過點 ( x
, y ) (x,y)
的、斜率為 a b -\frac ab
的直線在 y y 軸上截距的 b b 倍。
b > 0 b>0 時,要最大化截距,所以要在集合內所有點的上凸殼上二分找到答案。
b < 0 b<0 時,要最小化截距,所以要在集合內所有點的下凸殼上二分。
b = 0 b=0 時,我們只需要最大化 a x ax ,所以在上凸殼和下凸殼都行。
如果加入「第 L L 個到第 R R 個加入的向量」這一限制,又怎麼做呢?
線段樹!!!!!
線段樹每個節點維護對應區間內的點構成的上下凸殼。
但還有一個困難:每次加入的點的 x x 座標沒有單調性,所以不能直接增量維護凸殼,而如果每一次更新都將某個節點 [ l , r ] [l,r] ,合併 [ l , m i d ] [l,mid] [ m i d + 1 , r ] [mid+1,r] 的上下凸殼到 [ l , r ] [l,r] ,那麼插入的複雜度將是 O ( n ) O(n) 的,沒有任何改進。
但繼續思考,我們只需要在末尾加向量而不是修改一個向量,可以怎樣優化?
發現:對於每個線段樹上的節點對應的區間 [ l , r ] [l,r] ,如果已經加入的向量數 T < r T<r ,那麼這個點在加入其他向量之前一定不會被使用到。
所以,我們的策略是:線段樹上的節點 p p ,如果 p p 對應的區間為 [ l , r ] [l,r] ,當且僅當 T = r T=r 時才從 p p 的兩個子節點合併。
這樣,每個節點都只進行了一次合併。
查詢時只需要找到區間 [ L , R ] [L,R] 線上段樹上拆成的不超過 O ( log n ) O(\log n) 個區間後在這些區間的 上 / 下 凸殼上二分查詢最大值即可。
時間複雜度 O ( n log 2 n ) O(n\log^2n)

Code

#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define p2 p << 1
#define p3 p << 1 | 1

inline int read()
{
	int res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	return bo ? ~res + 1 : res;
}

inline char get()
{
	char c;
	while ((c = getchar()) != 'A' && c != 'Q');
	return c;
}

template <class T>
inline T Max(const T &a, const T &b) {return a > b ? a : b;}

typedef long long ll;

const int N = 4e5 + 5, M = N << 2;

int n, T;

struct point
{
	int x, y;
	
	friend inline point operator - (point a, point b)
	{
		return (point) {b.x - a.x, b.y - a.y};
	}
	
	friend inline ll operator * (point a, point b)
	{
		return 1ll * a.x * b.y - 1ll * a.y * b.x;
	}
};

std::vector<point> up[M], dn[M];
bool isfull[M];
char s[N];

void add_up(int p, point x)
{
	int top = up[p].size() - 1;
	while (top > 0 && (up[p][top - 1] - up[p][top]) * (up[p][top - 1] - x) >= 0)
		top--, up[p].pop_back();
	up[p].push_back(x);
}

void add_dn(int p, point x)
{
	int top = dn[p].size() - 1;
	while (top > 0 && (dn[p][top - 1] - dn[p][top]) * (dn[p][top - 1] - x) <= 0)
		top--, dn[p].pop_back();
	dn[p].push_back(x);
}

void merge_up(int p)
{
	int i, n1 = up[p2].size(), n2 = up[p3].size(), q1 = 0, q2 = 0;
	For (i, 1, n1 + n2)
		if (q2 == n2 || (q1 < n1 &&
			(up[p2][q1].x < up[p3][q2].x || (up[p2][q1].x == up[p3][q2].x
				&& up[p2][q1].y < up[p3][q2].y))))
					add_up(p, up[p2][q1]), q1++;
		else add_up(p, up[p3][q2]), q2++;
}

void merge_dn(int p)
{
	int i, n1 = dn[p2].size(), n2 = dn[p3].size(), q1 = 0, q2 = 0;
	For (i, 1, n1 + n2)
		if (q2 == n2 || (q1 < n1 &&
			(dn[p2][q1].x < dn[p3][q2].x || (dn[p2][q1].x == dn[p3][q2].x
				&& dn[p2][q1].y < dn[p3][q2].y))))
					add_dn(p, dn[p2][q1]), q1++;
		else add_dn(p, dn[p3][q2]), q2++;
}

ll findmax_up(int p, int a, int b)
{
	int l = 0, r = up[p].size() - 2;
	while (l <= r)
	{
		int mid = l + r >> 1;
		if (1ll * up[p][mid].x * a + 1ll * up[p][mid].y * b
			>= 1ll * up[p][mid + 1].x * a + 1ll * up[p][mid + 1].y * b)
				r = mid - 1;
		else l = mid + 1;
	}
	return 1ll * up[p][l].x * a + 1ll * up[p][l].y * b;
}

ll findmax_dn(int p, int a, int b)
{
	int l = 0, r = dn[p].size() - 2;
	while (l <= r)
	{
		int mid = l + r >> 1;
		if (1ll * dn[p][mid].x * a + 1ll * dn[p][mid].y * b
			>= 1ll * dn[p][mid + 1].x * a + 1ll * dn[p][mid + 1].y * b)
				r = mid - 1;
		else l = mid + 1;
	}
	return 1ll * dn[p][l].x * a + 1ll * dn[p][l].y * b;
}

void addpoint(int l, int r, int pos, point x, int p)
{
	if (l == r)
	{
		up[p].push_back(x); dn[p].push_back(x);
		return (void) (isfull[p] = 1);
	}
	int mid = l + r >> 1;
	if (pos <= mid) addpoint(l, mid, pos, x, p2);
	else addpoint(mid + 1, r, pos, x, p3);
	if (isfull[p2] && isfull[p3])
		merge_up(p), merge_dn(p), isfull[p] = 1;
}

ll querymax(int l, int r, int s, int e, int a, int b, int p)
{
	if (l == s && r == e)
		return b >= 0 ? findmax_up(p, a, b) : findmax_dn(p, a, b);
	int mid = l + r >> 1;
	if (e <= mid) return querymax(l, mid, s, e, a, b, p2);
	else if (s >= mid + 1) return querymax(mid + 1, r, s, e,