Skip to content
Snippets Groups Projects
Select Git revision
  • cdfa9ef265b54f5fd45da60ea854897c6cf5ec9a
  • main default protected
2 results

demo_final.md

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    steiner_tree.py 3.60 KiB
    import graph_tool as gt
    import graph_tool.topology as gtt
    import itertools as it
    
    
    def steiner_tree(g, seeds, seed_map, weights, non_zero_hub_penalty):
    
        node_name_attribute = "drugstone_id" # nodes in the input network which is created from RepoTrialDB have primaryDomainId as name attribute
        mc = gt.Graph(directed=False)
        eprop_dist = mc.new_edge_property("int")
        mc.ep['dist'] = eprop_dist
        vprop_name = mc.new_vertex_property("string")
        mc.vp[node_name_attribute] = vprop_name
    
        eprop_path = mc.new_edge_property("object")
        mc.ep['path'] = eprop_path
    
        mc_vertex_map = dict()
        mc_id_map = dict()
        for i in range(len(seeds)):
            vert = mc.add_vertex()
            vprop_name[i] = seeds[i]
            mc_vertex_map[seeds[i]] = vert
            mc_id_map[vert] = i
    
        for u, v in it.combinations(seeds, 2):
            _, elist = gtt.shortest_path(g, g.vertex(seed_map[u]), g.vertex(seed_map[v]), weights=weights, negative_weights=False, pred_map=None, dag=False)
            e = mc.add_edge(mc_vertex_map[u], mc_vertex_map[v])
            eprop_dist[e] = len(elist)
            mc.ep.path[e] = list(elist)
    
        mst = gtt.min_spanning_tree(mc, weights=eprop_dist, root=None, tree_map=None)
        mc.set_edge_filter(mst)
    
        g2 = gt.Graph(directed=False)
        vprop_name = g2.new_vertex_property("string")
        g2.vp[node_name_attribute] = vprop_name
    
        g2_vertex_map = dict()
        g2_id_map = dict()
        addedNodes = set()
        for i in range(len(seeds)):
            vert = g2.add_vertex()
            vprop_name[i] = seeds[i]
            g2_vertex_map[seeds[i]] = vert
            g2_id_map[vert] = i
            addedNodes.add(seeds[i])
    
        allmcedges = []
    
        for mc_edges in mc.edges():
            path = mc.ep.path[mc_edges]
            allmcedges.extend(path)
    
        j = len(seeds)
        allmcedges_g2 = []
        for e in allmcedges:
            # sourceName = g.vertex_properties["name"][e.source()]
            # targetName = g.vertex_properties["name"][e.target()]
            sourceName = g.vertex_properties[node_name_attribute][e.source()]
            targetName = g.vertex_properties[node_name_attribute][e.target()]
            if sourceName not in addedNodes:
                vert = g2.add_vertex()
                vprop_name[j] = sourceName
                g2_vertex_map[sourceName] = vert
                g2_id_map[vert] = j
                addedNodes.add(sourceName)
                j += 1
            if targetName not in addedNodes:
                vert = g2.add_vertex()
                vprop_name[j] = targetName
                g2_vertex_map[targetName] = vert
                g2_id_map[vert] = j
                addedNodes.add(targetName)
                j += 1
            allmcedges_g2.append(g2.add_edge(g2_vertex_map[sourceName], g2_vertex_map[targetName]))
        weights_g2 = g2.new_edge_property("double", val=1.0)
        if non_zero_hub_penalty:
            for e, e_g2 in zip(allmcedges, allmcedges_g2):
                weights_g2[e_g2] = weights[e]
        mst2 = gtt.min_spanning_tree(g2, root=None, tree_map=None, weights=weights_g2)
        g2.set_edge_filter(mst2)
        # vw = gt.GraphView(g2, efilt=mst2)
        # g3 = Graph(vw, prune=True)
    
        # g3 = Graph(g22)
    
        while True:
            noneSteinerLeaves = []
            for i in range(g2.num_vertices()):
                if g2.vertex(i).out_degree() == 1 and g2.vertex_properties[node_name_attribute][i] not in seeds:
                    noneSteinerLeaves.append(i)
            if len(noneSteinerLeaves) == 0:
                break
            noneSteinerLeaves = reversed(sorted(noneSteinerLeaves))
            for node in noneSteinerLeaves:
                # outarray = g3.get_out_edges(node)
                g2.remove_edge(g2.edge(g2.vertex(node), g2.get_all_neighbors(node)[0]))
                g2.remove_vertex(node)
                
        return g2