Program Listing for File GWDA.h

Return to documentation for file (include/gwmodelpp/GWDA.h)

#ifndef GWDA_H
#define GWDA_H

#include "SpatialMonoscaleAlgorithm.h"
#include "IMultivariableAnalysis.h"
#include "IParallelizable.h"


namespace gwm
{

    //template<class T>
    class GWDA : public SpatialMonoscaleAlgorithm, public IMultivariableAnalysis, public IParallelizable, public IParallelOpenmpEnabled
    {
    public:

        static arma::mat covwtmat(const arma::mat &x, const arma::vec &wt);

        static double covwt(const arma::mat &x1, const arma::mat &x2, const arma::vec &w)
        {
            return sum((sqrt(w) % (x1 - sum(x1 % w))) % (sqrt(w) % (x2 - sum(x2 % w)))) / (1 - sum(w % w));
        }

        static double corwt(const arma::mat &x1, const arma::mat &x2, const arma::vec &w)
        {
            return covwt(x1, x2, w) / sqrt(covwt(x1, x1, w) * covwt(x2, x2, w));
        }

        typedef void (GWDA::*DiscriminantAnalysisCalculator)();

    public:
        GWDA() {}

        GWDA(const arma::mat x, const arma::mat coords, const SpatialWeight &spatialWeight)
            : SpatialMonoscaleAlgorithm(spatialWeight, coords)
        {
            mX = x;
        }

        ~GWDA() {}


        bool isWqda() const { return mIsWqda; }

        void setIsWqda(bool iswqda)
        {
            mIsWqda = iswqda;
        }

        bool hasCov() const { return mHascov; }

        void setHascov(bool hascov)
        {
            mHascov = hascov;
        }

        bool hasMean() const { return mHasmean; }

        void setHasmean(bool hasmean)
        {
            mHasmean = hasmean;
        }

        bool hasPrior() const { return mHasprior; }

        void setHasprior(bool hasprior)
        {
            mHasprior = hasprior;
        }

        double correctRate() const { return mCorrectRate; }

        const arma::mat& res() const { return mRes; }

        const std::vector<std::string>& group() const { return mGroup; }

        const arma::mat& probs() const { return mProbs; }

        const arma::mat& pmax() const { return mPmax; }

        const arma::mat& entropy() const { return mEntropy; }

        arma::mat wqda(arma::mat &x, std::vector<std::string> &y, arma::mat &wt, arma::mat &xpr, bool hasCOv, bool hasMean, bool hasPrior);

        arma::mat wlda(arma::mat &x, std::vector<std::string> &y, arma::mat &wt, arma::mat &xpr, bool hasCOv, bool hasMean, bool hasPrior);

        std::vector<arma::mat> splitX(arma::mat &x, std::vector<std::string> &y);

        arma::mat wMean(arma::mat &x, arma::mat &wt);

        arma::cube wVarCov(arma::mat &x, arma::mat &wt);

        arma::vec wPrior(arma::mat &wt, double sumW);

        //arma::mat confusionMatrix(arma::mat &origin, arma::mat &classified);

        std::vector<std::string> levels(std::vector<std::string> &y);

        double shannonEntropy(arma::vec &p);

        arma::uvec findSameString(std::vector<std::string> &y,std::string s);

        std::unordered_map<std::string,arma::uword> ytable(std::vector<std::string> &y);

    public: // SpatialMonoscaleAlgorithm interface
        bool isValid() override;

    public: // IMultivariableAnalysis
        const arma::mat& variables() const override { return mX; }
        void setVariables(const arma::mat &x) override { mX = x; }
        void setGroup(std::vector<std::string> &y) { mY = y; }
        void run() override;

    public: // IParallelizable
        int parallelAbility() const override
        {
            return ParallelType::SerialOnly
#ifdef ENABLE_OPENMP
                   | ParallelType::OpenMP
#endif
                ;
        }
        ParallelType parallelType() const override { return mParallelType; }

        void setParallelType(const ParallelType &type) override;

    public: // IParallelOpenmpEnabled
        void setOmpThreadNum(const int threadNum) override { mOmpThreadNum = threadNum; }

    private:
        void discriminantAnalysisSerial();

#ifdef ENABLE_OPENMP
        void discriminantAnalysisOmp();
#endif

    private:
        bool mIsWqda = false;
        bool mHascov = true;
        bool mHasmean = true;
        bool mHasprior = true;

        double mCorrectRate = 0;

        arma::mat mX;
        std::vector<std::string> mY;
        bool mHasPredict;
        arma::mat mprX;
        std::vector<std::string> mprY;
        arma::mat mRes; // !< \~english the result matrix of geographical weighted discriminant analysis \~chinese 地理加权判别分析结果矩阵
        std::vector<std::string> mGroup;
        arma::mat mProbs;
        arma::mat mPmax;
        arma::mat mEntropy;

        DiscriminantAnalysisCalculator mDiscriminantAnalysisFunction = &GWDA::discriminantAnalysisSerial;

        ParallelType mParallelType = ParallelType::SerialOnly;
        int mOmpThreadNum = 8;
    };

}

#endif // GWDA_H