Program Listing for File SpatialWeight.h

Return to documentation for file (include/gwmodelpp/spatialweight/SpatialWeight.h)

#ifndef SPATIALWEIGHT_H
#define SPATIALWEIGHT_H

#ifdef ENABLE_CUDA
#include <cuda_runtime.h>
#include "gwmodelpp/spatialweight/cuda/ISpatialCudaEnabled.h"
#endif

#include "Weight.h"
#include "Distance.h"

#include "BandwidthWeight.h"

#include "CRSDistance.h"
#include "MinkwoskiDistance.h"
#include "DMatDistance.h"
#include "OneDimDistance.h"
#include "CRSSTDistance.h"

namespace gwm
{

class SpatialWeight
#ifdef ENABLE_CUDA
    : public ISpatialCudaEnabled
#endif
{
public:

    SpatialWeight() {}

    SpatialWeight(const Weight* weight, const Distance* distance):
        mWeight(weight->clone()),
        mDistance(distance->clone())
    {}

    SpatialWeight(const std::unique_ptr<Weight>& weight, const std::unique_ptr<Distance>& distance):
        mWeight(weight->clone()),
        mDistance(distance->clone())
    {}

    SpatialWeight(std::unique_ptr<Weight>&& weight, std::unique_ptr<Distance>&& distance):
        mWeight(std::move(weight)),
        mDistance(std::move(distance))
    {}

    SpatialWeight(const Weight& weight, const Distance& distance):
        mWeight(weight.clone()),
        mDistance(distance.clone())
    {}

    SpatialWeight(const SpatialWeight& spatialWeight):
        mWeight(spatialWeight.mWeight->clone()),
        mDistance(spatialWeight.mDistance->clone())
    {}

    SpatialWeight(SpatialWeight&& other):
        mWeight(std::move(other.mWeight)),
        mDistance(std::move(other.mDistance))
    {}

    virtual ~SpatialWeight() {};

    const std::unique_ptr<Weight>& weight() const
    {
        return mWeight;
    }

    void setWeight(const Weight *weight)
    {
        if (weight && weight != mWeight.get())
        {
            mWeight = weight->clone();
        }
    }

    void setWeight(const Weight& weight)
    {
        mWeight = weight.clone();
    }

    void setWeight(Weight&& weight)
    {
        mWeight = weight.clone();
    }

    void setWeight(std::unique_ptr<Weight>&& weight)
    {
        mWeight = std::move(weight);
    }

    template<typename T>
    T& weight() const { return nullptr; }

    const std::unique_ptr<Distance>& distance() const
    {
        return mDistance;
    }

    void setDistance(const Distance *distance)
    {
        if (distance && distance != mDistance.get())
        {
            mDistance = distance->clone();
        }
    }

    void setDistance(const Distance& distance)
    {
        mDistance = distance.clone();
    }

    void setDistance(Distance&& distance)
    {
        mDistance = distance.clone();
    }

    void setDistance(std::unique_ptr<Distance>&& distance)
    {
        mDistance = std::move(distance);
    }

    template<typename T>
    T& distance() const { return nullptr; }

public:

    SpatialWeight& operator=(const SpatialWeight& spatialWeight);

    SpatialWeight& operator=(SpatialWeight&& spatialWeight);

public:

    virtual arma::vec weightVector(arma::uword focus) const
    {
        return mWeight->weight(mDistance->distance(focus));
    }

#ifdef ENABLE_CUDA
    virtual cudaError_t prepareCuda(size_t gpuId) override
    {
        cudaError_t error;
        error = mDistance->prepareCuda(gpuId);
        if (error != cudaSuccess) return error;
        error = mWeight->prepareCuda(gpuId);
        return error;
    }

    virtual bool useCuda()
    {
        return mWeight->useCuda() || mDistance->useCuda();
    }

    virtual void setUseCuda(bool isUseCuda)
    {
        mWeight->setUseCuda(isUseCuda);
        mDistance->setUseCuda(isUseCuda);
    }

    virtual cudaError_t weightVector(arma::uword focus, double* d_dists, double* d_weights) const
    {
        cudaError_t error;
        size_t elems = 0;
        error = mDistance->distance(focus, d_dists, &elems);
        if (error != cudaSuccess) return error;
        error = mWeight->weight(d_dists, d_weights, elems);
        return error;
    }
#endif

    virtual bool isValid();

private:
    std::unique_ptr<Weight> mWeight = nullptr;
    std::unique_ptr<Distance> mDistance = nullptr;
};

template<>
inline BandwidthWeight& SpatialWeight::weight<BandwidthWeight>() const
{
    return static_cast<BandwidthWeight&>(*mWeight.get());
}

template<>
inline CRSDistance& SpatialWeight::distance<CRSDistance>() const
{
    return static_cast<CRSDistance&>(*mDistance.get());
}

template<>
inline CRSSTDistance& SpatialWeight::distance<CRSSTDistance>() const
{
    return static_cast<CRSSTDistance&>(*mDistance.get());
}

template<>
inline MinkwoskiDistance& SpatialWeight::distance<MinkwoskiDistance>() const
{
    return static_cast<MinkwoskiDistance&>(*mDistance.get());
}

template<>
inline DMatDistance& SpatialWeight::distance<DMatDistance>() const
{
    return static_cast<DMatDistance&>(*mDistance.get());
}

template<>
inline OneDimDistance& SpatialWeight::distance<OneDimDistance>() const
{
    return static_cast<OneDimDistance&>(*mDistance.get());
}

}

#endif // SPATIALWEIGHT_H