1. 程式人生 > >hdu3507 斜率優化學習筆記

hdu3507 斜率優化學習筆記

QWQ菜的真實。

首先來看這個題。
很顯然能得到一個樸素的 d p dp 柿子

d p [ i

] = m a x ( d p [ i
] , d p [ j ] + ( s
u m [ i ] s u m [ j ] ) 2 ) dp[i]=max(dp[i],dp[j]+(sum[i]-sum[j])^2)

但是因為 n 500000 n\le 500000 ,所以 n 2 n^2 一定是過不了的。
考慮應該怎麼優化。

考慮什麼時候存在一個 j > k j k j>k且j比k更優秀

d p [ j ] + ( s u m [ i ] s u m [ j ] ) 2 < d p [ k ] + ( s u m [ i ] s u m [ k ] ) 2 dp[j]+(sum[i]-sum[j])^2<dp[k]+(sum[i]-sum[k])^2

我們進行化簡
2 × s [ i ] × ( s [ j ] s [ k ] ) > d p [ j ] + s u m [ j ] 2 d p [ k ] s u m [ k ] 2 2\times s[i] \times (s[j]-s[k]) > dp[j]+sum[j]^2-dp[k]-sum[k]^2

由於權值都是正數,所以 s [ j ] s [ k ] > 0 s[j]-s[k]>0
我們設 f [ x ] = s u m [ x ] 2 + d p [ x ] f[x]=sum[x]^2+dp[x]
則上述柿子等於 2 × s [ i ] > f [ j ] f [ k ] s [ j ] s [ k ] 2\times s[i]>\frac{f[j]-f[k]}{s[j]-s[k]}

觀察到右邊這個柿子是一個斜率的形式。
我們可以直接用單調佇列維護一個下凸殼。

對於每次插入一個點,運用叉積進行 c h e c k check ,保證斜率是單調不降的。

int chacheng(Point x,Point y)
{
	return x.x*y.y-y.x*x.y;
}
bool count(Point i,Point j,Point k)
{
	Point x,y;
	x.x=(k.x-i.x);
	x.y=(k.y-i.y);
	y.x=(k.x-j.x);
	y.y=(k.y-j.y);
	if (chacheng(x,y)<=0) return true; 
	return false;
	// if ((double)(k.y-j.y)/(double)(k.x-j.x)<(double)(j.y-i.y)/(double)(j.x-i.x)) return true;
	//return false;
}
void push(Point x)
{
	while (tail>=head+1 && count(q[tail-1],q[tail],x)) tail--;
	q[++tail]=x;
}

刪除的話,只需要通過上面那個柿子,若存在 q [ h e a d + 1 ] q [ h e a d ] q[head+1]比q[head] 優秀,就彈出隊首元素

void pop(int lim)
{
	while (tail>=head+1 && (q[head+1].y-q[head].y)<=lim*(q[head+1].x-q[head].x)) head++;
}

剩下的就是 d p dp 部分

qwq因為一些奇奇怪怪的問題 W A WA 了一上午
xtbl

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk make_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
  int x=0,f=1;char ch=getchar();
  while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
  while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return x*f;
}
const int maxn = 1e6+1e2;
struct Point{
	int x,y;
};
Point q[maxn];
int dp[maxn];
int sum[maxn];
int val[maxn];
int n,m;
int head=1,tail=0;
int chacheng(Point x,Point y)
{
	return x.x*y.y-y.x*x.y;
}
bool count(Point i,Point j,Point k)
{
	Point x,y;
	x.x=(k.x-i.x);
	x.y=(k.y-i.y);
	y.x=(k.x-j.x);
	y.y=(k.y-j.y);
	if (chacheng(x,y)<=0) return true; 
	return false;
	// if ((double)(k.y-j.y)/(double)(k.x-j.x)<(double)(j.y-i.y)/(double)(j.x-i.x)) return true;
	//return false;
}
void push(Point x)
{
	while (tail>=head+1 && count(q[tail-1],q[tail],x)) tail--;
	q[++tail]=x;
}
void pop(int lim)
{
	while (tail>=head+1 && (q[head+1].y-q[head].y)<=lim*(q[head+1].x-q[head].x)) head++;
}
signed main()
{
  while (scanf("%lld%lld",&n,&m)!=EOF)
  {
  	memset(q,0,sizeof(q));
  	memset(dp,0,sizeof(dp)); 
  	memset(sum,0,sizeof(sum));
  	head=1,tail=0;
  	//n=read();m=read();
  	for (int i=1;i<=n;i++) val[i]=read();
  	for (int i=1;i<=n;i++) sum[i]=sum[i-1]+val[i];
 	dp[0]=0;
  	push((Point){0,0});
  	for (int i=1;i<=n;i++)
  	{
  	 	pop(2ll*sum[i]);
  	 	dp[i]=q[head].y-q[head].x*q[head].x+m+(sum[i]-q[head].x)*(sum[i]-q[head].x);
  	 	push((Point){sum[i],dp[i]+sum[i]*sum[i]});
  	 //cout<<i<<" "<<dp[i]<<" "<<q[head].x<<" "<<q[head].y<<" "<<head<<" "<<tail<<endl;
  	}
  	cout<<dp[n]<<"\n";
  }
  return 0;
}