Program Listing for File GTWR.h

Return to documentation for file (include/gwmodelpp/GTWR.h)

#ifndef GTWR_H
#define GTWR_H

#include <utility>
#include <string>
#include <initializer_list>
#include "GWRBase.h"
#include "RegressionDiagnostic.h"
#include "IBandwidthSelectable.h"
#include "IVarialbeSelectable.h"
#include "IParallelizable.h"
#include "spatialweight/CRSSTDistance.h"

#include <gsl/gsl_multimin.h>
#include <gsl/gsl_min.h>

namespace gwm
{

#define GWM_LOG_TAG_LAMBDA_OPTIMIZATION "#lambda-optimization "

class GTWR :  public GWRBase, public IBandwidthSelectable, public IParallelizable, public IParallelOpenmpEnabled
{
public:

    enum BandwidthSelectionCriterionType
    {
        AIC,
        CV
    };

    static std::unordered_map<BandwidthSelectionCriterionType, std::string> BandwidthSelectionCriterionTypeNameMapper;

    typedef arma::mat (GTWR::*PredictCalculator)(const arma::mat&, const arma::mat&, const arma::vec&);
    typedef arma::mat (GTWR::*FitCalculator)(const arma::mat&, const arma::vec&, arma::mat&, arma::vec&, arma::vec&, arma::mat&);

    typedef double (GTWR::*BandwidthSelectionCriterionCalculator)(BandwidthWeight*);
    typedef double (GTWR::*IndepVarsSelectCriterionCalculator)(const std::vector<std::size_t>&);

    static std::string infoLambdaCriterion()
    {
        return std::string(GWM_LOG_TAG_LAMBDA_OPTIMIZATION) + "lambda,criterion";
    }

    static std::string infoLambdaCriterion(const double& lambda, const double& criterion)
    {
        return std::string(GWM_LOG_TAG_LAMBDA_OPTIMIZATION) + std::to_string(lambda) + "," + std::to_string(criterion);
    }

private:

    static RegressionDiagnostic CalcDiagnostic(const arma::mat& x, const arma::vec& y, const arma::mat& betas, const arma::vec& shat);

public:

    GTWR(){};

    GTWR(const arma::mat& x, const arma::vec& y, const arma::mat& coords, const SpatialWeight& spatialWeight, bool hasHatMatrix = true, bool hasIntercept = true)
        : GWRBase(x, y, spatialWeight, coords)
    {
        mHasHatMatrix = hasHatMatrix;
        mHasIntercept = hasIntercept;
    }

    ~GTWR(){};

// //also unused
// private:
//     Weight* mWeight = nullptr;      //!< \~english weight pointer. \~chinese 权重指针。
//     Distance* mDistance = nullptr;  //!< \~english distance pointer. \~chinese 距离指针。
// public:
//    arma::vec weightVector(uword focus);//recalculate weight using spatial temporal distance

public:

    bool isAutoselectBandwidth() const { return mIsAutoselectBandwidth; }

    void setIsAutoselectBandwidth(bool isAutoSelect) { mIsAutoselectBandwidth = isAutoSelect; }

    BandwidthSelectionCriterionType bandwidthSelectionCriterion() const { return mBandwidthSelectionCriterion; }

    void setBandwidthSelectionCriterion(const BandwidthSelectionCriterionType& criterion);

    BandwidthCriterionList bandwidthSelectionCriterionList() const { return mBandwidthSelectionCriterionList; }

    bool hasHatMatrix() const { return mHasHatMatrix; }

    void setHasHatMatrix(const bool has) { mHasHatMatrix = has; }

    // void setTimes(const arma::vec& times)
    // {
    //     vTimes=times;
    // }

    void setCoords(const arma::mat& coords, const arma::vec& times)
    {
        mCoords=coords;
        vTimes=times;
    }

    const arma::mat& betasSE() { return mBetasSE; }

    const arma::vec& sHat() { return mSHat; }

    const arma::vec& qDiag() { return mQDiag; }

    const arma::mat& s() { return mS; }

public:     // Implement Algorithm
    bool isValid() override;

public:     // Implement IRegressionAnalysis
    arma::mat predict(const arma::mat& locations) override;

    arma::mat fit() override;

private:

    arma::mat predictSerial(const arma::mat& locations, const arma::mat& x, const arma::vec& y);

    arma::mat fitSerial(const arma::mat& x, const arma::vec& y, arma::mat& betasSE, arma::vec& shat, arma::vec& qDiag, arma::mat& S);

#ifdef ENABLE_OPENMP
    arma::mat predictOmp(const arma::mat& locations, const arma::mat& x, const arma::vec& y);

    arma::mat fitOmp(const arma::mat& x, const arma::vec& y, arma::mat& betasSE, arma::vec& shat, arma::vec& qDiag, arma::mat& S);
#endif

public:     // Implement IBandwidthSelectable
    Status getCriterion(BandwidthWeight* weight, double& criterion) override
    {
        criterion = (this->*mBandwidthSelectionCriterionFunction)(weight);
        return mStatus;
    }

private:

    double bandwidthSizeCriterionCVSerial(BandwidthWeight* bandwidthWeight);

    double bandwidthSizeCriterionAICSerial(BandwidthWeight* bandwidthWeight);

#ifdef ENABLE_OPENMP
    double bandwidthSizeCriterionCVOmp(BandwidthWeight* bandwidthWeight);

    double bandwidthSizeCriterionAICOmp(BandwidthWeight* bandwidthWeight);
#endif


public:     // Implement IParallelizable
    int parallelAbility() const override
    {
        return ParallelType::SerialOnly
#ifdef ENABLE_OPENMP
            | ParallelType::OpenMP
#endif
        ;
    }

    ParallelType parallelType() const override { return mParallelType; }

    void setParallelType(const ParallelType& type) override;

public:     // Implement IGwmParallelOpenmpEnabled
    void setOmpThreadNum(const int threadNum) override { mOmpThreadNum = threadNum; }

protected:

    bool isStoreS() { return mHasHatMatrix && (mCoords.n_rows < 8192); }

    void createPredictionDistanceParameter(const arma::mat& locations);


    void createDistanceParameter();

    // void LambdaBwAutoSelection();

    double LambdaAutoSelection(BandwidthWeight* bandwidthWeight);

    Status RsquareByLambda(BandwidthWeight* bandwidthWeight,double lambda, double& rsquare);

public:
    void setIsAutoselectLambda(bool isAutoSelect) { mIsAutoselectLambda = isAutoSelect; }

protected:

    bool mHasHatMatrix = true;
    bool mHasFTest = false;
    bool mHasPredict = false;

    bool mIsAutoselectBandwidth = false;
    bool mIsAutoselectLambda = false;

    BandwidthSelectionCriterionType mBandwidthSelectionCriterion = BandwidthSelectionCriterionType::AIC;
    BandwidthSelectionCriterionCalculator mBandwidthSelectionCriterionFunction = &GTWR::bandwidthSizeCriterionCVSerial;
    BandwidthCriterionList mBandwidthSelectionCriterionList;
    double mBandwidthLastCriterion = DBL_MAX;

    PredictCalculator mPredictFunction = &GTWR::predictSerial;
    FitCalculator mFitFunction = &GTWR::fitSerial;

    ParallelType mParallelType = ParallelType::SerialOnly;
    int mOmpThreadNum = 8;

    arma::mat mBetasSE;
    arma::vec mSHat;
    arma::vec mQDiag;
    arma::mat mS;

    arma::vec vTimes;

    CRSSTDistance* mStdistance;//use to change spatial temporal distance including lambda

    // gsl_function F;
};

}

#endif  // GTWR_H