Attachment 'solution-12-1.cpp'
Download 1 #include <iostream>
2 #include <fstream>
3 #include <vector>
4 #include <ext/hash_set>
5 #include <ext/hash_map>
6 #include <math.h>
7
8 using namespace std;
9 using __gnu_cxx::hash_set;
10 using __gnu_cxx::hash_map;
11
12 #define MIN(a, b) ( (a) > (b) ? (b) : (a) )
13
14 class StringHashFunction
15 {
16 public:
17 size_t operator()(const string& s) const
18 {
19 size_t HashValue = 0;
20 for (size_t i = 0; i < s.length(); i++)
21 HashValue = 31 * HashValue + s[i];
22 return HashValue;
23 }
24 };
25
26 bool sortOrder(const pair<string, int>& x, const pair<string, int>& y)
27 {
28 return x.second > y.second;
29 }
30
31 void readInput(const string& filename, vector<vector<string> >* texts, vector<int>* realAssignments)
32 {
33 std::ifstream inputFile(filename.c_str());
34 const size_t MAX_LINE_LENGTH = 100 * 1000;
35 char* buf = new char[MAX_LINE_LENGTH + 1];
36 size_t lineCount = 0;
37 while (true)
38 {
39 inputFile.getline(buf, MAX_LINE_LENGTH);
40 if (inputFile.eof()) break;
41 std::string line = buf;
42 ++lineCount;
43 size_t pos = line.find('\t');
44 if (pos == std::string::npos)
45 {
46 cout << "WARNING: line #" << lineCount
47 << " without TAB, skipping it " << endl;
48 continue;
49 }
50 std::string text = line.substr(pos + 1);
51 std::string label = line.substr(0, pos);
52 size_t i = 0;
53 vector<string> currentText;
54 while (i < text.size()) // assuming no duplicate words in titles
55 {
56 while (i < text.size() && !isalpha(text[i])) ++i;
57 size_t i0 = i;
58 while (i < text.size() && isalpha(text[i])) ++i;
59 if (i > i0) currentText.push_back(text.substr(i0, i - i0));
60 }
61 texts->push_back(currentText);
62 if (label == "SIGGRAPH") realAssignments->push_back(0);
63 else
64 if (label == "SIGIR") realAssignments->push_back(1);
65 else realAssignments->push_back(2);
66 }
67 inputFile.close();
68 }
69
70 int main(int argc, char** argv)
71 {
72 int M = 100;
73 int nofCentroids = 3;
74 double RSS = 0;
75 vector<hash_set<string, StringHashFunction> > centroids;
76 vector<vector<string> > texts;
77 vector<int> kmeansAssignments;
78 vector<int> realAssignments;
79 vector<int> realCounts;
80 vector<double> jaccardDistances;
81 vector<hash_map<string, int, StringHashFunction> > hashmaps;
82 readInput("dblp.txt", &texts, &realAssignments);
83 jaccardDistances.resize(texts.size());
84 kmeansAssignments.resize(texts.size());
85 realCounts.resize(texts.size());
86 hashmaps.resize(nofCentroids);
87 for(unsigned int i = 0; i < texts.size(); i++)
88 {
89 kmeansAssignments[i] = -1;
90 jaccardDistances[i] = INT_MAX;
91 realCounts[realAssignments[i]]++;
92 }
93 // get random centroids
94 srand(time(NULL));
95 for(int i = 0; i < nofCentroids; i++)
96 {
97 hash_set<string, StringHashFunction> currentCentroid;
98 for(unsigned int j = 0; j < 5; j++) // number of random titles to start with
99 {
100 int r = rand() % texts.size();
101 while(realAssignments[r] != i) r = rand() % texts.size();
102 for(unsigned int j = 0; j < texts[r].size(); j++)
103 currentCentroid.insert(texts[r][j]);
104 }
105 centroids.push_back(currentCentroid);
106 }
107 // do iterations
108 int iterations = 0;
109 bool change = true;
110 while(iterations < 100 && change)
111 {
112 change = false;
113 for(unsigned int i = 0; i < texts.size(); i++)
114 {
115 for(unsigned int j = 0; j < centroids.size(); j++)
116 {
117 int intersectionSize = 0;
118 int unionSize;
119 for(unsigned int k = 0; k < texts[i].size(); k++)
120 if (centroids[j].find(texts[i][k]) != centroids[j].end())
121 intersectionSize++;
122 unionSize = centroids[j].size() + texts[i].size() - intersectionSize;
123 double dist = pow(1 - 1.0 * intersectionSize / unionSize, 2.0);
124 if (dist < jaccardDistances[i])
125 {
126 change = true;
127 jaccardDistances[i] = dist;
128 kmeansAssignments[i] = j;
129 }
130 }
131 }
132 RSS = 0;
133 for(unsigned int i = 0; i < texts.size(); i++)
134 RSS += jaccardDistances[i]; // pow(1 - jaccardSimilarities[i], 2.0);
135 // recalculate centroids
136 for(unsigned int i = 0; i < centroids.size(); i++)
137 hashmaps[i].clear();
138 for(unsigned int i = 0; i < texts.size(); i++)
139 for(unsigned int j = 0; j < texts[i].size(); j++)
140 hashmaps[kmeansAssignments[i]][texts[i][j]]++;
141 hash_map<string, int, StringHashFunction>::iterator it;
142 for(unsigned int i = 0; i < centroids.size(); i++)
143 {
144 vector<pair<string, int> > counts;
145 pair<string, int> tmpPair;
146 for(it = hashmaps[i].begin(); it != hashmaps[i].end(); it++)
147 {
148 tmpPair.first = it->first;
149 tmpPair.second = it->second;
150 counts.push_back(tmpPair);
151 }
152 sort(counts.begin(), counts.end(), sortOrder);
153 int n = MIN(M, (int)counts.size());
154 centroids[i].clear();
155 for(int j = 0; j < n; j++)
156 centroids[i].insert(counts[j].first);
157 }
158 ++iterations;
159 cout << "Iteration " << iterations << ". RSS = " << RSS << endl;
160 }
161 // find which centroid is which
162 vector<vector<int> > counts;
163 vector<int> centroidMap;
164 vector<int> centroidCounts;
165 centroidMap.resize(nofCentroids);
166 counts.resize(nofCentroids);
167 centroidCounts.resize(nofCentroids);
168 for(int i = 0; i < nofCentroids; i++)
169 counts[i].resize(nofCentroids);
170 for(unsigned int i = 0; i < texts.size(); i++)
171 {
172 counts[kmeansAssignments[i]][realAssignments[i]]++;
173 centroidCounts[kmeansAssignments[i]]++;
174 }
175 vector<bool> centroidTaken;
176 centroidTaken.resize(nofCentroids);
177 for(int i = 0; i < nofCentroids; i++)
178 {
179 int maxCount = 0;
180 for(int j = 0; j < nofCentroids; j++)
181 {
182 if (counts[i][j] > maxCount && !centroidTaken[j])
183 {
184 maxCount = counts[i][j];
185 centroidMap[i] = j;
186 }
187 }
188 centroidTaken[centroidMap[i]] = true;
189 }
190 // compute precision and recall
191 double precision = 0;
192 double recall = 0;
193 for(int i = 0; i < nofCentroids; i++)
194 precision += 1.0 * counts[i][centroidMap[i]] / realCounts[centroidMap[i]];
195 precision /= nofCentroids;
196 for(int i = 0; i < nofCentroids; i++)
197 recall += 1.0 * counts[i][centroidMap[i]] / centroidCounts[i];
198 recall /= nofCentroids;
199 cout << "Precision : " << 100 * precision << "%" << endl;
200 cout << "Recal : " << 100 * recall << "%" << endl;
201 }
Attached Files
To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.You are not allowed to attach a file to this page.