# coding: utf-8
__author__ = 'Marcin Kowiel, Dariusz Brzezinski'

import os
import numpy as np
import pandas as pd
from unittest import TestCase
import shutil

import calculate_distances as cc

TEST_DATA_FOLDER = os.path.join(os.path.dirname(__file__), "TestData")
TEST_FILE_PATH_1 = os.path.join(TEST_DATA_FOLDER, "test1.csv")

class TestEvaluator(TestCase):
    def setUp(self):
        self.df1 = pd.read_csv(TEST_FILE_PATH_1, sep=",", header=None)

    def tearDown(self):
        pass

    def test_angle_cols_3(self):
        self.assertEqual(cc.angle_cols(3), range(10, 13))
        self.assertEqual(cc.angle_cols(3, ordering=range(3)), range(10, 13))
        self.assertEqual(cc.angle_cols(3, ordering=range(2, -1, -1)), [12, 11, 10])
        self.assertEqual(cc.angle_cols(3, ordering=[1, 2, 0]), [12, 10, 11])

    def test_binding_cols_3(self):
        self.assertEqual(cc.binding_cols(3), range(7, 10))
        self.assertEqual(cc.binding_cols(3, ordering=range(3)), range(7, 10))
        self.assertEqual(cc.binding_cols(3, ordering=range(2, -1, -1)), range(9, 6, -1))
        self.assertEqual(cc.binding_cols(3, ordering=[1, 2, 0]), [8, 9, 7])

    def test_element_cols_3(self):
        self.assertEqual(cc.element_cols(3), range(4, 7))
        self.assertEqual(cc.element_cols(3, ordering=range(3)), range(4, 7))
        self.assertEqual(cc.element_cols(3, ordering=range(2, -1, -1)), range(6, 3, -1))
        self.assertEqual(cc.element_cols(3, ordering=[1, 2, 0]), [5, 6, 4])

    def test_angle_cols_4(self):
        self.assertEqual(cc.angle_cols(4), range(13, 19))
        self.assertEqual(cc.angle_cols(4, ordering=range(4)), range(13, 19))
        self.assertEqual(cc.angle_cols(4, ordering=range(3, -1, -1)), [18, 17, 15, 16, 14, 13])
        self.assertEqual(cc.angle_cols(4, ordering=[1, 3, 2, 0]), [17, 16, 13, 18, 15, 14])

    def test_binding_cols_4(self):
        self.assertEqual(cc.binding_cols(4), range(9, 13))
        self.assertEqual(cc.binding_cols(4, ordering=range(4)), range(9, 13))
        self.assertEqual(cc.binding_cols(4, ordering=range(3, -1, -1)), range(12, 8, -1))
        self.assertEqual(cc.binding_cols(4, ordering=[1, 3, 2, 0]), [10, 12, 11, 9])

    def test_element_cols_4(self):
        self.assertEqual(cc.element_cols(4), range(5, 9))
        self.assertEqual(cc.element_cols(4, ordering=range(4)), range(5, 9))
        self.assertEqual(cc.element_cols(4, ordering=range(3, -1, -1)), range(8, 4, -1))
        self.assertEqual(cc.element_cols(4, ordering=[1, 3, 2, 0]), [6, 8, 7, 5])

    def test_angle_cols_5(self):
        self.assertEqual(cc.angle_cols(5), range(16, 26))
        self.assertEqual(cc.angle_cols(5, ordering=range(5)), range(16, 26))
        self.assertEqual(cc.angle_cols(5, ordering=range(4, -1, -1)), [25, 24, 22, 19, 23, 21, 18, 20, 17, 16])
        self.assertEqual(cc.angle_cols(5, ordering=[1, 3, 4, 2, 0]), [21, 22, 20, 16, 25, 23, 18, 24, 19, 17])

    def test_binding_cols_5(self):
        self.assertEqual(cc.binding_cols(5), range(11, 16))
        self.assertEqual(cc.binding_cols(5, ordering=range(5)), range(11, 16))
        self.assertEqual(cc.binding_cols(5, ordering=range(4, -1, -1)), range(15, 10, -1))
        self.assertEqual(cc.binding_cols(5, ordering=[1, 3, 4, 2, 0]), [12, 14, 15, 13, 11])

    def test_element_cols_5(self):
        self.assertEqual(cc.element_cols(5), range(6, 11))
        self.assertEqual(cc.element_cols(5, ordering=range(5)), range(6, 11))
        self.assertEqual(cc.element_cols(5, ordering=range(4, -1, -1)), range(10, 5, -1))
        self.assertEqual(cc.element_cols(5, ordering=[1, 3, 4, 2, 0]), [7, 9, 10, 8, 6])

    def test_find_best_permutation(self):
        c_cols = cc.angle_cols(4)
        self.assertEqual(cc.find_best_permutation(self.df1.iloc[0, c_cols], self.df1.iloc[1], 4)[0], range(4))
        self.assertEqual(cc.find_best_permutation(self.df1.iloc[0, c_cols], self.df1.iloc[2], 4)[0], [3, 0, 1, 2])
        self.assertEqual(cc.find_best_permutation(self.df1.iloc[0, c_cols], self.df1.iloc[3], 4)[0], [1, 3, 2, 0])
        self.assertEqual(cc.find_best_permutation(self.df1.iloc[3, c_cols], self.df1.iloc[4], 4)[0], [0, 1, 2, 3])

    def test_binding_distance(self):
        b_cols = cc.binding_cols(4)
        self.assertEqual(cc.binding_distance(self.df1.iloc[0, b_cols], self.df1.iloc[1], 4, range(4)), 0)
        self.assertEqual(cc.binding_distance(self.df1.iloc[0, b_cols], self.df1.iloc[2], 4, [3, 0, 1, 2]), 0)
        self.assertEqual(cc.binding_distance(self.df1.iloc[0, b_cols], self.df1.iloc[3], 4, [1, 3, 2, 0]), 0.40435133238311455)
        self.assertEqual(cc.binding_distance(self.df1.iloc[3, b_cols], self.df1.iloc[4], 4, [0, 1, 2, 3]), 0.057445626465380449)

    def test_element_distance(self):
        e_cols = cc.element_cols(4)
        self.assertEqual(cc.element_distance(self.df1.iloc[0, e_cols], self.df1.iloc[1], 4, range(4)), 0)
        self.assertEqual(cc.element_distance(self.df1.iloc[0, e_cols], self.df1.iloc[2], 4, [3, 0, 1, 2]), 0)
        self.assertEqual(cc.element_distance(self.df1.iloc[0, e_cols], self.df1.iloc[3], 4, [1, 3, 2, 0]), 0.75)
        self.assertEqual(cc.element_distance(self.df1.iloc[3, e_cols], self.df1.iloc[4], 4, [0, 1, 2, 3]), 0)

    def test_coordination_distance(self):
        c_cols = cc.angle_cols(4)
        self.assertEqual(cc.coordination_distance(self.df1.iloc[0, c_cols], self.df1.iloc[1], 4, range(4)), 0)
        self.assertEqual(cc.coordination_distance(self.df1.iloc[0, c_cols], self.df1.iloc[2], 4, [3, 0, 1, 2]), 0)
        self.assertEqual(cc.coordination_distance(self.df1.iloc[0, c_cols], self.df1.iloc[3], 4, [1, 3, 2, 0]), 19.631029723374176)
        self.assertEqual(cc.coordination_distance(self.df1.iloc[3, c_cols], self.df1.iloc[4], 4, [0, 1, 2, 3]), 2.9024982342802463)

    def test_compute_distance_matrices(self):
        c_correct = [[  0.        ,   0.        ,   0.        ,  19.63102972,  19.58998923,   15.75511298],
                     [  0.        ,   0.        ,   0.        ,  19.63102972,  19.58998923,   15.75511298],
                     [  0.        ,   0.        ,   0.        ,  19.63102972,  19.58998923,   15.75511298],
                     [ 19.63102972,  19.63102972,  19.63102972,   0.        ,   2.90249823,    9.72263354],
                     [ 19.58998923,  19.58998923,  19.58998923,   2.90249823,   0.        ,   7.83013327],
                     [ 15.75511298,  15.75511298,  15.75511298,   9.72263354,   7.83013327,   0.        ]]

        b_correct = [[ 0.        ,  0.        ,  0.        ,  0.40435133,  0.36742346,  0.3391165 ],
                     [ 0.        ,  0.        ,  0.        ,  0.40435133,  0.36742346,  0.3391165 ],
                     [ 0.        ,  0.        ,  0.        ,  0.40435133,  0.36742346,  0.3391165 ],
                     [ 0.40435133,  0.40435133,  0.40435133,  0.        ,  0.05744563,  0.43600459],
                     [ 0.36742346,  0.36742346,  0.36742346,  0.05744563,  0.        ,  0.41448764],
                     [ 0.3391165 ,  0.3391165 ,  0.3391165 ,  0.43600459,  0.41448764,  0.        ]]

        e_correct = [[0.00, 0.00, 0.00, 0.75, 0.75, 0.75],
                     [0.00, 0.00, 0.00, 0.75, 0.75, 0.75],
                     [0.00, 0.00, 0.00, 0.75, 0.75, 0.75],
                     [0.75, 0.75, 0.75, 0.00, 0.00, 0.75],
                     [0.75, 0.75, 0.75, 0.00, 0.00, 0.75],
                     [0.75, 0.75, 0.75, 0.75, 0.75, 0.00]]

        c, b, e, o = cc.compute_distance_matrices(self.df1, 4)

        self.assertTrue((c == c_correct).any())  # ;/
        self.assertTrue((b == b_correct).any())  # ;/
        self.assertTrue((e == e_correct).all())

    def test_compute_distance_matrices_parallel(self):
        c_correct = [[  0.        ,   0.        ,   0.        ,  19.63102972,  19.58998923,   15.75511298],
                     [  0.        ,   0.        ,   0.        ,  19.63102972,  19.58998923,   15.75511298],
                     [  0.        ,   0.        ,   0.        ,  19.63102972,  19.58998923,   15.75511298],
                     [ 19.63102972,  19.63102972,  19.63102972,   0.        ,   2.90249823,    9.72263354],
                     [ 19.58998923,  19.58998923,  19.58998923,   2.90249823,   0.        ,   7.83013327],
                     [ 15.75511298,  15.75511298,  15.75511298,   9.72263354,   7.83013327,   0.        ]]

        b_correct = [[ 0.        ,  0.        ,  0.        ,  0.40435133,  0.36742346,  0.3391165 ],
                     [ 0.        ,  0.        ,  0.        ,  0.40435133,  0.36742346,  0.3391165 ],
                     [ 0.        ,  0.        ,  0.        ,  0.40435133,  0.36742346,  0.3391165 ],
                     [ 0.40435133,  0.40435133,  0.40435133,  0.        ,  0.05744563,  0.43600459],
                     [ 0.36742346,  0.36742346,  0.36742346,  0.05744563,  0.        ,  0.41448764],
                     [ 0.3391165 ,  0.3391165 ,  0.3391165 ,  0.43600459,  0.41448764,  0.        ]]

        e_correct = [[0.00, 0.00, 0.00, 0.75, 0.75, 0.75],
                     [0.00, 0.00, 0.00, 0.75, 0.75, 0.75],
                     [0.00, 0.00, 0.00, 0.75, 0.75, 0.75],
                     [0.75, 0.75, 0.75, 0.00, 0.00, 0.75],
                     [0.75, 0.75, 0.75, 0.00, 0.00, 0.75],
                     [0.75, 0.75, 0.75, 0.75, 0.75, 0.00]]

        data_path = os.path.join(TEST_DATA_FOLDER, 'TestRows')
        if os.path.exists(data_path):
            shutil.rmtree(data_path)
        c, b, e, o = cc.compute_distance_matrices(self.df1, 4, 2, data_path)

        self.assertTrue((c == c_correct).any())  # ;/
        self.assertTrue((b == b_correct).any())  # ;/
        self.assertTrue((e == e_correct).all())
