# coding: utf-8

import os
import math
import csv
import datetime
import numpy as np
import pandas as pd
import plot

from collections import namedtuple
from collections import Counter
from sklearn.externals import joblib
from sklearn.grid_search import ParameterGrid
from sklearn import cluster, metrics
from calculate_distances import angle_cols
from k_medoids import KMedoids

__author__ = "Marcin Kowiel, Dariusz Brzezinski"


def dbscan(data, dist, binding_num, max_angle_diff=5, min_samples=7, seed=None):
    """
    Performs clustering using the DBSCAN algorithm.
    :param data: dataset
    :param dist: precomputed distance matrix for the dataset
    :param binding_num: number of bindings each element has
    :param max_angle_diff: maximal feasible angle difference in each dimension between "neighboring" compounds
    :param min_samples: minimum number of neighboring element to form a new cluster (if this value is not met the
    elements will be labeled as outliers)
    :param seed: random seed for reproducibility
    :return: cluster labels for all the examples in the dataset, number of detected clusters, list of outliers
    """
    eps = math.sqrt(binding_num) * max_angle_diff
    db = cluster.DBSCAN(eps=eps, min_samples=min_samples, metric="precomputed", random_state=seed).fit(dist)

    labels = db.labels_
    k_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    outliers = data.iloc[labels == -1]

    return labels, k_clusters, outliers


def pam(data, dist, binding_num, k_clusters, init="heuristic", max_iter=300, seed=None):
    """
    Performs clustering using the k-medoids algorithm (Partitioning Around Medoids).
    :param data: dataset
    :param dist: precomputed distance matrix for the dataset
    :param binding_num: number of bindings each element has
    :param k_clusters: number of expected clusters
    :param init: medoid initialization technique; possible values are "random" and "heuristic".
    :param max_iter: maximum number of medoid repositioning iterations (if medoids stop "moving" earlier, the algorithm
    also stops earlier)
    :param seed: random seed for reproducibility
    :return: cluster labels for all the examples in the dataset, number of detected clusters, list of outliers
    """
    pam = KMedoids(n_clusters=k_clusters, distance_metric="precomputed", init=init, max_iter=300,
                   random_state=seed).fit(dist)

    labels = pam.labels_
    outliers = pd.DataFrame()

    return labels, k_clusters, outliers


def spectral(data, dist, binding_num, k_clusters, eigen_tol=0.0, seed=None, n_neighbors=10):
    """
    Performs spectral clustering (Laplacian matrix + precomputed distance matrix + k-means).
    :param data: dataset
    :param dist: precomputed distance matrix for the dataset
    :param binding_num: number of bindings each element has
    :param k_clusters: number of expected clusters
    :param eigen_tol: Stopping criterion for eigendecomposition of the Laplacian matrix
    :param seed: random seed for reproducibility
    :param mad_outlier_test_columns: columns for which to perform outlier detection using the M test
    :return: cluster labels for all the examples in the dataset, number of detected clusters, list of outliers
    """
    sp = cluster.SpectralClustering(n_clusters=k_clusters, eigen_solver="arpack", affinity="precomputed",
                                    eigen_tol=eigen_tol, random_state=seed, n_neighbors=n_neighbors).fit(dist)
    labels = sp.labels_
    outliers = pd.DataFrame()

    return labels, k_clusters, outliers


def ahc(data, dist, binding_num, k_clusters, linkage="average", nc=None, t=None, mad_outlier_test_columns=False):
    """
    Performs hierarchical clustering (Agglomerative Hierarchical Clustering).
    :param data: dataset
    :param dist: precomputed distance matrix for the dataset
    :param binding_num: number of bindings each element has
    :param k_clusters: number of expected clusters
    :param linkage: method for joining clusters; possible values: "average" and "complete"
    :param nc: number of clusters for outlier detection proposed by Loreiro et al (2004)
    :param t: minimum number of object in a cluster; used only for outlier detection proposed by Loreiro et al (2004)
    :return: cluster labels for all the examples in the dataset, number of detected clusters, list of outliers
    """
    predetected_outliers = predetect_outliers_ahc(dist, linkage, nc, t)
    filtered_dist = dist.copy()

    if predetected_outliers is not None:
        filtered_dist = np.delete(filtered_dist, predetected_outliers, 0)
        filtered_dist = np.delete(filtered_dist, predetected_outliers, 1)

    ahc = cluster.AgglomerativeClustering(n_clusters=k_clusters, affinity="precomputed", linkage=linkage).fit(filtered_dist)
    labels = ahc.labels_
    outliers = pd.DataFrame()

    if predetected_outliers is not None:
        predetected_outliers.sort()

        for idx in predetected_outliers:
            labels = np.insert(labels, idx, -1, axis=0)
            outliers = outliers.append(data.iloc[idx])

    return labels, k_clusters, outliers


def predetect_outliers_ahc(dist, linkage="average", nc=23, t=4):
    """
    Outlier detection method proposed by Loreiro et al. (2004)
    :param dist: distance matrix
    :param linkage: linkage for AHC clustering
    :param nc: dendrogram cut-off point for outlier detection
    :param t: minimum number of objects in cluster; if smaller a cluster is considered as an outlier
    :return: None if parameter nc or t is not set; detected outliers otherwise
    """
    if nc is None or t is None:
        return None

    detected_outliers = []
    if nc == "auto":
        nc = max(2, dist.shape[0]/10) # as suggested by Loreiro et al.

    clusterer = cluster.AgglomerativeClustering(n_clusters=nc, affinity="precomputed", linkage=linkage).fit(dist)
    label_counter = Counter(clusterer.labels_)

    for key, value in label_counter.iteritems():
        if value < t:
            detected_outliers.extend(np.where(clusterer.labels_ == key)[0])

    return detected_outliers


def mad_outlier_detection(data, labels, columns):
    outliers = []

    for l in set(labels):
        if l != -1:
            class_angles = data.loc[labels == l, columns]
            class_medians = class_angles.median(axis=0)
            class_mads = (class_angles - class_medians).abs().median()

            for idx, row in class_angles.iterrows():
                m = (0.6745 * (row - class_medians) / class_mads.replace({0: np.nan})).abs()
                if (m > 3.5).any():
                    outliers.append(idx)

    return outliers

def get_data_and_distance_matrix(data_folder, binding_num, prefix, header=None):
    """
    Reads a precomputed distance matrix for the given number of element bindings from disk.
    :param data_folder: folder with all precomputed distance matrices (matrices for specific binding numbers should
    reside in subfolders of this folder)
    :param binding_num: number of element bindings
    :param prefix: prefix denoting the clustered atom (usually metal)
    :return: the dataset and distance matrix for the given number of element bindings
    """
    dataset_name = prefix + str(binding_num) + ".csv"
    dataset_path = os.path.join(data_folder, dataset_name)
    distance_matrix_path = os.path.join(data_folder, str(binding_num), dataset_name + ".cdm")

    data = pd.read_csv(dataset_path, sep=",", header=header)
    dist = joblib.load(distance_matrix_path)

    return data, dist


def evaluate_clusters(algorithm_name, params, data, dist, binding_num, prefix, labels, k_clusters, outliers):
    """
    Evaluates and summarizes a given clustering
    :param algorithm_name: name of the algorithm that produced the results being evaluated
    :param data: dataset
    :param dist: precomupted distance matrix
    :param binding_num: number of element bindings
    :param prefix: prefix denoting the clustered atom (usually metal)
    :param labels: cluister labels for each example
    :param k_clusters: number of clusters
    :param outliers: list of outliers
    :param printout: if True prints out a summary of the evaluation to the console
    :param verbose: if True (and printout is also True) prints out all the detected outliers
    :return: cluster summaries as named tuples (medoid, label, size), the inertia and silhoutte coefficient of the
    given clustering
    """
    ClusterSummary = namedtuple("ClusterSummary", "medoid label size")
    n = data.shape[0]
    n_outliers = len(outliers)
    distance_df = pd.DataFrame(dist)
    inertia = 0
    cluster_summaries = list()
    if k_clusters > 1:
        # we do not want take into account outliers, therefore we omit labels == -1
        silhouette = metrics.silhouette_score(dist[np.ix_(labels != -1, labels != -1)], labels[labels != -1])
    else:
        silhouette = -1

    for l in set(labels):
        if l != -1:
            class_members = data.iloc[labels == l].copy()
            class_member_distances = distance_df.iloc[labels == l, labels == l]
            class_members.loc[:, "inertia"] = (class_member_distances ** 2).sum(axis=1)
            medoid = class_members.loc[class_members["inertia"].idxmin()]
            inertia += medoid.loc["inertia"]

            cluster_summaries.append(ClusterSummary(medoid, l, len(class_members)))

    medoids_str = ""
    outliers_str = ""
    for cluster in cluster_summaries:
        medoids_str += cluster.medoid.iloc[0] + ": " + str(cluster.medoid.iloc[angle_cols(binding_num)].apply(np.round, decimals=2).values) +\
                   "(label: " + str(cluster.label) + ", clustered elements: " + str(cluster.size) + ")"  + "\r\n"
    for o in outliers.itertuples():
        outliers_str += str(o) + "\r\n"

    write_evaluation_to_file(binding_num, prefix, algorithm_name, params, n, k_clusters, n_outliers, inertia, silhouette,
                             medoids_str, outliers)

    return cluster_summaries, inertia, silhouette


def write_evaluation_to_file(binding_num, prefix, algorithm_name, params, n, k_clusters, n_outliers, inertia,
                             silhouette, medoids, outliers):
    """
    Appends an evaluation summary to a csv file.
    :param binding_num: number of element bindings
    :param prefix: prefix denoting the clustered atom (usually metal)
    :param algorithm_name: name of the algorithm that produced the results being evaluated
    :param params: algorithm parameters
    :param n: number of elements to cluster
    :param k_clusters: number of clusters
    :param n_outliers: number of outliers
    :param inertia: inertia of clustering
    :param silhouette: silhouette coefficient of clustering
    :param medoids: list of medoids
    :param outliers: list of outliers
    """
    file_path = os.path.join(prefix + "Results", prefix + "clustering_results.csv")

    if os.path.isfile(file_path):
        write_header = False
        mode = "ab"
    else:
        write_header = True
        mode = "wb"

    with open(file_path, mode) as f:
        writer = csv.writer(f, delimiter=";", quoting=csv.QUOTE_NONNUMERIC)

        if write_header:
            writer.writerow(["Binding num", "Algorithm", "Params", "Examples", "No of clusters", "No of outliers",
                             "Inertia", "Silhouette", "Medoids", "Outliers"])

        writer.writerow([str(binding_num), algorithm_name, params, n, k_clusters, n_outliers, inertia, silhouette,
                         medoids, outliers])


def grid_search(algorithm, grid_params, data, dist, binding_num, prefix):
    """
    Performs a grid search to find the best clustering according to the silhouette coefficient.
    :param algorithm: algorithm to be evaluated
    :param grid_params: algorithm parameters to be tested
    :param data: dataset
    :param dist: precomupted distance matrix
    :param binding_num: number of element bindings
    :param prefix: prefix denoting the clustered atom (usually metal)
    :return: the best clustering in the form of a named tuple (parameters, labels, cluster_num, outliers,
    cluster_summaries, inertia, silhouette), all the calculated inertias and silhouette coefficients
    """
    Clustering = namedtuple("Clustering", "parameters labels cluster_num outliers cluster_summaries inertia silhouette")
    inertias = pd.DataFrame(columns=["k","inertia","params"])
    silhouettes = pd.DataFrame(columns=["k","silhouette","params"])
    grid_clusterings = list()

    for parameters in ParameterGrid(grid_params):
        try:
            labels, cluster_num, outliers = algorithm(data, dist, binding_num, **parameters)

            pd.DataFrame(labels).to_csv(os.path.join(prefix + "Results", algorithm.__name__, prefix + algorithm.__name__ + "_" + str(binding_num) + "bindings_" + str(parameters).replace(": ", "=") + ".csv"))
            cluster_summaries, inertia, silhouette = evaluate_clusters(algorithm.__name__, str(parameters).replace(": ", "="), data, dist, binding_num, prefix, labels, cluster_num, outliers)

            grid_clusterings.append(Clustering(parameters, labels, cluster_num, outliers, cluster_summaries, inertia, silhouette))
            plot.visualize_clusters(algorithm.__name__, data, labels, binding_num, prefix, angle_cols(binding_num), inertia, silhouette, result_folder=os.path.join(prefix + "Results", algorithm.__name__), result_filename=prefix + algorithm.__name__ + "_" + str(binding_num) + "bindings_" + str(parameters).replace(": ", "="))

            if "k_clusters" in parameters: del parameters["k_clusters"]
            inertias = inertias.append(pd.Series([cluster_num, inertia, str(parameters)], index=["k", "inertia", "params"]), ignore_index=True)
            silhouettes = silhouettes.append(pd.Series([cluster_num, silhouette, str(parameters)], index=["k", "silhouette", "params"]), ignore_index=True)
        except Exception as e:
            print(str(e))

    if grid_clusterings.__len__() > 0:
        best = sorted(grid_clusterings, key=lambda c: c.silhouette, reverse=True)[0]
    else:
        best = None

    return best, inertias, silhouettes


def calculate_all_clusterings(clusterers, data_path, binding_num, prefix, header=None):
    """
    Given a set of clustering algorithms and their parameter grids, selects the best parameters for all algorithms for
    the specified number of element bindings.
    :param clusterers: dictionary with clustering algorithms (functions) as keys, and a dictionary of parameter
    name-value pairs as values
    :param data_path: folder with all precomputed distance matrices (matrices for specific binding numbers should
    reside in subfolders of this folder)
    :param binding_num: number of element bindings for which the algorithms should be evaluated
    :param prefix: prefix denoting the clustered atom (usually metal)
    :param header: determines whether the source csv file has a header row
    """
    print "Calculating all clusterings of %s with %s bindings: %s" % (prefix, binding_num, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    data, dist = get_data_and_distance_matrix(data_path, binding_num, prefix, header=header)

    for algorithm, param_grid in clusterers.iteritems():
        print "%s\tfor %s with %s bindings: %s" % (algorithm.__name__, prefix, binding_num, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
        if not os.path.exists(os.path.join(prefix + "Results", algorithm.__name__)):
            os.makedirs(os.path.join(prefix + "Results", algorithm.__name__))

        best, inertias, silhouettes = grid_search(algorithm, param_grid, data, dist, binding_num, prefix)
        if best is not None:
            plot.visualize_clusters(algorithm.__name__, data, best.labels, binding_num, prefix, angle_cols(binding_num), best.inertia, best.silhouette, result_folder=os.path.join(prefix + "Results", algorithm.__name__), result_filename=prefix + algorithm.__name__ + "_" + str(binding_num) + "bindings_best")
            plot.create_cluster_evaluation_plots(algorithm.__name__, binding_num, prefix, inertias, silhouettes, result_folder=os.path.join(prefix + "Results", algorithm.__name__), result_filename=prefix + algorithm.__name__ + "_" + str(binding_num))

    print "Finished clustering %s with %s bindings: %s" % (prefix, binding_num, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
