# coding: utf-8

import datetime

import os
import numpy as np
import itertools as it
import pandas as pd

from sklearn.externals import joblib

__author__ = "Marcin Kowiel, Dariusz Brzezinski"


def angle_cols(binding_num, ordering=None):
    """
    Returns the indexes of columns that are associated with angles. Since these indexes depend on the number of bindings
    a separate function is needed. If the indexes should be reordered, the "orderings" parameter can be set to a list
    with values from 0 to binding_num - 1.
    :param binding_num: number of bindings
    :param ordering: order of bindings
    :return: reordered indexes of angle columns
    """
    if ordering is not None:
        indexes = []

        two_binding_num = 2*binding_num
        for ordering_i, ordering_j in it.combinations(ordering, 2):
            first = min(ordering_i, ordering_j)
            second = max(ordering_i, ordering_j)
            index = (first+3)*(two_binding_num - first)/2 + second
            indexes.append(index)
        return indexes

    else:
        return range(3*binding_num + 1, 3*binding_num + 1 + (binding_num*(binding_num-1)/2))


def binding_cols(binding_num, ordering=None):
    """
    Returns the indexes of columns that are associated with binding lengths. If the indexes should be reordered, the
    "orderings" parameter can be set to a list with values from 0 to binding_num - 1.
    :param binding_num: number of bindings
    :param ordering: order of bindings
    :return: reordered indexes of binding length columns
    """
    two_binding_num_plus_one = (2*binding_num + 1)
    if ordering is not None:
        return [x+two_binding_num_plus_one for x in ordering]
    else:
        return range(two_binding_num_plus_one, two_binding_num_plus_one + binding_num)


def element_cols(binding_num, ordering=None):
    """
    Returns the indexes of columns that are associated with binding elements. If the indexes should be reordered, the
    "orderings" parameter can be set to a list with values from 0 to binding_num - 1.
    :param binding_num: number of bindings
    :param ordering: order of bindings
    :return: reordered indexes of binding element columns
    """
    binding_num_plus_one = (binding_num + 1)
    if ordering is not None:
        return [x+binding_num_plus_one for x in ordering]
    else:
        return range(binding_num_plus_one, binding_num_plus_one+binding_num)


def distance_euclidean(x, y):
    """
    Computes euclidean distance of two vectors
    :param x: first vector
    :param y: second vector
    :return: euclidean distance
    """
    diff = x.values-y.values
    return np.sqrt((diff*diff).sum())


def find_best_permutation(x_cols, y, binding_num, feasible_permutations=None):
    """
    Finds the binding permutation of y that best corresponds to angles in x.
    :param x: first example
    :param y: second example
    :param binding_num: number of bindings in each example (binding numbers must match for both examples)
    :param feasible_permutations: list of feasible permutations; if None all possible permutations will be considered
    :return: the best permutation order for y (this function does not change y, only shows the best order)
    """
    if feasible_permutations is None:
        feasible_permutations = it.permutations(range(binding_num))
        ordering = range(binding_num)
    else:
        ordering = feasible_permutations[0]

    smallest_distance = distance_euclidean(x_cols, y.iloc[angle_cols(binding_num, ordering)])

    for p in feasible_permutations:
        p_distance = distance_euclidean(x_cols, y.iloc[angle_cols(binding_num, p)])

        if p_distance < smallest_distance:
            ordering = list(p)
            smallest_distance = p_distance

    return ordering, smallest_distance


def coordination_distance(x_cols, y, binding_num, ordering):
    """
    Calculates the coordination distance, i.e., a distance measure based on the differences between corresponding
    angles.
    :param x: first example
    :param y: second example
    :param binding_num: number of bindings
    :param ordering: order of bindings in the second example that corresponds to the bindings in first example.
    :return: the coordination distance
    """
    return distance_euclidean(x_cols, y.iloc[angle_cols(binding_num, ordering)])


def binding_distance(x_cols, y, binding_num, ordering):
    """
    Calculates the euclidean distance between corresponding binding lengths.
    angles.
    :param x: first example
    :param y: second example
    :param binding_num: number of bindings
    :param ordering: order of bindings in the second example that corresponds to the bindings in first example.
    :return: the binding distance
    """
    return distance_euclidean(x_cols, y.iloc[binding_cols(binding_num, ordering)])


def element_distance(x_cols, y, binding_num, ordering):
    """
    Calculates the Jaccard distance between corresponding binding elements.
    angles.
    :param x: first example
    :param y: second example
    :param binding_num: number of bindings
    :param ordering: order of bindings in the second example that corresponds to the bindings in first example.
    :return: the element distance; value between 0 and 1
    """
    return 1 - np.average(x_cols == y.iloc[element_cols(binding_num, ordering)])


def print_speed_info(start_time, last_time, all_count, all_count_in_row, left, binding_num, i, max_jobs):
    """
    Print info about speed and predicted time of computation
    :param start_time:
    :param last_time: last row time
    :param all_count: number of elements to compute
    :param all_count_in_row: number of elements in the row
    :param left: number of elements left
    :param binding_num: number of bindings
    :param i: index of the row
    :param max_jobs: number of concurrent jobs
    :return: None
    """
    time_diff = datetime.datetime.now() - last_time
    average_speed = all_count_in_row/max(time_diff.total_seconds(), 0.001)
    time_left = left/average_speed if average_speed != 0.0 else 0
    time_left = time_left/max(max_jobs, 1)
    days = (int(time_left/(24*3600)))
    hours = (int(time_left/3600.0) % (24))
    minutes = (int(time_left/60.0) % (60))
    seconds = (int(time_left) % (60))
    time_left_msg = "{0} d {1} h {2} m {3} s".format(days, hours, minutes, seconds)
    print("BIND={0}({8}): {1} of {2} ({3:.2f}%) comparisons left, {4:.2f} comparisons per second {5} elapsed, {6} s left ({7}) with {9} jobs".format(
        binding_num, left, all_count, float(100.0*left)/max(all_count, 1), average_speed, str(time_diff), int(time_left), time_left_msg, i, max_jobs)
    )


def compute_distance_row(i, n, data, c_cols, b_cols, e_cols, binding_num, dataset_path, max_jobs, feasible_permutations=None):
    """
    Computes one row in the distance matrices and saves them to a file
    :param i: row number
    :param n: number of examples in the dataset (and thus columns in the distance matrices)
    :param data: dataset
    :param c_cols: angle column indices in the dataset
    :param b_cols: binding length column indices in the dataset
    :param e_cols: element column indices in the dataset
    :param binding_num: number of bindings
    :param dataset_path: path to dataset
    :param max_jobs: maximum number of worker threads
    :param feasible_permutations: list of feasible permutations; if None all possible permutations will be considered
    """
    coordination_row_filename = os.path.join(dataset_path, "row_%d.cdr" % i)
    binding_row_filename = os.path.join(dataset_path, "row_%d.bdr" % i)
    element_row_filename = os.path.join(dataset_path, "row_%d.edr" % i)
    order_row_filename = os.path.join(dataset_path, "row_%d.or" % i)

    if not os.path.exists(coordination_row_filename):
        start_time = datetime.datetime.now()

        # assume all previous are done
        if n > 1:
            all_count = n*(n-1)/2
        else:
            all_count = n
        done_count = all_count - (n-i)*(n-i-1)/2

        # prepare array
        coordination_row = np.zeros(n)
        binding_row = np.zeros(n)
        element_row = np.zeros(n)
        order_row = [range(binding_num) for x in range(n)]

        # extract data
        example_i_c = data.iloc[i, c_cols]
        example_i_b = data.iloc[i, b_cols]
        example_i_e = data.iloc[i, e_cols]

        # calculate
        for j in xrange(i+1, n):
            example_j = data.iloc[j]

            ordering, coordination_row[j] = find_best_permutation(example_i_c, example_j, binding_num, feasible_permutations)
            binding_row[j] = distance_euclidean(example_i_b, example_j.iloc[binding_cols(binding_num, ordering)])
            element_row[j] = element_distance(example_i_e, example_j, binding_num, ordering)
            order_row[j] = ordering

        # print speed info
        all_count_in_row = n-i-1
        left = all_count - done_count - all_count_in_row
        print_speed_info(start_time, start_time, all_count, all_count_in_row, left, binding_num, i, max_jobs)

        # da
        joblib.dump(coordination_row, coordination_row_filename)
        joblib.dump(binding_row, binding_row_filename)
        joblib.dump(element_row, element_row_filename)
        joblib.dump(order_row, order_row_filename)


def join_rows(dataset_path, n):
    """
    Read row data from the disc and join it to create a matrix
    """
    coordination_dm = np.zeros((n, n))
    binding_dm = np.zeros((n, n))
    element_dm = np.zeros((n, n))
    order_dm = [[None for x in range(n)] for x in range(n)]
    order_dm_warning_displayed = False

    for i in xrange(n):
        coordination_row_filename = os.path.join(dataset_path, "row_%d.cdr" % i)
        binding_row_filename = os.path.join(dataset_path, "row_%d.bdr" % i)
        element_row_filename = os.path.join(dataset_path, "row_%d.edr" % i)
        order_row_filename = os.path.join(dataset_path, "row_%d.or" % i)

        coordination = joblib.load(coordination_row_filename)
        binding = joblib.load(binding_row_filename)
        element = joblib.load(element_row_filename)
        try:
            ordering = joblib.load(order_row_filename)
        except:
            ordering = [None for x in range(n)]
            if not order_dm_warning_displayed:
                order_dm_warning_displayed = True
                print "Warning: Missing ordering rows for " + dataset_path

        coordination_dm[i] = coordination
        binding_dm[i] = binding
        element_dm[i] = element
        order_dm[i] = ordering

    return coordination_dm, binding_dm, element_dm, order_dm


def compute_distance_matrices(data, binding_num, max_jobs=1, dataset_path='', feasible_permutations=None):
    """
    Creates three distance matrices that can be later used to compute coordination, binding, and element distances,
    repsectively.
    :param data: a dataset
    :param binding_num: number of bindings in each example
    :param dataset_path: path to dataset
    :param feasible_permutations: list of feasible permutations; if None all possible permutations will be considered
    :return: the coordination, binding, and element distance matrices
    """
    n = data.shape[0]
    coordination_dm = np.zeros((n, n))
    binding_dm = np.zeros((n, n))
    element_dm = np.zeros((n, n))
    order_dm = [[range(binding_num) for x in range(n)] for x in range(n)]

    start_time = datetime.datetime.now()
    last_time = start_time
    print "BIND=%s Start: %s" % (binding_num, start_time)
    done_count = 0

    if n > 1:
        all_count = n*(n-1)/2
    else:
        all_count = n

    c_cols = angle_cols(binding_num)
    b_cols = binding_cols(binding_num)
    e_cols = element_cols(binding_num)

    # sequential
    if max_jobs == 1:
        for i in xrange(n):
            example_i_c = data.iloc[i, c_cols]
            example_i_b = data.iloc[i, b_cols]
            example_i_e = data.iloc[i, e_cols]

            for j in xrange(i+1, n):
                example_j = data.iloc[j]

                ordering, coordination_dm[i, j] = find_best_permutation(example_i_c, example_j, binding_num, feasible_permutations)
                binding_dm[i, j] = distance_euclidean(example_i_b, example_j.iloc[binding_cols(binding_num, ordering)])
                element_dm[i, j] = element_distance(example_i_e, example_j, binding_num, ordering)
                order_dm[i][j] = ordering

            all_count_in_row = n-i-1
            left = all_count - done_count - all_count_in_row
            print_speed_info(start_time, last_time, all_count, all_count_in_row, left, binding_num, i, max_jobs)
            done_count += all_count_in_row
            last_time = datetime.datetime.now()
    # parallel
    else:
        if not os.path.exists(dataset_path):
            os.makedirs(dataset_path)

        joblib.Parallel(n_jobs=max_jobs)(joblib.delayed(compute_distance_row)(i, n, data, c_cols, b_cols, e_cols, binding_num, dataset_path, max_jobs, feasible_permutations) for i in xrange(n))
        coordination_dm, binding_dm, element_dm, order_dm = join_rows(dataset_path, n)

    # fill out symmetrical part of the matrix
    for i in xrange(n):
        for j in xrange(i+1, n):
            coordination_dm[j, i] = coordination_dm[i, j]
            binding_dm[j, i] = binding_dm[i, j]
            element_dm[j, i] = element_dm[i, j]
            order_dm[j][i] = order_dm[i][j]

    return coordination_dm, binding_dm, element_dm, order_dm


def compute_and_save_distance_matrix(prefix, binding_num, data_folder, max_jobs=1, header=None, feasible_permutations=None):
    """
    Compute and save matrix to disc
    """
    dataset_name = prefix + str(binding_num) + ".csv"
    dataset_path = os.path.join(data_folder, dataset_name)
    data_rows_path = os.path.join(data_folder, str(binding_num), 'rows')

    if not os.path.exists(data_folder):
        os.makedirs(data_folder)

    data = pd.read_csv(dataset_path, sep=",", header=header)
    c, b, e, o = compute_distance_matrices(data, binding_num, max_jobs, data_rows_path, feasible_permutations)

    data_rows_path = os.path.join(data_folder, str(binding_num))

    joblib.dump(c, os.path.join(data_rows_path, dataset_name + ".cdm"))
    joblib.dump(b, os.path.join(data_rows_path, dataset_name + ".bdm"))
    joblib.dump(e, os.path.join(data_rows_path, dataset_name + ".edm"))
    joblib.dump(o, os.path.join(data_rows_path, dataset_name + ".odm"))
