from algorithms.AlgorithmWrapper import AlgorithmWrapper
from graph_tool import Graph
import subprocess
import os
import preprocess
import tempfile
#MC:
from configparser import ConfigParser
import ast

# python robust.py data/human_annotated_PPIs_brain.txt data/ms_seeds.txt ms.graphml 0.25 0.9 30 0.1
# python ../tools/robust/robust.py ../tools/robust/data/human_annotated_PPIs_brain.txt ../tools/robust/data/ms_seeds.txt ms.graphml 0.25 0.9 30 0.1
# ./cami.py -n ../tools/robust/data/human_annotated_PPIs_brain.txt -s ../tools/robust/data/ms_seeds.txt -v -id test_run -ncbi
# python -m pdb cami.py -n ../tools/robust/data/human_annotated_PPIs_brain.txt -s ../tools/robust/data/ms_seeds.txt -v -id test_run -ncbi
#           "*${ProgramArgs}",


class RobustWrapper(AlgorithmWrapper):
    def __init__(self):
        super().__init__()
        self.name = 'ROBUST'
        self.code = 3
        config = ConfigParser()
        config.read(self.config)

        self.initialFraction = float(config.get('robust', 'initial_fraction'))
        self.reductionFactor = float(config.get('robust', 'reduction_factor'))
        self.numberSteinerTrees = int(config.get('robust', 'number_steiner_trees'))
        self.threshold = float(config.get('robust', 'threshold'))

    def run_algorithm(self, inputparams):
        # -----------------------------------------------------
        # Checking for input from the command line:
        # -----------------------------------------------------
        # [1] file providing the network in the form of an edgelist
        #     (tab-separated table, columns 1 & 2 will be used)
        # [2] file with the seed genes (if table contains more than one
        #     column they must be tab-separated; the first column will be
        #     used only)
        # [3] path to output file
        # [4] initial fraction
        # [5] reduction factor
        # [6] number of steiner trees to be computed
        # [7] threshold
        robust_path = os.path.join(self.home_path, 'tools/robust')
        robust = f'cd "{robust_path}"; python robust.py'

        ppi = inputparams[0]
        seeds = inputparams[1]

        out_filename = self.name_file('out')
        algo_output = os.path.join(self.output_dir, out_filename)
        # algo_output = f'./ms.graphml'
        #0.25 0.9 30 0.1
        #MC:
        #CONFIG according to robust documentation https://github.com/bionetslab/robust
        command = f'{robust} "{ppi}" "{seeds}" "{algo_output}" \
            {self.initialFraction} {self.reductionFactor} \
            {self.numberSteinerTrees} {self.threshold}'
        subprocess.call(command, shell=True, stdout=subprocess.PIPE)
        if self.debug: print(f"Robust results saved in {algo_output}")
        return self.extract_output(algo_output)

        # inputparams = []
        # home_path = os.path.dirname(os.getcwd())
        # ppi_file = f"{home_path}/tool/robust/data/human_annotated_PPIs_brain.txt"
        # seed_file = f"{home_path}/tool/robust/data/ms_seeds.txt"

    def prepare_input(self):
        """prepares the input ppi and seed genes as needed by the algorithm
        """
        inputparams = []

        # prepare inputfiles
        ppi_filename = self.name_file('ppi')
        ppi_file = os.path.join(self.output_dir, ppi_filename)
        seed_filename = self.name_file('seeds')
        seed_file = os.path.join(self.output_dir, seed_filename)

        with open(ppi_file, "w") as file:
            file.write('node1\tnode2\n')
            for edge in self.ppi_network.edges():
                file.write(f"{str(edge.source())}\t{str(edge.target())}\n")
        inputparams.append(ppi_file)
        if self.debug:
            print(f'{self.name} ppi is saved in {ppi_file}')

        with open(seed_file, "w") as file:
            for seed in self.seeds:
                file.write(f"{seed}\n")
        if self.debug:
            print(f'{self.name} seeds are saved in {seed_file}')
        inputparams.append(seed_file)

        return inputparams

    def extract_output(self, algo_output):
        nodes = []
        with open(algo_output, "r") as output:
            for line in output.readlines():
                for node in line.split(' '):
                    nodes.append(int(node.strip()))
        return nodes