Program Listing for File SpatialWeight.h
↰ Return to documentation for file (include/gwmodelpp/spatialweight/SpatialWeight.h
)
#ifndef SPATIALWEIGHT_H
#define SPATIALWEIGHT_H
#ifdef ENABLE_CUDA
#include <cuda_runtime.h>
#include "gwmodelpp/spatialweight/cuda/ISpatialCudaEnabled.h"
#endif
#include "Weight.h"
#include "Distance.h"
#include "BandwidthWeight.h"
#include "CRSDistance.h"
#include "MinkwoskiDistance.h"
#include "DMatDistance.h"
#include "OneDimDistance.h"
#include "CRSSTDistance.h"
namespace gwm
{
class SpatialWeight
#ifdef ENABLE_CUDA
: public ISpatialCudaEnabled
#endif
{
public:
SpatialWeight() {}
SpatialWeight(const Weight* weight, const Distance* distance)
{
mWeight = weight->clone();
mDistance = distance->clone();
}
SpatialWeight(const Weight& weight, const Distance& distance)
{
mWeight = weight.clone();
mDistance = distance.clone();
}
SpatialWeight(const SpatialWeight& spatialWeight)
{
mWeight = spatialWeight.mWeight->clone();
mDistance = spatialWeight.mDistance->clone();
}
SpatialWeight(SpatialWeight&& other)
{
mWeight = other.mWeight;
mDistance = other.mDistance;
other.mWeight = nullptr;
other.mDistance = nullptr;
}
virtual ~SpatialWeight();
Weight *weight() const
{
return mWeight;
}
void setWeight(Weight *weight)
{
if (weight && weight != mWeight)
{
if (mWeight) delete mWeight;
mWeight = weight->clone();
}
}
void setWeight(Weight& weight)
{
if (mWeight) delete mWeight;
mWeight = weight.clone();
}
void setWeight(Weight&& weight)
{
if (mWeight) delete mWeight;
mWeight = weight.clone();
}
template<typename T>
T* weight() const { return nullptr; }
Distance *distance() const
{
return mDistance;
}
void setDistance(Distance *distance)
{
if (distance && distance != mDistance)
{
if (mDistance) delete mDistance;
mDistance = distance->clone();
}
}
void setDistance(Distance& distance)
{
if (mDistance) delete mDistance;
mDistance = distance.clone();
}
void setDistance(Distance&& distance)
{
if (mDistance) delete mDistance;
mDistance = distance.clone();
}
template<typename T>
T* distance() const { return nullptr; }
public:
SpatialWeight& operator=(const SpatialWeight& spatialWeight);
SpatialWeight& operator=(SpatialWeight&& spatialWeight);
public:
virtual arma::vec weightVector(arma::uword focus) const
{
return mWeight->weight(mDistance->distance(focus));
}
#ifdef ENABLE_CUDA
virtual cudaError_t prepareCuda(size_t gpuId) override
{
cudaError_t error;
error = mDistance->prepareCuda(gpuId);
if (error != cudaSuccess) return error;
error = mWeight->prepareCuda(gpuId);
return error;
}
virtual bool useCuda()
{
return mWeight->useCuda() || mDistance->useCuda();
}
virtual void setUseCuda(bool isUseCuda)
{
mWeight->setUseCuda(isUseCuda);
mDistance->setUseCuda(isUseCuda);
}
virtual cudaError_t weightVector(arma::uword focus, double* d_dists, double* d_weights) const
{
cudaError_t error;
size_t elems = 0;
error = mDistance->distance(focus, d_dists, &elems);
if (error != cudaSuccess) return error;
error = mWeight->weight(d_dists, d_weights, elems);
return error;
}
#endif
virtual bool isValid();
private:
Weight* mWeight = nullptr;
Distance* mDistance = nullptr;
};
template<>
inline BandwidthWeight* SpatialWeight::weight<BandwidthWeight>() const
{
return static_cast<BandwidthWeight*>(mWeight);
}
template<>
inline CRSDistance* SpatialWeight::distance<CRSDistance>() const
{
return static_cast<CRSDistance*>(mDistance);
}
template<>
inline CRSSTDistance* SpatialWeight::distance<CRSSTDistance>() const
{
return static_cast<CRSSTDistance*>(mDistance);
}
template<>
inline MinkwoskiDistance* SpatialWeight::distance<MinkwoskiDistance>() const
{
return static_cast<MinkwoskiDistance*>(mDistance);
}
template<>
inline DMatDistance* SpatialWeight::distance<DMatDistance>() const
{
return static_cast<DMatDistance*>(mDistance);
}
template<>
inline OneDimDistance* SpatialWeight::distance<OneDimDistance>() const
{
return static_cast<OneDimDistance*>(mDistance);
}
}
#endif // SPATIALWEIGHT_H