/*
 * Copyright information removed for SIGGRAPH anonymous review process.
 * This file is licensed under the GNU General Public License.
 *
 * bivariate.h: classes to represent bivariate functions, functions bases and
 *   robust fitting (based on M-estimator). Everything is template-parameterized
 *   by the function basis (implemented up to degree 3 bivariate polynomials).
 *   It is very easy to add a new function basis in the system. This just requires
 *   to implement one single function: 
 *      BivariateXXXBasis::bases(double x, double y, double* a).
 *   This function corresponds to the vectors Phi in the paper.
 */
 
#ifndef __BIVARIATE__
#define __BIVARIATE__

#include "m_estimator.h"
#include "matrix.h"
#include "geometry.h"

namespace OGF {

    template <int N> class BivariateBasis {
    public:
        enum { dimension = N } ;
    } ;

    class BivariateLinearBasis : public BivariateBasis<3> {
    public:
        static void bases(double x, double y, double* a) {
            a[0] = 1.0 ;
            a[1] = x ; 
            a[2] = y ;
        }
    } ;

    class BivariateQuadraticBasis : public BivariateBasis<6> {
    public:
        static void bases(double x, double y, double* a) {
            a[0] = 1.0 ;
            a[1] = x ;
            a[2] = y ;
            a[3] = x*x ;
            a[4] = x*y ;
            a[5] = y*y ;
        }
    } ;

    class BivariateCubicBasis : public BivariateBasis<10> {
    public:
        static void bases(double x, double y, double* a) {
            a[0] = 1.0 ;
            a[1] = x ;
            a[2] = y ;
            a[3] = x*x ;
            a[4] = x*y ;
            a[5] = y*y ;
            a[6] = a[3] * x ;
            a[7] = a[4] * x ;
            a[8] = a[4] * y ;
            a[9] = a[5] * y ;
        }
    } ;

    template <class BASIS> class BivariateFunction {
    public:
        enum { dimension = BASIS::dimension } ;
        typedef BASIS Basis ;
        typedef BivariateFunction<BASIS> thisclass ;
        
        BivariateFunction() { }

        BivariateFunction(const thisclass& rhs) {
            for(int i=0; i<dimension; i++) {
                coeff[i] = rhs.coeff[i] ;
            }
        }

        thisclass& operator=(const thisclass& rhs) {
            for(int i=0; i<dimension; i++) {
                coeff[i] = rhs.coeff[i] ;
            }
            return *this ;
        }

        double& operator[](int i) { return coeff[i] ; }

        const double& operator[](int i) const { return coeff[i] ; }

        void clear() { Memory::clear(coeff, dimension * sizeof(double)) ; }

        double eval(double x, double y) const {
            double a[dimension] ;
            BASIS::bases(x,y,a) ;
            double result = 0.0 ;
            for(unsigned int i=0; i<dimension; i++) {
                result += coeff[i] * a[i] ;
            }
            return result ;
        }

        double eval(const Vec2& p) const {
            return eval(p.x(), p.y()) ;
        }

        double coeff[dimension] ;
    } ;

    typedef BivariateFunction<BivariateLinearBasis>       BivariateLinearFunction ;
    typedef BivariateFunction<BivariateQuadraticBasis> BivariateQuadraticFunction ;
    typedef BivariateFunction<BivariateCubicBasis>       BivariateCubicFunction ;

    //____________________________________________________________________________________

    template <class FUNC> class BivariateFitting : public MEstimator {
    public:
        typedef FUNC Function ;
        typedef typename FUNC::Basis Basis ;
        enum { dimension = Basis::dimension } ;

        BivariateFitting() : M_mode_(false) { }

        void set_M_mode(bool x) { M_mode_ = x ; }
        void set_reference(const Function& ref) { reference_ = ref ; }

        void begin() {
            M_.load_zero() ;
            for(unsigned int i=0; i<Function::dimension; i++) {  
                b_[i] = 0.0 ;   
            }
        }

        void end() {
            Numeric::solve_SPD_system(M_, b_, x_.coeff) ;
        }

        void add_sample(double x, double y, double g, double importance = 1.0) {
            double v[Function::dimension] ;
            Function::Basis::bases(x,y,v) ;
            if(M_mode_) {
                importance *= weight(reference_.eval(x,y) - g) ;
            }
            add_sample(v,g,importance) ;
        }

        const Function& result() const { return x_ ; }

        Matrix<Function::dimension>& matrix() { return M_ ; }
        const Matrix<Function::dimension>& matrix() const { return M_ ; }
        double* rhs() { return b_ ; }
        const double* rhs() const { return b_ ; }

    protected:

        void add_sample(double* v, double g, double importance) {
            // Since M is symmetric, we compute only its lower triangle
            for(unsigned int i=0; i<Function::dimension; i++) {
                b_[i] += importance * g * v[i] ;
                for(unsigned int j=i; j<Function::dimension; j++) {
                    M_(i,j) += importance * v[i] * v[j] ;
                }
            }
        }

    private:
        bool M_mode_ ;
        Matrix<Function::dimension> M_ ;
        double b_[Function::dimension] ;
        Function x_ ;
        Function reference_ ;
    } ;

    //____________________________________________________________________________________

}

#endif
