Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
steiner_tree.py 3.60 KiB
import graph_tool as gt
import graph_tool.topology as gtt
import itertools as it


def steiner_tree(g, seeds, seed_map, weights, non_zero_hub_penalty):

    node_name_attribute = "internal_id" # nodes in the input network which is created from RepoTrialDB have primaryDomainId as name attribute
    mc = gt.Graph(directed=False)
    eprop_dist = mc.new_edge_property("int")
    mc.ep['dist'] = eprop_dist
    vprop_name = mc.new_vertex_property("string")
    mc.vp[node_name_attribute] = vprop_name

    eprop_path = mc.new_edge_property("object")
    mc.ep['path'] = eprop_path

    mc_vertex_map = dict()
    mc_id_map = dict()
    for i in range(len(seeds)):
        vert = mc.add_vertex()
        vprop_name[i] = seeds[i]
        mc_vertex_map[seeds[i]] = vert
        mc_id_map[vert] = i

    for u, v in it.combinations(seeds, 2):
        _, elist = gtt.shortest_path(g, g.vertex(seed_map[u]), g.vertex(seed_map[v]), weights=weights, negative_weights=False, pred_map=None, dag=False)
        e = mc.add_edge(mc_vertex_map[u], mc_vertex_map[v])
        eprop_dist[e] = len(elist)
        mc.ep.path[e] = list(elist)

    mst = gtt.min_spanning_tree(mc, weights=eprop_dist, root=None, tree_map=None)
    mc.set_edge_filter(mst)

    g2 = gt.Graph(directed=False)
    vprop_name = g2.new_vertex_property("string")
    g2.vp[node_name_attribute] = vprop_name

    g2_vertex_map = dict()
    g2_id_map = dict()
    addedNodes = set()
    for i in range(len(seeds)):
        vert = g2.add_vertex()
        vprop_name[i] = seeds[i]
        g2_vertex_map[seeds[i]] = vert
        g2_id_map[vert] = i
        addedNodes.add(seeds[i])

    allmcedges = []

    for mc_edges in mc.edges():
        path = mc.ep.path[mc_edges]
        allmcedges.extend(path)

    j = len(seeds)
    allmcedges_g2 = []
    for e in allmcedges:
        # sourceName = g.vertex_properties["name"][e.source()]
        # targetName = g.vertex_properties["name"][e.target()]
        sourceName = g.vertex_properties[node_name_attribute][e.source()]
        targetName = g.vertex_properties[node_name_attribute][e.target()]
        if sourceName not in addedNodes:
            vert = g2.add_vertex()
            vprop_name[j] = sourceName
            g2_vertex_map[sourceName] = vert
            g2_id_map[vert] = j
            addedNodes.add(sourceName)
            j += 1
        if targetName not in addedNodes:
            vert = g2.add_vertex()
            vprop_name[j] = targetName
            g2_vertex_map[targetName] = vert
            g2_id_map[vert] = j
            addedNodes.add(targetName)
            j += 1
        allmcedges_g2.append(g2.add_edge(g2_vertex_map[sourceName], g2_vertex_map[targetName]))
    weights_g2 = g2.new_edge_property("double", val=1.0)
    if non_zero_hub_penalty:
        for e, e_g2 in zip(allmcedges, allmcedges_g2):
            weights_g2[e_g2] = weights[e]
    mst2 = gtt.min_spanning_tree(g2, root=None, tree_map=None, weights=weights_g2)
    g2.set_edge_filter(mst2)
    # vw = gt.GraphView(g2, efilt=mst2)
    # g3 = Graph(vw, prune=True)

    # g3 = Graph(g22)

    while True:
        noneSteinerLeaves = []
        for i in range(g2.num_vertices()):
            if g2.vertex(i).out_degree() == 1 and g2.vertex_properties[node_name_attribute][i] not in seeds:
                noneSteinerLeaves.append(i)
        if len(noneSteinerLeaves) == 0:
            break
        noneSteinerLeaves = reversed(sorted(noneSteinerLeaves))
        for node in noneSteinerLeaves:
            # outarray = g3.get_out_edges(node)
            g2.remove_edge(g2.edge(g2.vertex(node), g2.get_all_neighbors(node)[0]))
            g2.remove_vertex(node)
            
    return g2