-
AndiMajore authoredAndiMajore authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
steiner_tree.py 3.60 KiB
import graph_tool as gt
import graph_tool.topology as gtt
import itertools as it
def steiner_tree(g, seeds, seed_map, weights, non_zero_hub_penalty):
node_name_attribute = "internal_id" # nodes in the input network which is created from RepoTrialDB have primaryDomainId as name attribute
mc = gt.Graph(directed=False)
eprop_dist = mc.new_edge_property("int")
mc.ep['dist'] = eprop_dist
vprop_name = mc.new_vertex_property("string")
mc.vp[node_name_attribute] = vprop_name
eprop_path = mc.new_edge_property("object")
mc.ep['path'] = eprop_path
mc_vertex_map = dict()
mc_id_map = dict()
for i in range(len(seeds)):
vert = mc.add_vertex()
vprop_name[i] = seeds[i]
mc_vertex_map[seeds[i]] = vert
mc_id_map[vert] = i
for u, v in it.combinations(seeds, 2):
_, elist = gtt.shortest_path(g, g.vertex(seed_map[u]), g.vertex(seed_map[v]), weights=weights, negative_weights=False, pred_map=None, dag=False)
e = mc.add_edge(mc_vertex_map[u], mc_vertex_map[v])
eprop_dist[e] = len(elist)
mc.ep.path[e] = list(elist)
mst = gtt.min_spanning_tree(mc, weights=eprop_dist, root=None, tree_map=None)
mc.set_edge_filter(mst)
g2 = gt.Graph(directed=False)
vprop_name = g2.new_vertex_property("string")
g2.vp[node_name_attribute] = vprop_name
g2_vertex_map = dict()
g2_id_map = dict()
addedNodes = set()
for i in range(len(seeds)):
vert = g2.add_vertex()
vprop_name[i] = seeds[i]
g2_vertex_map[seeds[i]] = vert
g2_id_map[vert] = i
addedNodes.add(seeds[i])
allmcedges = []
for mc_edges in mc.edges():
path = mc.ep.path[mc_edges]
allmcedges.extend(path)
j = len(seeds)
allmcedges_g2 = []
for e in allmcedges:
# sourceName = g.vertex_properties["name"][e.source()]
# targetName = g.vertex_properties["name"][e.target()]
sourceName = g.vertex_properties[node_name_attribute][e.source()]
targetName = g.vertex_properties[node_name_attribute][e.target()]
if sourceName not in addedNodes:
vert = g2.add_vertex()
vprop_name[j] = sourceName
g2_vertex_map[sourceName] = vert
g2_id_map[vert] = j
addedNodes.add(sourceName)
j += 1
if targetName not in addedNodes:
vert = g2.add_vertex()
vprop_name[j] = targetName
g2_vertex_map[targetName] = vert
g2_id_map[vert] = j
addedNodes.add(targetName)
j += 1
allmcedges_g2.append(g2.add_edge(g2_vertex_map[sourceName], g2_vertex_map[targetName]))
weights_g2 = g2.new_edge_property("double", val=1.0)
if non_zero_hub_penalty:
for e, e_g2 in zip(allmcedges, allmcedges_g2):
weights_g2[e_g2] = weights[e]
mst2 = gtt.min_spanning_tree(g2, root=None, tree_map=None, weights=weights_g2)
g2.set_edge_filter(mst2)
# vw = gt.GraphView(g2, efilt=mst2)
# g3 = Graph(vw, prune=True)
# g3 = Graph(g22)
while True:
noneSteinerLeaves = []
for i in range(g2.num_vertices()):
if g2.vertex(i).out_degree() == 1 and g2.vertex_properties[node_name_attribute][i] not in seeds:
noneSteinerLeaves.append(i)
if len(noneSteinerLeaves) == 0:
break
noneSteinerLeaves = reversed(sorted(noneSteinerLeaves))
for node in noneSteinerLeaves:
# outarray = g3.get_out_edges(node)
g2.remove_edge(g2.edge(g2.vertex(node), g2.get_all_neighbors(node)[0]))
g2.remove_vertex(node)
return g2