Program Listing for File GTWR.h
↰ Return to documentation for file (include/gwmodelpp/GTWR.h
)
#ifndef GTWR_H
#define GTWR_H
#include <utility>
#include <string>
#include <initializer_list>
#include "GWRBase.h"
#include "RegressionDiagnostic.h"
#include "IBandwidthSelectable.h"
#include "IVarialbeSelectable.h"
#include "IParallelizable.h"
#include "spatialweight/CRSSTDistance.h"
#include <gsl/gsl_multimin.h>
#include <gsl/gsl_min.h>
namespace gwm
{
#define GWM_LOG_TAG_LAMBDA_OPTIMIZATION "#lambda-optimization "
class GTWR : public GWRBase, public IBandwidthSelectable, public IParallelizable, public IParallelOpenmpEnabled
{
public:
enum BandwidthSelectionCriterionType
{
AIC,
CV
};
static std::unordered_map<BandwidthSelectionCriterionType, std::string> BandwidthSelectionCriterionTypeNameMapper;
typedef arma::mat (GTWR::*PredictCalculator)(const arma::mat&, const arma::mat&, const arma::vec&);
typedef arma::mat (GTWR::*FitCalculator)(const arma::mat&, const arma::vec&, arma::mat&, arma::vec&, arma::vec&, arma::mat&);
typedef double (GTWR::*BandwidthSelectionCriterionCalculator)(BandwidthWeight*);
typedef double (GTWR::*IndepVarsSelectCriterionCalculator)(const std::vector<std::size_t>&);
static std::string infoLambdaCriterion()
{
return std::string(GWM_LOG_TAG_LAMBDA_OPTIMIZATION) + "lambda,criterion";
}
static std::string infoLambdaCriterion(const double& lambda, const double& criterion)
{
return std::string(GWM_LOG_TAG_LAMBDA_OPTIMIZATION) + std::to_string(lambda) + "," + std::to_string(criterion);
}
private:
static RegressionDiagnostic CalcDiagnostic(const arma::mat& x, const arma::vec& y, const arma::mat& betas, const arma::vec& shat);
public:
GTWR(){};
GTWR(const arma::mat& x, const arma::vec& y, const arma::mat& coords, const SpatialWeight& spatialWeight, bool hasHatMatrix = true, bool hasIntercept = true)
: GWRBase(x, y, spatialWeight, coords)
{
mHasHatMatrix = hasHatMatrix;
mHasIntercept = hasIntercept;
}
~GTWR(){};
// //also unused
// private:
// Weight* mWeight = nullptr; //!< \~english weight pointer. \~chinese 权重指针。
// Distance* mDistance = nullptr; //!< \~english distance pointer. \~chinese 距离指针。
// public:
// arma::vec weightVector(uword focus);//recalculate weight using spatial temporal distance
public:
bool isAutoselectBandwidth() const { return mIsAutoselectBandwidth; }
void setIsAutoselectBandwidth(bool isAutoSelect) { mIsAutoselectBandwidth = isAutoSelect; }
BandwidthSelectionCriterionType bandwidthSelectionCriterion() const { return mBandwidthSelectionCriterion; }
void setBandwidthSelectionCriterion(const BandwidthSelectionCriterionType& criterion);
BandwidthCriterionList bandwidthSelectionCriterionList() const { return mBandwidthSelectionCriterionList; }
bool hasHatMatrix() const { return mHasHatMatrix; }
void setHasHatMatrix(const bool has) { mHasHatMatrix = has; }
// void setTimes(const arma::vec& times)
// {
// vTimes=times;
// }
void setCoords(const arma::mat& coords, const arma::vec& times)
{
mCoords=coords;
vTimes=times;
}
const arma::mat& betasSE() { return mBetasSE; }
const arma::vec& sHat() { return mSHat; }
const arma::vec& qDiag() { return mQDiag; }
const arma::mat& s() { return mS; }
public: // Implement Algorithm
bool isValid() override;
public: // Implement IRegressionAnalysis
arma::mat predict(const arma::mat& locations) override;
arma::mat fit() override;
private:
arma::mat predictSerial(const arma::mat& locations, const arma::mat& x, const arma::vec& y);
arma::mat fitSerial(const arma::mat& x, const arma::vec& y, arma::mat& betasSE, arma::vec& shat, arma::vec& qDiag, arma::mat& S);
#ifdef ENABLE_OPENMP
arma::mat predictOmp(const arma::mat& locations, const arma::mat& x, const arma::vec& y);
arma::mat fitOmp(const arma::mat& x, const arma::vec& y, arma::mat& betasSE, arma::vec& shat, arma::vec& qDiag, arma::mat& S);
#endif
public: // Implement IBandwidthSelectable
Status getCriterion(BandwidthWeight* weight, double& criterion) override
{
criterion = (this->*mBandwidthSelectionCriterionFunction)(weight);
return mStatus;
}
private:
double bandwidthSizeCriterionCVSerial(BandwidthWeight* bandwidthWeight);
double bandwidthSizeCriterionAICSerial(BandwidthWeight* bandwidthWeight);
#ifdef ENABLE_OPENMP
double bandwidthSizeCriterionCVOmp(BandwidthWeight* bandwidthWeight);
double bandwidthSizeCriterionAICOmp(BandwidthWeight* bandwidthWeight);
#endif
public: // Implement IParallelizable
int parallelAbility() const override
{
return ParallelType::SerialOnly
#ifdef ENABLE_OPENMP
| ParallelType::OpenMP
#endif
;
}
ParallelType parallelType() const override { return mParallelType; }
void setParallelType(const ParallelType& type) override;
public: // Implement IGwmParallelOpenmpEnabled
void setOmpThreadNum(const int threadNum) override { mOmpThreadNum = threadNum; }
protected:
bool isStoreS() { return mHasHatMatrix && (mCoords.n_rows < 8192); }
void createPredictionDistanceParameter(const arma::mat& locations);
void createDistanceParameter();
// void LambdaBwAutoSelection();
double LambdaAutoSelection(BandwidthWeight* bandwidthWeight);
Status RsquareByLambda(BandwidthWeight* bandwidthWeight,double lambda, double& rsquare);
public:
void setIsAutoselectLambda(bool isAutoSelect) { mIsAutoselectLambda = isAutoSelect; }
protected:
bool mHasHatMatrix = true;
bool mHasFTest = false;
bool mHasPredict = false;
bool mIsAutoselectBandwidth = false;
bool mIsAutoselectLambda = false;
BandwidthSelectionCriterionType mBandwidthSelectionCriterion = BandwidthSelectionCriterionType::AIC;
BandwidthSelectionCriterionCalculator mBandwidthSelectionCriterionFunction = >WR::bandwidthSizeCriterionCVSerial;
BandwidthCriterionList mBandwidthSelectionCriterionList;
double mBandwidthLastCriterion = DBL_MAX;
PredictCalculator mPredictFunction = >WR::predictSerial;
FitCalculator mFitFunction = >WR::fitSerial;
ParallelType mParallelType = ParallelType::SerialOnly;
int mOmpThreadNum = 8;
arma::mat mBetasSE;
arma::vec mSHat;
arma::vec mQDiag;
arma::mat mS;
arma::vec vTimes;
CRSSTDistance* mStdistance;//use to change spatial temporal distance including lambda
// gsl_function F;
};
}
#endif // GTWR_H