From e8d6a3cc6692d8e64ce09ec9f0c492b8e3cf8c2d Mon Sep 17 00:00:00 2001
From: AndiMajore <andi.majore@googlemail.com>
Date: Thu, 9 Feb 2023 16:07:27 +0100
Subject: [PATCH] fixed connector node definition

---
 tasks/util/scores_to_results.py | 35 ++++++++++++++++++++++-----------
 1 file changed, 24 insertions(+), 11 deletions(-)

diff --git a/tasks/util/scores_to_results.py b/tasks/util/scores_to_results.py
index b690db5..cb47496 100755
--- a/tasks/util/scores_to_results.py
+++ b/tasks/util/scores_to_results.py
@@ -12,14 +12,14 @@ def scores_to_results(
         pdi_dataset,
         filterPaths
 ):
-
     r"""Transforms the scores to the required result format."""
 
     node_name_attribute = "internal_id"  # nodes in the input network which is created from RepoTrialDB have primaryDomainId as name attribute
     if target == "drug":
         candidates = [(node, scores[node]) for node in drug_ids if scores[node] > 0]
     else:
-        candidates = [(node, scores[node]) for node in range(g.num_vertices()) if scores[node] > 0 and node not in set(seed_ids)]
+        candidates = [(node, scores[node]) for node in range(g.num_vertices()) if
+                      scores[node] > 0 and node not in set(seed_ids)]
     best_candidates = [item[0] for item in sorted(candidates, key=lambda item: item[1], reverse=True)[:result_size]]
     # Concatenate best result candidates with seeds and compute induced subgraph.
     # since the result size filters out nodes, the result network is not complete anymore.
@@ -29,7 +29,7 @@ def scores_to_results(
     intermediate_nodes = set()
 
     returned_edges = set()
-    returned_nodes = set(seed_ids) # return seed_ids in any case
+    returned_nodes = set(seed_ids)  # return seed_ids in any case
 
     # return only the path to a drug with the shortest distance
     accepted_candidates = set()
@@ -44,11 +44,15 @@ def scores_to_results(
                 vertices, edges = gtt.shortest_path(g, candidate, seed_id)
 
                 drug_in_path = False
+                seed_in_path = False
                 for vertex in vertices:
                     if g.vertex_properties["type"][int(vertex)] == "drug" and vertex != candidate:
                         drug_in_path = True
                         break
-                if drug_in_path:
+                    if int(vertex) in seed_ids and int(vertex) != seed_id:
+                        seed_in_path = True
+                        break
+                if drug_in_path or seed_in_path:
                     continue
                 accepted_candidates.add(g.vertex_properties[node_name_attribute][int(candidate)])
                 for vertex in vertices:
@@ -58,7 +62,8 @@ def scores_to_results(
                             intermediate_nodes.add(g.vertex_properties[node_name_attribute][int(vertex)])
                         returned_nodes.add(int(vertex))
                 for edge in edges:
-                    if ((edge.source(), edge.target()) not in returned_edges) or ((edge.target(), edge.source()) not in returned_edges):
+                    if (((edge.source(), edge.target()) not in returned_edges) or (
+                            (edge.target(), edge.source()) not in returned_edges)) and int(edge.target()) in returned_nodes and int(edge.source()) in returned_nodes:
                         returned_edges.add((edge.source(), edge.target()))
     else:
         for candidate in best_candidates:
@@ -66,11 +71,15 @@ def scores_to_results(
                 vertices, edges = gtt.shortest_path(g, candidate, seed_id)
 
                 drug_in_path = False
+                seed_in_path = False
                 for vertex in vertices:
                     if g.vertex_properties["type"][int(vertex)] == "drug" and vertex != candidate:
                         drug_in_path = True
                         break
-                if drug_in_path:
+                    if int(vertex) in seed_ids and int(vertex) != seed_id:
+                        seed_in_path = True
+                        break
+                if drug_in_path or seed_in_path:
                     continue
                 accepted_candidates.add(g.vertex_properties[node_name_attribute][int(candidate)])
                 for vertex in vertices:
@@ -80,18 +89,22 @@ def scores_to_results(
                             intermediate_nodes.add(g.vertex_properties[node_name_attribute][int(vertex)])
                         returned_nodes.add(int(vertex))
                 for edge in edges:
-                    if ((edge.source(), edge.target()) not in returned_edges) or ((edge.target(), edge.source()) not in returned_edges):
+                    if (((edge.source(), edge.target()) not in returned_edges) or (
+                            (edge.target(), edge.source()) not in returned_edges)) and int(
+                        edge.target()) in returned_nodes and int(edge.source()) in returned_nodes:
                         returned_edges.add((edge.source(), edge.target()))
     for node in accepted_candidates:
         if node in intermediate_nodes:
             intermediate_nodes.remove(node)
     subgraph = {
-        "nodes":[g.vertex_properties[node_name_attribute][node] for node in returned_nodes],
-        "edges": [{"from": g.vertex_properties[node_name_attribute][source], "to": g.vertex_properties[node_name_attribute][target]} for source, target in returned_edges],
-        }
+        "nodes": [g.vertex_properties[node_name_attribute][node] for node in returned_nodes],
+        "edges": [{"from": g.vertex_properties[node_name_attribute][source],
+                   "to": g.vertex_properties[node_name_attribute][target]} for source, target in returned_edges],
+    }
 
     # Compute node attributes.
-    node_types = {g.vertex_properties[node_name_attribute][node]: g.vertex_properties["type"][node] for node in returned_nodes}
+    node_types = {g.vertex_properties[node_name_attribute][node]: g.vertex_properties["type"][node] for node in
+                  returned_nodes}
     is_seed = {g.vertex_properties[node_name_attribute][node]: node in set(seed_ids) for node in returned_nodes}
     returned_scores = {g.vertex_properties[node_name_attribute][node]: scores[node] for node in returned_nodes}
 
-- 
GitLab