# to be run with cctbx.python
import sys
import math

import iotbx.pdb

from collections import defaultdict

from printer import ShelxPrinter
from printer import RefmacPrinter
from atom import Atom

from lib.PO4 import PO4_ATOM_NAMES
from lib.PO4 import PO4_ATOM_RES
from lib.PO4 import PO4_CONDITION
from lib.PO4 import PO4_RESTRAINS
from lib.PO4 import PO4_DISTANCE_MEASURE
from lib.PO4 import PO4_CONDITION_DISTANCE_MEASURE

from lib.PO4_terminal_C5 import PO4_5_TERMINAL_ATOM_NAMES
from lib.PO4_terminal_C5 import PO4_5_TERMINAL_ATOM_RES
from lib.PO4_terminal_C5 import PO4_5_TERMINAL_CONDITION
from lib.PO4_terminal_C5 import PO4_5_TERMINAL_RESTRAINS
from lib.PO4_terminal_C5 import PO4_5_TERMINAL_DISTANCE_MEASURE
from lib.PO4_terminal_C5 import PO4_5_TERMINAL_CONDITION_DISTANCE_MEASURE

from lib.PO4_terminal_C3 import PO4_3_TERMINAL_ATOM_NAMES
from lib.PO4_terminal_C3 import PO4_3_TERMINAL_ATOM_RES
from lib.PO4_terminal_C3 import PO4_3_TERMINAL_CONDITION
from lib.PO4_terminal_C3 import PO4_3_TERMINAL_RESTRAINS
from lib.PO4_terminal_C3 import PO4_3_TERMINAL_DISTANCE_MEASURE
from lib.PO4_terminal_C3 import PO4_3_TERMINAL_CONDITION_DISTANCE_MEASURE

RES_NAMES = ['A', 'C', 'G', 'U', 'DA', 'DC', 'DG', 'DT']


class DistanceMeasure(object):
    def __init__(self, measure, restraint_names):
        self.measure = getattr(self, measure)
        self.restraint_names = restraint_names

    @classmethod
    def euclidean(cls, vector1, vector2):
        if len(vector1) == 0 or len(vector2) == 0:
            return None
        if len(vector1) != len(vector2):
            raise Exception('uneven number of elements')

        dist_sq_sum = 0.0
        for a, b in zip(vector1, vector2):
            diff = a-b
            dist_sq_sum += diff*diff
        return math.sqrt(dist_sq_sum)

    def atoms_dist(self, restraint, atoms):
        atom0 = atoms[restraint.atom_names[0]]
        atom1 = atoms[restraint.atom_names[1]]
        return atom0.dist(atom1)

    def atoms_angle(self, restraint, atoms):
        atom0 = atoms[restraint.atom_names[0]]
        atom1 = atoms[restraint.atom_names[1]]
        atom2 = atoms[restraint.atom_names[2]]
        return atom0.angle(atom1, atom2)

    def atoms_torsion(self, restraint, atoms):
        atom0 = atoms[restraint.atom_names[0]]
        atom1 = atoms[restraint.atom_names[1]]
        atom2 = atoms[restraint.atom_names[2]]
        atom3 = atoms[restraint.atom_names[3]]

        torsion = atom0.torsion(atom1, atom2, atom3)
        torsion = ConditionItem.fix_torsion(torsion)
        return torsion

    def distance(self, conditional_restraint, atoms):
        vector1 = []
        vector2 = []
        for restraint in conditional_restraint:
            if restraint.name in self.restraint_names:
                vector1.append(restraint.value)
                value = getattr(self, 'atoms_%s' % restraint.type)(restraint, atoms)
                vector2.append(value)
        return self.measure(vector1, vector2)

class ConditionalRestraintItem(object):
    def __init__(self, restraint_type, name, atom_names, value, sigma):
        self.type = restraint_type
        self.name = name
        self.atom_names = atom_names
        self.value = value
        self.sigma = sigma

    def get_restraint(self, atom_map):
        atoms = [atom_map[atom_name] for atom_name in self.atom_names]
        return Restraint(self.type, atoms, self.value, self.sigma)


class ConditionItem(object):
    multiplier = 4

    def __init__(self, condition_type, name, atom_names, value, sigma):
        self.type = condition_type
        self.name = name
        self.atom_names = atom_names
        self.value = self.fix_torsion(value)
        self.sigma = sigma

    @classmethod
    def fix_torsion(cls, value):
        if value > 180:
            value = value - ((1+int(value)/360))*360
        elif value < -180:
            value = value - (int(value)/360)*360
        return value

    def check_condition(self, atoms):
        if self.type == 'torsion':
            if len(self.atom_names) == 4:
                atom0 = atoms[self.atom_names[0]]
                atom1 = atoms[self.atom_names[1]]
                atom2 = atoms[self.atom_names[2]]
                atom3 = atoms[self.atom_names[3]]

                torsion = atom0.torsion(atom1, atom2, atom3)
                torsion = self.fix_torsion(torsion)

                for x in [0, 360, -360]:
                    if self.value-self.multiplier*self.sigma <= torsion+x <= self.value+self.multiplier*self.sigma:
                        #print 'condition true:', self.atom_names[0], self.atom_names[1], self.atom_names[2], self.atom_names[3], self.value-self.multiplier*self.sigma <= torsion+x <= self.value+self.multiplier*self.sigma, self.value-self.multiplier*self.sigma, torsion, self.value+self.multiplier*self.sigma
                        return True
                #print 'condition false:', self.atom_names[0], self.atom_names[1], self.atom_names[2], self.atom_names[3], self.value-self.multiplier*self.sigma <= torsion <= self.value+self.multiplier*self.sigma, self.value-self.multiplier*self.sigma, torsion, self.value+self.multiplier*self.sigma
                return False
            else:
                raise Exception('Wrong number of atoms for torsion condition')
        raise Exception('Unknown condition type')


class ConditionalRestraint(object):

    def __init__(self, name, conditions, restraints):
        create_condition = lambda condition: condition if isinstance(condition, ConditionItem) else ConditionItem(*condition)
        create_restraint = lambda restraint: restraint if isinstance(restraint, ConditionalRestraintItem) else ConditionalRestraintItem(*restraint)

        self.name = name
        self.conditions = [create_condition(con) for con in conditions]
        self.restraints = [create_restraint(res) for res in restraints]

    def check_conditions(self, atoms):
        if len(self.conditions) == 0 or self.name == 'default':
            return True
        for condition in self.conditions:
            if condition.check_condition(atoms) is False:
                return False
        return True

    def get_restraints(self, atoms):
        printable_restraints = []
        for restraint in self.restraints:
            printable_restraints.append(restraint.get_restraint(atoms))
        return printable_restraints


class ConditionalRestraintList(list):
    def __init__(self, data=[]):
        super(ConditionalRestraintList, self).__init__()
        for item in data:
            create_conditional_restraint = lambda obj: obj if isinstance(obj, ConditionalRestraint) else ConditionalRestraint(**obj)
            self.append(create_conditional_restraint(item))

    def __getslice__(self, i, j):
        return ConditionalRestraintList(list.__getslice__(self, i, j))

    def get_fesible(self, atoms):
        feasible = ConditionalRestraintList([])
        for conditional_restraint in iter(self):
            if conditional_restraint.check_conditions(atoms) is True:
                feasible.append(conditional_restraint)
        return feasible

    def get_names(self):
        return [conditional_restraint.name for conditional_restraint in self]

    def remove_deault(self):
        i_to_delete = []
        for i, conditional_restraint in enumerate(self):
            if conditional_restraint.name == 'default':
                i_to_delete.append(i)
        for i in i_to_delete:
            self.pop(i)

    def get_default(self):
        default = None
        for i, conditional_restraint in enumerate(self):
            if conditional_restraint.name == 'default':
                return conditional_restraint
        return default

    def find_closest(self, atoms, distance_measure, variable):
        if len(self) == 0:
            return None

        if 'default' in self.get_names() and len(feasible_restraints) > 1:
            self.remove_deault()

        min_distance = 999999999.0
        min_distance_i = None
        for i, conditional_restraint in enumerate(self):
            distance = distance_measure.distance(getattr(conditional_restraint, variable), atoms)
            if min_distance > distance:
                min_distance = distance
                min_distance_i = i

        if min_distance_i is not None:
            return self[min_distance_i]

        return None

    def find_restraint_closest(self, atoms, distance_measure):
        return self.find_closest(atoms, distance_measure, 'restraints')

    def find_condition_closest(self, atoms, distance_measure):
        return self.find_closest(atoms, distance_measure, 'conditions')

class Restraint(object):
    def __init__(self, restraint_type, atoms, value, sigma):
        self.type = restraint_type
        self.atoms = atoms
        self.value = value
        self.sigma = sigma


class MonomerRestraintGroup(object):
    def __init__(self, name, res_names, atom_labels, res_numbers, conditions, restraints, distance_measure, condition_measure):
        self.name = name
        self.res_names = res_names
        self.atom_labels = atom_labels
        self._valid_atom_labels = atom_labels.keys()
        self.res_numbers = res_numbers
        self.conditions = conditions
        self.distance_measure = DistanceMeasure(**distance_measure)
        self.condition_measure = DistanceMeasure(**condition_measure)

        if isinstance(restraints, ConditionalRestraintList):
            self.restraints = restraints
        else:
            self.restraints = ConditionalRestraintList(restraints)
        self.atoms = []
        # dict (chain_id, res_id-modifier, altloc): [] - lists with indexes to atoms
        self.groups = defaultdict(list)

    def is_valid_res_name(self, res_name):
        return res_name.strip() in self.res_names

    def is_valid_atom_name(self, atom_name):
        return atom_name.strip() in self._valid_atom_labels

    def add_atom(self, chain_id, res_id, res_name, atom_name, alt_loc, atom_xyz):
        if self.is_valid_res_name(res_name) and self.is_valid_atom_name(atom_name):
            self.atoms.append(Atom(chain_id, res_id, res_name, atom_name, alt_loc, atom_xyz))

    def create_res_groups(self):
        preliminary_groups = defaultdict(list)

        for i_atom, atom in enumerate(self.atoms):
            chain_id = atom.chain_id
            res_id = atom.res_id-self.res_numbers[atom.atom_name]
            # or key = (chain_id, res_id)
            key = "{}_{:05d}".format(chain_id, res_id)
            preliminary_groups[key].append(i_atom)

        for key, i_atom_list in preliminary_groups.iteritems():
            for i_atom in i_atom_list:
                locs = set([atom.alt_loc for atom in self.get_atoms(i_atom_list)])
            locs.discard('')

            if len(locs) > 0:
                # at least one atom with alt_loc code
                for loc in locs:
                    #create new key
                    key_alt = "{}_{}".format(key, loc)
                    for i_atom in i_atom_list:
                        atom_alt_loc = self.atoms[i_atom].alt_loc
                        if atom_alt_loc in (loc, ''):
                            self.groups[key_alt].append(i_atom)
            else:
                # no alt_loc codes
                for i_atom in i_atom_list:
                    self.groups[key].append(i_atom)

    def _print_groups(self):
        for key, atoms_indices in self.groups.iteritems():
            for i_atom in atoms_indices:
                atom = self.atoms[i_atom]
                print self.name, key, atom.chain_id, atom.res_id, atom.res_name, atom.atom_name, atom.alt_loc, atom.atom_xyz

    def get_atoms(self, i_atom_list):
        if isinstance(i_atom_list, dict):
            return [self.atoms[i_atom] for i_atom in i_atom_list.values()]
        return [self.atoms[i_atom] for i_atom in i_atom_list]

    @classmethod
    def find_atom(cls, atoms, atom_name):
        for atom in atoms:
            if atom.atom_name == atom_name:
                return atom
        return None

    def is_valid_atom_group(self, i_atom_list):
        atoms = self.get_atoms(i_atom_list)

        for atom_name_1, atom_name_2, dist in self.conditions:
            atom_1 = self.find_atom(atoms, atom_name_1)
            atom_2 = self.find_atom(atoms, atom_name_2)

            if atom_1 is None or atom_2 is None or atom_1.dist(atom_2) > dist:
                return False
        return True

    def validate_links(self):
        """
        Deletes groups that are not bonded correctly
        """
        for key in self.groups.keys():
            if not self.is_valid_atom_group(self.groups[key]):
                del self.groups[key]
                print "# {} ignoring {}".format(self.name, key)

    def get_atom(self, atom_proxies, atom_name):
        return self.atoms[atom_proxies[atom_name]]

    def _map_atom_proxies(self, atom_proxies):
        return {key: self.atoms[val] for key, val in atom_proxies.iteritems()}

    def atom_restraints(self):
        result_restraints = []
        for atom_group_key in sorted(self.groups.keys()):
            atom_proxies = {self.atoms[i_atom].atom_name: i_atom for i_atom in self.groups[atom_group_key]}

            atoms = self._map_atom_proxies(atom_proxies)

            feasible_restraints = self.restraints.get_fesible(atoms)
            closest_restraint = feasible_restraints.find_restraint_closest(atoms, self.distance_measure)

            if closest_restraint is not None:
                print "# {} group {} recognized as {}".format(self.name, atom_group_key, closest_restraint.name)
            else:
                print "# {} group {} not feasible".format(self.name, atom_group_key)

                closest_restraint = self.restraints.get_default()

                if closest_restraint is None:
                    print "# {} group {} default not provided".format(self.name, atom_group_key)
                    closest_restraint = self.restraints.find_condition_closest(atoms, self.condition_measure)
                print "# {} group {} closest to {}".format(self.name, atom_group_key, closest_restraint.name)

            result_restraints.extend(closest_restraint.get_restraints(atoms))

        return result_restraints

    def print_restraints(self, printer_class=RefmacPrinter):
        self.create_res_groups()
        #self._print_groups()
        self.validate_links()
        print "###############"

        restraints = self.atom_restraints()
        return printer_class.print_restraints(restraints)

def get_info(stream, restraint_list):
    print >> stream, '###########################################################################'
    print >> stream, '#                   RestraintLib version 2016.0.1                         #'
    print >> stream, '###########################################################################'
    print >> stream, "#                                                                         #"
    print >> stream, "# Kowiel, M., Brzezinski, D. & Jaskolski, M.(2016)                        #"
    print >> stream, "# Conformation-dependent restraints for polynucleotides I:                #"
    print >> stream, "# Clustering of the geometry of the phosphodiester group, to be published.#"
    print >> stream, "#                                                                         #"
    print >> stream, '###########################################################################'
    print >> stream, "#                                                                         #"
    libs = ", ".join((str(res.name) for res in restraint_list))
    print >> stream, "# Restraints for: {}{}#".format(libs, ' '*(77-21-len(libs)))
    print >> stream, "#                                                                         #"
    print >> stream, '###########################################################################'


def parse_pdb(in_pdb, restraint_groups, out_filename):
    data_pdb = iotbx.pdb.input(file_name=in_pdb)
    pdb_hierarchy = data_pdb.construct_hierarchy()
    for model in pdb_hierarchy.models():
        for chain in model.chains():
            for rg in chain.residue_groups():
                for ag in rg.atom_groups():
                    # if necessary only for speed optimization
                    for restraint in restraint_groups:
                        if restraint.is_valid_res_name(ag.resname) is True:
                            altloc = ag.altloc.strip()
                            for atom in ag.atoms():
                                restraint.add_atom(chain.id, rg.resid(), ag.resname, atom.name, altloc, atom.xyz)

    restraint_text_all = []
    for restraint in restraint_groups:
        restraint_text = restraint.print_restraints(RefmacPrinter)
        if len(restraint_text) > 0:
            restraint_text_all.append(restraint_text)

    if len(restraint_text_all) == 0:
        restraint_text_all.append("#There were no restraints to be created based on the submitted PDB file")

    restraint_text_all = "\n".join(restraint_text_all)

    if (type(out_filename) == str or type(out_filename) == unicode):
        with open(out_filename, 'w') as res_file:
            get_info(res_file, restraint_groups)
            print >> res_file, restraint_text_all
    else:
        get_info(out_filename, restraint_groups)
        print >> out_filename, restraint_text_all


def run():
    if len(sys.argv) > 2:
        in_pdb = sys.argv[1] #'../Data/hybryda.pdb'
        out_filename = sys.argv[2]
    else:
        in_pdb = 'in.pdb'
        out_filename = 'restraints.txt'

    restraint_list = []
    restraint_list.append(MonomerRestraintGroup('PO4', RES_NAMES, PO4_ATOM_NAMES, PO4_ATOM_RES, PO4_CONDITION, PO4_RESTRAINS, PO4_DISTANCE_MEASURE, PO4_CONDITION_DISTANCE_MEASURE))
    restraint_list.append(MonomerRestraintGroup('PO4_terminal_C5', RES_NAMES, PO4_5_TERMINAL_ATOM_NAMES, PO4_5_TERMINAL_ATOM_RES, PO4_5_TERMINAL_CONDITION, PO4_5_TERMINAL_RESTRAINS, PO4_5_TERMINAL_DISTANCE_MEASURE, PO4_5_TERMINAL_CONDITION_DISTANCE_MEASURE))
    restraint_list.append(MonomerRestraintGroup('PO4_terminal_C3', RES_NAMES, PO4_3_TERMINAL_ATOM_NAMES, PO4_3_TERMINAL_ATOM_RES, PO4_3_TERMINAL_CONDITION, PO4_3_TERMINAL_RESTRAINS, PO4_3_TERMINAL_DISTANCE_MEASURE, PO4_3_TERMINAL_CONDITION_DISTANCE_MEASURE))

    parse_pdb(in_pdb, restraint_list, out_filename)

run()
