思路十分简单,答案只有 3 种可能,但是有一些细节需要额外注意一下. code: #include bits/stdc++.h#define N 300002 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int val[N],hd[N],to[N1],nex[N1],d1[N
思路十分简单,答案只有 3 种可能,但是有一些细节需要额外注意一下.
code:
#include <bits/stdc++.h> #define N 300002 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int val[N],hd[N],to[N<<1],nex[N<<1],d1[N],d2[N],n,edges,maxx,mx,m2,cnt,uu; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs(int u,int ff) { if(val[u]==mx) d1[u]=0, uu=u; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs(v,u); if(d1[v]+1>d1[u]) { d2[u]=d1[u],d1[u]=d1[v]+1; } else if(d1[v]+1>d2[u]) d2[u]=d1[v]+1; } maxx=max(d1[u]+d2[u], maxx); } int main() { int i,j; // setIO("input"); mx=-1000300000; m2=mx; scanf("%d",&n); for(i=1;i<=n;++i) { scanf("%d",&val[i]),mx=max(mx,val[i]); } for(i=1;i<=n;++i) if(val[i]<mx) m2=max(m2, val[i]); for(i=1;i<=n;++i) if(val[i]==m2) ++cnt; for(i=1;i<n;++i) { int u,v; scanf("%d%d",&u,&v),add(u,v),add(v,u); } memset(d1,-0x3f,sizeof(d1)); memset(d2,-0x3f,sizeof(d2)); dfs(1,0); if(maxx==0) { if(m2!=mx-1) printf("%d\n",mx); else { for(int i=hd[uu];i;i=nex[i]) { int v=to[i]; if(val[v]==m2) --cnt; } if(cnt) printf("%d\n",mx+1); else printf("%d\n",mx); } } else if(maxx<=2) printf("%d\n",mx+1); else printf("%d\n",mx+2); return 0; }