当前位置 : 主页 > 手机开发 > harmonyos >

ZOJ 3649

来源:互联网 收集:自由互联 发布时间:2023-10-08
倍增法DP #includecstdio#includecstring#includealgorithm#define N 50100#define POW 17#define inf 1000000using namespace std;int val[N];int head[N],cnt;int p[N][POW],dep[N],mi[N][POW],mx[N][POW],dp[N][POW],dp1[N][POW],f[N];struct relation{


倍增法DP

#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 50100
#define POW 17
#define inf 1000000
using namespace std;
int val[N];
int head[N],cnt;
int p[N][POW],dep[N],mi[N][POW],mx[N][POW],dp[N][POW],dp1[N][POW],f[N];
struct relation{
    int u,v,w;
}r[N];
struct Edge{
    int v,next;
}edge[N*2];
void addedge(int u,int v){
    edge[cnt].v=v;
    edge[cnt].next=head[u];
    head[u]=cnt++;
    edge[cnt].v=u;
    edge[cnt].next=head[v];
    head[v]=cnt++;
}
void init(int n){
    int i,j;
    memset(head,-1,sizeof(head));
    memset(dep,0,sizeof(dep));
    memset(p,0,sizeof(p));
    for(i=1;i<=n;i++)f[i]=i;
    for(i=1;i<=n;i++)
        for(j=0;j<POW;j++)
            mi[i][j]=inf,dp[i][j]=dp1[i][j]=mx[i][j]=0;
    cnt=0;
}
void dfs(int u,int fa){
    int i,j;
    dep[u]=dep[fa]+1;
    for(i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v==fa)continue;
        p[v][0]=u;
        mi[v][0]=min(val[u],val[v]);
        mx[v][0]=max(val[u],val[v]);
        dp[v][0]=val[u]-val[v];
        dp1[v][0]=val[v]-val[u];
        for(j=1;(1<<j)<=dep[u]+1;j++){
            p[v][j]=p[p[v][j-1]][j-1];
            mi[v][j]=min(mi[v][j-1],mi[p[v][j-1]][j-1]);
            mx[v][j]=max(mx[v][j-1],mx[p[v][j-1]][j-1]);

            dp[v][j]=max(dp[v][j-1],dp[p[v][j-1]][j-1]);
            dp[v][j]=max(dp[v][j],mx[p[v][j-1]][j-1]-mi[v][j-1]);

            dp1[v][j]=max(dp1[v][j-1],dp1[p[v][j-1]][j-1]);
            dp1[v][j]=max(dp1[v][j],mx[v][j-1]-mi[p[v][j-1]][j-1]);
        }
        dfs(v,u);
    }
}
int LCA(int u,int v){
    int i;
    if(dep[u]>dep[v]) u^=v,v^=u,u^=v;
    if(dep[u]<dep[v]){
        int del=dep[v]-dep[u];
        for(i=0;i<POW;i++)
            if(del & (1<<i))
                v=p[v][i];
    }
    if(u!=v){
        for(i=POW-1;i>=0;i--)
            if(p[u][i]!=p[v][i])
                u=p[u][i],v=p[v][i];
        u=p[u][0],v=p[v][0];
    }
    return u;
}
int getmaxdp(int u,int v){
    int ans=0,i,tmp=inf; // tmp记录最小值
    int del=dep[u]-dep[v];
    for(i=POW-1;i>=0;i--)
        if(del & (1<<i)){
            ans=max(ans,dp[u][i]);
            ans=max(ans,mx[u][i]-tmp);
            tmp=min(tmp,mi[u][i]);
            u=p[u][i];
        }
    return ans;
}
int getmaxdp1(int u,int v){
    int ans=0,i,tmp=0; // tmp记录最大值
    int del=dep[u]-dep[v];
    for(i=POW-1;i>=0;i--)
        if(del & (1<<i)){
            ans=max(ans,dp1[u][i]);
            ans=max(ans,tmp-mi[u][i]);
            tmp=max(tmp,mx[u][i]);
            u=p[u][i];
        }
    return ans;
}
int getmx(int u,int v){
    int ans=0,i;
    int del=dep[u]-dep[v];
    for(i=POW-1;i>=0;i--)
        if(del & (1<<i)){
            ans=max(ans,mx[u][i]);
            u=p[u][i];
        }
    return ans;
}
int getmi(int u,int v){
    int ans=inf,i;
    int del=dep[u]-dep[v];
    for(i=POW-1;i>=0;i--)
        if(del & (1<<i)){
            ans=min(ans,mi[u][i]);
            u=p[u][i];
        }
        return ans;
}
void solve(int u,int v){
    int lca=LCA(u,v);
    int a,b,c,d;
    a=getmaxdp(u,lca);
    b=getmaxdp1(v,lca);
    c=getmi(u,lca);
    d=getmx(v,lca);
    printf("%d\n",max(max(a,b),d-c));
}
bool cmp(struct relation a,struct relation b){
    return a.w>b.w;
}
int find(int u){
    if(f[u]==u)return u;
    return f[u]=find(f[u]);
}
int main(){
    int i,u,v,n,m,q;
    while(scanf("%d",&n)==1){
        init(n);
        for(i=1;i<=n;i++)
            scanf("%d",&val[i]);
        scanf("%d",&m);
        for(i=1;i<=m;i++)
            scanf("%d %d %d",&r[i].u,&r[i].v,&r[i].w);
        sort(r+1,r+1+m,cmp);
        int w=0;
        for(i=1;i<=m;i++){
            int uu=find(r[i].u),vv=find(r[i].v);
            if(uu!=vv){
                addedge(r[i].u,r[i].v);
                w+=r[i].w;
                f[uu]=vv;
            }
        }
        printf("%d\n",w);
        dfs(1,0);
        scanf("%d",&q);
        for(i=1;i<=q;i++){
            scanf("%d %d",&u,&v);
            solve(u,v);
        }
    }
    return 0;
}






上一篇:light oj 1128
下一篇:没有了
网友评论