# coding: utf-8

import os
import shutil
import json

import pandas as pd
import uuid
import plot
import itertools as it
import calculate_distances as cd
import numpy as np
from operator import attrgetter


from sklearn.externals import joblib

__author__ = "Marcin Kowiel, Dariusz Brzezinski"


def cleanup_and_prepare_folders(data_folder, binding_num, clustering_results_folder, clustering_results_file,
                                manual_clustering_folder, manual_clustering):
    """
    Cleans folder and files with outdated results
    :param data_folder: source data folder
    :param binding_num: number of bindings
    :param clustering_results_folder: clustering results folder
    :param clustering_results_file: clustering results file
    :param manual_clustering_folder: folder with visualizations of manual clustering
    :param manual_clustering: manual cluster assignment csv file
    """
    print "Performing cleanup..."
    dir_path = os.path.join(data_folder, str(binding_num))

    if os.path.exists(dir_path):
        print "clean %s folder" % dir_path
        shutil.rmtree(dir_path)
    if os.path.exists(clustering_results_folder):
        print "clean %s folder" % clustering_results_folder
        shutil.rmtree(clustering_results_folder)
    if os.path.exists(manual_clustering_folder):
        print "clean %s folder" % manual_clustering_folder
        shutil.rmtree(manual_clustering_folder)

    if not os.path.exists(clustering_results_folder):
        os.makedirs(clustering_results_folder)

    if not os.path.exists(manual_clustering_folder):
        os.makedirs(manual_clustering_folder)

    if os.path.isfile(clustering_results_file):
        os.remove(clustering_results_file)

    shutil.copyfile(os.path.join(data_folder, manual_clustering),
                    os.path.join(manual_clustering_folder, manual_clustering))
    shutil.copyfile(os.path.join(data_folder, "gelbin.restraints.json"),
                    os.path.join(manual_clustering_folder, "gelbin.restraints.json"))
    shutil.copyfile(os.path.join(data_folder, "kowiel.restraints.json"),
                    os.path.join(manual_clustering_folder, "kowiel.restraints.json"))


def select_best_clustering(prefix, binding_num, clustering_results_folder, clustering_results_file,
                           max_outliers_pct=0.1):
    """
    Selects the best out of a set of clusterings based on the Silhouette coefficient while restraining the maximum
    number of outliers.
    :param prefix: clustering prefix
    :param binding_num: number of bindings
    :param clustering_results_folder: results folders
    :param clustering_results_file: results summary filename
    :param max_outliers_pct: maximum allowed percentage of outliers
    :return: folder with best clustering, best clustering csv filename, best number of clusters
    """
    clustering_results = pd.read_csv(clustering_results_file, sep=";").sort_values(by="Silhouette", ascending=False)

    for index, row in clustering_results.iterrows():
        max_outliers = max_outliers_pct * row["Examples"]

        if row["No of outliers"] < max_outliers:
            selected_clustering_folder = os.path.join(clustering_results_folder, row["Algorithm"])
            selected_clustering = prefix + row["Algorithm"] + "_" + str(binding_num) + "bindings_" + row["Params"] + ".csv"
            selected_k = row["No of clusters"]
            break

    return selected_clustering_folder, selected_clustering, selected_k


def calculate_cluster_statistics(cluster_members, binding_num, label, min_col=None):
    """
    Calculates cluster statistics.
    :param cluster_members: cluster members
    :param medoid_idx: medoid idx
    :param binding_num: binding num
    :param label: cluster label
    :param min_col: columns to take into summary; defualt binding_num+1
    :return: cluster mean values, cluster standard deviations
    """
    if min_col is None:
        min_col = binding_num + 1

    mean = pd.Series(cluster_members.iloc[:, min_col:].mean(axis=0), name="mean " + str(label))
    mean.set_value("Cluster size", len(cluster_members))
    std = pd.Series(cluster_members.iloc[:, min_col:].std(axis=0), name="std " + str(label))

    return mean, std


def create_clustering_summary(clustered_data_df, result_folder, result_filename, binding_num, feasible_permutations,
                              serialize_as_restraints=False, restraint_type_mapping=None, restraint_atom_mapping=None,
                              condition_type_mapping=None, condition_atom_mapping=None, permutation_mapping=None,
                              restraint_subname=None, min_col=None):
    """
    Creates a csv summary of a clustering. This method is also capable of serializing clusters as restraints JSON.
    :param clustered_data_df: folder with dataset
    :param result_folder: folder with clustering data
    :param result_filename: clustering name
    :param binding_num: number of bindings
    :param feasible_permutations: allowed permutations
    :param serialize_as_restraints: if True serializes the cluster summaries as JSON
    :param restraint_type_mapping: restraint type mapping for serialization
    :param restraint_atom_mapping: restraint atom mapping for serialization
    :param condition_type_mapping: condition type mapping for serialization
    :param condition_atom_mapping: condition type mapping for serialization
    :param permutation_mapping: permutation label mapping for serialization
    :param restraint_subname: additional duffix in the restraint name
    :param min_col: min column number to take into summary
    """
    cluster_summaries = list()
    restraints = list()
    labels = clustered_data_df.loc[:, "group"]

    for l in set(labels):
        # label -1 denotes outliers
        if l != -1:
            class_members = clustered_data_df.loc[labels == l].copy()
            mean, std = calculate_cluster_statistics(class_members, binding_num, l, min_col)
            cluster_summary = pd.concat([mean, std], axis=1)
            cluster_summaries.append(cluster_summary)

            if serialize_as_restraints:
                    restraint = create_conditional_restraint(cluster_summary, binding_num, restraint_type_mapping,
                                                             restraint_atom_mapping, condition_type_mapping,
                                                             condition_atom_mapping, permutation_mapping,
                                                             feasible_permutations, restraint_subname)
                    restraints.extend(restraint)

        pd.concat(cluster_summaries, axis=1).to_csv(os.path.join(result_folder, result_filename) + ".summary.csv",
                                                    sep=";", decimal=",")

    if serialize_as_restraints:
        serialized_file_name = os.path.join(result_folder, result_filename) + ".restraints.json"
        print "Serializing restraints to", serialized_file_name
        with open(serialized_file_name, "wt") as restraint_file:
            print >> restraint_file, json.dumps(restraints, indent=4, sort_keys=True, cls=RestraintEncoder)


def align_data(data_folder, data_filename, result_folder, result_filename, binding_num, torsion_angles,
               inverse_permutations):
    """
    Aligns angles and distance columns of clustered data to medoids. This is required to get a meaningful comparison of
    angles within clusters.
    :param data_folder: folder with dataset
    :param data_filename: dataset name
    :param result_folder: folder with clustering data
    :param result_filename: clustering name
    :param binding_num: number of bindings
    :param torsion_angles: list of torsion angle attribute names
    :param inverse_permutations: permutations which require changing the sign of torsion angles
    :return: a data frame with all the clustered data (with and additional "group" column defining cluster assignment)
    with angles and binding lengths aligned to cluster medoids; list of  indexes of medoids within the returned data
    frame
    """
    # load the dataset and cluster assignments
    data_df = pd.read_csv(os.path.join(data_folder, data_filename), sep=",", header=0, index_col=None)
    groups_df = pd.read_csv(os.path.join(result_folder, result_filename), sep=",", header=0, names=["index", "group"])
    clustered_data_df = pd.concat([data_df, groups_df], axis=1)

    labels = clustered_data_df.loc[:, "group"]
    columns = list(clustered_data_df.columns.values)
    order_dm = joblib.load(os.path.join(data_folder, str(binding_num), data_filename + ".odm"))
    distance_df = pd.DataFrame(joblib.load(os.path.join(data_folder, str(binding_num), data_filename + ".cdm")))

    # align each cluster element to its medoid
    aligned_clusters = []
    medoid_idxs = []

    for l in set(labels):
        # label -1 denotes outliers
        if l != -1:
            class_members = clustered_data_df.loc[labels == l].copy()
            class_member_distances = distance_df.loc[labels == l, labels == l]
            class_members.loc[:, "inertia"] = (class_member_distances ** 2).sum(axis=1)
            medoid_idx = class_members["inertia"].idxmin()
            medoid_idxs.append(medoid_idx)

            for index, row in class_members.iterrows():
                aligned_order = cd.angle_cols(binding_num, order_dm[medoid_idx][index])

                # check if angles need reordering
                if aligned_order != cd.angle_cols(binding_num):
                    new_angle_values = class_members.loc[index].iloc[aligned_order].copy()
                    for i, col in enumerate(cd.angle_cols(binding_num)):
                        class_members.set_value(index, columns[col], new_angle_values[i])

                    aligned_binding_order = cd.binding_cols(binding_num, order_dm[medoid_idx][index])
                    new_binding_values = class_members.loc[index].iloc[aligned_binding_order].copy()
                    for i, col in enumerate(cd.binding_cols(binding_num)):
                        class_members.set_value(index, columns[col], new_binding_values[i])

                    # change sign of the torsion angles if necessary
                    for ip in inverse_permutations:
                        if aligned_order == cd.angle_cols(binding_num, ip):
                            for ta in torsion_angles:
                                class_members.set_value(index, ta, -class_members.loc[index, ta])

                # align torsion angle value to medoid (to avoid strange mean values)
                for ta in torsion_angles:
                    align_torsion_angle_to_medoid(class_members, medoid_idx, index, ta)
            aligned_clusters.append(class_members)

    return pd.concat(aligned_clusters, axis=0), groups_df


def align_torsion_angle_to_medoid(class_members, medoid_idx, row_idx, torsion_angle_column):
    """
    Aligns torsion angles to medoid torsion angles so that values of different signs don"t mess with the torsion angle
    mean. Effectively this function changes angles like -170 to 190 and 170 to -190.
    :param class_members: clustered objects
    :param medoid_idx: medoid index
    :param row_idx: selected object index
    :param torsion_angle_column: name of the torion angle column in the data frame
    """
    row_value = class_members.loc[row_idx, torsion_angle_column]
    medoid_sign = np.sign(class_members.loc[medoid_idx, torsion_angle_column])

    if medoid_sign == 1 and row_value < -135:
        class_members.set_value(row_idx, torsion_angle_column, row_value + 360)
    elif medoid_sign == -1 and row_value > 135:
        class_members.set_value(row_idx, torsion_angle_column, row_value - 360)


def load_json_restraints(filename):
    with open(filename) as json_file:
        json_data = json.load(json_file)
        result_dict = dict()

        for json_row in json_data:
            restraints_dict = dict()
            conditions_dict = dict()

            for restraint in json_row["restraints"]:
                restraints_dict[restraint[1]] = {"mean": restraint[3], "std": restraint[4]}

            for condition in json_row["conditions"]:
                conditions_dict[condition[1]] = {"mean": condition[3], "std": condition[4]}

            result_dict[json_row["name"]] = {"restraints": restraints_dict, "conditions": conditions_dict}

        return result_dict


def check_std(clustered_data_df, result_folder, po4_restraints, gelbin_restraints, verbose=False):
    gelbin = load_json_restraints(os.path.join(result_folder, gelbin_restraints))
    po4 = load_json_restraints(os.path.join(result_folder, po4_restraints))

    po4_cluster_mapper = {
        "AS(+sc/+sc)": ["AS_2", "AS_3"],
        "AS(-sc/-sc)": ["AS_0", "AS_1"],
        "AA(-sc/ap)": ["AA_0"],
        "AA(+sc/ap)": ["AA_3"],
        "AA(ap/-sc)": ["AA_1"],
        "AA(ap/+sc)": ["AA_2"],
    }
    gelbin_cluster_mapper = {
        "AS(+sc/+sc)": ["G_0", "G_1", "G_2", "G_3"],
        "AS(-sc/-sc)": ["G_0", "G_1", "G_2", "G_3"],
        "AA(-sc/ap)": ["G_0", "G_1", "G_2", "G_3"],
        "AA(+sc/ap)": ["G_0", "G_1", "G_2", "G_3"],
        "AA(ap/-sc)": ["G_0", "G_1", "G_2", "G_3"],
        "AA(ap/+sc)": ["G_0", "G_1", "G_2", "G_3"]
    }
    angle_keys = ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5"]
    column_keys = ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5", "aP4O3C3", "aP4O5C5", "dO1P4", "dO2P4",
                   "dO3P4", "dO5P4"]
    result_dfs = []

    for mapper, json_data, prefix in ((po4_cluster_mapper, po4, "po4"), (gelbin_cluster_mapper, gelbin, "gelbin")):
        result_values = []

        for i_row, row in clustered_data_df.iterrows():
            clusters = mapper[row["group"].strip()]
            dist = list()
            row_values = []

            for i_cluster, cluster in enumerate(clusters):
                dist_val = 0
                for angle in angle_keys:
                    diff = row[angle] - json_data[cluster]["restraints"][angle]["mean"]
                    dist_val += diff*diff
                dist.append(dist_val)

            closest_cluster = clusters[dist.index(min(dist))]

            row_values.append(row["NAME"])
            row_values.append(row["group"])
            row_values.append(prefix)

            for column in column_keys:
                std_diff = abs(row[column] - json_data[closest_cluster]["restraints"][column]["mean"])
                row_values.append(std_diff)

            result_values.append(row_values)

        column_names = ["NAME", "group", "restraint"]
        column_names.extend(column_keys)
        result_df = pd.DataFrame(result_values, columns=column_names)
        result_dfs.append(result_df)

        if verbose:
            plot.generate_histograms(result_df,
                                     ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5", "aP4O3C3", "aP4O5C5"],
                                     ["tC3O3P4O5", "tC5O5P4O3"],
                                     "group", None, result_folder, prefix + "_angle_std.hist",
                                     bin_width=1, x_min=[0]*8, x_max=[7]*8, count_max=[40]*8*5)
            plot.generate_histograms(result_df,
                                     ["dO1P4", "dO2P4", "dO3P4", "dO5P4"],
                                     ["tC3O3P4O5", "tC5O5P4O3"],
                                     "group", None, result_folder, prefix + "_dist_std.hist",
                                     bin_width=0.01*1, x_min=[0.01*0] * 4, x_max=[0.01*7] * 4, count_max=[30] * 4 * 5)
            result_df.to_csv(os.path.join(result_folder, prefix + "_std.csv"))

    po4_angle_df = plot.pandas_gather(result_dfs[0], "column", "value", ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5", "aP4O3C3", "aP4O5C5"])
    gelbin_angle_df = plot.pandas_gather(result_dfs[1], "column", "value", ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5", "aP4O3C3", "aP4O5C5"])
    po4_dist_df = plot.pandas_gather(result_dfs[0], "column", "value", ["dO1P4", "dO2P4", "dO3P4", "dO5P4"])
    gelbin_dist_df = plot.pandas_gather(result_dfs[1], "column", "value", ["dO1P4", "dO2P4", "dO3P4", "dO5P4"])

    plot.simple_histogram([po4_angle_df["value"], gelbin_angle_df["value"]], 12, (0, 6), ["#80B5D6", "#e74c3c"],
                          ["This work", "Gelbin et al."], "Angle difference [$^\circ$]", "Count", result_folder, "restrain_angle_comparison.hist")
    plot.simple_histogram([po4_dist_df["value"], gelbin_dist_df["value"]], 12, (0, 0.06), ["#80B5D6", "#e74c3c"],
                          ["This work", "Gelbin et al."], "Bond length difference [$\AA$]", "Count", result_folder, "restrain_dist_comparison.hist")


def create_conditional_restraint(cluster_summary, binding_num, restraint_type_mapping=None, restraint_atom_mapping=None,
                                 condition_type_mapping=None, condition_atom_mapping=None, permutation_mapping=None,
                                 feasible_permutations=None, restraint_subname=""):
    """
    Serializes restraints to a file or to the standard output.
    :param cluster_summary: cluster summary (means and standard deviations of angles and binding lengths)
    :param binding_num: binding num
    :param restraint_type_mapping: restraint type mapping for serialization
    :param restraint_atom_mapping: restraint atom mapping for serialization
    :param condition_type_mapping: condition type mapping for serialization
    :param condition_atom_mapping: condition type mapping for serialization
    :param permutation_mapping: permutation label mapping for serialization
    :param feasible_permutations: allowed permutations
    :param restraint_subname: additional duffix in the restraint name
    :return a list of dictionary all feasible permutations for a single restraint
    """
    restraint_list = list()

    if feasible_permutations is None:
        feasible_permutations = it.permutations(range(binding_num))

    if restraint_subname is None:
        restraint_subname = ""

    for permutation_idx, permutation in enumerate(feasible_permutations):
        restraints = list()
        conditions = list()

        for row_idx, row in cluster_summary.iterrows():
            if row_idx in restraint_type_mapping and row_idx in restraint_atom_mapping:
                idx, modifier = _permute_row_index(row_idx, restraint_type_mapping, permutation, permutation_mapping)
                restraints.append(RestraintDefinition(restraint_type_mapping[idx], idx, restraint_atom_mapping[idx], modifier*row.iloc[0], row.iloc[1]))
            elif row_idx in condition_type_mapping and row_idx in condition_atom_mapping:
                idx, modifier = _permute_row_index(row_idx, condition_type_mapping, permutation, permutation_mapping)
                conditions.append(RestraintDefinition(condition_type_mapping[idx], idx, condition_atom_mapping[idx], modifier*row.iloc[0], row.iloc[1]))

        restraint_dict = dict()
        restraint_dict["name"] = cluster_summary.iloc[:, 0].name[4:].strip() + restraint_subname +"_" + str(permutation_idx)
        restraint_dict["conditions"] = sorted(conditions, key=attrgetter("name"))
        restraint_dict["restraints"] = sorted(restraints, key=attrgetter("name"))
        restraint_list.append(restraint_dict)

    return restraint_list


def _permute_row_index(row_idx, type_mapping, permutation, permutation_mapping):
    permutation_key = repr(permutation)
    permuted_idx = row_idx
    modifier = 1

    if permutation_key in permutation_mapping and row_idx in permutation_mapping[permutation_key]:
        permuted_idx = permutation_mapping[permutation_key][row_idx]

        if "torsion_modifier" in permutation_mapping[permutation_key] and type_mapping[row_idx] == "torsion":
            modifier = permutation_mapping[permutation_key]["torsion_modifier"]

    return permuted_idx, modifier


class RestraintDefinition:
    def __init__(self, restraint_type, name, atoms, mean, sigma):
        if restraint_type == "angle":
            self.precision = 1
        else:
            self.precision = 3

        self.restraint_type = restraint_type
        self.name = name
        self.atoms = atoms
        self.mean = round(mean, self.precision)
        self.sigma = round(sigma, self.precision)

    def get_definition_list(self):
        return [self.restraint_type, self.name, self.atoms, self.mean, self.sigma]


class RestraintEncoder(json.JSONEncoder):
    def __init__(self, *args, **kwargs):
        super(RestraintEncoder, self).__init__(*args, **kwargs)
        self.kwargs = dict(kwargs)
        del self.kwargs['indent']
        self._replacement_map = {}

    def default(self, o):
        if isinstance(o, RestraintDefinition):
            key = uuid.uuid4().hex
            self._replacement_map[key] = json.dumps(o.get_definition_list(), **self.kwargs)
            return "@@%s@@" % (key,)
        else:
            return super(RestraintEncoder, self).default(o)

    def encode(self, o):
        result = super(RestraintEncoder, self).encode(o)
        for k, v in self._replacement_map.iteritems():
            result = result.replace('"@@%s@@"' % (k,), v)
        return result
