00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #include "apa.h"
00026
00027
00028 #include <cmath>
00029 #include <fstream>
00030 #include <iostream>
00031 #include <set>
00032
00033 using namespace std;
00034
00035 bool APA::verbose = false;
00036 float APA::updatePercent = 0.5;
00037
00038 Float APA::C = 100.0;
00039 int APA::kd = 1;
00040 Float APA::ka = 0.0;
00041
00043
00044
00045 APAS::APAS(int k, int d) :
00046 APA(k, d)
00047 {
00048 M.rehash(k * d / 8);
00049 }
00050
00051 APASV::APASV(int k, int d) :
00052 APA(k, d),
00053 M(k)
00054 { }
00055
00056 void APASV::init(int k, int d)
00057 {
00058 this->k = k;
00059 this->d = d;
00060 M.resize(k);
00061 }
00062
00064
00065
00070 Float APAS::score(unsigned r, X& v) {
00071 Float ans = 0;
00072 for (unsigned i = 0; i < v.size(); ++i) {
00073 pair<unsigned, unsigned> p(r, v[i]);
00074 Matrix::const_iterator m_i = M.find(p);
00075 if (m_i != M.end())
00076 ans += m_i->second.alpha;
00077 }
00078 return ka + pow(ans, kd);
00079 }
00080
00081 Float APASV::score(unsigned r, X& v) {
00082 Float ans = 0;
00083 Row& Mr = M[r];
00084 for (unsigned i = 0; i < v.size(); ++i) {
00085 Row::const_iterator r_i = Mr.find(v[i]);
00086 if (r_i != Mr.end())
00087 ans += r_i->second.alpha;
00088 }
00089 return ka + pow(ans, kd);
00090 }
00091
00092
00093 void APAS::update(Y y, Float tau, X& x) {
00094 for (unsigned i = 0; i < x.size(); ++i) {
00095 pair<unsigned, unsigned> p(y, x[i]);
00096 M[p].update(t, tau);
00097 }
00098 }
00099
00100 void APASV::update(Y y, Float tau, X& x) {
00101 for (unsigned i = 0; i < x.size(); ++i) {
00102 unsigned xi = x[i];
00103 M[y][xi].update(t, tau);
00104 }
00105 }
00106
00111 void APA::rand_permutation(vector<int>& OUT) {
00112 size_t N = OUT.size();
00113 for (int i = 0; i < N; i++)
00114 OUT[i] = -1;
00115 for (int i = 0; i < N; ) {
00116 int s = int(float(N) * (rand() / (RAND_MAX + 1.0)));
00117 if (OUT[s] == -1)
00118 OUT[s] = i++;
00119 }
00120 }
00121
00122
00123 void APA::train(Cases& cases, int T) {
00124 unsigned _S_ = cases.size();
00125 if (verbose)
00126 cerr << "APA::APA(" << k << ", " << d << ")"
00127 << " cases = " << _S_ << endl;
00128 vector<int> perm(_S_);
00129 for (int it = 0; it < T; ++it) {
00130 rand_permutation(perm);
00131 int updates = 0;
00132 for (int i = 0; i < _S_; ++i) {
00133 ++t;
00134 Case& c = cases[perm[i]];
00135 X& xt = c.first;
00136 Y yt = c.second;
00137
00138
00139 int s = -1;
00140 Float score_r = score(yt, xt);
00141 Float score_s = score_r;
00142 for (int l = 0; l < k; ++l) {
00143 if (l != yt) {
00144 Float xs = score(l, xt);
00145 if (xs >= score_s) {
00146 s = l;
00147 score_s = xs;
00148 }
00149 }
00150 }
00151 if (s < 0)
00152 continue;
00153
00154 Float loss = margin - (score_r - score_s);
00155 if (loss > 0.0) {
00156 updates++;
00157 int xtNormSq = xt.size();
00158 Float tau = min(C, loss / xtNormSq);
00159 update(yt, tau, xt);
00160 update(s, -tau, xt);
00161 }
00162 }
00163 float updPercent = (100.0 * updates) / _S_;
00164 if (verbose)
00165 cerr << "\tupds_" << t << " = " << updates
00166 << " (" << updPercent << "%)" << endl;
00167 if (updates == 0 || updPercent < updatePercent)
00168 break;
00169 # ifdef DEBUG
00170 save(cerr, true);
00171 # endif
00172 }
00173 }
00174
00176
00177
00178 Y APA::predict(X& x) {
00179 Y best_k = 0;
00180 Float best_score = score(0, x);
00181
00182 for (unsigned r = 1; r < k; r++) {
00183 Float score_r = score(r, x);
00184 if (score_r > best_score) {
00185 best_score = score_r;
00186 best_k = r;
00187 }
00188 }
00189 return best_k;
00190 }
00191
00193
00194
00195 #define MAX_LINE_LEN 8196
00196
00197 APASV::APASV(char const* modelFile)
00198 {
00199 ifstream ifs(modelFile);
00200 if (!ifs) {
00201 cerr << "Missing model file: " << modelFile << endl;
00202 return;
00203 }
00204 char line[MAX_LINE_LEN];
00205 if (!ifs.getline(line, MAX_LINE_LEN)) {
00206 cerr << "Bad model file" << endl;
00207 return;
00208 }
00209
00210 int nc = atoi(line);
00211 int n = nc;
00212 while (n-- && ifs.getline(line, MAX_LINE_LEN)) {
00213 labels.push_back(line);
00214 }
00215
00216 if (!ifs.getline(line, MAX_LINE_LEN)) {
00217 cerr << "Bad model file" << endl;
00218 return;
00219 }
00220 int np = atoi(line);
00221 n = 0;
00222 while (n < np && ifs.getline(line, MAX_LINE_LEN))
00223 predIndex[(char const*)line] = n++;
00224 init(nc, np);
00225 load(ifs);
00226 }
00227
00228 void APAS::save(std::ostream& os)
00229 {
00230 os.precision(20);
00231 for (unsigned i = 0; i < k; ++i) {
00232 os << i;
00233 for (unsigned j = 0; j < d; ++j) {
00234 Matrix::const_iterator mit = M.find(make_pair(i, j));
00235 if (mit != M.end())
00236 os << ' ' << j << ':' << mit->second.alpha;
00237 }
00238 os << endl;
00239 }
00240 }
00241
00242 bool APAS::load(std::istream& is)
00243 {
00244 string line;
00245 while (getline(is, line)) {
00246 char* cline = (char *)line.c_str();
00247 int c = strtol(cline, &cline, 10);
00248 int i;
00249 float a;
00250 while (cline && sscanf(cline, " %d:%f", &i, &a)) {
00251 M[make_pair((unsigned)c, (unsigned)i)].alpha = a;
00252 cline = strchr(cline + 1, ' ');
00253 }
00254 }
00255 return true;
00256 }
00257
00258 void APASV::save(std::ostream& os)
00259 {
00260 #ifdef TEXT
00261 os.precision(3);
00262 for (unsigned i = 0; i < k; ++i) {
00263 os << i;
00264 for (unsigned j = 0; j < d; ++j) {
00265 if (M[i].find(j) != M[i].end()) {
00266 float w = M[i][j];
00267 if (w != 0.0)
00268 os << ' ' << j << ':' << w;
00269 }
00270 }
00271 os << endl;
00272 }
00273 return;
00274 #else
00275 for (unsigned i = 0; i < k; ++i) {
00276
00277 int count = 0;
00278 for (unsigned j = 0; j < d; ++j) {
00279 if (M[i].find(j) != M[i].end()) {
00280 float w = M[i][j].alpha;
00281 if (w != 0.0)
00282 count++;
00283 }
00284 }
00285 if (count == 0)
00286 continue;
00287 os.write((char const*)&i, sizeof(i));
00288 os.write((char const*)&count, sizeof(count));
00289 for (unsigned j = 0; j < d; ++j) {
00290 if (M[i].find(j) != M[i].end()) {
00291 float w = M[i][j].alpha;
00292 if (w != 0.0) {
00293 os.write((char const*)&j, sizeof(j));
00294 os.write((char const*)&w, sizeof(w));
00295 }
00296 }
00297 }
00298 }
00299 #endif
00300 }
00301
00302 bool APASV::load(std::istream& is)
00303 {
00304 int c;
00305 while (is.read((char*)&c, sizeof(c))) {
00306 int count;
00307 if (!is.read((char*)&count, sizeof(count))) {
00308 cerr << "bad file format: count" << endl;
00309 return false;
00310 }
00311 while (count--) {
00312 int i;
00313 if (!is.read((char*)&i, sizeof(i))) {
00314 cerr << "bad file format: i" << endl;
00315 return false;
00316 }
00317 float a;
00318 if (!is.read((char*)&a, sizeof(a))) {
00319 cerr << "bad file format: a" << endl;
00320 return false;
00321 }
00322 M[c][i].alpha = a;
00323 }
00324 }
00325 return true;
00326 }