Program Listing for File GWRBasicGpuTask.h

Return to documentation for file (include/gwmodelcuda/GWRBasicGpuTask.h)

#ifndef GWRBASICGPUTASK
#define GWRBASICGPUTASK

#include <vector>
#include <memory>
#include "armadillo_config.h"

#include "IGWRBasicGpuTask.h"
#include "../gwmodel.h"

class GWRBasicGpuTask : public IGWRBasicGpuTask
{
private:
    arma::mat mX;
    arma::vec mY;
    arma::mat mCoords;
    arma::mat mPredictLocations;
    arma::mat mBetas;
    arma::mat mBetasSE;
    arma::vec mSHat;
    arma::vec mQDiag;
    arma::mat mS;

    gwm::Distance* mDistance = nullptr;
    gwm::Weight* mWeight = nullptr;

    bool mIsOptimizeBandwidth = false;
    gwm::GWRBasic::BandwidthSelectionCriterionType mBandwidthOptimizationCriterion = gwm::GWRBasic::BandwidthSelectionCriterionType::CV;

    bool mIsOptimizeVariables = false;
    double mOptimizeVariablesThreshold = 3.0;

    gwm::RegressionDiagnostic mDiagnostic = {};

    double mOptimizedBandwidth = 0.0;
    std::vector<arma::uword> mSelectedVars;
    gwm::VariablesCriterionList mVariableOptimizationCriterionList;

public:
    GWRBasicGpuTask(int nDp, int nVar, gwm::Distance::DistanceType distanceType) :
        mX(nDp, nVar),
        mY(nDp),
        mCoords(nDp, 2)
    {
        switch (distanceType)
        {
        case gwm::Distance::DistanceType::CRSDistance:
            mDistance = new gwm::CRSDistance();
            break;
        case gwm::Distance::DistanceType::MinkwoskiDistance:
            mDistance = new gwm::MinkwoskiDistance();
        default:
            break;
        }
        mWeight = new gwm::BandwidthWeight();
    }

    GWRBasicGpuTask(int nDp, int nVar, gwm::Distance::DistanceType distanceType, int nPredictPoints) :
        mX(nDp, nVar),
        mY(nDp),
        mCoords(nDp, 2),
        mPredictLocations(nPredictPoints, 2)
    {
        switch (distanceType)
        {
        case gwm::Distance::DistanceType::CRSDistance:
            mDistance = new gwm::CRSDistance();
            break;
        case gwm::Distance::DistanceType::MinkwoskiDistance:
            mDistance = new gwm::MinkwoskiDistance();
        default:
            break;
        }
        mWeight = new gwm::BandwidthWeight();
    }

    GWRBasicGpuTask(const GWRBasicGpuTask& source) :
        mX(source.mX),
        mY(source.mY),
        mCoords(source.mCoords),
        mPredictLocations(source.mPredictLocations),
        mBetas(source.mBetas),
        mBetasSE(source.mBetasSE),
        mSHat(source.mSHat),
        mQDiag(source.mQDiag),
        mS(source.mS),
        mIsOptimizeBandwidth(source.mIsOptimizeBandwidth),
        mBandwidthOptimizationCriterion(source.mBandwidthOptimizationCriterion),
        mIsOptimizeVariables(source.mIsOptimizeVariables),
        mOptimizeVariablesThreshold(source.mOptimizeVariablesThreshold),
        mOptimizedBandwidth(source.mOptimizedBandwidth),
        mSelectedVars(source.mSelectedVars),
        mVariableOptimizationCriterionList(source.mVariableOptimizationCriterionList)
    {
        mDistance = source.mDistance->clone();
        mWeight = source.mWeight->clone();
    }

    ~GWRBasicGpuTask()
    {
        if (mDistance) delete mDistance;
        if (mWeight) delete mWeight;
    }

    GWRBasicGpuTask& operator=(const GWRBasicGpuTask& source)
    {
        mX = source.mX;
        mY = source.mY;
        mCoords = source.mCoords;
        mPredictLocations = source.mPredictLocations;
        mBetas = source.mBetas;
        mBetasSE = source.mBetasSE;
        mSHat = source.mSHat;
        mQDiag = source.mQDiag;
        mS = source.mS;
        mDistance = source.mDistance->clone();
        mWeight = source.mWeight->clone();
        mIsOptimizeBandwidth = source.mIsOptimizeBandwidth;
        mBandwidthOptimizationCriterion = source.mBandwidthOptimizationCriterion;
        mIsOptimizeVariables = source.mIsOptimizeVariables;
        mOptimizeVariablesThreshold = source.mOptimizeVariablesThreshold;
        mOptimizedBandwidth = source.mOptimizedBandwidth;
        mSelectedVars = source.mSelectedVars;
        mVariableOptimizationCriterionList = source.mVariableOptimizationCriterionList;
        return *this;
    }

    void setX(int i, int k, double value) override
    {
        mX(i, k) = value;
    }

    void setY(int i, double value) override
    {
        mY(i) = value;
    }

    void setCoords(int i, double u, double v) override
    {
        mCoords(i, u) = v;
    }

    void setPredictLocations(int i, double u, double v) override
    {
        mPredictLocations(i, u) = v;
    }

    void setDistanceType(int type) override
    {
        switch ((gwm::Distance::DistanceType)type)
        {
        case gwm::Distance::DistanceType::CRSDistance:
            mDistance = new gwm::CRSDistance();
            break;
        case gwm::Distance::DistanceType::MinkwoskiDistance:
            mDistance = new gwm::MinkwoskiDistance();
        default:
            break;
        }
    }

    void setCRSDistanceGergraphic(bool isGeographic) override
    {
        static_cast<gwm::CRSDistance*>(mDistance)->setGeographic(isGeographic);
    }

    void setMinkwoskiDistancePoly(int poly) override
    {
        static_cast<gwm::MinkwoskiDistance*>(mDistance)->setPoly(poly);
    }

    void setMinkwoskiDistanceTheta(double theta) override
    {
        static_cast<gwm::MinkwoskiDistance*>(mDistance)->setTheta(theta);
    }


    void setBandwidthSize(double bw) override
    {
        static_cast<gwm::BandwidthWeight*>(mWeight)->setBandwidth(bw);
    }

    void setBandwidthAdaptive(bool adaptive) override
    {
        static_cast<gwm::BandwidthWeight*>(mWeight)->setAdaptive(adaptive);
    }

    void setBandwidthKernel(int kernel) override
    {
        static_cast<gwm::BandwidthWeight*>(mWeight)->setKernel((gwm::BandwidthWeight::KernelFunctionType)kernel);
    }

    void enableBandwidthOptimization(int criterion) override
    {
        mIsOptimizeBandwidth = true;
        mBandwidthOptimizationCriterion = static_cast<gwm::GWRBasic::BandwidthSelectionCriterionType>(criterion);
    }

    void enableVariablesOptimization(double threshold) override
    {
        mIsOptimizeVariables = true;
        mOptimizeVariablesThreshold = threshold;
    }

    double betas(int i, int k) override { return mBetas(i, k); }

    double betasSE(int i, int k) override { return mBetasSE(i, k); }

    double shat1() override { return mSHat(0); }

    double shat2() override { return mSHat(1); }

    double qDiag(int i) override { return mQDiag(i); }

    unsigned long long sRows() override { return mS.n_rows; }

    double s(int i, int k) override { return mS(i, k); }

    double diagnosticRSS() override { return mDiagnostic.RSS; }

    double diagnosticAIC() override { return mDiagnostic.AIC; }

    double diagnosticAICc() override { return mDiagnostic.AICc; }

    double diagnosticENP() override { return mDiagnostic.ENP; }

    double diagnosticEDF() override { return mDiagnostic.EDF; }

    double diagnosticRSquare() override { return mDiagnostic.RSquare; }

    double diagnosticRSquareAdjust() override { return mDiagnostic.RSquareAdjust; }

    double optimizedBandwidth() override { return mOptimizedBandwidth; }

    unsigned long long selectedVarSize() override { return mSelectedVars.size(); }

    unsigned long long selectedVar(unsigned long long i) override { return mSelectedVars[i]; }

    unsigned long long variableSelectionCriterionSize() override { return mVariableOptimizationCriterionList.size(); }

    unsigned long long variableSelectionCriterionItemVarSize(unsigned long long i) override { return mVariableOptimizationCriterionList[i].first.size(); }

    unsigned long long variableSelectionCriterionItemVar(unsigned long long i, unsigned long long j) override { return mVariableOptimizationCriterionList[i].first[j]; }

    double variableSelectionCriterionItemValue(unsigned long long i) override { return mVariableOptimizationCriterionList[i].second; }

    bool fit(bool hasIntercept = true) override;

    bool predict(bool hasIntercept) override;

};

#endif  // GWRBASICGPUTASK