# coding: utf-8

import os

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

__author__ = "Marcin Kowiel, Dariusz Brzezinski"

def pandas_gather(df, key, value, cols):
    """
    Utility function for transforming multiple columns in a pandas data frame into a key and value column.
    :param df: data frame to transform
    :param key: key column name
    :param value: value column name
    :param cols: columns to transform
    :return: transformed data frame
    """
    id_vars = [ col for col in df.columns if col not in cols ]
    id_values = cols
    var_name = key
    value_name = value
    return pd.melt(df, id_vars, id_values, var_name, value_name)


def visualize_clusters(algorithm_name, data, labels, binding_num, prefix, angle_cols, inertia, silhouette,
                       result_folder, result_filename, width=18, height=10):
    """
    Creates a histogram depicting a given clustering
    :param algorithm_name: Name of the algorithm that produced the results being visualized
    :param data: dataset
    :param labels: cluster labels
    :param binding_num: number of element bindings
    :param prefix: prefix denoting the clustered atom (usually metal)
    :param inertia: inertia of the clustering
    :param silhouette: silhouette coefficient of the clustering
    :param result_folder: folder to save plots to
    :param result_filename: plot file name
    :param width: plot width (in)
    :param height: plot height (in)
    """
    plot_data = pd.DataFrame()

    for column in angle_cols:
        angle_labels = pd.concat([data.iloc[:, column], pd.Series(labels)], axis=1)
        angle_labels.columns = ["angle", "label"]
        plot_data = plot_data.append(angle_labels, ignore_index=True)

    plot_data["label"] = "label: " + plot_data["label"].astype(str)

    hue_grouped = plot_data["angle"].groupby(plot_data["label"])
    hue_names = sorted(plot_data["label"].unique())
    vals = []
    for label in hue_names:
        try:
            vals.append(np.asarray(hue_grouped.get_group(label)))
        except KeyError:
            vals.append(np.array([]))
    plt.hist(vals, histtype="barstacked", label=hue_names)
    plt.legend()

    sns.plt.title(algorithm_name + ": Clusters for " + prefix + " with " + str(binding_num) +
                  " bindings\n Silhouette: " + str(silhouette) + ", Inertia: " + str(inertia))
    fig = plt.gcf()
    fig.set_size_inches(width, height)
    fig.savefig(os.path.join(result_folder, result_filename) + ".png", bbox_inches="tight")
    plt.close()


def create_cluster_evaluation_plots(algorithm_name, binding_num, prefix, inertias, silhouettes, result_folder,
                                    result_filename, width=18, height=10):
    """
    Creates two scatter plots: one depicting inertia for growing k (number of clusters), one depicting the silhouette
    coefficient for growing k.
    :param algorithm_name: Name of the algorithm that produced the results being visualized
    :param binding_num: number of element bindings
    :param prefix: prefix denoting the clustered atom (usually metal)
    :param inertias: list of pairs (k, inertia, params)
    :param silhouettes: list of pairs (k, silhouette, params)
    :param result_folder: folder to save plots to
    :param result_filename: plot file name
    :param width: plot width (in)
    :param height: plot height (in)
    """
    for param, plot_data in inertias.groupby("params"):
        plt.plot(plot_data["k"], plot_data["inertia"], marker="o", label=param)
    plt.legend()
    plt.title(algorithm_name + ": Inertia for " + prefix + " with " + str(binding_num) + " bindings")

    fig = plt.gcf()
    fig.set_size_inches(width, height)
    fig.savefig(os.path.join(result_folder, result_filename) + "_inertia.png", bbox_inches="tight")
    plt.close()

    for param, plot_data in silhouettes.groupby("params"):
        plt.plot(plot_data["k"], plot_data["silhouette"], marker="o", label=param)
    plt.legend()
    plt.title(algorithm_name + ": Silhouette for " + prefix + " with " + str(binding_num) + " bindings")

    fig = plt.gcf()
    fig.set_size_inches(width, height)
    fig.savefig(os.path.join(result_folder, result_filename) + "_silhouette.png", bbox_inches="tight")
    plt.close()


def generate_scatter_matrix(df, columns, category_column, markers, result_folder, result_filename):
    """
    Creates a scatter matrix and saves it as png and svg files.
    :param df: data frame
    :param columns: data frame columns to plot
    :param category_column: column that defines poitn colors
    :param markers: markers
    :param result_folder: folder to save results to
    :param result_filename: plot file name
    :return:
    """
    plot_data = df.loc[:, columns]
    if plot_data[category_column].dtype == np.int64:
        plot_data[category_column] = plot_data[category_column].apply(lambda x: chr(x + ord("A")))

    fg = sns.pairplot(plot_data, markers=markers, hue=category_column, hue_order=plot_data["group"].unique(),
                      diag_kind="kde", plot_kws=dict(s=30))

    print "Saving scatter matrix plot to:", result_filename
    fg.savefig(os.path.join(result_folder, result_filename) + ".png", bbox_inches="tight")
    fg.savefig(os.path.join(result_folder, result_filename) + ".svg", format="svg", bbox_inches="tight")
    plt.close()


def generate_torsion_angle_plot(df, x, y, category_column, result_folder, result_filename, k_clusters, x_min=None,
                                x_max=None, y_min=None, y_max=None, palette=None):
    """
    Creates a scatter plot and saves it as png and svg files. Originally designed to display torsion angles.
    :param df: data frame
    :param x: attribute displayed on the x axis
    :param y: attribute displayed on the y axis
    :param category_column: attribute used to color points
    :param result_folder: folder to save results to
    :param result_filename: plot file name
    :param k_clusters: number of clusters = markers needed
    :param x_min: min x-axis tick value
    :param x_max: max x-axis tick value
    :param y_min: min y-axis tick value
    :param y_max: max y-axis tick value
    :param palette: color palette
    """
    if palette is None:
        palette = ["#9b59b6", "#3498db", "#fed82f", "#2ecc71", "#34495e", "#e74c3c"]
    sns.set(style="ticks", palette=palette, font_scale=1.5, font="Times New Roman")
    markers = ["*", "^", "o", "s", "D", "v", "x", "<", ">", "1", "-"]
    if k_clusters < 11:
        markers = markers[0:k_clusters]
    else:
        markers = None

    columns = [x, y, category_column]
    plot_data = df.loc[:, columns]
    if plot_data[category_column].dtype == np.int64:
        plot_data[category_column] = plot_data[category_column].apply(lambda x: chr(x + ord("A")))

    fg = sns.lmplot(x, y, data=plot_data, hue=category_column, hue_order=plot_data["group"].unique(),
                    fit_reg=False, markers=markers, scatter_kws=dict(s=40, alpha=1))
    fg.set(xlabel="$\\tau_{3}$", ylabel="$\\tau_{5}$")

    if y_min is not None and y_max is not None:
        for i, ax in enumerate(fg.axes.flat):
            ax.set_ylim(y_min, y_max)

    if x_min is not None and x_max is not None:
        for i, ax in enumerate(fg.axes.flat):
            ax.set_xlim(x_min, x_max)

    print "Saving torsion angle plot to:", result_filename
    fg.savefig(os.path.join(result_folder, result_filename) + ".png", bbox_inches="tight")
    fg.savefig(os.path.join(result_folder, result_filename) + ".svg", format="svg", bbox_inches="tight")
    plt.close()


def generate_custom_pair_plot(df, x_vars, y_vars, category_column, markers, result_folder, result_filename,
                              x_min=None, x_max=None, y_min=None, y_max=None):
    """
    Creates a scatter matrix based on a set of independent y and x axes attributes and saves it as png and svg files.
    :param df: data frame
    :param x_vars: attributes displayed on the x-axes
    :param y_vars: attributes displayed on the y-axes
    :param category_column: attribute used to color points
    :param markers: markers
    :param result_folder: folder to save results to
    :param result_filename: plot file name
    :param k_clusters: number of clusters = markers needed
    :param x_min: min x-axis tick value
    :param x_max: max x-axis tick value
    :param y_min: min y-axis tick value
    :param y_max: max y-axis tick value
    :param palette: color palette
    """
    columns = list(set(x_vars) | set(y_vars) | set([category_column]))
    plot_data = df.loc[:, columns]
    if plot_data[category_column].dtype == np.int64:
        plot_data[category_column] = plot_data[category_column].apply(lambda x: chr(x + ord("A")))

    fg = sns.pairplot(plot_data, x_vars=x_vars, y_vars=y_vars, markers=markers, hue=category_column,
                      hue_order=plot_data["group"].unique(), plot_kws=dict(s=40, alpha=1))
    x_len = fg.axes.shape[1]

    if y_min is not None and y_max is not None:
        for i, ax in enumerate(fg.axes.flat):
            ax.set_ylim(y_min[i/x_len], y_max[i/x_len])

    if x_min is not None and x_max is not None:
        for i, ax in enumerate(fg.axes.flat):
            ax.set_xlim(x_min[i % x_len], x_max[i % x_len])

    # for improved axis label readability
    for i, ax in enumerate(fg.axes.flat):
        for j, label in enumerate(ax.xaxis.get_ticklabels()):
            if j % 2 == 1:
                label.set_visible(False)
        for j, label in enumerate(ax.yaxis.get_ticklabels()):
            if j % 2 == 0:
                label.set_visible(False)

    print "Saving custom pair plot to:", result_filename
    fg.savefig(os.path.join(result_folder, result_filename) + ".png", bbox_inches="tight")
    fg.savefig(os.path.join(result_folder, result_filename) + ".svg", format="svg", bbox_inches="tight")
    plt.close()


def generate_histograms(df, x_vars, y_vars, category_column, markers, result_folder, result_filename,
                        x_min=None, x_max=None, bin_width=1, count_max=None):
    """
    Creates a histogram matrix based on a set of independent y and x axes attributes and saves it as png and svg files.
    :param df: data frame
    :param x_vars: attributes displayed on the x-axes
    :param y_vars: attributes displayed on the y-axes
    :param category_column: attribute used to color points
    :param markers: markers
    :param result_folder: folder to save results to
    :param result_filename: plot file name
    :param x_min: min x-axis tick value
    :param x_max: max x-axis tick value
    :param bin_width: width of each histogram bin
    :param count_max: max y-axis tick value
    """
    columns = list(set(x_vars) | set(y_vars) | set([category_column]))
    plot_data = df.loc[:, columns]
    if plot_data[category_column].dtype == np.int64:
        plot_data[category_column] = plot_data[category_column].apply(lambda x: chr(x + ord("A")))

    tidy_df = pandas_gather(df, "column", "value", x_vars)
    x_len = 1

    bin_lengths = []
    if x_min is not None and x_max is not None:
        x_len = x_min.__len__()
        for i in range(x_len):
            bin_lengths.append(np.arange(x_min[i], x_max[i] + bin_width, bin_width))
        fg = sns.FacetGrid(tidy_df, col="column", row="group", sharex=False, sharey=False, hue="column", margin_titles=True,
                       hue_kws={"bins": bin_lengths})
    else:
        fg = sns.FacetGrid(tidy_df, col="column", row="group", sharex=False, sharey=False, hue="column", margin_titles=True)

    if count_max is not None:
        for i, ax in enumerate(fg.axes.flat):
            ax.set_ylim(0, count_max[i/x_len])

    if x_min is not None and x_max is not None:
        for i, ax in enumerate(fg.axes.flat):
            ax.set_xlim(x_min[i % x_len], x_max[i % x_len])

    # for improved axis label readability
    for i, ax in enumerate(fg.axes.flat):
        for j, label in enumerate(ax.xaxis.get_ticklabels()):
            if j % 2 == 1:
                label.set_visible(False)

    if x_vars[0].startswith("$\\phi") or x_vars[0].startswith("$\\tau"):
        fg = (fg.map(plt.hist, "value", color="gray")
              .set_axis_labels("", "Count")
              .set_titles(template="", row_template="{row_name}", col_template="{col_name}")
              )
    else:
        fg = (fg.map(plt.hist, "value", color="gray")
              .set_axis_labels("", "Count")
              .set_titles(template="", row_template="{row_name}", col_template="{col_name}")
              )

    print "Saving custom histograms to:", result_filename
    fg.savefig(os.path.join(result_folder, result_filename) + ".png", bbox_inches="tight")
    fg.savefig(os.path.join(result_folder, result_filename) + ".svg", format="svg", bbox_inches="tight")
    plt.close()


def simple_histogram(data, bins, range, colors, labels, x_label, y_label, result_folder, result_filename):
    sns.set(style="ticks", palette="Set1", font_scale=1.5, font="Times New Roman")
    fg = plt.figure()
    plt.hist(data, bins, range, color=colors, label=labels)
    plt.legend()
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    print "Saving histogram to:", result_filename
    fg.savefig(os.path.join(result_folder, result_filename) + ".png", bbox_inches="tight")
    fg.savefig(os.path.join(result_folder, result_filename) + ".svg", format="svg", bbox_inches="tight")
    plt.close()


def generate_parallel_coordinates_plot(df, columns, category_column, medoid_idxs, markers, result_folder,
                                       result_filename, x_label, y_label):
    """
    Creates a parallel coordinates plot and saves it as png and svg files.
    :param df: data frame
    :param columns: attributes to depict for each category/
    :param category_column: attribute used to color lines
    :param medoid_idxs: category summaries to depict on plot
    :param markers: markers
    :param result_folder: folder to save results to
    :param result_filename: plot file name
    :param x_label: x-axis label
    :param y_label: y-axis label
    :return:
    """
    plot_data = df.loc[:, columns]
    if plot_data[category_column].dtype == np.int64:
        plot_data[category_column] = plot_data[category_column].apply(lambda x: chr(x + ord("A")))

    if medoid_idxs is not None:
        parallel_plot_data = plot_data.loc[medoid_idxs, :]
    else:
        parallel_plot_data = plot_data.groupby(category_column).mean()
        parallel_plot_data[category_column] = parallel_plot_data.index

    columns.remove(category_column)
    parallel_plot_data = pandas_gather(parallel_plot_data, x_label, y_label, columns)

    fg = sns.factorplot(x=x_label, y=y_label, markers=markers, hue=category_column, data=parallel_plot_data)

    print "Saving parallel coordinates plot to:", result_filename
    fg.savefig(os.path.join(result_folder, result_filename) + ".png", bbox_inches="tight")
    fg.savefig(os.path.join(result_folder, result_filename) + ".svg", format="svg", bbox_inches="tight")
    plt.close()
