2026-05-04:树组的交互代价总和。用go语言,给定一个整数 n,以及一棵有 n 个节点的无向树,节点编号为 0 到 n-1。树的结构由数组 edges 表示:数组长度为 n-1,其中 edges[i] = [u, v] 表示节点 u 与节点 v 之间有一条无向边。
另给定一个数组 group,长度为 n。group[i] 表示节点 i 所属的组。
如果两个节点 u 和 v 满足 group[u] == group[v],则称它们属于同一组。由于是树结构,任意两个节点之间都存在且仅存在一条唯一路径。所谓交互代价定义为:这条唯一路径上包含的边的数量。
目标:枚举所有无序且不同的节点对 (u, v),要求它们同组(group[u] == group[v])。把这些节点对的交互代价全部累加并返回总和;若不存在满足条件的节点对,则返回 0。
1 <= n <= 100000。
edges.length == n - 1。
edges[i] = [ui, vi]。
0 <= ui, vi <= n - 1。
group.length == n。
1 <= group[i] <= 20。
输入保证 edges 表示一棵有效的树。
输入: n = 3, edges = [[0,1],[1,2]], group = [3,2,3]。
输出: 2。
解释:
节点 0 和节点 2 属于组 3,它们之间的交互代价为 2。
节点 1 属于不同的组,因此没有有效的节点对。
总交互代价为 2。
题目来自力扣3786。
代码执行全流程详细拆解第一步:构建原始树的邻接表
1. 初始化一个长度为
n的二维数组,作为无向树的邻接表;2. 遍历所有边,把每条边的两个节点互相添加到对方的邻接列表中;
3. 作用:让程序可以快速访问每个节点的所有相邻节点,为后续树的遍历做准备。
这一步是为了快速求任意两点的最近公共祖先(LCA)和两点间路径长度,是树问题的核心预处理。
子步骤1:DFS遍历树,记录核心信息
1. 定义递归函数,从根节点0开始遍历整棵树;
2. 记录每个节点的DFS时间戳(dfn):用于后续给节点排序,是构建虚树的关键;
3. 记录每个节点的父节点(第一层祖先);
4. 记录每个节点的深度(dep):根节点深度为0,子节点深度=父节点+1;
5. 深度的作用:两点间路径边数 =
深度[u] + 深度[v] - 2×深度[LCA(u,v)]。
1. 计算树的最大深度对应的二进制位数,确定倍增的层数;
2. 预处理每个节点的
2^i级祖先(2级、4级、8级...祖先);3. 作用:实现O(logn)时间查询任意两个节点的最近公共祖先(LCA)。
1. 封装两个工具函数:
• 把节点向上提升到指定深度;
• 求任意两个节点的最近公共祖先;
2. 这是后续计算路径长度、构建虚树的基础工具。
第三步:按分组归类所有节点
1. 创建一个哈希表(字典),key=组号,value=该组所有节点的列表;
2. 遍历所有节点,把每个节点按照
group数组的值,放入对应组的列表中;3. 作用:后续只需要逐组计算,组号最多20个,极大减少计算量。
因为组号只有20个,我们逐个组处理,每组独立计算贡献:
虚树作用:只保留当前组的节点 + 这些节点之间路径的必要公共祖先,剔除无关节点,把大树压缩成小树,大幅降低计算量。单组构建虚树的完整步骤
1.按DFS时间戳排序:把当前组的所有节点,按照第一步记录的
dfn从小到大排序;2.初始化栈和虚树:用根节点作为栈的初始元素,清空虚树结构;
3.标记关键节点:把当前组的节点标记为「真实关键节点」;
4.栈+LCA构建虚树:
• 遍历排序后的每个节点;
• 计算栈顶节点与当前节点的LCA(路径拐点);
• 不断回溯栈,给虚树添加边,直到栈顶深度小于LCA深度;
• 如果LCA不在栈中,将其加入栈和虚树;
• 最后把当前节点入栈;
5.收尾加边:遍历结束后,把栈中剩余节点依次连边,完成虚树构建。
第五步:在虚树上DFS计算本组的交互代价(贡献法)
这是计算答案的核心步骤,使用贡献法:不枚举所有节点对(会超时),而是计算每条边被多少对节点经过,总代价 = 边数 × 经过的节点对数。
计算步骤
1. 定义递归DFS函数,遍历当前组的虚树;
2. 递归统计每个子树中当前组的节点数量;
3. 对于虚树上的每一条边:
• 边的实际长度 = 子节点深度 - 父节点深度(对应原始树的边数);
• 设子树内有
sz个本组节点,本组总节点数为total;• 这条边会被
sz × (total - sz)对节点经过;• 本组总代价 += 边长度 ×
sz × (total - sz);
4. 把本组的代价累加到全局答案中;
5. 一组计算完成后,重置虚树,开始处理下一组。
第六步:所有组计算完成,返回最终答案
1. 遍历完所有组(最多20组);
2. 全局累加的结果就是所有同组节点对的交互代价总和;
3. 示例中仅组3贡献了2,最终输出2。
O(n × logn + G × k × logk)
拆解说明:
1.预处理LCA:DFS遍历树是
O(n),倍增预处理是O(n × logn);2.分组归类:
O(n);3.构建虚树+计算贡献:
• 组数量
G ≤ 20(题目限定);• 每组节点数
k,排序O(k logk),构建虚树+DFSO(k);• 所有组总耗时
O(n logn);
4. 整体主导项:O(n × logn),完全满足n=1e5的时间要求。
二、总额外空间复杂度
O(n × logn)
拆解说明:
1. 邻接表:
O(n);2. 倍增祖先数组:
O(n × 17)(log₂(1e5)≈17),是核心空间开销;3. DFN、深度数组、虚树、栈、哈希表:均为
O(n);4. 整体空间复杂度由倍增数组主导:
O(n × logn)。
1. 整体流程:建原始树 → 预处理LCA → 按组分节点 → 每组建虚树压缩 → 贡献法算代价 → 累加答案;
2. 核心优化:利用
group[i]≤20的限定,逐组处理+虚树压缩,避免暴力枚举节点对;3. 时间复杂度:O(n logn),高效处理1e5节点;
4. 空间复杂度:O(n logn),符合算法题常规空间要求。
package main
import (
"fmt"
"math/bits"
"slices"
)
func interactionCosts(n int, edges [][]int, group []int) (ans int64) {
g := make([][]int, n)
for _, e := range edges {
v, w := e[0], e[1]
g[v] = append(g[v], w)
g[w] = append(g[w], v)
}
dfn := make([]int, n)
ts := 0
pa := make([][17]int, n)
dep := make([]int, n)
var build func(int, int)
build = func(v, p int) {
dfn[v] = ts
ts++
pa[v][0] = p
for _, w := range g[v] {
if w != p {
dep[w] = dep[v] + 1
build(w, v)
}
}
}
build(0, -1)
mx := bits.Len(uint(n))
for i := range mx - 1 {
for v := range pa {
p := pa[v][i]
if p != -1 {
pa[v][i+1] = pa[p][i]
} else {
pa[v][i+1] = -1
}
}
}
uptoDep := func(v, d int)int {
for k := uint32(dep[v] - d); k > 0; k &= k - 1 {
v = pa[v][bits.TrailingZeros32(k)]
}
return v
}
getLCA := func(v, w int)int {
if dep[v] > dep[w] {
v, w = w, v
}
w = uptoDep(w, dep[v])
if w == v {
return v
}
for i := mx - 1; i >= 0; i-- {
pv, pw := pa[v][i], pa[w][i]
if pv != pw {
v, w = pv, pw
}
}
return pa[v][0]
}
nodesMap := map[int][]int{}
for i, x := range group {
nodesMap[x] = append(nodesMap[x], i)
}
vt := make([][]int, n) // 虚树
isNode := make([]int, n) // 用来区分是关键节点还是 LCA
for i := range isNode {
isNode[i] = -1
}
addVtEdge := func(v, w int) {
vt[v] = append(vt[v], w) // 往虚树上添加一条有向边
}
const root = 0
st := []int{root} // 用根节点作为栈底哨兵
for val, nodes := range nodesMap {
// 对于相同点权的这一组关键节点 nodes,构建虚树
slices.SortFunc(nodes, func(a, b int)int { return dfn[a] - dfn[b] })
vt[root] = vt[root][:0] // 重置虚树
st = st[:1]
for _, v := range nodes {
isNode[v] = val
if v == root {
continue
}
vt[v] = vt[v][:0]
lca := getLCA(st[len(st)-1], v) // 路径的拐点(LCA)也加到虚树中
// 回溯,加边
forlen(st) > 1 && dfn[lca] <= dfn[st[len(st)-2]] {
addVtEdge(st[len(st)-2], st[len(st)-1])
st = st[:len(st)-1]
}
if lca != st[len(st)-1] { // lca 不在栈中(首次遇到)
vt[lca] = vt[lca][:0]
addVtEdge(lca, st[len(st)-1])
st[len(st)-1] = lca // 加到栈中
}
st = append(st, v)
}
// 最后的回溯,加边
for i := 1; i < len(st); i++ {
addVtEdge(st[i-1], st[i])
}
var dfs func(int)int
dfs = func(v int) (size int) {
// 如果 isNode[v] != t,那么 v 只是关键节点之间路径上的「拐点」
if isNode[v] == val {
size = 1
}
for _, w := range vt[v] {
sz := dfs(w)
wt := dep[w] - dep[v] // 虚树边权
// 贡献法
ans += int64(wt) * int64(sz) * int64(len(nodes)-sz)
size += sz
}
return
}
rt := root
if isNode[rt] != val && len(vt[rt]) == 1 {
// 注意 root 只是一个哨兵,不一定在虚树上,得从真正的根节点开始
rt = vt[rt][0]
}
dfs(rt)
}
return
}func main() {
n := 3
edges := [][]int{{0, 1}, {1, 2}}
group := []int{3, 2, 3}
result := interactionCosts(n, edges, group)
fmt.Println(result)
}
Python完整代码如下:
# -*-coding:utf-8-*-
import sys
sys.setrecursionlimit(10**6)
def interactionCosts(n, edges, group):
ans = 0
g = [[] for _ in range(n)]
for v, w in edges:
g[v].append(w)
g[w].append(v)
dfn = [0] * n
ts = 0
pa = [[-1] * 17for _ in range(n)]
dep = [0] * n
def build(v, p):
nonlocal ts
dfn[v] = ts
ts += 1
pa[v][0] = p
for w in g[v]:
if w != p:
dep[w] = dep[v] + 1
build(w, v)
build(0, -1)
mx = n.bit_length()
for i in range(mx - 1):
for v in range(n):
p = pa[v][i]
if p != -1:
pa[v][i+1] = pa[p][i]
else:
pa[v][i+1] = -1
def uptoDep(v, d):
k = dep[v] - d
while k:
step = (k & -k).bit_length() - 1
v = pa[v][step]
k -= (1 << step)
return v
def getLCA(v, w):
if dep[v] > dep[w]:
v, w = w, v
w = uptoDep(w, dep[v])
if w == v:
return v
for i in range(mx-1, -1, -1):
pv, pw = pa[v][i], pa[w][i]
if pv != pw:
v, w = pv, pw
return pa[v][0]
nodesMap = {}
for i, x in enumerate(group):
nodesMap.setdefault(x, []).append(i)
vt = [[] for _ in range(n)]
isNode = [-1] * n
def addVtEdge(v, w):
vt[v].append(w)
root = 0
st = [root]
for val, nodes in nodesMap.items():
nodes.sort(key=lambda x: dfn[x])
vt[root] = []
st = [root]
for v in nodes:
isNode[v] = val
if v == root:
continue
vt[v] = []
lca = getLCA(st[-1], v)
while len(st) > 1 and dfn[lca] <= dfn[st[-2]]:
addVtEdge(st[-2], st[-1])
st.pop()
if lca != st[-1]:
vt[lca] = []
addVtEdge(lca, st[-1])
st[-1] = lca
st.append(v)
for i in range(1, len(st)):
addVtEdge(st[i-1], st[i])
sys.setrecursionlimit(10**6)
def dfs(v):
nonlocal ans
size = 1if isNode[v] == val else0
for w in vt[v]:
sz = dfs(w)
wt = dep[w] - dep[v]
ans += wt * sz * (len(nodes) - sz)
size += sz
return size
rt = root
if isNode[rt] != val and len(vt[rt]) == 1:
rt = vt[rt][0]
dfs(rt)
return ansif __name__ == "__main__":
n = 3
edges = [[0, 1], [1, 2]]
group = [3, 2, 3]
result = interactionCosts(n, edges, group)
print(result)
C++完整代码如下:
using namespace std;class Solution {
public:
long long interactionCosts(int n, vector int >>& edges, vector< int >& group) {
long long ans = 0 ;
// 构建邻接表
vector int >> g(n);
for (auto& e : edges) {
int v = e[ 0 ], w = e[ 1 ];
g[v].push_back(w);
g[w].push_back(v);
}
// 预处理 DFS 序、深度和倍增祖先
vector< int > dfn(n, 0 );
int ts = 0 ;
vector int , 17 >> pa(n);
vector< int > dep(n, 0 );
function int , int )> build = [&]( int v, int p) {
dfn[v] = ts++;
pa[v][ 0 ] = p;
for ( int w : g[v]) {
if (w != p) {
dep[w] = dep[v] + 1 ;
build(w, v);
}
}
};
build( 0 , -1 );
int mx = 32 - __builtin_clz(n); // bits.Len(uint(n))
for ( int i = 0 ; i < mx - 1 ; i++) {
for ( int v = 0 ; v < n; v++) {
int p = pa[v][i];
if (p != -1 ) {
pa[v][i + 1 ] = pa[p][i];
} else {
pa[v][i + 1 ] = -1 ;
}
}
}
// 跳到指定深度
auto uptoDep = [&]( int v, int d) -> int {
int k = dep[v] - d;
while (k > 0 ) {
int step = __builtin_ctz(k);
v = pa[v][step];
k &= k - 1 ;
}
return v;
};
// 获取 LCA
auto getLCA = [&]( int v, int w) -> int {
if (dep[v] > dep[w]) {
swap(v, w);
}
w = uptoDep(w, dep[v]);
if (w == v) return v;
for ( int i = mx - 1 ; i >= 0 ; i--) {
int pv = pa[v][i], pw = pa[w][i];
if (pv != pw) {
v = pv;
w = pw;
}
}
return pa[v][ 0 ];
};
// 按点权分组节点
map < int , vector< int >> nodesMap;
for ( int i = 0 ; i < n; i++) {
nodesMap[group[i]].push_back(i);
}
// 虚树
vector int >> vt(n);
vector< int > isNode(n, -1 );
auto addVtEdge = [&]( int v, int w) {
vt[v].push_back(w);
};
const int root = 0 ;
vector< int > st;
// 处理每个点权组
for (auto& [val, nodes] : nodesMap) {
// 按 DFS 序排序
sort(nodes.begin(), nodes.end(), [&]( int a, int b) {
return dfn[a] < dfn[b];
});
// 清空虚树
for ( int v : nodes) {
vt[v].clear();
}
vt[root].clear();
st.clear();
st.push_back(root);
// 构建虚树
for ( int v : nodes) {
isNode[v] = val;
if (v == root) continue ;
vt[v].clear();
int lca = getLCA(st.back(), v);
// 回溯并加边
while (st.size() > 1 && dfn[lca] <= dfn[st[st.size() - 2 ]]) {
addVtEdge(st[st.size() - 2 ], st.back());
st.pop_back();
}
if (lca != st.back()) {
vt[lca].clear();
addVtEdge(lca, st.back());
st.back() = lca;
}
st.push_back(v);
}
// 添加剩余边
for ( int i = 1 ; i < st.size(); i++) {
addVtEdge(st[i - 1 ], st[i]);
}
// DFS 遍历虚树计算贡献
function< int ( int )> dfs = [&]( int v) -> int {
int size = (isNode[v] == val) ? 1 : 0 ;
for ( int w : vt[v]) {
int sz = dfs(w);
int wt = dep[w] - dep[v];
ans += 1 LL * wt * sz * (nodes.size() - sz);
size += sz;
}
return size;
};
// 找到真正的根节点
int rt = root;
if (isNode[rt] != val && vt[rt].size() == 1 ) {
rt = vt[rt][ 0 ];
}
dfs(rt);
}
return ans;
}
};
int main() {
int n = 3 ;
vector int >> edges = {{ 0 , 1 }, { 1 , 2 }};
vector< int > group = { 3 , 2 , 3 };
Solution solution;
long long result = solution.interactionCosts(n, edges, group);
cout << result << endl;
return 0 ;
}
我们相信人工智能为普通人提供了一种“增强工具”,并致力于分享全方位的AI知识。在这里,您可以找到最新的AI科普文章、工具评测、提升效率的秘籍以及行业洞察。 欢迎关注“福大大架构师每日一题”,发消息可获得面试资料,让AI助力您的未来发展。
特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.