00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #include "pa.h"
00026
00027 #ifdef _WIN32
00028 #include <string>
00029 #endif
00030
00031
00032 #include <cmath>
00033 #include <fstream>
00034 #include <iostream>
00035 #include <set>
00036
00037 using namespace std;
00038
00039 bool PA::verbose = false;
00040 float PA::updatePercent = 0.5;
00041
00043
00044
00045 PAS::PAS(int k, int d) :
00046 PA(k, d)
00047 {
00048 M.rehash(k * d / 3);
00049 if (verbose)
00050 cerr << "\nPAS::PAS(" << k << ',' << d << ")" << endl;
00051 }
00052
00053 PASV::PASV(int k, int d) :
00054 PA(k, d),
00055 M(k)
00056 {
00057 if (verbose)
00058 cerr << "\nPASV::PASV(" << k << ',' << d << ")" << endl;
00059 }
00060
00061 void PASV::init(int k, int d)
00062 {
00063 this->k = k;
00064 this->d = d;
00065 M.resize(k);
00066 }
00067
00068 PAD::PAD(int k, int d) :
00069 PA(k, d),
00070 M(k)
00071 {
00072 for (int i = 0; i < k; i++)
00073 M[i].resize(d);
00074 if (verbose)
00075 cerr << "\nPAD::PAD(" << k << "," << d << ")" << endl;
00076 }
00077
00079
00080
00085 Float PAS::score(unsigned r, X& v) {
00086 Float ans = 0;
00087 for (unsigned i = 0; i < v.size(); ++i) {
00088 pair<unsigned, unsigned> p(r, v[i]);
00089 Matrix::const_iterator m_i = M.find(p);
00090 if (m_i != M.end())
00091 ans += m_i->second;
00092 }
00093 return ans;
00094 }
00095
00096 Float PASV::score(unsigned r, X& v) {
00097 Float ans = 0;
00098 Row& Mr = M[r];
00099 for (unsigned i = 0; i < v.size(); ++i) {
00100 Row::const_iterator r_i = Mr.find(v[i]);
00101 if (r_i != Mr.end())
00102 ans += r_i->second;
00103 }
00104 return ans;
00105 }
00106
00107 Float PAD::score(unsigned r, X& v) {
00108 vector<Float>& m = M[r];
00109 Float ans = 0;
00110 for (unsigned i = 0; i < v.size(); ++i)
00111 ans += m[v[i]];
00112 return ans;
00113 }
00114
00115
00116 void PAS::update(Y y, Float tau, X& x) {
00117 for (unsigned i = 0; i < x.size(); ++i) {
00118 pair<unsigned, unsigned> p(y, x[i]);
00119 M[p] += tau;
00120 }
00121 }
00122
00123 void PASV::update(Y y, Float tau, X& x) {
00124 for (unsigned i = 0; i < x.size(); ++i) {
00125 unsigned xi = x[i];
00126 M[y][xi] += tau;
00127 }
00128 }
00129
00130 void PAD::update(Y y, Float tau, X& x) {
00131 for (unsigned i = 0; i < x.size(); ++i) {
00132 M[y][x[i]] += tau;
00133 }
00134 }
00135
00140 void PA::rand_permutation(vector<int>& OUT) {
00141 size_t N = OUT.size();
00142 for (int i = 0; i < N; i++)
00143 OUT[i] = -1;
00144 for (int i = 0; i < N; ) {
00145 int s = int(float(N) * (rand() / (RAND_MAX + 1.0)));
00146 if (OUT[s] == -1)
00147 OUT[s] = i++;
00148 }
00149 }
00150
00151
00152 void PA::train(Cases& cases, int T) {
00153 unsigned _S_ = cases.size();
00154 vector<int> perm(_S_);
00155 for (int t = 0; t < T; ++t) {
00156 rand_permutation(perm);
00157 int updates = 0;
00158 for (int i = 0; i < _S_; ++i) {
00159 Case& c = cases[perm[i]];
00160 X& xt = c.first;
00161 Y yt = c.second;
00162
00163
00164 int s = -1;
00165 Float score_r = score(yt, xt);
00166 Float score_s = 0.0;
00167 for (int l = 0; l < k; ++l) {
00168 if (l != yt) {
00169 Float xs = score(l, xt);
00170 if (xs > score_r && xs > score_s) {
00171 s = l;
00172 score_s = xs;
00173 }
00174 }
00175 }
00176
00177 Float loss = margin - (score_r - score_s);
00178
00179 if (yt == 0)
00180 loss = 1.0 - (score_r - score_s);
00181 if (loss > 0.0) {
00182 updates++;
00183 int xtNormSq = xt.size();
00184 Float tau = loss / (2 * xtNormSq);
00185 update(yt, tau, xt);
00186 if (s >= 0)
00187 update(s, -tau, xt);
00188 }
00189 }
00190 float updPercent = (100.0*updates) / _S_;
00191 if (verbose)
00192 cerr << "\tupds_" << t << " = " << updates
00193 << " (" << updPercent << "%)" << endl;
00194 if (updates == 0 || updPercent < updatePercent)
00195 break;
00196 # ifdef DEBUG
00197 save(cerr, true);
00198 # endif
00199 }
00200 }
00201
00203
00204
00205 Y PA::predict(X& x) {
00206 Y best_k = 0;
00207 Float best_score = score(0, x);
00208
00209 for (unsigned r = 1; r < k; r++) {
00210 Float score_r = score(r, x);
00211 if (score_r > best_score) {
00212 best_score = score_r;
00213 best_k = r;
00214 }
00215 }
00216 return best_k;
00217 }
00218
00220
00221
00222 #define MAX_LINE_LEN 8196
00223
00224 PASV::PASV(char const* modelFile)
00225 {
00226 ifstream ifs(modelFile);
00227 if (!ifs) {
00228 cerr << "Missing model file: " << modelFile << endl;
00229 return;
00230 }
00231 char line[MAX_LINE_LEN];
00232 if (!ifs.getline(line, MAX_LINE_LEN)) {
00233 cerr << "Bad model file" << endl;
00234 return;
00235 }
00236
00237 int nc = atoi(line);
00238 int n = nc;
00239 while (n-- && ifs.getline(line, MAX_LINE_LEN)) {
00240 labels.push_back(line);
00241 }
00242
00243 if (!ifs.getline(line, MAX_LINE_LEN)) {
00244 cerr << "Bad model file" << endl;
00245 return;
00246 }
00247 int np = atoi(line);
00248 n = 0;
00249 while (n < np && ifs.getline(line, MAX_LINE_LEN))
00250 predIndex[(char const*)line] = n++;
00251 init(nc, np);
00252 load(ifs);
00253 }
00254
00255 void PAS::save(std::ostream& os)
00256 {
00257 os.precision(20);
00258 for (unsigned i = 0; i < k; ++i) {
00259 os << i;
00260 for (unsigned j = 0; j < d; ++j) {
00261 Matrix::const_iterator mit = M.find(make_pair(i, j));
00262 if (mit != M.end())
00263 os << ' ' << j << ':' << mit->second;
00264 }
00265 os << endl;
00266 }
00267 }
00268
00269 bool PAS::load(std::istream& is)
00270 {
00271 string line;
00272 while (std::getline(is, line)) {
00273 char* cline = (char *)line.c_str();
00274 int c = strtol(cline, &cline, 10);
00275 int i;
00276 float a;
00277 while (cline && sscanf(cline, " %d:%f", &i, &a)) {
00278 M[make_pair((unsigned)c, (unsigned)i)] = a;
00279 cline = strchr(cline + 1, ' ');
00280 }
00281 }
00282 return true;
00283 }
00284
00285 void PASV::save(std::ostream& os)
00286 {
00287 #ifdef TEXT
00288 os.precision(3);
00289 for (unsigned i = 0; i < k; ++i) {
00290 os << i;
00291 for (unsigned j = 0; j < d; ++j) {
00292 if (M[i].find(j) != M[i].end()) {
00293 float w = M[i][j];
00294 if (w != 0.0)
00295 os << ' ' << j << ':' << w;
00296 }
00297 }
00298 os << endl;
00299 }
00300 return;
00301 #else
00302 for (unsigned i = 0; i < k; ++i) {
00303
00304 int count = 0;
00305 for (unsigned j = 0; j < d; ++j) {
00306 if (M[i].find(j) != M[i].end()) {
00307 float w = M[i][j];
00308 if (w != 0.0)
00309 count++;
00310 }
00311 }
00312 if (count == 0)
00313 continue;
00314 os.write((char const*)&i, sizeof(i));
00315 os.write((char const*)&count, sizeof(count));
00316 for (unsigned j = 0; j < d; ++j) {
00317 if (M[i].find(j) != M[i].end()) {
00318 float w = M[i][j];
00319 if (w != 0.0) {
00320 os.write((char const*)&j, sizeof(j));
00321 os.write((char const*)&w, sizeof(w));
00322 }
00323 }
00324 }
00325 }
00326 #endif
00327 }
00328
00329 bool PASV::load(std::istream& is)
00330 {
00331 int c;
00332 while (is.read((char*)&c, sizeof(c))) {
00333 int count;
00334 if (!is.read((char*)&count, sizeof(count))) {
00335 cerr << "bad file format: count" << endl;
00336 return false;
00337 }
00338 int i;
00339 float a;
00340 while (count--) {
00341 if (!is.read((char*)&i, sizeof(i))) {
00342 cerr << "bad file format: i" << endl;
00343 return false;
00344 }
00345 if (!is.read((char*)&a, sizeof(a))) {
00346 cerr << "bad file format: a" << endl;
00347 return false;
00348 }
00349 M[c][i] = a;
00350 }
00351 }
00352 return true;
00353 }
00354
00355 void PAD::save(std::ostream& os)
00356 {
00357 os.precision(20);
00358 for (unsigned i = 0; i < k; ++i) {
00359 os << i;
00360 for (unsigned j = 0; j < d; ++j) {
00361 Float w = M[i][j];
00362 if (w != 0.0)
00363 os << ' ' << j << ':' << w;
00364 }
00365 os << endl;
00366 }
00367 }
00368
00369 bool PAD::load(std::istream& is)
00370 {
00371 string line;
00372 while (getline(is, line)) {
00373 char* cline = (char *)line.c_str();
00374 int c = strtol(cline, &cline, 10);
00375 int i;
00376 float a;
00377 while (cline && sscanf(cline, " %d:%f", &i, &a)) {
00378 M[c][i] = a;
00379 cline = strchr(cline + 1, ' ');
00380 }
00381 }
00382 return true;
00383 }
00384
00385 #ifdef MESCHACH
00386
00387 PASM::PASM(int k, int d) :
00388 PA(k, d)
00389 {
00390 M = sp_get(k, d, 32);
00391 if (verbose)
00392 cerr << "\nPASM::PASM(" << k << ',' << d << ")" << endl;
00393 }
00394
00395 void PASM::init(int k, int d)
00396 {
00397 this->k = k;
00398 this->d = d;
00399 sp_resize(M, k, d);
00400 }
00401
00402 Float PASM::score(unsigned r, X& v) {
00403 Float ans = 0;
00404 SPROW& Mr = M->row[r];
00405 for (unsigned i = 0; i < v.size(); ++i) {
00406 int idx = sprow_idx(Mr, v[i]);
00407 if (idx >= 0)
00408 ans += Mr->elt[idx].val;
00409 }
00410 return ans;
00411 }
00412
00413 void PASM::update(Y y, Float tau, X& x) {
00414 for (unsigned i = 0; i < x.size(); ++i) {
00415 unsigned xi = x[i];
00416 M[y][xi] += tau;
00417 }
00418 }
00419 void PASM::save(std::ostream& os)
00420 {
00421 for (unsigned i = 0; i < k; ++i) {
00422
00423 int count = 0;
00424 for (unsigned j = 0; j < d; ++j) {
00425 if (M[i].find(j) != M[i].end()) {
00426 float w = M[i][j].weight(t);
00427 if (w != 0.0)
00428 count++;
00429 }
00430 }
00431 if (count == 0)
00432 continue;
00433 os.write((char const*)&i, sizeof(i));
00434 os.write((char const*)&count, sizeof(count));
00435 for (unsigned j = 0; j < d; ++j) {
00436 if (M[i].find(j) != M[i].end()) {
00437 float w = M[i][j].weight(t);
00438 if (w != 0.0) {
00439 os.write((char const*)&j, sizeof(j));
00440 os.write((char const*)&w, sizeof(w));
00441 }
00442 }
00443 }
00444 }
00445 }
00446
00447 bool PASM::load(std::istream& is)
00448 {
00449 int c;
00450 while (is.read((char*)&c, sizeof(c))) {
00451 int count;
00452 if (!is.read((char*)&count, sizeof(count))) {
00453 cerr << "bad file format: count" << endl;
00454 return false;
00455 }
00456 int i;
00457 float a;
00458 while (count--) {
00459 if (!is.read((char*)&i, sizeof(i))) {
00460 cerr << "bad file format: i" << endl;
00461 return false;
00462 }
00463 if (!is.read((char*)&a, sizeof(a))) {
00464 cerr << "bad file format: a" << endl;
00465 return false;
00466 }
00467 M[c][i] = a;
00468 }
00469 }
00470 return true;
00471 }
00472
00473 #endif