import sys
from collections import defaultdict

import graph_tool as gt
import graph_tool.centrality as gtc

import graph_tool.stats as gts
import graph_tool.topology as gtt
import graph_tool.util as gtu
import itertools as it


def edge_weights(g, base_weigths, hub_penalty = 0, inverse=False):
    avdeg = gts.vertex_average(g, "total")[0]
    weights = g.new_edge_property("double", val=avdeg)
    if base_weigths is not None:
        for v in g.vertices():
            weights.a[int(v)] = base_weigths.a[int(v)]
    if hub_penalty <= 0:
        return weights
    if hub_penalty > 1:
        raise ValueError("Invalid hub penalty {}.".format(hub_penalty))
    for e in g.edges():
        edge_avdeg = float(e.source().out_degree() + e.target().out_degree()) / 2.0
        penalized_weight = (1.0 - hub_penalty) * avdeg + hub_penalty * edge_avdeg
        if inverse:
            weights[e] = 1.0 / penalized_weight
        else:
            weights[e] = penalized_weight
    return weights


def steiner_tree(g, seeds, seed_map, weights, non_zero_hub_penalty):
    node_name_attribute = "name"  # nodes in the input network which is created from RepoTrialDB have primaryDomainId as name attribute
    mc = gt.Graph(directed=False)
    eprop_dist = mc.new_edge_property("int")
    mc.ep['dist'] = eprop_dist
    vprop_name = mc.new_vertex_property("string")
    mc.vp[node_name_attribute] = vprop_name

    eprop_path = mc.new_edge_property("object")
    mc.ep['path'] = eprop_path

    mc_vertex_map = dict()
    mc_id_map = dict()
    for i in range(len(seeds)):
        vert = mc.add_vertex()
        vprop_name[i] = seeds[i]
        mc_vertex_map[seeds[i]] = vert
        mc_id_map[vert] = i

    for u, v in it.combinations(seeds, 2):
        _, elist = gtt.shortest_path(g, g.vertex(seed_map[u]), g.vertex(seed_map[v]), weights=weights,
                                     negative_weights=False, pred_map=None, dag=False)
        e = mc.add_edge(mc_vertex_map[u], mc_vertex_map[v])
        eprop_dist[e] = len(elist)
        mc.ep.path[e] = list(elist)

    mst = gtt.min_spanning_tree(mc, weights=eprop_dist, root=None, tree_map=None)
    mc.set_edge_filter(mst)

    g2 = gt.Graph(directed=False)
    vprop_name = g2.new_vertex_property("string")
    g2.vp[node_name_attribute] = vprop_name

    g2_vertex_map = dict()
    g2_id_map = dict()
    addedNodes = set()
    for i in range(len(seeds)):
        vert = g2.add_vertex()
        vprop_name[i] = seeds[i]
        g2_vertex_map[seeds[i]] = vert
        g2_id_map[vert] = i
        addedNodes.add(seeds[i])

    allmcedges = []

    for mc_edges in mc.edges():
        path = mc.ep.path[mc_edges]
        allmcedges.extend(path)

    j = len(seeds)
    allmcedges_g2 = []
    for e in allmcedges:
        # sourceName = g.vertex_properties["name"][e.source()]
        # targetName = g.vertex_properties["name"][e.target()]
        sourceName = g.vertex_properties[node_name_attribute][e.source()]
        targetName = g.vertex_properties[node_name_attribute][e.target()]
        if sourceName not in addedNodes:
            vert = g2.add_vertex()
            vprop_name[j] = sourceName
            g2_vertex_map[sourceName] = vert
            g2_id_map[vert] = j
            addedNodes.add(sourceName)
            j += 1
        if targetName not in addedNodes:
            vert = g2.add_vertex()
            vprop_name[j] = targetName
            g2_vertex_map[targetName] = vert
            g2_id_map[vert] = j
            addedNodes.add(targetName)
            j += 1
        allmcedges_g2.append(g2.add_edge(g2_vertex_map[sourceName], g2_vertex_map[targetName]))
    weights_g2 = g2.new_edge_property("double", val=1.0)
    if non_zero_hub_penalty:
        for e, e_g2 in zip(allmcedges, allmcedges_g2):
            weights_g2[e_g2] = weights[e]
    mst2 = gtt.min_spanning_tree(g2, root=None, tree_map=None, weights=weights_g2)
    g2.set_edge_filter(mst2)

    while True:
        noneSteinerLeaves = []
        for i in range(g2.num_vertices()):
            if g2.vertex(i).out_degree() == 1 and g2.vertex_properties[node_name_attribute][i] not in seeds:
                noneSteinerLeaves.append(i)
        if len(noneSteinerLeaves) == 0:
            break
        noneSteinerLeaves = reversed(sorted(noneSteinerLeaves))
        for node in noneSteinerLeaves:
            try:
                g2.remove_edge(g2.edge(g2.vertex(node), g2.get_all_neighbors(node)[0]))
            except:
                pass
            g2.remove_vertex(node)

    return g2


def find_bridges(g):
    r"""Finds all bridges in a graph."""
    global __time
    __time = 0
    sys.setrecursionlimit(g.num_vertices() + 1)
    visited = g.new_vertex_property("boolean", False)
    disc = g.new_vertex_property("float", float("inf"))
    low = g.new_vertex_property("float", float("inf"))
    parent = g.new_vertex_property("int", -1)
    is_bridge = g.new_edge_property("boolean", False)
    for node in range(g.num_vertices()):
        if not visited[node]:
            __dfs_find_bridges(g, node, visited, disc, low, parent, is_bridge)
    return is_bridge


def __dfs_find_bridges(g, node, visited, disc, low, parent, is_bridge):
    visited[node] = True
    global __time
    disc[node] = __time
    low[node] = __time
    __time += 1

    for nb in g.get_all_neighbors(node):
        if not visited[nb]:
            parent[nb] = node
            __dfs_find_bridges(g, int(nb), visited, disc, low, parent, is_bridge)
            low[node] = min(low[node], low[nb])
            if low[nb] > disc[node]:
                try:
                    is_bridge[g.edge(node, nb)] = True
                except:
                    pass
        elif int(nb) != parent[node]:  # TODO can in theory be removed
            low[node] = min(low[node], disc[nb])


def must(g, seed_ids, num_trees, hub_penalty, weights=None, tolerance=10):
    if gt.openmp_enabled():
        gt.openmp_set_num_threads(6)
    weights = edge_weights(g, weights, hub_penalty, inverse=True)
    scores = defaultdict(lambda: 0)
    node_name_attribute = 'name'
    seed_map = {g.vertex_properties[node_name_attribute][node]: node for node in seed_ids}
    seed_ids = list(seed_map.keys())
    first_tree = steiner_tree(g, seed_ids, seed_map, weights, hub_penalty > 0)
    num_found_trees = 1
    tree_edges = []
    tree_nodes = set()
    for tree_edge in first_tree.edges():
        source_name = first_tree.vertex_properties[node_name_attribute][first_tree.vertex_index[tree_edge.source()]]
        target_name = first_tree.vertex_properties[node_name_attribute][first_tree.vertex_index[tree_edge.target()]]
        tree_edges.append((gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=source_name)[0],
                           gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=target_name)[0]))
        tree_nodes.add(gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=source_name)[0])
        tree_nodes.add(gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=target_name)[0])
    cost_first_tree = sum([weights[g.edge(source, target)] for source, target in tree_edges])
    returned_nodes = set(int(gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute],
                                             match=first_tree.vertex_properties[node_name_attribute][node])[0]) for node
                         in range(first_tree.num_vertices()))
    for vertex in tree_nodes:
        scores[vertex] += 1
    if num_trees > 1:
        is_bridge = find_bridges(g)
        edge_filter = g.new_edge_property("boolean", True)
        while len(tree_edges) > 0:
            tree_edge = tree_edges.pop()
            g_edge = g.edge(tree_edge[0], tree_edge[1])
            if not is_bridge[g_edge]:
                edge_filter[g_edge] = False
                g.set_edge_filter(edge_filter)
                next_tree = steiner_tree(g, seed_ids, seed_map, weights, hub_penalty > 0)
                next_tree_edges = set()
                for next_tree_edge in next_tree.edges():
                    source_name = next_tree.vertex_properties[node_name_attribute][
                        next_tree.vertex_index[next_tree_edge.source()]]
                    target_name = next_tree.vertex_properties[node_name_attribute][
                        next_tree.vertex_index[next_tree_edge.target()]]
                    tree_nodes.add(
                        gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=source_name)[0])
                    tree_nodes.add(
                        gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute], match=target_name)[0])
                    next_tree_edges.add((gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute],
                                                         match=source_name)[0],
                                         gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute],
                                                         match=target_name)[0]))
                for vertex in tree_nodes:
                    scores[vertex] += 1
                cost_next_tree = sum([weights[g.edge(source, target)] for source, target in next_tree_edges])
                if cost_next_tree <= cost_first_tree * ((100.0 + tolerance) / 100.0):
                    num_found_trees += 1
                    for node in range(next_tree.num_vertices()):
                        returned_nodes.add(int(gtu.find_vertex(g, prop=g.vertex_properties[node_name_attribute],
                                                               match=next_tree.vertex_properties[node_name_attribute][
                                                                   node])[0]))
                    removed_edges = []
                    for source, target in tree_edges:
                        if not ((source, target) in set(next_tree_edges)) or ((target, source) in set(next_tree_edges)):
                            removed_edges.append((source, target))
                    for edge in removed_edges:
                        tree_edges.remove(edge)
                g.clear_filters()
                edge_filter[g_edge] = True
            if num_found_trees >= num_trees:
                break
    score_prop = g.new_vertex_property("float")
    for v, c in scores.items():
        score_prop[int(v)] = c
    return score_prop


def trustrank(g, seed_ids, damping_factor, hub_penalty=0, weights=None):
    if gt.openmp_enabled():
        gt.openmp_set_num_threads(6)
    weights = edge_weights(g, weights, hub_penalty, inverse=False)

    # Call graph-tool to compute TrustRank.
    trust = g.new_vertex_property("double")
    trust.a[[int(id) for id in seed_ids]] = 1.0 / len(seed_ids)

    scores = gtc.pagerank(g, damping=damping_factor, pers=trust, weight=weights)
    # Compute and return the results.
    return scores


def betweenness(g, hub_penalty, weights=None):
    if gt.openmp_enabled():
        gt.openmp_set_num_threads(6)
    weights = edge_weights(g, weights, hub_penalty, inverse=True)
    # Call graph-tool to compute TrustRank.
    # trust = g.new_vertex_property("double")
    scores, _ = gtc.betweenness(g, weight=weights)
    # Compute and return the results.
    return scores


def closeness(g, hub_penalty, weights=None):
    if gt.openmp_enabled():
        gt.openmp_set_num_threads(6)
    weights = edge_weights(g, weights, hub_penalty, inverse=True)
    # Call graph-tool to compute TrustRank.
    # trust = g.new_vertex_property("double")
    scores = gtc.closeness(g, weight=weights, harmonic=True)
    # Compute and return the results.
    return scores