1 package eu.fbk.dkm.pikes.raid.pipeline;
2
3 import java.io.IOException;
4 import java.nio.file.Path;
5 import java.util.Collections;
6 import java.util.Comparator;
7 import java.util.List;
8 import java.util.Set;
9
10 import javax.annotation.Nullable;
11
12 import com.google.common.collect.ImmutableList;
13 import com.google.common.collect.ImmutableSet;
14 import com.google.common.collect.Iterables;
15 import com.google.common.collect.Lists;
16 import com.google.common.collect.Ordering;
17 import com.google.common.collect.Sets;
18
19 import eu.fbk.utils.core.Range;
20 import org.slf4j.Logger;
21 import org.slf4j.LoggerFactory;
22
23 import ixa.kaflib.Dep;
24 import ixa.kaflib.Entity;
25 import ixa.kaflib.KAFDocument;
26 import ixa.kaflib.Predicate;
27 import ixa.kaflib.Predicate.Role;
28 import ixa.kaflib.Span;
29 import ixa.kaflib.Term;
30
31 import eu.fbk.dkm.pikes.resources.NAFUtils;
32 import eu.fbk.utils.eval.ConfusionMatrix;
33 import eu.fbk.utils.eval.PrecisionRecall;
34 import eu.fbk.utils.eval.SetPrecisionRecall;
35 import eu.fbk.utils.svm.Classifier;
36 import eu.fbk.utils.svm.FeatureStats;
37 import eu.fbk.utils.svm.LabelledVector;
38 import eu.fbk.utils.svm.Vector;
39
40 public final class SpanLabeller {
41
42 private static final Logger LOGGER = LoggerFactory.getLogger(SpanLabeller.class);
43
44 private final Classifier classifier;
45
46 private Predictor predictor;
47
48 private SpanLabeller(final Classifier classifier) {
49 this.classifier = classifier;
50 this.predictor = new Predictor() {
51
52 @Override
53 public LabelledVector predict(final Vector vector) {
54 return classifier.predict(false, vector);
55 }
56
57 };
58 }
59
60 public static SpanLabeller readFrom(final Path path) throws IOException {
61 return new SpanLabeller(Classifier.readFrom(path));
62 }
63
64 public Span<Term> expand(final KAFDocument document, final Iterable<Term> heads,
65 final Iterable<Term> excludedTerms, final boolean merge) {
66
67
68 final Set<Term> headSet = ImmutableSet.copyOf(heads);
69 final Set<Term> terms = Sets.newHashSet();
70 for (final Term head : headSet) {
71 final Iterable<Term> exclusion = Iterables.concat(excludedTerms,
72 Sets.difference(headSet, ImmutableSet.of(head)));
73 terms.addAll(expand(document, head, exclusion).getTargets());
74 }
75
76
77 if (terms.isEmpty()) {
78 return KAFDocument.newTermSpan();
79 }
80
81
82 if (merge) {
83 int startIndex = Integer.MAX_VALUE;
84 int endIndex = Integer.MIN_VALUE;
85 for (final Term term : terms) {
86 final int index = document.getTerms().indexOf(term);
87 startIndex = Math.min(startIndex, index);
88 endIndex = Math.max(endIndex, index);
89 }
90 for (int i = startIndex + 1; i < endIndex; ++i) {
91 final Term term = document.getTerms().get(i);
92 if (!terms.contains(term) && terms.contains(document.getTerms().get(i - 1))) {
93 final Dep dep = document.getDepToTerm(term);
94 if (dep != null) {
95 final String func = dep.getRfunc().toUpperCase();
96 if ((func.contains("COORD") || func.equals("P"))
97 && (terms.contains(document.getTerms().get(i + 1)) || i + 2 <= endIndex
98 && terms.contains(document.getTerms().get(i + 2)))) {
99 terms.add(term);
100 }
101 }
102 }
103 }
104 }
105
106
107 final Span<Term> result = KAFDocument.newTermSpan(Ordering.from(Term.OFFSET_COMPARATOR)
108 .sortedCopy(terms));
109 Iterables.addAll(result.getHeads(), headSet);
110 return result;
111 }
112
113 public Span<Term> expand(final KAFDocument document, final Term head,
114 final Iterable<Term> excludedTerms) {
115 return expand(document, this.predictor, excludedTerms, getMinimalSpan(document, head));
116 }
117
118 private static Span<Term> expand(final KAFDocument document, final Predictor predictor,
119 @Nullable final Iterable<Term> marked, final Span<Term> span) {
120
121
122 final Set<Term> markedSet = marked == null ? ImmutableSet.of() : ImmutableSet
123 .copyOf(marked);
124
125
126 final Set<Term> selection = Sets.newHashSet();
127 expandRecursive(document, predictor, markedSet, span, span, 0, selection);
128
129
130 final Span<Term> result = KAFDocument.newTermSpan(Ordering.from(Term.OFFSET_COMPARATOR)
131 .sortedCopy(selection), NAFUtils.extractHead(document, span));
132
133 return result;
134 }
135
136 private static void expandRecursive(final KAFDocument document, final Predictor predictor,
137 final Set<Term> marked, final Span<Term> span, final Span<Term> root, final int depth,
138 final Set<Term> selection) {
139
140
141 selection.addAll(span.getTargets());
142
143
144 final List<Span<Term>> children = Lists.newArrayList();
145 for (final Term head : span.getTargets()) {
146 final List<Term> queue = Lists.newLinkedList();
147 queue.add(head);
148 while (!queue.isEmpty()) {
149 final Term term = queue.remove(0);
150 final List<Dep> deps = document.getDepsFromTerm(term);
151 if (!deps.isEmpty()) {
152 for (final Dep dep : deps) {
153 final Term to = dep.getTo();
154 if (!span.getTargets().contains(to)) {
155 if ("PC".contains(to.getPos())) {
156 queue.add(to);
157 } else {
158 children.add(getMinimalSpan(document, to));
159 }
160 }
161 }
162 } else if (!span.getTargets().contains(term)) {
163 children.add(getMinimalSpan(document, term));
164 }
165 }
166 }
167
168
169 final int offset = getSpanOffset(span);
170 Collections.sort(children, new Comparator<Span<Term>>() {
171
172 @Override
173 public int compare(final Span<Term> span1, final Span<Term> span2) {
174 final int distance1 = Math.abs(offset - getSpanOffset(span1));
175 final int distance2 = Math.abs(offset - getSpanOffset(span2));
176 return distance1 - distance2;
177 }
178
179 });
180
181
182 for (final Span<Term> child : children) {
183 final Vector features = features(document, marked, selection, root, span, child, depth);
184 if (predictor.predict(features).getLabel() == 1) {
185 for (final Term term : child.getTargets()) {
186 for (Dep dep = document.getDepToTerm(term); dep != null
187 && !selection.contains(dep.getFrom()); dep = document.getDepToTerm(dep
188 .getFrom())) {
189 selection.add(dep.getFrom());
190 }
191 }
192 expandRecursive(document, predictor, marked, child, root, depth + 1, selection);
193 }
194 }
195 }
196
197
198
199
200
201
202
203
204
205
206 private static List<String> features(final String prefix, final Range parent, final Range child) {
207 final int dist = parent.distance(child);
208 final String direction = child.begin() > parent.begin() ? "after" : "before";
209 final String distance = dist <= 1 ? "adjacent" : dist <= 3 ? "verynear"
210 : dist <= 6 ? "near" : "far";
211 return ImmutableList.of(prefix + distance, prefix + direction, prefix + direction + "."
212 + distance);
213 }
214
215 private static Vector features(final KAFDocument document, final Set<Term> marked,
216 final Set<Term> selection, final Span<Term> root, final Span<Term> parent,
217 final Span<Term> child, final int depth) {
218
219
220
221 final Term parentTerm = getMainTerm(document, parent);
222 final Term childTerm = getMainTerm(document, child);
223 final Term childHead = NAFUtils.extractHead(document, child);
224
225
226 Dep dep = document.getDepToTerm(childHead);
227 String path = dep.getRfunc();
228 String pathex = dep.getRfunc() + "-" + dep.getTo().getMorphofeat().substring(0, 1);
229 final Set<Term> connectives = Sets.newHashSet();
230 while (!parent.getTargets().contains(dep.getFrom())) {
231 connectives.add(dep.getFrom());
232 dep = document.getDepToTerm(dep.getFrom());
233 path = dep.getRfunc() + "." + path;
234 pathex = dep.getRfunc() + "-" + dep.getTo().getLemma().toLowerCase() + "." + pathex;
235 }
236
237
238 final Vector.Builder builder = Vector.builder();
239
240
241 builder.set("_cluster." + document.getPublic().uri);
242
243
244 builder.set("_index", document.getTerms().indexOf(childTerm));
245 builder.set("depth" + depth);
246
247
248 final Set<Term> descendants = document.getTermsByDepAncestors(Iterables.concat(
249 child.getTargets(), connectives));
250 final Range rootRange = Range.enclose(NAFUtils.termRangesFor(document, root.getTargets()));
251 final Range parentRange = Range.enclose(NAFUtils.termRangesFor(document,
252 parent.getTargets()));
253 final Range childRange = Range
254 .enclose(NAFUtils.termRangesFor(document, child.getTargets()));
255 final Range descendantsRange = Range
256 .enclose(NAFUtils.termRangesFor(document, descendants));
257 builder.set(features("pos.descroot.", rootRange, descendantsRange));
258 builder.set(features("pos.descparent.", parentRange, descendantsRange));
259 builder.set(features("pos.childparent.", parentRange, childRange));
260
261 final List<Range> parentRanges = NAFUtils.termRangesFor(document, parent.getTargets());
262 final List<Range> selectionRanges = NAFUtils.termRangesFor(document, selection);
263 final List<Range> descendantRanges = NAFUtils.termRangesFor(document, descendants);
264 builder.set("span.enclosed", Range.enclose(selectionRanges).overlaps(descendantRanges));
265 builder.set("span.connected",
266 Range.enclose(selectionRanges).connectedWith(descendantRanges));
267 builder.set("span.connected.parent",
268 Range.enclose(parentRanges).connectedWith(descendantRanges));
269 builder.set("span.marked", marked.contains(childTerm));
270 builder.set("span.marked.descendant", !Sets.intersection(marked, descendants).isEmpty());
271 builder.set("span.depth", depth);
272
273
274
275
276
277
278
279
280
281
282 builder.set("parent.pos." + parentTerm.getMorphofeat().substring(0, 1));
283 builder.set("parent.named", parentTerm.getMorphofeat().startsWith("NNP"));
284 builder.set("parent.lemma." + parentTerm.getLemma().toLowerCase());
285 builder.set("parent.morph." + parentTerm.getMorphofeat());
286
287
288
289
290
291
292
293
294
295 builder.set("child.pos." + childTerm.getMorphofeat().substring(0, 1));
296 builder.set("child.named", childTerm.getMorphofeat().startsWith("NNP"));
297 builder.set("child.lemma." + childTerm.getLemma().toLowerCase());
298 builder.set("child.morph." + childTerm.getMorphofeat());
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320 for (final Dep childDep : document.getDepsFromTerm(childHead)) {
321 final Term to = childDep.getTo();
322
323
324
325 builder.set("depdown." + childDep.getRfunc() + "."
326 + to.getMorphofeat().substring(0, 1));
327 }
328
329
330
331
332
333 builder.set("dep." + pathex);
334
335 builder.set("dep." + (depth == 0 ? "top." : "nested.") + dep.getRfunc());
336 if (dep.getRfunc().equals("COORD")) {
337
338 }
339
340
341 if (!child.getTargets().contains(dep.getTo())) {
342
343
344
345
346
347 }
348
349
350 for (final Predicate predicate : document.getPredicatesByTerm(parentTerm)) {
351 for (final Role role : predicate.getRoles()) {
352 if (role.getTerms().contains(childTerm)) {
353 builder.set("srl." + role.getSemRole());
354 }
355 }
356 }
357
358
359 return builder.build();
360 }
361
362 private static Span<Term> getMinimalSpan(final KAFDocument document, final Term term) {
363
364
365
366
367 final Set<Term> terms = Sets.newHashSet(term);
368 for (final Entity entity : document.getEntitiesByTerm(term)) {
369 if (document.getTermsHead(Iterables.concat(terms, entity.getTerms())) != null) {
370 terms.addAll(entity.getTerms());
371 }
372 }
373 final List<Term> queue = Lists.newLinkedList(terms);
374 while (!queue.isEmpty()) {
375 final Term t = queue.remove(0);
376 for (final Dep dep : document.getDepsByTerm(t)) {
377 if (dep.getRfunc().equals("VC") || dep.getRfunc().equals("IM")) {
378 if (terms.add(dep.getFrom())) {
379 queue.add(dep.getFrom());
380 }
381 if (terms.add(dep.getTo())) {
382 queue.add(dep.getTo());
383 }
384 }
385 }
386 }
387 final Term head = document.getTermsHead(terms);
388 return KAFDocument.newTermSpan(Ordering.from(Term.OFFSET_COMPARATOR).sortedCopy(terms),
389 head);
390 }
391
392 private static Term getMainTerm(final KAFDocument document, final Span<Term> span) {
393 @SuppressWarnings("deprecation")
394 Term term = span.getHead();
395 if (term == null) {
396 term = document.getTermsHead(span.getTargets());
397 span.setHead(term);
398 }
399 outer: while (true) {
400 for (final Dep dep : document.getDepsFromTerm(term)) {
401 if (dep.getRfunc().equals("VC") || dep.getRfunc().equals("IM")) {
402 term = dep.getTo();
403 continue outer;
404 }
405 }
406 break;
407 }
408 return term;
409 }
410
411 private static int getSpanOffset(final Span<Term> span) {
412 int offset = Integer.MAX_VALUE;
413 for (final Term term : span.getTargets()) {
414 offset = Math.min(term.getOffset(), offset);
415 }
416 return offset;
417 }
418
419 public void writeTo(final Path path) throws IOException {
420 this.classifier.writeTo(path);
421 }
422
423 @Override
424 public boolean equals(final Object object) {
425 if (object == this) {
426 return true;
427 }
428 if (!(object instanceof SpanLabeller)) {
429 return false;
430 }
431 final SpanLabeller other = (SpanLabeller) object;
432 return this.classifier.equals(other.classifier);
433 }
434
435 @Override
436 public int hashCode() {
437 return this.classifier.hashCode();
438 }
439
440 @Override
441 public String toString() {
442 return "HeadExpander (" + this.classifier.toString() + ")";
443 }
444
445 public static Trainer train() {
446 return new Trainer();
447 }
448
449 public static final class Trainer {
450
451 private final List<LabelledVector> trainingSet;
452
453 private final SetPrecisionRecall.Evaluator evaluator;
454
455 private Trainer() {
456 this.trainingSet = Lists.newArrayList();
457 this.evaluator = SetPrecisionRecall.evaluator();
458 }
459
460 public void add(final KAFDocument document, final Term head,
461 final Iterable<Term> excluded, final Span<Term> span) {
462
463 final Set<Term> spanTerms = ImmutableSet.copyOf(span.getTargets());
464 final Set<Term> excludedTerms = ImmutableSet.copyOf(excluded);
465
466 final Span<Term> outSpan = expand(document, new Predictor() {
467
468 @Override
469 public LabelledVector predict(final Vector vector) {
470 final Term term = document.getTerms().get((int) vector.getValue("_index"));
471 final boolean included = spanTerms.contains(term)
472 && !excludedTerms.contains(term);
473 final LabelledVector result = vector.label(included ? 1 : 0);
474 Trainer.this.trainingSet.add(result);
475 return result;
476 }
477
478 }, excluded, getMinimalSpan(document, head));
479
480 this.evaluator.add(ImmutableList.of(spanTerms),
481 ImmutableList.of(ImmutableSet.copyOf(outSpan.getTargets())));
482 }
483
484 public SpanLabeller end(final int gridSize, final boolean analyze,
485 final boolean fastClassifier) throws IOException {
486
487
488 if (analyze && LOGGER.isInfoEnabled()) {
489 LOGGER.info("Feature analysis (top 30 features):\n{}", FeatureStats.toString(
490 FeatureStats.forVectors(2, this.trainingSet, null).values(), 30));
491 }
492
493
494 LOGGER.info("Maximum achievable performances on training set due to recursive "
495 + "algorithm: " + this.evaluator.getResult());
496
497
498
499
500
501 final List<Classifier.Parameters> grid = Lists.newArrayList();
502 for (final float weight : new float[] { 0.25f, 0.5f, 1.0f, 2.0f, 4.0f }) {
503
504
505
506 if (fastClassifier) {
507 grid.addAll(Classifier.Parameters.forLinearLRLossL1Reg(2,
508 new float[] { 1f, weight }, 1f, 1f).grid(Math.max(1, gridSize), 10.0f));
509 } else {
510 grid.addAll(Classifier.Parameters.forSVMLinearKernel(2,
511 new float[] { 1f, weight }, 1f).grid(Math.max(1, gridSize), 10.0f));
512 }
513 }
514 final Classifier classifier = Classifier.train(grid, this.trainingSet,
515 ConfusionMatrix.labelComparator(PrecisionRecall.Measure.F1, 1, true), 100000);
516
517
518
519
520
521
522
523 LOGGER.info("Best classifier parameters: {}", classifier.getParameters());
524
525
526 if (analyze && LOGGER.isInfoEnabled()) {
527 final List<LabelledVector> trainingPredictions = classifier.predict(false,
528 this.trainingSet);
529 final ConfusionMatrix matrix = LabelledVector.evaluate(this.trainingSet,
530 trainingPredictions, 2);
531 LOGGER.info("Performances on training set:\n{}", matrix);
532 final ConfusionMatrix crossMatrix = Classifier.crossValidate(
533 classifier.getParameters(), this.trainingSet, 5, -1);
534 LOGGER.info("5-fold cross-validation performances:\n{}", crossMatrix);
535 }
536
537
538 return new SpanLabeller(classifier);
539 }
540
541 }
542
543 private interface Predictor {
544
545 LabelledVector predict(final Vector vector);
546
547 }
548
549 }