00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00026
00027 #ifdef _WIN32
00028 #include "lib/strtok_r.h"
00029 #endif
00030
00031 #include "conf/conf_int.h"
00032 #include "Parser.h"
00033 #include "EventStream.h"
00034 #include "State.h"
00035
00036
00037 #include "GIS.h"
00038 #include "LBFGS.h"
00039
00040
00041 #include <math.h>
00042
00043 using namespace std;
00044
00045 namespace Parser {
00046
00047 IXE::conf<int> iterations("MEiter", 60);
00048
00049
00050 #define BEAM
00051
00055 struct MeParser : public Parser
00056 {
00057 MeParser(char const* modelFile);
00058
00059 void train(SentenceReader* sentenceReader, char const* modelFile);
00060
00061 Sentence* parse(Sentence* sentence);
00062
00063 void revise(SentenceReader* sentenceReader, char const* actionFile = 0);
00064
00065 Tanl::Classifier::MaxEnt model;
00066 };
00067
00071 Parser* MeParserFactory(char const* modelFile = 0)
00072 {
00073 return new MeParser(modelFile);
00074 }
00075
00076 REGISTER_PARSER(ME, MeParserFactory);
00077
00078 MeParser::MeParser(char const* modelFile) :
00079 Parser(model.PredIndex())
00080 {
00081 if (!modelFile)
00082 return;
00083 ifstream ifs(modelFile);
00084 if (!ifs)
00085 throw IXE::FileError(string("Missing model file: ") + modelFile);
00086
00087 readHeader(ifs);
00088 model.load(ifs);
00089
00090 info.load(ifs);
00091 ifs.close();
00092 }
00093
00094 void MeParser::train(SentenceReader* sentenceReader, char const* modelFile)
00095 {
00096 EventStream eventStream(sentenceReader, &info);
00097 Tanl::Classifier::LBFGS model(iterations, featureCutoff);
00098 model.verbose = verbose;
00099 model.read(eventStream);
00100 ofstream ofs(modelFile, ios::binary | ios::trunc);
00101
00102 writeHeader(ofs);
00103 model.train();
00104
00105 model.writeHeader(ofs);
00106 model.writeData(ofs);
00107
00108 info.save(ofs);
00109 }
00110
00111 #ifdef BEAM
00112
00113 static double addState(ParseState* s, vector<ParseState*>& states)
00114 {
00115 int size = states.size();
00116 if (size == 0) {
00117 states.push_back(s);
00118 return s->lprob;
00119 }
00120 double worst = states[size-1]->lprob;
00121 if (size == beam && s->lprob < worst)
00122 return worst;
00123 TO_EACH (vector<ParseState*>, states, it)
00124 if (s->lprob > (*it)->lprob) {
00125 if (size == beam) {
00126 delete states.back();
00127 states.pop_back();
00128 }
00129 states.insert(it, s);
00130 return states.back()->lprob;
00131 }
00132 if (size < beam)
00133 states.push_back(s);
00134 return states.back()->lprob;
00135 }
00136 #endif
00137
00138 Sentence* MeParser::parse(Sentence* sentence)
00139 {
00140 int numOutcomes = model.NumOutcomes();
00141 double params[numOutcomes];
00142
00143 # ifdef BEAM
00144 vector<ParseState*> currStates; currStates.reserve(beam);
00145 vector<ParseState*> nextStates; nextStates.reserve(beam);
00146 vector<ParseState*>* bestStates = &currStates;
00147 vector<ParseState*>* bestNextStates = &nextStates;
00148 ParseState* state = new ParseState(*sentence, &info, predIndex);
00149 addState(state, *bestStates);
00150
00151 while (true) {
00152 int finished = 0;
00153 int numBest = bestStates->size();
00154
00155 double worstProb = -numeric_limits<double>::infinity();
00156 for (int i = 0; i < numBest; i++) {
00157 state = (*bestStates)[i];
00158 if (state->hasNext()) {
00159 Tanl::Classifier::Context& context = *state->next();
00160 model.estimate(context, params);
00161 for (int o = 0; o < numOutcomes; o++) {
00162 if (params[o] < 1e-4)
00163 continue;
00164 double lprob = log(params[o]) + state->lprob;
00165 if (bestNextStates->size() == beam && lprob < worstProb)
00166 continue;
00167 char const* outcome = model.OutcomeName(o);
00168 ParseState* next = state->transition(outcome);
00169 if (!next) {
00170
00171 state->dispose();
00172 continue;
00173 }
00174 next->lprob = lprob;
00175 worstProb = addState(next, *bestNextStates);
00176 }
00177 } else {
00178 worstProb = addState(state, *bestNextStates);
00179 finished++;
00180 }
00181 }
00182 if (finished == numBest)
00183 break;
00184
00185 vector<ParseState*>* tmp = bestStates;
00186 bestStates = bestNextStates;
00187 bestNextStates = tmp;
00188 bestNextStates->clear();
00189 }
00190 Sentence* s = (*bestStates)[0]->getSentence();
00191 return s;
00192 # else
00193 ParseState* state = new ParseState(*sentence, &info, predIndex);
00194
00195 while (state->hasNext()) {
00196 Tanl::Classifier::Context& context = *state->next();
00197 model.estimate(context, params);
00198 int best = model.BestOutcome(params);
00199 char const* outcome = model.OutcomeName(best);
00200 cerr << outcome << ' ' << params[best] << endl;
00201 ParseState* next = state->transition(outcome);
00202 if (!next)
00203 next = state->transition("S");
00204
00205 state = next;
00206 }
00207 Sentence* s = state->getSentence();
00208 delete state;
00209 return s;
00210 # endif
00211 }
00212
00213 void MeParser::revise(SentenceReader* sentenceReader, char const* actionFile)
00214 {
00215 if (actionFile) {
00216
00217 ifstream ifs(actionFile);
00218 WordIndex predIndex;
00219
00220 ReviseContextStream contextStream(sentenceReader, predIndex);
00221
00222 char line[4000];
00223 while (contextStream.hasNext()) {
00224 ++contextStream.cur;
00225 ifs.getline(line, sizeof(line));
00226 char* next = line;
00227 char const* outcome = strtok_r(0, " \t", &next);
00228 contextStream.actions.push_back(outcome);
00229 }
00230 } else {
00231 int numOutcomes = model.NumOutcomes();
00232 double* params = new double[numOutcomes];
00233 int correct = 0;
00234 int wrong = 0;
00235
00236 ReviseContextStream contextStream(sentenceReader, model.PredIndex());
00237
00238 while (contextStream.hasNext()) {
00239 Tanl::Classifier::Context& context = *contextStream.next();
00240 model.estimate(context, params);
00241 char const* outcome = model.OutcomeName(model.BestOutcome(params));
00242 contextStream.actions.push_back(outcome);
00243 }
00244 }
00245 }
00246
00247 }