Skip to content
Snippets Groups Projects
multi_steiner.py 11.7 KiB
Newer Older
from tasks.task_hook import TaskHook
from tasks.util.steiner_tree import steiner_tree
from tasks.util.find_bridges import find_bridges
from tasks.util.read_graph_tool_graph import read_graph_tool_graph
from tasks.util.edge_weights import edge_weights
import os.path
import graph_tool as gt
import graph_tool.util as gtu
import sys


def multi_steiner(task_hook: TaskHook):

    # Type: List of str
    # Semantics: Names of the seed proteins. Use UNIPROT AC for host proteins, and
    #            names of the format SARS_CoV2_<IDENTIFIER> (tree_edge.g., SARS_CoV2_ORF6) for
    #            virus proteins.
    # Reasonable default: None, has to be selected by user via "select for analysis"
    #            utility in frontend.
    # Acceptable values: UNIPROT ACs, identifiers of viral proteins.
    seeds = task_hook.parameters["seeds"]
    # Since we get the LCC of the graphs from read_graph_tool_graph, no need for the following two lines
    # nodes_not_in_lcc = {'D3W0D1', 'O15178', 'O60542', 'O60609', 'O60882', 'Q5T4W7', 'Q6UVW9', 'Q6UW32', 'Q6UWQ7', 'Q6UXB1', 'Q7RTX0', 'Q7RTX1', 'Q8TE23', 'Q9GZZ7', 'Q9H2W2', 'Q9H665', 'Q9NZW4'}
    # seeds = [node for node in set(seeds).difference(nodes_not_in_lcc)]
    seeds.sort()

    # Type: str.
    # Semantics: The virus strain for which the analysis should be run.
    # Example: "SARS_CoV2"
    # Reasonable default: None, has to be specified by the caller.
    ### Acceptable values: "SARS_CoV2", ...
    # strain_or_drugs = task_hook.parameters.get("strain_or_drugs", "SARS_CoV2")
    # Acceptable values: "PPI", "PPDr"
    # target_or_drugs = task_hook.parameters.get("target_or_drugs", "PPI")

    # Type: list of str.
    # Semantics: The datasets which should be considered for the analysis.
    # Example: ["Krogan", "TUM"].
    # Note: If empty, all available datasets are used.
    # Reasonable default: [].
    # Acceptable values: "Krogan", "TUM".
    # datasets = task_hook.parameters.get("datasets", [])

    # Type: list of str.
    # Semantics: Virus-host g_edge types which should be ignored for the analysis.
    # Example: ["Overexpression"].
    # Note: If empty, all available g_edge types are used.
    # Reasonable default: [].
    # Acceptable values: "AP-MS", "overexpression".
    # ignored_edge_types = task_hook.parameters.get("ignored_edge_types", [])
    
    # Type bool.
    # Semantics: Ignore viral proteins which are not selected as seeds.
    # Example: False.
    # Reasonable default: False.
    # Has no effect when the algorithm is used for ranking drugs.
    # ignore_non_seed_baits = task_hook.parameters.get("ignore_non_seed_baits", False)

    # Type: int.
    # Semantics: Number of steiner trees to return.
    # Example: 10.
    # Reasonable default: 5.
    # Acceptable values: integers n with 0 < n < 25.
    num_trees = task_hook.parameters.get("num_trees", 5)
    
    # Type: float.
    # Semantics: The error tolerance of the subsequent Steiner trees w.r.t. the first one in percent.
    # Example: 10
    # Reasonable default: 10.
    # Acceptable values: floats x >= 0. 
    tolerance = task_hook.parameters.get("tolerance", 10)
    
    # Type: int.
    # Semantics: All nodes with degree > max_deg * g.num_vertices() are ignored.
    # Example: 39.
    # Reasonable default: sys.maxsize.
    # Acceptable values: Positive integers.
    max_deg = task_hook.parameters.get("max_deg", sys.maxsize)
    
    # Type: float.
    # Semantics: Penalty parameter for hubs. Set edge weight to 1 + (hub_penalty / 2) (e.source.degree + e.target.degree)
    # Example: 0.5.
    # Reasonable default: 0.
    # Acceptable values: Floats between 0 and 1.
    hub_penalty = task_hook.parameters.get("hub_penalty", 0.0)
    
    # Type: int.
    # Semantics: Number of threads used for running the analysis.
    # Example: 1.
    # Reasonable default: 1.
    # Note: We probably do not want to expose this parameter to the user.
    num_threads = task_hook.parameters.get("num_threads", 1)

    ppi_dataset = task_hook.parameters.get("ppi_dataset")

    pdi_dataset = task_hook.parameters.get("pdi_dataset")

    search_target = task_hook.parameters.get("target", "drug-target")

    node_name_attribute = "drugstone_id" # nodes in the input network which is created from RepoTrialDB have primaryDomainId as name attribute

    # Set number of threads if OpenMP support is enabled.
    if gt.openmp_enabled():
        gt.openmp_set_num_threads(num_threads)
    
    # Parsing input file.
    task_hook.set_progress(0 / (float(num_trees + 3)), "Parsing input.")
    file_path = os.path.join(task_hook.data_directory, f"internal_{ppi_dataset['name']}_{pdi_dataset['name']}.gt")
    # g, seed_ids, _, _ = read_graph_tool_graph(file_path, seeds, datasets, ignored_edge_types, max_deg, ignore_non_seed_baits)
    g, seed_ids, _ = read_graph_tool_graph(file_path, seeds, max_deg, target=search_target)
    # seed_map = {g.vertex_properties["name"][node]: node for node in seed_ids}
    seed_map = {g.vertex_properties[node_name_attribute][node]: node for node in seed_ids}
    task_hook.set_progress(1 / (float(num_trees + 3)), "Computing edge weights.")
    weights = edge_weights(g, hub_penalty)
    
    # Find first steiner trees
    seeds = list(filter(lambda s: s in seed_map, seeds))
    print(seeds)
    print(seed_ids)
    task_hook.set_progress(2 / (float(num_trees + 3)), "Computing Steiner tree 1 of {}.".format(num_trees))
    first_tree = steiner_tree(g, seeds, seed_map, weights, hub_penalty > 0)
    num_found_trees = 1
    tree_edges = []
    for tree_edge in first_tree.edges():
        # source_name = first_tree.vertex_properties["name"][first_tree.vertex_index[tree_edge.source()]]
        # target_name = first_tree.vertex_properties["name"][first_tree.vertex_index[tree_edge.target()]]
        # tree_edges.append((gtu.find_vertex(g, prop=g.vertex_properties['name'], match=source_name)[0],gtu.find_vertex(g, prop=g.vertex_properties['name'], match=target_name)[0]))
        source_name = first_tree.vertex_properties[node_name_attribute][first_tree.vertex_index[tree_edge.source()]]
        target_name = first_tree.vertex_properties[node_name_attribute][first_tree.vertex_index[tree_edge.target()]]
        tree_edges.append((gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=source_name)[0], gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=target_name)[0]))
    cost_first_tree = sum([weights[g.edge(source, target)] for source, target in tree_edges])
    # returned_nodes = set(int(gtu.find_vertex(g, prop=g.vertex_properties['name'], match=first_tree.vertex_properties["name"][node])[0]) for node in range(first_tree.num_vertices()))
AndiMajore's avatar
AndiMajore committed
    print(f"Before gtu: Costs={cost_first_tree}")
    returned_nodes = set(int(gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=first_tree.vertex_properties[node_name_attribute][node])[0]) for node in range(first_tree.num_vertices()))
AndiMajore's avatar
AndiMajore committed
    print(f"After gtu: {returned_nodes}")
    print(num_trees)
    if num_trees > 1:
AndiMajore's avatar
AndiMajore committed
        print("num_trees > 1")
        is_bridge = find_bridges(g)
AndiMajore's avatar
AndiMajore committed
        print("found bridges")
        edge_filter = g.new_edge_property("boolean", True)
AndiMajore's avatar
AndiMajore committed
        print("filtered edges")
        found_new_tree = True
        while len(tree_edges) > 0:
AndiMajore's avatar
AndiMajore committed
            print(f"Tree edges length: {len(tree_edges)}")
            if found_new_tree:
                task_hook.set_progress(float(num_found_trees + 2) / (float(num_trees + 3)), "Computing Steiner tree {} of {}.".format(num_found_trees + 1, num_trees))
            found_new_tree = False
            tree_edge = tree_edges.pop()
AndiMajore's avatar
AndiMajore committed
            print("1")
            g_edge = g.edge(tree_edge[0], tree_edge[1])
            if not is_bridge[g_edge]:
AndiMajore's avatar
AndiMajore committed
                print("2")
                edge_filter[g_edge] = False
                g.set_edge_filter(edge_filter)
                next_tree = steiner_tree(g, seeds, seed_map, weights, hub_penalty > 0)
AndiMajore's avatar
AndiMajore committed
                print("3")
                next_tree_edges = set()
                for next_tree_edge in next_tree.edges():
                    # source_name = next_tree.vertex_properties["name"][next_tree.vertex_index[next_tree_edge.source()]]
                    # target_name = next_tree.vertex_properties["name"][next_tree.vertex_index[next_tree_edge.target()]]
                    # next_tree_edges.add((gtu.find_vertex(g, prop=g.vertex_properties['name'], match=source_name)[0], gtu.find_vertex(g, prop=g.vertex_properties['name'], match=target_name)[0]))
AndiMajore's avatar
AndiMajore committed
                    print("4")
                    source_name = next_tree.vertex_properties[node_name_attribute][next_tree.vertex_index[next_tree_edge.source()]]
                    target_name = next_tree.vertex_properties[node_name_attribute][next_tree.vertex_index[next_tree_edge.target()]]
                    next_tree_edges.add((gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=source_name)[0],gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=target_name)[0]))
                cost_next_tree = sum([weights[g.edge(source, target)] for source, target in next_tree_edges])
                if cost_next_tree <= cost_first_tree * ((100.0 + tolerance) / 100.0):
AndiMajore's avatar
AndiMajore committed
                    print("5")
                    found_new_tree = True
                    num_found_trees += 1
                    for node in range(next_tree.num_vertices()):
AndiMajore's avatar
AndiMajore committed
                        print("GTU again")
                        # returned_nodes.add(int(gtu.find_vertex(g, prop=g.vertex_properties['name'], match=next_tree.vertex_properties["name"][node])[0]))
                        returned_nodes.add(int(gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute],match=next_tree.vertex_properties[node_name_attribute][node])[0]))
AndiMajore's avatar
AndiMajore committed
                        print("GTU done")
                    removed_edges = []
AndiMajore's avatar
AndiMajore committed
                    print("6")
                    for source, target in tree_edges:
                        if not ((source, target) in set(next_tree_edges)) or ((target, source) in set(next_tree_edges)):
                            removed_edges.append((source, target))
                    for edge in removed_edges:
                        tree_edges.remove(edge)
                g.clear_filters()
                edge_filter[g_edge] = True
            if num_found_trees >= num_trees:
                break
    
    task_hook.set_progress((float(num_trees + 2)) / (float(num_trees + 3)), "Formatting results")
    returned_edges = []
    for node in returned_nodes:
        for neighbor in g.get_all_neighbors(node):
            if int(neighbor) > node and int(neighbor) in returned_nodes:
                returned_edges.append((node, int(neighbor)))
    # subgraph = {"nodes": [g.vertex_properties["name"][node] for node in returned_nodes],
    #             "edges": [{"from": g.vertex_properties["name"][source], "to": g.vertex_properties["name"][target]} for source, target in returned_edges]}
    # node_types = {g.vertex_properties["name"][node]: g.vertex_properties["type"][node] for node in returned_nodes}
    # is_seed = {g.vertex_properties["name"][node]: node in set(seed_ids) for node in returned_nodes}
    subgraph = {"nodes": [g.vertex_properties[node_name_attribute][node] for node in returned_nodes],
                "edges": [{"from": g.vertex_properties[node_name_attribute][source], "to": g.vertex_properties[node_name_attribute][target]} for
                          source, target in returned_edges]}
    node_types = {g.vertex_properties[node_name_attribute][node]: g.vertex_properties["type"][node] for node in returned_nodes}
    is_seed = {g.vertex_properties[node_name_attribute][node]: node in set(seed_ids) for node in returned_nodes}
    task_hook.set_results({
        "network": subgraph,
        "node_attributes": {"node_types": node_types, "is_seed": is_seed},
        'gene_interaction_dataset': ppi_dataset,
        'drug_interaction_dataset': pdi_dataset
    })