00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00026
00027 #include "Parser.h"
00028 #include "EventStream.h"
00029 #include "svm.h"
00030 #include "conf_feature.h"
00031
00032
00033 #include "conf/conf_string.h"
00034 #include "include/unordered_map.h"
00035
00036
00037 #include <algorithm>
00038 #ifdef _WIN32
00039 #include <functional>
00040 #include <algorithm>
00041 #include <stdlib.h>
00042 #include <stdio.h>
00043 #include <io.h>
00044 #include "lib/strtok_r.h"
00045 #else
00046 #include <ext/functional>
00047 #endif
00048
00049 #include <iostream>
00050 #include <list>
00051
00052 using namespace std;
00053
00054 #define MAX_LINE_LEN 8196
00055
00056 namespace Parser {
00057
00059 IXE::conf<int> svmSkip("SvmSkip", 0);
00060
00062 IXE::conf<string> svmParams("SvmParams", "-t 1 -d 2 -g 0.2 -c 0.4 -e 0.1");
00063
00065 extern IXE::conf<bool> CompositeActions;
00066
00067 extern conf_feature SplitFeature;
00068
00072 struct SvmParser : public Parser
00073 {
00074 SvmParser(char const* modelFile);
00075
00076 ~SvmParser() {
00077 for (unsigned i = 0; i < model.size(); i++)
00078 svm_destroy_model(model[i]);
00079 }
00080
00081 void train(SentenceReader* sentenceReader, char const* modelFile);
00082 Sentence* parse(Sentence* sentence);
00083
00084
00085 WordIndex splits;
00086 vector<string> splitNames;
00087 unordered_map<char, char> splitGroup;
00088
00089 WordIndex predIndex;
00090 WordIndex classIndex;
00091 vector<string> classLabels;
00092 vector<svm_model*> model;
00093
00094 private:
00095 void collectEvents(SentenceReader* sentenceReader, char const* modelFile,
00096 vector<svm_problem>& problem);
00097
00099 bool splitModel() { return !SplitFeature->empty(); }
00100 };
00101
00102 static char* mkext(char* ext, int i)
00103 {
00104 ext[0] = '.'; ext[1] = 'a' + i/26; ext[2] = 'a' + (i%26); ext[3] = '\0';
00105 return ext;
00106 }
00107
00108 SvmParser::SvmParser(char const* modelFile) :
00109 Parser(predIndex)
00110 {
00111 if (!modelFile)
00112 return;
00113 ifstream ifs(modelFile);
00114 if (!ifs)
00115 throw IXE::FileError(string("Missing symbols file: ") + modelFile);
00116
00117 readHeader(ifs);
00118
00119 char line[MAX_LINE_LEN];
00120 if (!ifs.getline(line, MAX_LINE_LEN))
00121 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00122 int len = atoi(line);
00123 int n = 0;
00124 while (len--) {
00125 if (!ifs.getline(line, MAX_LINE_LEN))
00126 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00127 classIndex[(char const*)line] = n++;
00128 classLabels.push_back(line);
00129 }
00130
00131 if (!ifs.getline(line, MAX_LINE_LEN))
00132 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00133 len = atoi(line);
00134 n = 0;
00135 while (len--) {
00136 if (!ifs.getline(line, MAX_LINE_LEN))
00137 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00138 predIndex[(char const*)line] = n++;
00139 }
00140
00141 if (!ifs.getline(line, MAX_LINE_LEN))
00142 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00143 len = atoi(line);
00144 n = 0;
00145 int models = 0;
00146 int skipGroup = (len != 0);
00147 while (len--) {
00148 ifs.getline(line, MAX_LINE_LEN);
00149 char* next = line;
00150
00151 char* code = (line[0] == ' ' || line[0] == '\t') ? (char*)"" : strtok_r(0, " \t", &next);
00152 splits.insert(code);
00153 splitNames.push_back(code);
00154 code = strtok_r(0, " \t", &next);
00155 int group = atoi(code);
00156 splitGroup[n] = group;
00157 models = max(models, group);
00158 if (group == 0) skipGroup = 0;
00159 n++;
00160 }
00161 models++;
00162
00163 info.load(ifs);
00164
00165 model.resize(models);
00166 for (int i = skipGroup; i < models; i++) {
00167 char ext[3];
00168 string modeliFile = string(modelFile) + mkext(ext, i);
00169 model[i] = svm_load_model(modeliFile.c_str());
00170 if (!model[i])
00171 throw IXE::FileError(string("can't open model file ") + modeliFile);
00172 }
00173 }
00174
00178 Parser* SvmParserFactory(char const* modelFile = 0)
00179 {
00180 SvmParser* p = new SvmParser(modelFile);
00181 if (modelFile && p->model.empty()) {
00182 delete p;
00183 return 0;
00184 }
00185 return p;
00186 }
00187
00188 REGISTER_PARSER(SVM, SvmParserFactory);
00189
00190 void parseParameters(svm_parameter& param, char* parameters)
00191 {
00192
00193 param.svm_type = C_SVC;
00194 param.kernel_type = RBF;
00195 param.degree = 3;
00196 param.gamma = 0;
00197 param.coef0 = 0;
00198 param.nu = 0.5;
00199 param.cache_size = 100;
00200 param.C = 1;
00201 param.eps = 1e-3;
00202 param.p = 0.1;
00203 param.shrinking = 1;
00204 param.probability = 0;
00205 param.nr_weight = 0;
00206 param.weight_label = NULL;
00207 param.weight = NULL;
00208
00209 char const* opt = "";
00210 char* next = parameters;
00211
00212 while (opt = strtok_r(0, " \t", &next)) {
00213 if (opt[0] != '-') {
00214 cerr << "Missing option: " << opt << endl;
00215 return;
00216 }
00217 char* tok = strtok_r(0, " \t", &next);
00218 if (!tok) {
00219 cerr << "Missing option value: " << opt << endl;
00220 return;
00221 }
00222 switch (opt[1]) {
00223 case 's':
00224 param.svm_type = atoi(tok);
00225 break;
00226 case 't':
00227 param.kernel_type = atoi(tok);
00228 break;
00229 case 'd':
00230 param.degree = atoi(tok);
00231 break;
00232 case 'g':
00233 param.gamma = atof(tok);
00234 break;
00235 case 'r':
00236 param.coef0 = atof(tok);
00237 break;
00238 case 'n':
00239 param.nu = atof(tok);
00240 break;
00241 case 'm':
00242 param.cache_size = atof(tok);
00243 break;
00244 case 'c':
00245 param.C = atof(tok);
00246 break;
00247 case 'e':
00248 param.eps = atof(tok);
00249 break;
00250 case 'p':
00251 param.p = atof(tok);
00252 break;
00253 case 'h':
00254 param.shrinking = atoi(tok);
00255 break;
00256 case 'b':
00257 param.probability = atoi(tok);
00258 break;
00259 case 'w':
00260 ++param.nr_weight;
00261 param.weight_label = (int *)realloc(param.weight_label,
00262 sizeof(int)*param.nr_weight);
00263 param.weight = (double *)realloc(param.weight,
00264 sizeof(double)*param.nr_weight);
00265 param.weight_label[param.nr_weight-1] = atoi(opt+2);
00266 param.weight[param.nr_weight-1] = atof(tok);
00267 break;
00268 default:
00269 cerr << "unknown option: " << opt << endl;
00270 return;
00271 }
00272 }
00273 }
00274
00275 int compare_nodes(const void* a, const void* b) {
00276 return ((svm_node const*)a)->index - ((svm_node const*)b)->index;
00277 }
00278
00279 int MinimumSvmSize = 5000;
00280
00281 void SvmParser::collectEvents(SentenceReader* sentenceReader,
00282 char const* modelFile,
00283 vector<svm_problem>& problem)
00284 {
00285 WordIndex labelIndex;
00286 vector<string> labels;
00287
00288 vector<string> predLabels;
00289
00290
00291 list<Tanl::Classifier::Event*> events;
00292
00293 WordCounts predCount;
00294
00295 WordCounts splitCount;
00296 vector<char> splitEvents;
00297
00298 int evCount = 0;
00299 Tanl::Classifier::PID pID = 0;
00300
00301
00302 EventStream eventStream(sentenceReader, &info);
00303
00304 bool doSplit = splitModel();
00305
00306 while (eventStream.hasNext()) {
00307 Tanl::Classifier::Event* ev = eventStream.next();
00308 events.push_back(ev);
00309 evCount++;
00310 if (verbose) {
00311 if (evCount % 10000 == 0)
00312 cerr << '+' << flush;
00313 else if (evCount % 1000 == 0)
00314 cerr << '.' << flush;
00315 }
00316 vector<string>& ec = ev->features;
00317
00318 for (unsigned j = 0; j < ec.size(); j++) {
00319 string& pred = ec[j];
00320
00321 if (predIndex.find(pred.c_str()) == predIndex.end()) {
00322
00323
00324 int count = predCount.add(pred);
00325 if (count >= featureCutoff) {
00326 predLabels.push_back(pred);
00327 predIndex[pred.c_str()] = pID++;
00328 predCount.erase(pred);
00329 }
00330 }
00331 }
00332 if (doSplit) {
00333 string code = eventStream.splitFeature();
00334
00335 char const* ccode = code.c_str();
00336 if (splits.index(ccode) == -1) {
00337 splits.insert(ccode);
00338 splitNames.push_back(ccode);
00339 }
00340
00341 splitEvents.push_back((char)splits[ccode]);
00342 splitCount.add(code);
00343 }
00344 }
00345 if (verbose)
00346 cerr << endl;
00347
00348 predCount.clear();
00349 predCount = WordCounts();
00350
00351
00352 if (doSplit) {
00353
00354 vector<int> splitNewSize(max(1, (int)splits.size()));
00355 splitNewSize[0] = 0;
00356 int models = 1;
00357 FOR_EACH (WordCounts, splitCount, sit) {
00358 char splitCode = (char)splits[sit->first.c_str()];
00359 if (sit->second < MinimumSvmSize || splitCount.size() == 1) {
00360 splitGroup[splitCode] = 0;
00361 splitNewSize[0] += sit->second;
00362 } else {
00363 splitGroup[splitCode] = models;
00364 splitNewSize[models] = sit->second;
00365 models++;
00366 }
00367 }
00368
00369 problem.resize(models);
00370 int skipGroup = splitNewSize[0] == 0;
00371 problem[0].l = 0;
00372 for (int i = skipGroup; i < models; i++) {
00373 int size = splitNewSize[i];
00374 problem[i].y = new double[size];
00375 problem[i].x = new svm_node*[size];
00376 problem[i].l = 0;
00377 }
00378 } else {
00379 problem.resize(1);
00380 problem[0].y = new double[evCount];
00381 problem[0].x = new svm_node*[evCount];
00382 problem[0].l = 0;
00383 }
00384 int nTot = 0;
00385 Tanl::Classifier::ClassID oID = 0;
00386 evCount = 0;
00387 while (!events.empty()) {
00388 Tanl::Classifier::Event* ev = events.front();
00389 events.pop_front();
00390 char const* c = ev->className.c_str();
00391
00392 vector<string>& ec = ev->features;
00393 svm_node* preds = new svm_node[ec.size()+1];
00394 unsigned k = 0;
00395 for (unsigned j = 0; j < ec.size(); j++) {
00396 string& pred = ec[j];
00397 WordIndex::const_iterator pit = predIndex.find(pred.c_str());
00398 if (pit != predIndex.end()) {
00399 svm_node& node = preds[k++];
00400 node.index = pit->second + 1;
00401 node.value = 1.0;
00402 }
00403 }
00404 if (k) {
00405
00406 qsort(preds, k, sizeof(svm_node), compare_nodes);
00407
00408 svm_node& node = preds[k++];
00409 node.index = -1;
00410 node.value = 1.0;
00411 if (labelIndex.find(c) == labelIndex.end()) {
00412 labelIndex[c] = oID++;
00413 labels.push_back(c);
00414 }
00415 int i = 0;
00416 if (!splitEvents.empty())
00417 i = splitGroup[splitEvents[evCount]];
00418 int& ni = problem[i].l;
00419 problem[i].y[ni] = labelIndex[c];
00420
00421 preds = (svm_node*)realloc(preds, k * sizeof(svm_node));
00422 problem[i].x[ni] = preds;
00423 ni++;
00424 nTot++;
00425 if (verbose) {
00426 if (nTot % 10000 == 0)
00427 cerr << '+' << flush;
00428 else if (nTot % 1000 == 0)
00429 cerr << '.' << flush;
00430 }
00431 } else {
00432 cerr << "Discarded event" << endl;
00433 delete preds;
00434 }
00435 evCount++;
00436 delete ev;
00437 }
00438
00439 if (verbose)
00440 cerr << endl;
00441
00442
00443 ofstream ofs(modelFile, ios::binary | ios::trunc);
00444
00445 writeHeader(ofs);
00446
00447 ofs << labels.size() << endl;
00448 FOR_EACH (vector<string>, labels, pit)
00449 ofs << *pit << endl;
00450
00451 ofs << predLabels.size() << endl;
00452 FOR_EACH (vector<string>, predLabels, pit)
00453 ofs << *pit << endl;
00454
00455 ofs << splitNames.size() << endl;
00456 FOR_EACH (vector<string>, splitNames, pit)
00457 ofs << *pit << "\t" << (int)splitGroup[(char)splits[pit->c_str()]] << endl;
00458 info.save(ofs);
00459 ofs.close();
00460
00461 labels.clear();
00462 predLabels.clear();
00463 predIndex.clear();
00464 predIndex = WordIndex();
00465 labelIndex.clear();
00466 labelIndex = WordIndex();
00467
00468 info.clearRareEntities();
00469 }
00470
00471 void SvmParser::train(SentenceReader* sentenceReader, char const* modelFile)
00472 {
00473 vector<svm_problem> problem;
00474 collectEvents(sentenceReader, modelFile, problem);
00475
00476
00477 svm_parameter param;
00478 parseParameters(param, svmParams);
00479
00480 if (dup2(fileno(stderr), fileno(stdout)) < 0)
00481 cerr << "could not redirect stdout to stderr" << endl;
00482
00483 int models = problem.size();
00484 int skipGroup = problem[0].l == 0;
00485 for (int i = skipGroup; i < models; i++) {
00486 if (i >= svmSkip) {
00487 struct svm_model* model = svm_train(&problem[i], ¶m);
00488
00489 char ext[4];
00490 string modeliFile = string(modelFile) + mkext(ext, i);
00491 svm_save_model(modeliFile.c_str(), model);
00492
00493 svm_destroy_model(model);
00494 }
00495 for (int j = problem[i].l - 1; j >= 0 ; j--)
00496 delete [] problem[i].x[j];
00497 delete [] problem[i].x;
00498 delete [] problem[i].y;
00499 }
00500 svm_destroy_param(¶m);
00501 }
00502
00503 Sentence* SvmParser::parse(Sentence* sentence)
00504 {
00505 vector<svm_node> nodes(predIndex.size());
00506 ParseState state(*sentence, &info, predIndex);
00507 while (state.hasNext()) {
00508 Tanl::Classifier::Context& preds = *state.next();
00509
00510 sort(preds.begin(), preds.end());
00511 nodes.resize(preds.size() + 1);
00512 int j = 0;
00513 FOR_EACH (vector<Tanl::Classifier::PID>, preds, pit) {
00514 nodes[j].index = *pit + 1;
00515 nodes[j++].value = 1.0;
00516 }
00517 nodes[preds.size()].index = -1;
00518 string code = state.splitFeature;
00519 int i = splitGroup[splits[code.c_str()]];
00520 double prediction = svm_predict(model[i], &nodes[0]);
00521 string& outcome = classLabels[(int)prediction];
00522 # ifdef DUMP
00523 cerr << classIndex[rightOutcome];
00524 FOR_EACH (vector<Tanl::Classifier::PID>, preds, pit)
00525 cerr << " " << *pit << ":1";
00526 cerr << endl;
00527 # endif
00528 if (!state.transition(outcome.c_str())) {
00529 state.transition("S");
00530 }
00531 }
00532 return state.getSentence();
00533 }
00534
00535
00536
00537
00538
00539
00540
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565
00566
00567
00568
00569
00570 }