from drugstone.management.includes.DataLoader import DataLoader
import drugstone.models as models
from drugstone.management.includes.NodeCache import NodeCache


class DataPopulator:

    def __init__(self, cache: NodeCache):
        self.cache = cache

    def populate_expressions(self, update):

        self.cache.init_proteins()
        df = DataLoader.load_expressions()

        tissues_models = dict()
        for tissue_name in df.columns.values[2:]:
            tissue,_ = models.Tissue.objects.get_or_create(name=tissue_name)
            tissues_models[tissue_name] = tissue

        proteins_linked = 0
        bulk = set()
        uniq = set()

        size = 0
        for _, row in df.iterrows():
            gene_name = row['Description']

            for protein_model in self.cache.get_proteins_by_gene(gene_name):
                proteins_linked += 1
                if not update or self.cache.is_new_protein(protein_model):
                    for tissue_name, tissue_model in tissues_models.items():
                        expr = models.ExpressionLevel(protein=protein_model,
                                                      tissue=tissue_model,
                                                      expression_level=row[tissue_name])
                        id = hash(expr)
                        if id in uniq:
                            continue
                        uniq.add(id)
                        bulk.add(expr)
            if len(bulk) > 100000:
                models.ExpressionLevel.objects.bulk_create(bulk)
                size += len(bulk)
                bulk = set()

        models.ExpressionLevel.objects.bulk_create(bulk)
        return size + len(bulk)

    def populate_ensg(self,update) -> int:
        """ Populates the Ensembl-Gene table in the django database.
        Also maps the added ensg entries to the corresponding proteins.
        Handles loading the data and passing it to the django database

        Returns:
            int: Count of how many ensg-protein relations were added
        """
        self.cache.init_proteins()
        data = DataLoader.load_ensg()
        bulk = list()

        for entrez, ensg_list in data.items():
            proteins = self.cache.get_proteins_by_entrez(entrez)
            for protein in proteins:
                for ensg in ensg_list:
                    if not update or self.cache.is_new_protein(protein):
                        bulk.append(models.EnsemblGene(name=ensg, protein=protein))
        models.EnsemblGene.objects.bulk_create(bulk)
        return len(bulk)

    def populate_ppi_string(self, dataset, update) -> int:
        """ Populates the Protein-Protein-Interactions from STRINGdb
        Handles loading the data and passing it to the django database

        Returns:
            int: Count of how many interactions were added
        """
        self.cache.init_proteins()

        df = DataLoader.load_ppi_string()
        bulk = list()
        for _, row in df.iterrows():
            try:
                # try fetching proteins
                proteins_a = self.cache.get_proteins_by_entrez(row['entrez_a'])
                proteins_b = self.cache.get_proteins_by_entrez(row['entrez_b'])
            except KeyError:
                continue
            for protein_a in proteins_a:
                for protein_b in proteins_b:
                    if not update or (self.cache.is_new_protein(protein_a) or self.cache.is_new_protein(protein_b)):
                        bulk.append(models.ProteinProteinInteraction(
                            ppi_dataset=dataset,
                            from_protein=protein_a,
                            to_protein=protein_b
                        ))
        models.ProteinProteinInteraction.objects.bulk_create(bulk)
        return len(bulk)

    def populate_ppi_apid(self, dataset, update) -> int:
        """ Populates the Protein-Protein-Interactions from Apid
        Handles loading the data and passing it to the django database

        Returns:
            int: Count of how many interactions were added
        """
        self.cache.init_proteins()

        df = DataLoader.load_ppi_apid()
        bulk = set()
        for _, row in df.iterrows():
            try:
                # try fetching proteins
                protein_a = self.cache.get_protein_by_uniprot(row['from_protein_ac'])
                protein_b = self.cache.get_protein_by_uniprot(row['to_protein_ac'])
            except KeyError:
                # continue if not found
                continue
            if not update or (self.cache.is_new_protein(protein_a) or self.cache.is_new_protein(protein_b)):
                bulk.add(models.ProteinProteinInteraction(
                    ppi_dataset=dataset,
                    from_protein=protein_a,
                    to_protein=protein_b
                ))
        models.ProteinProteinInteraction.objects.bulk_create(bulk)
        return len(bulk)

    def populate_ppi_biogrid(self,dataset, update) -> int:
        """ Populates the Protein-Protein-Interactions from BioGRID
        Handles loading the data and passing it to the django database

        Returns:
            int: Count of how many interactions were added
        """
        self.cache.init_proteins()

        df = DataLoader.load_ppi_biogrid()
        bulk = list()
        for _, row in df.iterrows():
            try:
                # try fetching proteins
                proteins_a = self.cache.get_proteins_by_entrez(row['entrez_a'])
                proteins_b = self.cache.get_proteins_by_entrez(row['entrez_b'])
            except KeyError:
                # TODO update error
                # continue if not found
                continue
            for protein_a in proteins_a:
                for protein_b in proteins_b:
                    if not update or (self.cache.is_new_protein(protein_a) or self.cache.is_new_protein(protein_b)):
                        bulk.append(models.ProteinProteinInteraction(
                            ppi_dataset=dataset,
                            from_protein=protein_a,
                            to_protein=protein_b
                        ))
        models.ProteinProteinInteraction.objects.bulk_create(bulk)
        return len(bulk)

    def populate_pdi_chembl(self,dataset, update) -> int:
        """ Populates the Protein-Drug-Interactions from Chembl
        Handles Loading the data and passing it to the django database

        Returns:
            int: Count of how many interactions were added
        """
        self.cache.init_proteins()
        self.cache.init_drugs()

        df = DataLoader.load_pdi_chembl()
        bulk = set()
        for _, row in df.iterrows():
            try:
                protein = self.cache.get_protein_by_uniprot(row['protein_ac'])
            except KeyError:
                # continue if not found
                continue
            try:
                # try fetching drug
                drug = self.cache.get_drug_by_drugbank(row['drug_id'])
            except KeyError:
                # continue if not found
                continue
            if not update or (self.cache.is_new_protein(protein) or self.cache.is_new_drug(drug)):
                bulk.add(models.ProteinDrugInteraction(
                    pdi_dataset=dataset,
                    protein=protein,
                    drug=drug
                ))
        models.ProteinDrugInteraction.objects.bulk_create(bulk)
        return len(bulk)

    def populate_pdis_disgenet(self, dataset, update) -> int:
        """ Populates the Protein-Disorder-Interactions from DisGeNET
        Handles Loading the data and passing it to the django database

        Returns:
            int: Count of how many interactions were added
        """
        self.cache.init_proteins()
        self.cache.init_disorders()

        df = DataLoader.load_pdis_disgenet()
        bulk = set()
        for _, row in df.iterrows():
            try:
                # try fetching protein
                protein = self.cache.get_protein_by_uniprot(row['protein_name'])
            except KeyError:
                # continue if not found
                continue
            try:
                # try fetching disorder
                disorder = self.cache.get_disorder_by_mondo(row['disorder_name'])
            except KeyError:
                # continue if not found
                continue
            if not update or (self.cache.is_new_protein(protein) or self.cache.is_new_disease(disorder)):
                bulk.add(models.ProteinDisorderAssociation(
                    pdis_dataset=dataset,
                    protein=protein,
                    disorder=disorder,
                    score=row['score']
                ))
        models.ProteinDisorderAssociation.objects.bulk_create(bulk)
        return len(bulk)

    def populate_drdis_drugbank(self, dataset, update) -> int:
        """ Populates the Drug-Disorder-Indications from DrugBank
        Handles Loading the data and passing it to the django database

        Returns:
            int: Count of how many edges were added
        """
        self.cache.init_drugs()
        self.cache.init_disorders()

        df = DataLoader.load_drdis_drugbank()
        bulk = set()
        for _, row in df.iterrows():
            try:
                # try fetching protein
                drug = self.cache.get_drug_by_drugbank(row['drugbank_id'])
            except KeyError:
                # continue if not found
                continue
            try:
                # try fetching drug
                disorder = self.cache.get_disorder_by_mondo(row['mondo_id'])
            except KeyError:
                # continue if not found
                continue
            if not update or (self.cache.is_new_drug(drug) or self.cache.is_new_disease(disorder)):
                bulk.add(models.DrugDisorderIndication(
                    drdi_dataset=dataset,
                    drug=drug,
                    disorder=disorder,
                ))
        models.DrugDisorderIndication.objects.bulk_create(bulk)
        return len(bulk)

    def populate_pdi_dgidb(self,dataset, update) -> int:
        """ Populates the Protein-Drug-Interactions from DGIdb
        Handles Loading the data and passing it to the django database

        Returns:
            int: Count of how many interactions were added
        """
        self.cache.init_proteins()
        self.cache.init_drugs()

        df = DataLoader.load_pdi_dgidb()
        bulk = set()
        for _, row in df.iterrows():
            try:
                proteins = self.cache.get_proteins_by_entrez(row['entrez_id'])
            except KeyError:
                continue
            try:
                drug = self.cache.get_drug_by_drugbank(row['drug_id'])
            except KeyError:
                continue
            for protein in proteins:
                if not update or (self.cache.is_new_protein(protein) or self.cache.is_new_drug(drug)):
                    bulk.add(models.ProteinDrugInteraction(
                        pdi_dataset=dataset,
                        protein=protein,
                        drug=drug
                    ))
        models.ProteinDrugInteraction.objects.bulk_create(bulk)
        return len(bulk)

    def populate_pdi_drugbank(self,dataset, update) -> int:
        """ Populates the Protein-Drug-Interactions from Drugbank
        Handles Loading the data and passing it to the django database

        Returns:
            int: Count of how many interactions were added
        """
        self.cache.init_proteins()
        self.cache.init_drugs()

        df = DataLoader.load_pdi_drugbank()
        bulk = set()
        for _, row in df.iterrows():
            try:
                proteins = self.cache.get_proteins_by_entrez(row['entrez_id'])
            except KeyError:
                continue
            try:
                drug = self.cache.get_drug_by_drugbank(row['drug_id'])
            except KeyError:
                continue
            for protein in proteins:
                if not update or (self.cache.is_new_protein(protein) or self.cache.is_new_drug(drug)):
                    bulk.add(models.ProteinDrugInteraction(
                        pdi_dataset=dataset,
                        protein=protein,
                        drug=drug
                    ))
        models.ProteinDrugInteraction.objects.bulk_create(bulk)
        return len(bulk)