#ifndef WEISZFELD_H
#define WEISZFELD_H

#include <functional>
#include <algorithm>
#include <limits>

#include "Point.hpp"
#include "Metric.hpp"

/**
 * @brief 1-median approximation
 * 
 * Weiszfeld
 */
class Weiszfeld
{
public:
    class IterationFailed {};
    
    Weiszfeld(std::function<Metric<Point>*() > createMetric);

    /**
     * @brief Approximate 1-median
     */
    template<typename ForwardIterator>
    Point approximateOneMedian(ForwardIterator begin, ForwardIterator end, int max_iteration = 15);

private:
    Metric<Point>* metric;
};

template<typename ForwardIterator>
Point Weiszfeld::approximateOneMedian(ForwardIterator begin, ForwardIterator end, int max_iteration)
{
    size_t n = 0;
    double eps = std::numeric_limits<double>::epsilon();
    int dimension = begin->getDimension();
    
    // Compute starting point
    Point cog(dimension);
    double weight = 0;
    for (ForwardIterator it = begin; it != end; ++it)
    {
        ++n;
        weight += it->getWeight();
        cog += weight * (*it);
    }
    cog = (1.0 / weight) * cog;

    // Check for sets with only one element
    if(n == 1)
        return *begin;
    
    bool running = true;
    bool failed = false;
    Point y(cog);
    double lastDist = std::numeric_limits<double>::infinity();
    int iteration = 0;
    while (running && iteration < max_iteration)
    {
        ++iteration;
        // Compute new iteration point
        Point numerator(dimension);
        double denominator = 0;
        for (ForwardIterator it = begin; it != end; ++it)
        {
            double dist = metric->distance(*it, y);
            if (dist <= eps)
            {
                failed = true;
                break;
            }
            double invdist = (1.0 / dist);
            numerator += it->getWeight() * invdist * (*it);
            denominator += it->getWeight() * invdist;
        }
        Point newPoint((1.0 / denominator) * numerator);
        
        // Check for break condition
        double currentDist = metric->distance(y, newPoint);
        if (currentDist / lastDist > 0.10 || currentDist <= eps)
            running = false;
        
        // Update iteration data
        lastDist = currentDist;
        y = newPoint;
    }
    
    if(failed)
        throw IterationFailed();
    else
        return y;
}

#endif
