线段树上的均摊分析

发布于

概述

某些问题使用比较暴力的方法也可以通过,这种解法的复杂度均摊之后是对的。分析方式一般需要用到势能分析法。

例子 1

给定长度为 $n$ 的数组 $a_i$($a_i \leq V$),$m$ 次操作。

操作 1:区间对给定数 $x$ 取模。

操作 2:查询区间和。

操作 2:单点加。

$1 \leq n, m \leq 10^5, V \leq 10^9$

不知道这能不能算势能分析。

对于操作 1:如果一个数大于等于 $x$,那么取模至少让它减半。因此一个数最多被取模 $\mathcal O\left(\log V\right)$ 次。

对于操作 3:把一个数单点加会让这个数会满血,但就只有这一个数。所以一次单点加最多会让这个数再能被模 $\mathcal O\left(\log V\right)$ 次,复杂度加上个 $\mathcal O\left(m \log V\right)$。

代码懒得写了,从网上扒了一份。

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cassert>
#include <ctime>
#include <cmath>
#include <algorithm>
#include <string>
#include <vector>
#include <deque>
#include <queue>
#include <list>
#include <set>
#include <map>
#include <iostream>

#define pb push_back
#define mp make_pair
#define TASKNAME ""

#ifdef LOCAL
#define eprintf(...) fprintf(stderr,__VA_ARGS__)
#else
#define eprintf(...)
#endif

#define TIMESTAMP(x) eprintf("[" #x "] Time = %.3lfs\n",clock()*1.0/CLOCKS_PER_SEC)

#ifdef linux
#define LLD "%lld"
#else
#define LLD "%I64d"
#endif

#define sz(x) ((int)(x).size())

using namespace std;

typedef long double ld;
typedef long long ll;
typedef vector<ll> vll;
typedef vector<int> vi;
typedef vector<vi> vvi;
typedef vector<bool> vb;
typedef vector<vb> vvb;
typedef pair<int, int> pii;
typedef pair <ll, ll> pll;
typedef pair <ll, ll> pli;
typedef vector<pii> vpii;

const int inf = 1e9;
const double eps = 1e-9;
const double INF = inf;
const double EPS = eps;

pli operator + (const pli& a, const pli& b)
{
  return mp(a.first+b.first,max(a.second,b.second));
}

pli A[(1<<18)+100];
                     
ll sum (int l, int r)
{
  l+=(1<<17), r+=(1<<17);
  ll res=0;
 // cerr<<l<<"  :::  "<<r<<endl;
  while (l<r)
  {
    //cerr<<r-1<<" "<<A[r-1].first<<endl;
    if (l&1)
      res+=A[l].first, l++;
    if (r&1)
      r--, res+=A[r].first;
    l>>=1, r>>=1;
  }
  return res;
}

void mod (int v, int L, int R, int l, int r, int x)
{
  if (R<=l || r<=L || A[v].second<x)
    return;
  if (R-L==1)
  {
    #ifdef LOCAL
    assert(A[v].second==A[v].first && A[v].second>=x);
    #endif
    //cerr<<"?? "<<L<<" "<<x<<" "<<A[v].first<<"    "<<v<<endl;
    A[v].second%=x;
    A[v].first=A[v].second;
    return;
  }
  mod(v<<1,L,(L+R)>>1,l,r,x);
  mod((v<<1)|1,(L+R)>>1,R,l,r,x);
  A[v]=A[v<<1]+A[(v<<1)|1];
}

int main()
{
  int n, m, l, r, i, x, v, tp;
  #ifdef LOCAL
  freopen(TASKNAME".in","r",stdin);
  freopen(TASKNAME".out","w",stdout);
  #endif
  scanf("%d%d", &n, &m);
  for (i=0; i<n; i++)
    scanf("%d", &v), A[i+(1<<17)]=mp(v,v);
  for (i=(1<<17)-1; i>0; i--)
    A[i]=A[2*i]+A[2*i+1];
  while (m) 
  {
    m--;
    scanf("%d", &tp);
    if (tp==1)
    {
      scanf("%d%d", &l, &r), l--;
      printf(LLD"\n", sum(l,r));
      continue;
    }
    if (tp==2)
    {
      scanf("%d%d%d", &l, &r, &x), l--;
      mod(1,0,(1<<17),l,r,x);      
      continue;
    }
    assert(tp==3);
    scanf("%d%d", &v, &x), v+=(1<<17)-1, A[v]=mp(x,x), v>>=1;
    while (v)
      A[v]=A[2*v]+A[2*v+1], v>>=1;
  }
  TIMESTAMP(end);
  return 0;
}

例子 2

给定长度为 $n$ 的数组 $a_i$($a_i \leq V$),$m$ 次操作。

操作 1:区间加。

操作 2:区间开根号。

操作 3:查询区间和。

$1 \leq n, m \leq 10^5, V \leq 10^9$

势能分析。

定义线段树中极差大于 $1$ 的节点为“关键点”。

区间加 和 区间开根号都只能让“关键点”个数增加 $\mathcal O\left(\log n\right)$,因为只有在区间边缘的点可能变成“关键点”。而一个节点开 $\log \log V$ 次根号就会变成 $1$(主定理解 $T\left(N\right) = T\left(\sqrt N\right) + 1$)。因此对着关键点暴力开根的复杂度就是 $\mathcal O\left(m \log n \log \log V\right)$。

代码懒得写了,从网上扒了一份。

//minamoto
#include<bits/stdc++.h>
#define R register
#define ll long long
#define ls (p<<1)
#define rs (p<<1|1)
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
inline ll max(const R ll &x,const R ll &y){return x>y?x:y;}
inline ll min(const R ll &x,const R ll &y){return x<y?x:y;}
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
    R int res,f=1;R char ch;
    while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
    for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
    return res*f;
}
char sr[1<<21],z[20];int C=-1,Z=0;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
void print(R ll x){
    if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
    while(z[++Z]=x%10+48,x/=10);
    while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
const int N=1e5+5;
struct node{int len;ll sum,tag,mn,mx;}tr[N<<2];
int n,m,a[N],ql,qr,val,op;ll ans;
inline void pd(R node &x,R ll v){x.tag+=v,x.mn+=v,x.mx+=v,x.sum+=v*x.len;}
void upd(R int p){
    tr[p].sum=tr[ls].sum+tr[rs].sum+tr[p].tag*tr[p].len;
    tr[p].mx=max(tr[ls].mx,tr[rs].mx)+tr[p].tag;
    tr[p].mn=min(tr[ls].mn,tr[rs].mn)+tr[p].tag;
}
void build(int p,int l,int r){
    tr[p].len=r-l+1;
    if(l==r)return (void)(tr[p].sum=tr[p].mx=tr[p].mn=a[l]);
    int mid=(l+r)>>1;
    build(ls,l,mid),build(rs,mid+1,r);
    upd(p);
}
void update(int p,int l,int r){
    if(ql<=l&&qr>=r)return pd(tr[p],val);
    int mid=(l+r)>>1;
    if(ql<=mid)update(ls,l,mid);
    if(qr>mid)update(rs,mid+1,r);
    upd(p);
}
void Sqrt(int p,int l,int r,ll tag){
    if(ql<=l&&qr>=r){
        if(tr[p].mx==tr[p].mn){
            ll del=tr[p].mn+tag-(ll)sqrt(tr[p].mn+tag);
            return pd(tr[p],-del);
        }
        ll c1=sqrt(tr[p].mn+tag)+1,c2=sqrt(tr[p].mx+tag);
        if(tr[p].mx==tr[p].mn+1&&c1==c2){
            ll del=tr[p].mn+tag-(ll)sqrt(tr[p].mn+tag);
            return pd(tr[p],-del);
        }
    }
    int mid=(l+r)>>1;
    if(ql<=mid)Sqrt(ls,l,mid,tag+tr[p].tag);
    if(qr>mid)Sqrt(rs,mid+1,r,tag+tr[p].tag);
    upd(p);
}
void query(int p,int l,int r,ll tag){
    if(ql<=l&&qr>=r)return (void)(ans+=tr[p].sum+tr[p].len*tag);
    int mid=(l+r)>>1;
    if(ql<=mid)query(ls,l,mid,tag+tr[p].tag);
    if(qr>mid)query(rs,mid+1,r,tag+tr[p].tag);
}
int main(){
//    freopen("testdata.in","r",stdin);
    n=read(),m=read();
    fp(i,1,n)a[i]=read();
    build(1,1,n);
    while(m--){
        op=read(),ql=read(),qr=read();
        switch(op){
            case 1:val=read(),update(1,1,n);break;
            case 2:Sqrt(1,1,n,0);break;
            case 3:ans=0;query(1,1,n,0);print(ans);break;
        }
    }return Ot(),0;
}

例子 3

给定长度为 $n$ 的数组 $a_i$($a_i \leq V$),$m$ 次操作。

操作 1:区间加。

操作 2:查询区间 $\gcd$。

$1 \leq n, m \leq 10^5, V \leq 10^9$

考虑一段区间 $\left[l, r\right]$ 的 $\gcd$:
$$\gcd \left\{a_l, a_{l + 1}, a_{l + 2}, \ldots, a_r\right\} = \gcd \left\{a_l, a_{l + 1} - a_l, a_{l + 2} - a_{l + 1}, \ldots, a_r - a_{r - 1}\right\}$$

这是差分的形式。因此一次区间加相当于是两个单点改。

代码懒得写了,从网上扒了一份。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 5e5 + 23;
struct Node{
    int l, r;
    ll v, d;
}tr[maxn * 4];
ll a[maxn], b[maxn];
void pushup(Node &u, Node &l, Node &r)
{
    u.v = l.v + r.v;
    u.d = __gcd(l.d, r.d);
}
void pushup(int u)
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
    //printf("%d[%d,%d]=%d\n",u,tr[u].l,tr[u].r,tr[u].d);
}
void build(int u, int l, int r)
{
    if(l == r) tr[u] = {l, r, b[l], b[l]};
    else 
    {
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}
ll query(int u, int l, int r)
{
    if(tr[u].l >= l && tr[u].r <= r) return tr[u].d;
    else 
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if(r <= mid) return query(u << 1, l, r);
        else if(l > mid) return query(u << 1 | 1, l, r);
        else return __gcd(query(u << 1, l, r), query(u << 1 | 1, l, r));
    }
}
ll query2(int u, int l, int r)
{
    if(tr[u].l >= l && tr[u].r <= r) return tr[u].v;
    else 
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if(r <= mid) return query2(u << 1, l, r);
        else if(l > mid) return query2(u << 1 | 1, l, r);
        else return query2(u << 1, l, r) + query2(u << 1 | 1, l, r);
    }
}
void modify(int u, int p, ll v)
{
    if(tr[u].l == tr[u].r && tr[u].l == p) tr[u].d += v, tr[u].v += v;
    else 
    {
        int mid = tr[u].l + tr[u].r >> 1;
        if(p <= mid) modify(u << 1, p, v);
        else modify(u << 1 | 1, p, v);
        pushup(u);
    }
}
int main()
{
    int n, m;
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++) scanf("%lld", &a[i]), b[i] = a[i] - a[i - 1];
    build(1, 1, n);
    char op[4];
    ll op1 = 0, op2 = 0, op3 = 0;
    while(m--)
    {
        scanf("%s%lld%lld", op, &op1, &op2);
        if(*op == 'C') 
        {
            scanf("%lld", &op3);
            modify(1, op1, op3);
            if(op2 + 1 <= n) modify(1, op2 + 1, -op3);
        }
        else 
        {
            ll t = query2(1, 1, op1); //cout << t << endl;
            printf("%lld\n", abs( __gcd(t, query(1, op1 + 1, op2))));
        }
    }
}

// 作者:这个显卡不太冷

暂无评论

发表评论