library(dplyr)

wkNN <- list(
    label = "Weighted k-Nearest Neighbors",
    library = NULL,
    type = "Classification",
    parameters = data.frame(parameter = c("k"),
                            class = c("numeric"),
                            label = c("#Neighbors")),
    grid = function (x, y, len = NULL, search = "grid") {
        if (search == "grid") {
            out <- data.frame(k = (5:((2 * len) + 4))[(5:((2 * len) + 4))%%2 > 0])
        }
        else {
            #by_val <- if (is.factor(y)) length(levels(y)) else 1
            #out <- data.frame(k = sample(seq(1, floor(nrow(x)/10), by = by_val), size = len, replace = TRUE))
            out <- data.frame(k = 1:len)
        }
        out
    },
    fit = function(x, y, wts, param, lev, last, weights, classProbs, ...) {
        
        # x <- dataset[, -ncol(dataset)]
        # y <- dataset[, ncol(dataset)]
        # param <- list()
        # wts <- rep(1, nrow(dataset))
        # param$k <- 5
        
        result <- list()
        class(result) <- "wkNN"
        
        result$x <- x
        result$y <- y
        if(is.null(wts)) {
            result$wts <- rep(1, nrow(x))
        }
        else {
            result$wts <- wts
        }
        result$k <- param$k
        
        result
    },
    predict = function(modelFit, newdata, preProc = NULL, submodels = NULL) {
        
        # modelFit <- result
        # newdata <- dataset[, -ncol(dataset)]
        
        x <- modelFit$x
        y <- modelFit$y
        wts <- modelFit$wts
        k <- modelFit$k
        
        distances <- cbind(expand.grid(1:nrow(newdata), 1:nrow(x)), data.frame(rep(0, nrow(newdata) * nrow(x))))
        colnames(distances) <- c("new", "train", "distance")
        
        tmp.counts <- rep(ncol(newdata), nrow(newdata) * nrow(x))
        
        for (i in 1:(ncol(newdata))) {
            if(is.numeric(newdata[, i])) {
                tmp <- expand.grid(newdata[, i], x[, i])
                tmp.dist <- (abs(tmp[, 1] - tmp[, 2]) - min(abs(tmp[, 1] - tmp[, 2]))) / (max(abs(tmp[, 1] - tmp[, 2])) - min(abs(tmp[, 1] - tmp[, 2])))
            }
            else {
                tmp.bound <- c(newdata[, i], x[, i])
                tmp.new <- tmp.bound[1:nrow(newdata)]
                tmp.train <- tmp.bound[(nrow(newdata) + 1):(nrow(newdata) + nrow(x))]
                tmp <- expand.grid(tmp.new, tmp.train)
                tmp.dist <- (tmp[, 1] != tmp[, 2]) * 1
            }
         
            tmp.counts[is.na(tmp.dist)] <- tmp.counts[is.na(tmp.dist)] - 1
            tmp.dist[is.na(tmp.dist)] <- 0;
            
            distances$distance <- distances$distance + tmp.dist
        }
        
        distances$distance <- distances$distance / tmp.counts
        
        bestClass <- function(dists, c, w) {
            
            # dists <- distances %>% filter(new == 1)
            # c <- y
            # w <- wts
            
            if (w == "distance") {
                tmp <- data.frame(cbind.data.frame(dists$distance, c[dists$train]))
                colnames(tmp) <- c("distance", "class")
                
                result <- tmp %>%
                    top_n(k, -distance) %>%
                    group_by(class) %>%
                    dplyr::summarise(score = sum(1 - distance)) %>%
                    top_n(1, score) %>%
                    dplyr::select(class)
            }
            else {
                tmp <- data.frame(cbind.data.frame(dists$distance, c[dists$train], w[dists$train]))
                colnames(tmp) <- c("distance", "class", "weight")
                
                result <- tmp %>%
                    top_n(k, -distance) %>%
                    group_by(class) %>%
                    dplyr::summarise(score = sum(weight)) %>%
                    top_n(1, score) %>%
                    dplyr::select(class)
            }
            
            result[[1, 1]]
        }
        
        best <- distances %>%
            group_by(new) %>%
            do(classes = bestClass(., y, wts)) %>%
            dplyr::select(classes)
        
        unlist(best$classes)
    },
    prob = NULL,
    sort = NULL)
