# coding: utf-8

import os
import sys
import cluster
import plot

import pandas as pd
from shutil import copyfile
import seaborn as sns
import numpy as np
import calculate_distances as cd
import clustering_util as util


__author__ = "Marcin Kowiel, Dariusz Brzezinski"


RESTRAINT_TYPE_MAPPING = {
    # valence angles
    "aO1O2": "angle",
    "aO1O3": "angle",
    "aO1O5": "angle",
    "aO2O3": "angle",
    "aO2O5": "angle",
    "aO3O5": "angle",
    "aP4O5C5": "angle",
    # bond lengths
    "dO1P4": "dist",
    "dO2P4": "dist",
    "dO3P4": "dist",
    "dO5P4": "dist"
}

RESTRAINT_ATOM_MAPPING_C5 = {
    # valence angles
    "aO1O2": ["OP1", "P", "OP2"],
    "aO1O3": ["OP1", "P", "OP3"],
    "aO1O5": ["OP1", "P", "O5'"],
    "aO2O3": ["OP2", "P", "OP3"],
    "aO2O5": ["OP2", "P", "O5'"],
    "aO3O5": ["OP3", "P", "O5'"],
    "aP4O5C5": ["P", "O5'", "C5'"],
    # bond lengths
    "dO1P4": ["OP1", "P"],
    "dO2P4": ["OP2", "P"],
    "dO3P4": ["OP3", "P"],
    "dO5P4": ["O5'", "P"],
}

RESTRAINT_ATOM_MAPPING_C3 = {
    # valence angles
    "aO1O2": ["OP1", "P", "OP2"],
    "aO1O3": ["OP1", "P", "OP3"],
    "aO1O5": ["OP1", "P", "O3'"],
    "aO2O3": ["OP2", "P", "OP3"],
    "aO2O5": ["OP2", "P", "O3'"],
    "aO3O5": ["OP3", "P", "O3'"],
    "aP4O5C5": ["P", "O3'", "C3'"],
    # bond lengths
    "dO1P4": ["OP1", "P"],
    "dO2P4": ["OP2", "P"],
    "dO3P4": ["OP3", "P"],
    "dO5P4": ["O3'", "P"],
}

CONDITION_TYPE_MAPPING = {
    # torsion angles
    #"tC5O5P4O3": "torsion"
}

CONDITION_ATOM_MAPPING = {
    # valence angles
    #"tC5O5P4O3": ["OP3", "P", "O5'", "C5'"],
}

PERMUTATION_MAPPING = {
    "[0, 1, 2, 3]": {
        "dO1P4": "dO1P4", "dO2P4": "dO2P4", "dO3P4": "dO3P4", "dO5P4": "dO5P4",
        "aO1O2": "aO1O2", "aO1O3": "aO1O3", "aO1O5": "aO1O5", "aO2O3": "aO2O3", "aO2O5": "aO2O5", "aO3O5": "aO3O5",
        "tC5O5P4O3": "tC5O5P4O3",
        "aP4O3C3": "aP4O3C3", "aP4O5C5": "aP4O5C5",
        "torsion_modifier": 1
    },
    "[0, 2, 1, 3]": {
        "dO1P4": "dO1P4", "dO2P4": "dO3P4", "dO3P4": "dO2P4", "dO5P4": "dO5P4",
        "aO1O2": "aO1O3", "aO1O3": "aO1O2", "aO1O5": "aO1O5", "aO2O3": "aO2O3", "aO2O5": "aO3O5", "aO3O5": "aO2O5",
        "tC5O5P4O3": "tC5O5P4O3",
        "aP4O3C3": "aP4O3C3", "aP4O5C5": "aP4O5C5",
        "torsion_modifier": 1
    },
    "[1, 0, 2, 3]": {
        "dO1P4": "dO2P4", "dO2P4": "dO1P4", "dO3P4": "dO3P4", "dO5P4": "dO5P4",
        "aO1O2": "aO1O2", "aO1O3": "aO2O3", "aO1O5": "aO2O5", "aO2O3": "aO1O3", "aO2O5": "aO1O5", "aO3O5": "aO3O5",
        "tC5O5P4O3": "tC5O5P4O3",
        "aP4O3C3": "aP4O3C3", "aP4O5C5": "aP4O5C5",
        "torsion_modifier": -1
    },
    "[1, 2, 0, 3]": {
        "dO1P4": "dO2P4", "dO2P4": "dO3P4", "dO3P4": "dO1P4", "dO5P4": "dO5P4",
        "aO1O2": "aO2O3", "aO1O3": "aO1O2", "aO1O5": "aO2O5", "aO2O3": "aO1O3", "aO2O5": "aO3O5", "aO3O5": "aO1O5",
        "tC5O5P4O3": "tC5O5P4O3",
        "aP4O3C3": "aP4O3C3", "aP4O5C5": "aP4O5C5",
        "torsion_modifier": -1
    },
    "[2, 0, 1, 3]": {
        "dO1P4": "dO3P4", "dO2P4": "dO1P4", "dO3P4": "dO2P4", "dO5P4": "dO5P4",
        "aO1O2": "aO1O3", "aO1O3": "aO2O3", "aO1O5": "aO3O5", "aO2O3": "aO1O2", "aO2O5": "aO1O5", "aO3O5": "aO2O5",
        "tC5O5P4O3": "tC5O5P4O3",
        "aP4O3C3": "aP4O3C3", "aP4O5C5": "aP4O5C5",
        "torsion_modifier": 1
    },
    "[2, 1, 0, 3]": {
        "dO1P4": "dO3P4", "dO2P4": "dO2P4", "dO3P4": "dO1P4", "dO5P4": "dO5P4",
        "aO1O2": "aO2O3", "aO1O3": "aO1O3", "aO1O5": "aO3O5", "aO2O3": "aO1O2", "aO2O5": "aO2O5", "aO3O5": "aO1O5",
        "tC5O5P4O3": "tC5O5P4O3",
        "aP4O3C3": "aP4O3C3", "aP4O5C5": "aP4O5C5",
        "torsion_modifier": -1
    }
}


def join_po4_terminal_data(folder, file_names, combined_file_name):
    """
    Joins several files with PO4 into one dataset for clustering.
    :param folder: folder with raw source files
    :param file_names: files to combine
    :param combined_file_name: combined file with data for clustering
    """
    lines = [("NAME", "O1", "O2", "O3", "O5",
             "eO1", "eO2", "eO3", "eO5",
             "dO1P4", "dO2P4", "dO3P4", "dO5P4",
             "aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5",
             "tC5O5P4O3", "dO5C5", "aP4O5C5", "C5O5P4O3")]

    for filename in file_names:
        filename_short = filename.replace("PO4_", "").replace("_minsq", "").replace(".txt", "").replace("_", "")
        df = pd.read_csv(os.path.join(folder, filename), sep=" ")
        for i_row, row in df.iterrows():
            name = row["NAME"]
            angle12 = row["aO1O2"]
            angle13 = row["aO1O3"]
            angle14 = row["aO1O5"]
            angle23 = row["aO2O3"]
            angle24 = row["aO2O5"]
            angle34 = row["aO3O5"]
            label0 = row["P4"]
            label1 = row["O1"]
            label2 = row["O2"]
            label3 = row["O3"]
            label4 = row["O5"]
            dist1 = row["dO1P4"]
            dist2 = row["dO2P4"]
            dist3 = row["dO3P4"]
            dist4 = row["dO5P4"]
            tor2 = row["tC5O5P4O3"]
            dist6 = row["dO5C5"]
            angle5c = row["aP4O5C5"]
            rawt5 = row["C5O5P4O3"]
            key = "%s_%s_%s" % (filename_short, name, label0)
            lines.append((key,
                          label1, label2, label3, label4,
                          8, 8, 8, 8,
                          dist1, dist2, dist3, dist4,
                          angle12, angle13, angle14, angle23, angle24, angle34,
                          tor2, dist6, angle5c,
                          rawt5))

    out_file = open(os.path.join(folder, combined_file_name), "w")
    for line in lines:
        sline = ",".join((str(_) for _ in line))
        print >> out_file, sline
    out_file.close()


def generate_plots(df, result_folder, result_filename, k_clusters, y_min=None, y_max=None, count_max=None,
                   bin_width=1, palette=None, verbose=False):
    """
    Visualizes the clustered data and saves the visualizations to files.
    :param df: data frame with the clustered data
    :param medoid_idxs: row indexes of group medoids
    :param result_folder: folder to save results to
    :param result_filename: plot file name
    :param k_clusters: number of clusters to visualize
    :param y_min: min y-axis tick value for scatter plots
    :param y_max: max y-axis tick value for scatter plots
    :param count_max: max y-axis tick value for histograms
    :param bin_width: width of each histogram bin
    :param palette: color palette to use
    :param verbose: if True creates additional plots (mainly scatter matrices)
    """
    print "----------------------"
    print "Generating plots for: ", result_filename
    print "----------------------"

    if palette is None:
        palette = ["#9b59b6", "#3498db", "#fed82f", "#2ecc71", "#34495e", "#e74c3c"]
    sns.set(style="ticks", palette=palette)
    markers = ["*", "^", "o", "s", "D", "v", "x", "<", ">", "1", "-"]
    if k_clusters < 11:
        markers = markers[0:k_clusters]
    else:
        markers = None

    # custom scatter plots
    plot.generate_custom_pair_plot(df,
                                   ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5"],
                                   ["tC5O5P4O3"],
                                   "group", markers, result_folder, result_filename + ".pair_plot",
                                   x_min=[109, 110, 104, 110, 104, 100],
                                   x_max=[115, 116, 110, 116, 110, 106],
                                   y_min=y_min,
                                   y_max=y_max)

    plot.generate_custom_pair_plot(df,
                                   ["dO1P4", "dO2P4", "dO3P4", "dO5P4"],
                                   ["tC5O5P4O3"],
                                   "group", markers, result_folder, result_filename + ".dist_pair_plot",
                                   x_min=[1.48, 1.49, 1.49, 1.60],
                                   x_max=[1.54, 1.55, 1.55, 1.66],
                                   y_min=y_min,
                                   y_max=y_max)

    # histograms
    plot.generate_histograms(df,
                             ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5"],
                             ["tC5O5P4O3"],
                             "group", markers, result_folder, result_filename + ".hist",
                             x_min=[109, 110, 104, 110, 104, 100],
                             x_max=[115, 116, 110, 116, 110, 106],
                             count_max=count_max,
                             bin_width=bin_width)

    if verbose:
        # angle scatter matrices
        plot.generate_scatter_matrix(df, ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5", "group"], "group", markers, result_folder, result_filename + ".scatter_matrix")
        plot.generate_scatter_matrix(df, ["tC5O5P4O3", "group"], "group", markers, result_folder, result_filename + ".torsion_matrix")
        plot.generate_scatter_matrix(df, ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5", "tC5O5P4O3", "group"], "group", markers, result_folder, result_filename + ".scatter_matrix_with_torsion")

        # angle parallel coordinates
        plot.generate_parallel_coordinates_plot(df, ["aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5", "group"], "group", None, markers, result_folder, result_filename + ".parallel_coordinates", "angle", "degrees")

        # binding length scatter matrices
        plot.generate_scatter_matrix(df, ["dO1P4", "dO2P4", "dO3P4", "dO5P4", "group"], "group", markers, result_folder, result_filename + ".binding_scatter_matrix")

        # binding length coordinates
        plot.generate_parallel_coordinates_plot(df, ["dO1P4", "dO2P4", "dO3P4", "dO5P4", "group"], "group", None, markers, result_folder, result_filename + ".binding_parallel_coordinates", "binding", "length")


def manual_categorization_terminal(data_folder, data_file, manual_clustering_folder, manual_clustering, binding_num,
                                   feasible_permutations, mad_outlier_test_columns):
    """
    Performs manual categorization of the CSD data and summarizes the results.
    :param data_folder: folder with source CSD data
    :param data_file: CSD data file
    :param manual_clustering_folder: folder to save summaries and visualizations to
    :param manual_clustering: file name prefix to save summaries and visualizations to
    :param binding_num: number of P bindings
    :param feasible_permutations: allowed atom permutations
    :param mad_outlier_test_columns: columns to test for outliers using the MAD test
    """
    data_df = pd.read_csv(os.path.join(data_folder, data_file), sep=",", header=0, index_col=None)
    groups_df = pd.read_csv(os.path.join(manual_clustering_folder, manual_clustering), sep=",", header=0,
                            names=["index", "group"])
    manually_selected_k = groups_df.loc[:, "group"].nunique()
    manually_clustered_data_df = pd.concat([data_df, groups_df], axis=1)
    outliers = cluster.mad_outlier_detection(manually_clustered_data_df, groups_df.loc[:, "group"],
                                             mad_outlier_test_columns)
    manually_clustered_data_df_filtered = manually_clustered_data_df.drop(outliers)
    generate_plots(manually_clustered_data_df_filtered, manual_clustering_folder, manual_clustering,
                   manually_selected_k)  # , y_min=[-140, -120], y_max=[120, 220], count_max=[4, 18, 25, 7, 30, 35], verbose=False)
    util.create_clustering_summary(manually_clustered_data_df_filtered, manual_clustering_folder, manual_clustering + "_C5",
                                   binding_num, feasible_permutations, True, RESTRAINT_TYPE_MAPPING,
                                   RESTRAINT_ATOM_MAPPING_C5, CONDITION_TYPE_MAPPING, CONDITION_ATOM_MAPPING,
                                   PERMUTATION_MAPPING, "_C5")
    util.create_clustering_summary(manually_clustered_data_df_filtered, manual_clustering_folder, manual_clustering + "_C3",
                                   binding_num, feasible_permutations, True, RESTRAINT_TYPE_MAPPING,
                                   RESTRAINT_ATOM_MAPPING_C3, CONDITION_TYPE_MAPPING, CONDITION_ATOM_MAPPING,
                                   PERMUTATION_MAPPING, "_C3")


def automatic_clustering_terminal(data_folder, data_file, clustering_results_folder, clustering_results_file,
                                  binding_num, feasible_permutations, mad_outlier_test_columns, prefix):
    """
    Performs automatic clustering  of the CSD data and summarizes the results.
    :param data_folder: folder with source CSD data
    :param data_file: CSD data file
    :param clustering_results_folder: folder to save summaries and visualizations to
    :param clustering_results_file: file name prefix to save summaries and visualizations to
    :param binding_num: number of P bindings
    :param feasible_permutations: allowed atom permutations
    :param mad_outlier_test_columns: columns to test for outliers using the MAD test
    :param prefix: automatic clustering result files' prefix.
    """
    selected_clustering_folder, selected_clustering, selected_k = util.select_best_clustering(prefix, binding_num,
                                                                                              clustering_results_folder,
                                                                                              clustering_results_file,
                                                                                              max_outliers_pct=0.2)
    aligned_clustered_data_df, clusters_df = util.align_data(data_folder, data_file, selected_clustering_folder,
                                                             selected_clustering, binding_num,
                                                             torsion_angles=["tC5O5P4O3"],
                                                             inverse_permutations=[[1, 0, 2, 3]])
    cluster_outliers = cluster.mad_outlier_detection(aligned_clustered_data_df, clusters_df.loc[:, "group"],
                                                     mad_outlier_test_columns)
    aligned_clustered_data_df_filtered = aligned_clustered_data_df.drop(cluster_outliers)
    generate_plots(aligned_clustered_data_df_filtered, selected_clustering_folder, selected_clustering, selected_k,
                   palette="Dark2")  # , y_min=[-140, -120], y_max=[120, 220], verbose=False)
    util.create_clustering_summary(aligned_clustered_data_df_filtered, selected_clustering_folder, selected_clustering,
                                   binding_num, feasible_permutations)


def main(args):
    """
    Main method. Performs clustering, categorization, and outputs cluster summaries and visualizations.
    :param args: script arguments. One argument is recognized - "cached". When "cached" is passed to the script it will
    not recalculate the distance matrix or perform automatic clustering, it will use data from a previous uncached run
    instead and only summarize and visualize the results.
    """
    # Clustering parameters
    seed = 23
    binding_num = 4
    max_jobs = 4
    feasible_permutations = [[0, 1, 2, 3], [1, 0, 2, 3]]
    feasible_permutations_output = [[0, 1, 2, 3], [0, 2, 1, 3], [1, 0, 2, 3], [1, 2, 0, 3], [2, 0, 1, 3], [2, 1, 0, 3]]
    mad_outlier_test_columns = ["dO1P4", "dO2P4", "dO3P4", "dO5P4",
                                "aO1O2", "aO1O3", "aO1O5", "aO2O3", "aO2O5", "aO3O5",
                                "dO5C5", "aP4O5C5"]
    clusterers = {
        cluster.dbscan:
            [{"max_angle_diff": np.arange(0.05, 1.25, 0.05), "min_samples": [3, 4, 5]}],
        cluster.pam:
            [{"k_clusters": range(2, 13)}],
        cluster.ahc:
            [{"k_clusters": range(2, 13), "linkage": ["complete", "average"], "nc": ["auto"], "t": [2]}],
        cluster.spectral:
            [{"k_clusters": range(2, 13), "eigen_tol": np.arange(0.02, 0.12, 0.02), "seed": [seed],
              "n_neighbors": [5, 7, 10]}]
    }
    sns.set(style="ticks", palette="Set1")

    # Files and folders
    prefix = "P_terminal_"
    data_folder = os.path.join("..", "Data", "PO4_terminal")
    data_file = "P_terminal_4.csv"
    source_files = (
        "terminal_P_0_data.txt",
        "terminal_P_180_data.txt",
    )
    clustering_results_folder = os.path.join(os.path.dirname(__file__), prefix + "Results")
    clustering_results_file = os.path.join(clustering_results_folder, prefix + "clustering_results.csv")
    manual_clustering_folder = os.path.join(clustering_results_folder, "manual")
    manual_clustering = "PO4_terminal_manual_clustering.csv"

    # Cleanup and automatic clustering grid search (parameter optimization)
    if len(args) < 2 or "cached" not in args[1]:
        util.cleanup_and_prepare_folders(data_folder, binding_num, clustering_results_folder, clustering_results_file,
                                         manual_clustering_folder, manual_clustering)
        join_po4_terminal_data(data_folder, source_files, data_file)
        cd.compute_and_save_distance_matrix(prefix, binding_num, data_folder, max_jobs, header=0,
                                            feasible_permutations=feasible_permutations)
        cluster.calculate_all_clusterings(clusterers, data_folder, binding_num, prefix, header=0)

    # Manual categorization and automatic clustering
    manual_categorization_terminal(data_folder, data_file, manual_clustering_folder, manual_clustering, binding_num,
                                   feasible_permutations_output, mad_outlier_test_columns)
    automatic_clustering_terminal(data_folder, data_file, clustering_results_folder, clustering_results_file,
                                  binding_num, feasible_permutations_output, mad_outlier_test_columns, prefix)

###########################
#       Main script       #
###########################
if __name__ == "__main__":
    main(sys.argv)
