#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include<algorithm>
// includes CUDA
#include <cuda_runtime.h>
#include<device_launch_parameters.h>
// includes, project
#include <helper_cuda.h>
#include <helper_functions.h> // helper functions for SDK examples
//include thrust
#include <thrust/device_ptr.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/copy.h>
using namespace std;
#define N (3) //每個(gè)對(duì)象的條件屬性個(gè)數(shù)
#define M (12) //保存對(duì)象個(gè)數(shù)
__device__ __managed__ int d_attrSelect[N];//定義全局變量尊浪,保存有效的屬性選擇。
struct Element
{
int attrval[N];//保存該對(duì)象的所有的條件屬性值
bool __host__ __device__ operator<(const Element& e)const
{
for (int i = 0; i < N; i++)
{
if (d_attrSelect[i] == 0)//為0表示該屬性無(wú)效
{
continue;
}
if (attrval[i]<e.attrval[i])
{
return true;
}
if (attrval[i]>e.attrval[i])
{
return false;
}
if (attrval[i] == e.attrval[i])
{
continue;
}
}
return false;
}
bool __host__ __device__ operator == (const Element & e)const
{
bool res = true;
for (int i = 0; i < N; i++)
{
if (d_attrSelect[i] == 0)//如果為0則無(wú)效
continue;
if (attrval[i] != e.attrval[i])
{
res = false;
break;
}
}
return res;
}
};
void Split(const string& src, const string& separator, vector<string>& dest)
{
string str = src;
string substring;
string::size_type start = 0, index;
do
{
index = str.find_first_of(separator, start);
if (index != string::npos)
{
substring = str.substr(start, index - start);
dest.push_back(substring);
start = str.find_first_not_of(separator, index);
if (start == string::npos) return;
}
} while (index != string::npos);
//the last token
substring = str.substr(start);
dest.push_back(substring);
}
void InData(const char* fname, thrust::host_vector<Element>&all_elments)
{
#ifdef TIME_TEST
double timeuse;
StartTimer();
#endif
ifstream fin(fname);
if (!fin)
{
// NO, abort program
cerr << "can't open input file \"" << fname << "\"" << endl;
getchar();
exit(EXIT_FAILURE);
}
string line;
string separtor = " ";
while (getline(fin, line))
{
//cout << line << endl;
Element temp_obj;
vector<string> temp_string;
Split(line, separtor, temp_string);
for (int i = 0; i < temp_string.size(); i++)//去掉最后的決策屬性
{
temp_obj.attrval[i] = stoi(temp_string[i]);
//cout << temp_obj.attrval[i] << " ";
}
//cout << endl;
all_elments.push_back(temp_obj);
}
fin.close();
#ifdef TIME_TEST
timeuse = GetTimer();
std::cerr << "[CPU] InData 讀取數(shù)據(jù)的時(shí)間(in second)\t" << timeuse << std::endl;
StartTimer();
#endif
}
int main()
{
for (int i = 0; i < N; i++)
{
d_attrSelect[i] = 1;
}
d_attrSelect[2] = 0;
for (int i = 0; i <N; i++)
{
printf("Host: the value of the global variable is %d\n", d_attrSelect[i]);
}
thrust::host_vector <Element> ElementList;
char s[1024] = "C://Users//hym//Desktop//test.txt";
InData(s, ElementList);
cout << " CPU: before sort" << endl;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
cout << ElementList[i].attrval[j] << " ";
}
cout << endl;
//printf("CPU Sort before:%d %d %d\n", ElementList[i].attrval[0], ElementList[i].attrval[1], ElementList[i].attrval[2]);
}
std::sort(ElementList.begin(), ElementList.end());// cpu 串行排序
thrust::device_vector<Element> dev_element = ElementList;
thrust::sort(dev_element.begin(), dev_element.end());
ElementList = dev_element;
cout << " CPU: After sort" << endl;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
cout << ElementList[i].attrval[j] << " ";
}
cout << endl;
}
thrust::host_vector<int> h_value(M);
for (int i = 0; i < M; i++)
{
h_value[i] = i + 1;
}
thrust::device_vector<int> d_value = h_value;
thrust::sort_by_key(dev_element.begin(), dev_element.end(), d_value.begin());
ElementList = dev_element;
h_value = d_value;
for (int i = 0; i < M; i++)
{
cout << "id=" << h_value[i] << " ";
for (int j = 0; j < N; j++)
{
cout << ElementList[i].attrval[j] << " ";
}
cout << endl;
}
getchar();
return 0;
}
my test:
0 0 0
0 0 1
1 0 1
1 1 1
0 0 0
1 1 2
0 0 0
1 1 1
1 1 2
0 1 1
1 0 2
1 1 1
最新的cuda中kernel函數(shù)中也可以使用thrust了逼纸。
Thrust inside user written kernels
使用仿函數(shù)進(jìn)行排序
////////////////////////////////////////////////////////////////////////////
//
// Copyright 1993-2015 NVIDIA Corporation. All rights reserved.
//
// Please refer to the NVIDIA end user license agreement (EULA) associated
// with this source code for terms and conditions that govern your use of
// this software. Any use, reproduction, disclosure, or distribution of
// this software and related documentation outside the terms of the EULA
// is strictly prohibited.
//
////////////////////////////////////////////////////////////////////////////
/* Template project which demonstrates the basics on how to setup a project
* example application.
* Host code.
*/
///*
// includes, system
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include<algorithm>
#include<map>
#include <iomanip>//輸出格式限制
// includes CUDA
#include <cuda_runtime.h>
#include<device_launch_parameters.h>
#include<cusparse.h>
// includes, project
#include <helper_cuda.h>
#include <helper_functions.h> // helper functions for SDK examples
//include thrust
#include <thrust/device_ptr.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/sequence.h>
#include "basic.h"
using namespace std;
///*
#define N (2)
#define M (12)
#define D (3)
/*/
//*
#define N (16)
#define M (20000)
#define D (26)
//*/
#define TILE_X 8
#define TILE_Y 32
struct Element
{
int attrval[N];//條件屬性值
int attrdec;//決策屬性值
};
struct compare_element
{
thrust::device_vector<int> attr;
compare_element(thrust::device_vector<int> temp)
{
attr = temp;
}
__device__ bool operator()(const Element &a, const Element &b)const
{
for (int i = 0; i < N; i++)
{
if (attr[i] == 0)
{
continue;
}
if (a.attrval[i]<b.attrval[i])
{
return true;
}
if (a.attrval[i]>b.attrval[i])
{
return false;
}
if (a.attrval[i] == b.attrval[i])
{
continue;
}
}
return false;
}
};
void Split(const string& src, const string& separator, vector<string>& dest)
{
string str = src;
string substring;
string::size_type start = 0, index;
do
{
index = str.find_first_of(separator, start);
if (index != string::npos)
{
substring = str.substr(start, index - start);
dest.push_back(substring);
start = str.find_first_not_of(separator, index);
if (start == string::npos) return;
}
} while (index != string::npos);
//the last token
substring = str.substr(start);
dest.push_back(substring);
}
void InData(const char* fname, thrust::host_vector<Element>&all_elments)
{
#ifdef TIME_TEST
double timeuse;
StartTimer();
#endif
ifstream fin(fname);
if (!fin)
{
// NO, abort program
cerr << "can't open input file \"" << fname << "\"" << endl;
getchar();
exit(EXIT_FAILURE);
}
string line;
string separtor = " ";
int obj_cnt = 0;
while (getline(fin, line) && obj_cnt<M)
{
vector<string> temp_string;
Split(line, separtor, temp_string);
int length = temp_string.size();
if ((length - 1) != N)
{
cerr << "屬性個(gè)數(shù)不正確:N=" << N << " 實(shí)際的屬性個(gè)數(shù):" << length - 1 << endl;
getchar();
exit(EXIT_FAILURE);
}
for (int i = 0; i < length - 1; i++)//去掉最后的決策屬性
{
all_elments[obj_cnt].attrval[i] = stoi(temp_string[i]);
}
all_elments[obj_cnt].attrdec = stoi(temp_string[length - 1]);
obj_cnt++;
}
fin.close();
#ifdef TIME_TEST
timeuse = GetTimer();
std::cerr << "[CPU] InData 讀取數(shù)據(jù)的時(shí)間(in second)\t" << timeuse << std::endl;
StartTimer();
#endif
}
void printData(thrust::host_vector<Element> all_Elements)
{
cout << "-------------Data------" << endl;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
cout << all_Elements[i].attrval[j] << " ";
}
cout << " dec=" << all_Elements[i].attrdec << endl;
}
cout << endl;
}
int main()
{
thrust::host_vector<Element> all_Elements(M);//考慮到push_back很慢
//char s[1024] = "C://Users//hym//Desktop//GPU_Data//letter.txt";//20000 16 26
char s[1024] = "C://Users//hym//Desktop//GPU_Data//test.txt";//12 2 3
InData(s, all_Elements);
printData(all_Elements);
thrust::device_vector<Element> dev_element = all_Elements;
thrust::device_vector<int> attr(N, 1);
compare_element cmp(attr);
//thrust::sort(dev_element.begin(), dev_element.end(),cmp);
all_Elements = dev_element;
printData(all_Elements);
getchar();
return 0;
}