00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00026
00027 #include "platform.h"
00028
00029 #ifdef _WIN32
00030 #include <io.h>
00031 #include <algorithm>
00032 #endif
00033
00034 #include "Parser.h"
00035 #include "EventStream.h"
00036 #include "svm.h"
00037
00038
00039 #include "conf/conf_string.h"
00040
00041
00042 #include <algorithm>
00043 #include <iostream>
00044 #include <list>
00045
00046 using namespace std;
00047
00048 #define MAX_LINE_LEN 8196
00049
00050 namespace Parser {
00051
00053 extern IXE::conf<string> svmParams;
00054
00059 struct MultiSvmParser : public Parser
00060 {
00061 MultiSvmParser(char const* modelFile);
00062
00063 ~MultiSvmParser() {
00064 for (unsigned i = 0; i < model.size(); i++)
00065 svm_destroy_model(model[i]);
00066 }
00067
00068 void train(SentenceReader* sentenceReader, char const* modelFile);
00069 Sentence* parse(Sentence* sentence);
00070
00071
00072 WordIndex classIndex;
00073 vector<string> classLabels;
00074 vector<struct svm_model*> model;
00075 };
00076
00080 Parser* MultiSvmParserFactory(char const* modelFile = 0)
00081 {
00082 return new MultiSvmParser(modelFile);
00083 }
00084
00085 REGISTER_PARSER(MSVM, MultiSvmParserFactory);
00086
00087 extern void parseParameters(svm_parameter& param, char* parameters);
00088
00089 static int compare_nodes(const void* a, const void* b) {
00090 return ((svm_node const*)a)->index - ((svm_node const*)b)->index;
00091 }
00092
00093 static const char* actType = "AD";
00094
00095 enum ActionType { Shift, Reduce };
00096
00097 MultiSvmParser::MultiSvmParser(char const* modelFile) :
00098 Parser(predIndex)
00099 {
00100 if (!modelFile)
00101 return;
00102 ifstream ifs(modelFile);
00103 if (!ifs)
00104 throw IXE::FileError(string("Missing model file: ") + modelFile);;
00105
00106 readHeader(ifs);
00107
00108 char line[MAX_LINE_LEN];
00109 if (!ifs.getline(line, MAX_LINE_LEN))
00110 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00111 int len = atoi(line);
00112 int n = 0;
00113 while (len--) {
00114 if (!ifs.getline(line, MAX_LINE_LEN))
00115 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00116 classIndex[(char const*)line] = n++;
00117 classLabels.push_back(line);
00118 }
00119
00120 if (!ifs.getline(line, MAX_LINE_LEN))
00121 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00122 len = atoi(line);
00123 n = 0;
00124 while (len--) {
00125 if (!ifs.getline(line, MAX_LINE_LEN))
00126 throw IXE::FileError(string("Wrong file format: ") + modelFile);
00127 predIndex[(char const*)line] = n++;
00128 }
00129
00130 int models = 2;
00131 model.resize(models);
00132 for (int i = 0; i < models; i++) {
00133 string modeliFile = (string(modelFile) + '.') + actType[i];
00134 model[i] = svm_load_model(modeliFile.c_str());
00135 if (!model[i])
00136 throw IXE::FileError(string("can't open model file: ") + modeliFile);
00137 }
00138
00139 ifstream ent(modelFile);
00140 if (!ent)
00141 throw IXE::FileError(string("Missing entities file: ") + modelFile);
00142 info.load(ent);
00143 ent.close();
00144 }
00145
00146 void MultiSvmParser::train(SentenceReader* sentenceReader, char const* modelFile)
00147 {
00148 WordIndex labelIndex;
00149 vector<string> labels;
00150
00151 vector<string> predLabels;
00152
00153
00154 list<Tanl::Classifier::Event*> events;
00155
00156 WordCounts predCount;
00157
00158 int actionCount[2] = {0, 0};
00159 int prevAction = Shift;
00160
00161 int evCount = 0;
00162 Tanl::Classifier::PID pID = 0;
00163
00164
00165 EventStream eventStream(sentenceReader, &info);
00166 while (eventStream.hasNext()) {
00167 Tanl::Classifier::Event* ev = eventStream.next();
00168 events.push_back(ev);
00169 evCount++;
00170 if (verbose) {
00171 if (evCount % 10000 == 0)
00172 cerr << '+' << flush;
00173 else if (evCount % 1000 == 0)
00174 cerr << '.' << flush;
00175 }
00176 vector<string>& ec = ev->features;
00177 for (unsigned j = 0; j < ec.size(); j++) {
00178 string& pred = ec[j];
00179
00180 if (predIndex.find(pred.c_str()) == predIndex.end()) {
00181
00182 WordCounts::iterator wcit = predCount.find(pred);
00183
00184 int count;
00185 if (wcit == predCount.end())
00186 count = predCount[pred] = 1;
00187 else
00188 count = ++wcit->second;
00189 if (count >= featureCutoff) {
00190 predLabels.push_back(pred);
00191 predIndex[pred.c_str()] = pID++;
00192 predCount.erase(pred);
00193 }
00194 }
00195 }
00196 actionCount[prevAction]++;
00197 char a = toupper(ev->className[0]);
00198 prevAction = ActionType(a == 'R' || a == 'L');
00199 }
00200 if (verbose)
00201 cerr << endl;
00202
00203
00204 int models = 2;
00205 vector<svm_problem> problem(models);
00206 for (int i = 0; i < models; i++) {
00207 int size = actionCount[i];
00208 problem[i].y = new double[size];
00209 problem[i].x = new svm_node*[size];
00210 problem[i].l = 0;
00211 }
00212 prevAction = Shift;
00213 int nTot = 0;
00214 Tanl::Classifier::ClassID oID = 0;
00215 while (!events.empty()) {
00216 Tanl::Classifier::Event* ev = events.front();
00217 events.pop_front();
00218 char const* c = ev->className.c_str();
00219
00220 vector<string>& ec = ev->features;
00221 svm_node* preds = new svm_node[ec.size()+1];
00222 unsigned k = 0;
00223 for (unsigned j = 0; j < ec.size(); j++) {
00224 string& pred = ec[j];
00225 WordIndex::const_iterator pit = predIndex.find(pred.c_str());
00226 if (pit != predIndex.end()) {
00227 svm_node& node = preds[k++];
00228 node.index = pit->second + 1;
00229 node.value = 1.0;
00230 }
00231 }
00232 if (k) {
00233
00234 qsort(preds, k, sizeof(svm_node), compare_nodes);
00235
00236 svm_node& node = preds[k++];
00237 node.index = -1;
00238 node.value = 1.0;
00239 if (labelIndex.find(c) == labelIndex.end()) {
00240 labelIndex[c] = oID++;
00241 labels.push_back(c);
00242 }
00243 int i = prevAction;
00244 int& ni = problem[i].l;
00245 problem[i].y[ni] = labelIndex[c];
00246
00247 preds = (svm_node*)realloc(preds, k * sizeof(svm_node));
00248 problem[i].x[ni] = preds;
00249 ni++;
00250 nTot++;
00251 if (verbose) {
00252 if (nTot % 10000 == 0)
00253 cerr << '+' << flush;
00254 else if (nTot % 1000 == 0)
00255 cerr << '.' << flush;
00256 }
00257 } else {
00258 cerr << "Discarded event" << endl;
00259 delete preds;
00260 }
00261 char a = toupper(c[0]);
00262 prevAction = ActionType(a == 'R' || a == 'L');
00263 delete ev;
00264 }
00265
00266 if (verbose)
00267 cerr << endl;
00268
00269 ofstream ofs(modelFile, ios::binary | ios::trunc);
00270
00271 writeHeader(ofs);
00272
00273 ofs << labels.size() << endl;
00274 FOR_EACH (vector<string>, labels, pit)
00275 ofs << *pit << endl;
00276
00277 ofs << predLabels.size() << endl;
00278 FOR_EACH (vector<string>, predLabels, pit)
00279 ofs << *pit << endl;
00280
00281 predIndex.clear();
00282 predIndex = WordIndex();
00283 labelIndex.clear();
00284 labelIndex = WordIndex();
00285
00286 info.clearRareEntities();
00287
00288 svm_parameter param;
00289 parseParameters(param, svmParams);
00290
00291 if (dup2(fileno(stderr), fileno(stdout)) < 0)
00292 cerr << "could not redirect stdout to stderr" << endl;
00293
00294 for (int i = 0; i < models; i++) {
00295 struct svm_model* model = svm_train(&problem[i], ¶m);
00296
00297 string modeliFile = (string(modelFile) + '.') + actType[i];
00298 svm_save_model(modeliFile.c_str(), model);
00299
00300 svm_destroy_model(model);
00301 for (int j = problem[i].l - 1; j >= 0 ; j--)
00302 delete [] problem[i].x[j];
00303 delete [] problem[i].x;
00304 delete [] problem[i].y;
00305 }
00306 svm_destroy_param(¶m);
00307 }
00308
00309 Sentence* MultiSvmParser::parse(Sentence* sentence)
00310 {
00311 int prevAction = Shift;
00312 vector<svm_node> nodes(predIndex.size());
00313 ParseState state(*sentence, &info, predIndex);
00314 while (state.hasNext()) {
00315 Tanl::Classifier::Context& preds = *state.next();
00316
00317 sort(preds.begin(), preds.end());
00318 nodes.resize(preds.size() + 1);
00319 int j = 0;
00320 FOR_EACH (vector<Tanl::Classifier::PID>, preds, pit) {
00321 nodes[j].index = *pit + 1;
00322 nodes[j++].value = 1.0;
00323 }
00324 nodes[preds.size()].index = -1;
00325 int i = prevAction;
00326 double prediction = svm_predict(model[i], &nodes[0]);
00327 string& outcome = classLabels[(int)prediction];
00328 # ifdef DUMP
00329 cerr << classIndex[rightOutcome];
00330 FOR_EACH (vector<Tanl::Classifier::PID>, preds, pit)
00331 cerr << " " << *pit << ":1";
00332 cerr << endl;
00333 # endif
00334
00335 char a = toupper(outcome[0]);
00336 prevAction = ActionType(a == 'R' || a == 'L');
00337 if (!state.transition(outcome.c_str())) {
00338 state.transition("S");
00339 }
00340 }
00341 return state.getSentence();
00342 }
00343
00344 }