1   package eu.fbk.dkm.pikes.raid.pipeline;
2   
3   import java.io.IOException;
4   import java.nio.file.Path;
5   import java.util.Collections;
6   import java.util.Comparator;
7   import java.util.List;
8   import java.util.Set;
9   
10  import javax.annotation.Nullable;
11  
12  import com.google.common.collect.ImmutableList;
13  import com.google.common.collect.ImmutableSet;
14  import com.google.common.collect.Iterables;
15  import com.google.common.collect.Lists;
16  import com.google.common.collect.Ordering;
17  import com.google.common.collect.Sets;
18  
19  import eu.fbk.utils.core.Range;
20  import org.slf4j.Logger;
21  import org.slf4j.LoggerFactory;
22  
23  import ixa.kaflib.Dep;
24  import ixa.kaflib.Entity;
25  import ixa.kaflib.KAFDocument;
26  import ixa.kaflib.Predicate;
27  import ixa.kaflib.Predicate.Role;
28  import ixa.kaflib.Span;
29  import ixa.kaflib.Term;
30  
31  import eu.fbk.dkm.pikes.resources.NAFUtils;
32  import eu.fbk.utils.eval.ConfusionMatrix;
33  import eu.fbk.utils.eval.PrecisionRecall;
34  import eu.fbk.utils.eval.SetPrecisionRecall;
35  import eu.fbk.utils.svm.Classifier;
36  import eu.fbk.utils.svm.FeatureStats;
37  import eu.fbk.utils.svm.LabelledVector;
38  import eu.fbk.utils.svm.Vector;
39  
40  public final class SpanLabeller {
41  
42      private static final Logger LOGGER = LoggerFactory.getLogger(SpanLabeller.class);
43  
44      private final Classifier classifier;
45  
46      private Predictor predictor;
47  
48      private SpanLabeller(final Classifier classifier) {
49          this.classifier = classifier;
50          this.predictor = new Predictor() {
51  
52              @Override
53              public LabelledVector predict(final Vector vector) {
54                  return classifier.predict(false, vector);
55              }
56  
57          };
58      }
59  
60      public static SpanLabeller readFrom(final Path path) throws IOException {
61          return new SpanLabeller(Classifier.readFrom(path));
62      }
63  
64      public Span<Term> expand(final KAFDocument document, final Iterable<Term> heads,
65              final Iterable<Term> excludedTerms, final boolean merge) {
66  
67          // Start expanding all the heads
68          final Set<Term> headSet = ImmutableSet.copyOf(heads);
69          final Set<Term> terms = Sets.newHashSet();
70          for (final Term head : headSet) {
71              final Iterable<Term> exclusion = Iterables.concat(excludedTerms,
72                      Sets.difference(headSet, ImmutableSet.of(head)));
73              terms.addAll(expand(document, head, exclusion).getTargets());
74          }
75  
76          // Return null if no term was selected
77          if (terms.isEmpty()) {
78              return KAFDocument.newTermSpan();
79          }
80  
81          // Merge separate ranges, if requested and possible
82          if (merge) {
83              int startIndex = Integer.MAX_VALUE;
84              int endIndex = Integer.MIN_VALUE;
85              for (final Term term : terms) {
86                  final int index = document.getTerms().indexOf(term);
87                  startIndex = Math.min(startIndex, index);
88                  endIndex = Math.max(endIndex, index);
89              }
90              for (int i = startIndex + 1; i < endIndex; ++i) {
91                  final Term term = document.getTerms().get(i);
92                  if (!terms.contains(term) && terms.contains(document.getTerms().get(i - 1))) {
93                      final Dep dep = document.getDepToTerm(term);
94                      if (dep != null) {
95                          final String func = dep.getRfunc().toUpperCase();
96                          if ((func.contains("COORD") || func.equals("P"))
97                                  && (terms.contains(document.getTerms().get(i + 1)) || i + 2 <= endIndex
98                                          && terms.contains(document.getTerms().get(i + 2)))) {
99                              terms.add(term);
100                         }
101                     }
102                 }
103             }
104         }
105 
106         // Create and return the resulting span, setting all the heads in it
107         final Span<Term> result = KAFDocument.newTermSpan(Ordering.from(Term.OFFSET_COMPARATOR)
108                 .sortedCopy(terms));
109         Iterables.addAll(result.getHeads(), headSet);
110         return result;
111     }
112 
113     public Span<Term> expand(final KAFDocument document, final Term head,
114             final Iterable<Term> excludedTerms) {
115         return expand(document, this.predictor, excludedTerms, getMinimalSpan(document, head));
116     }
117 
118     private static Span<Term> expand(final KAFDocument document, final Predictor predictor,
119             @Nullable final Iterable<Term> marked, final Span<Term> span) {
120 
121         // Build a set of marked term
122         final Set<Term> markedSet = marked == null ? ImmutableSet.of() : ImmutableSet
123                 .copyOf(marked);
124 
125         // Select terms recursively
126         final Set<Term> selection = Sets.newHashSet();
127         expandRecursive(document, predictor, markedSet, span, span, 0, selection);
128 
129         // Create and return resulting span
130         final Span<Term> result = KAFDocument.newTermSpan(Ordering.from(Term.OFFSET_COMPARATOR)
131                 .sortedCopy(selection), NAFUtils.extractHead(document, span));
132         // System.out.println(span.getStr() + " -> " + result.getStr()); // TODO
133         return result;
134     }
135 
136     private static void expandRecursive(final KAFDocument document, final Predictor predictor,
137             final Set<Term> marked, final Span<Term> span, final Span<Term> root, final int depth,
138             final Set<Term> selection) {
139 
140         // Add the supplied span to the selection
141         selection.addAll(span.getTargets());
142 
143         // Build a list of related terms comprising R, Q, V, G, D terms dominated by the head
144         final List<Span<Term>> children = Lists.newArrayList();
145         for (final Term head : span.getTargets()) {
146             final List<Term> queue = Lists.newLinkedList();
147             queue.add(head);
148             while (!queue.isEmpty()) {
149                 final Term term = queue.remove(0);
150                 final List<Dep> deps = document.getDepsFromTerm(term);
151                 if (!deps.isEmpty()) {
152                     for (final Dep dep : deps) {
153                         final Term to = dep.getTo();
154                         if (!span.getTargets().contains(to)) {
155                             if ("PC".contains(to.getPos())) {
156                                 queue.add(to);
157                             } else {
158                                 children.add(getMinimalSpan(document, to));
159                             }
160                         }
161                     }
162                 } else if (!span.getTargets().contains(term)) {
163                     children.add(getMinimalSpan(document, term));
164                 }
165             }
166         }
167 
168         // Sort child spans by absolute offset distance w.r.t. reference span
169         final int offset = getSpanOffset(span);
170         Collections.sort(children, new Comparator<Span<Term>>() {
171 
172             @Override
173             public int compare(final Span<Term> span1, final Span<Term> span2) {
174                 final int distance1 = Math.abs(offset - getSpanOffset(span1));
175                 final int distance2 = Math.abs(offset - getSpanOffset(span2));
176                 return distance1 - distance2;
177             }
178 
179         });
180 
181         // Select terms, relying on the supplied predictor
182         for (final Span<Term> child : children) {
183             final Vector features = features(document, marked, selection, root, span, child, depth);
184             if (predictor.predict(features).getLabel() == 1) {
185                 for (final Term term : child.getTargets()) {
186                     for (Dep dep = document.getDepToTerm(term); dep != null
187                             && !selection.contains(dep.getFrom()); dep = document.getDepToTerm(dep
188                             .getFrom())) {
189                         selection.add(dep.getFrom());
190                     }
191                 }
192                 expandRecursive(document, predictor, marked, child, root, depth + 1, selection);
193             }
194         }
195     }
196 
197     // private static List<String> features(final String prefix, final Term term,
198     // final String externalRefResource) {
199     // final List<String> result = Lists.newArrayList();
200     // for (final ExternalRef ref : NAFUtils.getRefs(term, externalRefResource, null)) {
201     // result.add(prefix + ref.getReference());
202     // }
203     // return result;
204     // }
205 
206     private static List<String> features(final String prefix, final Range parent, final Range child) {
207         final int dist = parent.distance(child);
208         final String direction = child.begin() > parent.begin() ? "after" : "before";
209         final String distance = dist <= 1 ? "adjacent" : dist <= 3 ? "verynear"
210                 : dist <= 6 ? "near" : "far";
211         return ImmutableList.of(prefix + distance, prefix + direction, prefix + direction + "."
212                 + distance);
213     }
214 
215     private static Vector features(final KAFDocument document, final Set<Term> marked,
216             final Set<Term> selection, final Span<Term> root, final Span<Term> parent,
217             final Span<Term> child, final int depth) {
218 
219         // Extract main terms
220         // final Term rootTerm = getMainTerm(document, root);
221         final Term parentTerm = getMainTerm(document, parent);
222         final Term childTerm = getMainTerm(document, child);
223         final Term childHead = NAFUtils.extractHead(document, child);
224 
225         // Compute dependency data
226         Dep dep = document.getDepToTerm(childHead);
227         String path = dep.getRfunc();
228         String pathex = dep.getRfunc() + "-" + dep.getTo().getMorphofeat().substring(0, 1);
229         final Set<Term> connectives = Sets.newHashSet();
230         while (!parent.getTargets().contains(dep.getFrom())) {
231             connectives.add(dep.getFrom());
232             dep = document.getDepToTerm(dep.getFrom());
233             path = dep.getRfunc() + "." + path;
234             pathex = dep.getRfunc() + "-" + dep.getTo().getLemma().toLowerCase() + "." + pathex;
235         }
236 
237         // Allocate a builder for constructing the feature vector
238         final Vector.Builder builder = Vector.builder();
239 
240         // Add document id (not used for training, only for proper CV splitting)
241         builder.set("_cluster." + document.getPublic().uri);
242 
243         // Add term index (not used for training / classification)
244         builder.set("_index", document.getTerms().indexOf(childTerm));
245         builder.set("depth" + depth);
246 
247         // Add features related to relative span positions
248         final Set<Term> descendants = document.getTermsByDepAncestors(Iterables.concat(
249                 child.getTargets(), connectives));
250         final Range rootRange = Range.enclose(NAFUtils.termRangesFor(document, root.getTargets()));
251         final Range parentRange = Range.enclose(NAFUtils.termRangesFor(document,
252                 parent.getTargets()));
253         final Range childRange = Range
254                 .enclose(NAFUtils.termRangesFor(document, child.getTargets()));
255         final Range descendantsRange = Range
256                 .enclose(NAFUtils.termRangesFor(document, descendants));
257         builder.set(features("pos.descroot.", rootRange, descendantsRange));
258         builder.set(features("pos.descparent.", parentRange, descendantsRange));
259         builder.set(features("pos.childparent.", parentRange, childRange));
260 
261         final List<Range> parentRanges = NAFUtils.termRangesFor(document, parent.getTargets());
262         final List<Range> selectionRanges = NAFUtils.termRangesFor(document, selection);
263         final List<Range> descendantRanges = NAFUtils.termRangesFor(document, descendants);
264         builder.set("span.enclosed", Range.enclose(selectionRanges).overlaps(descendantRanges));
265         builder.set("span.connected",
266                 Range.enclose(selectionRanges).connectedWith(descendantRanges));
267         builder.set("span.connected.parent",
268                 Range.enclose(parentRanges).connectedWith(descendantRanges));
269         builder.set("span.marked", marked.contains(childTerm));
270         builder.set("span.marked.descendant", !Sets.intersection(marked, descendants).isEmpty());
271         builder.set("span.depth", depth);
272 
273         // Add root features
274         // builder.set("parent.pos." + parentTerm.getPos());
275         // if (parentTerm != rootTerm) {
276         // builder.set("root.pos." + rootTerm.getMorphofeat().substring(0, 1));
277         // builder.set("root.named", rootTerm.getMorphofeat().startsWith("NNP"));
278         // }
279 
280         // Add parent features
281         // builder.set("parent.pos." + parentTerm.getPos());
282         builder.set("parent.pos." + parentTerm.getMorphofeat().substring(0, 1));
283         builder.set("parent.named", parentTerm.getMorphofeat().startsWith("NNP"));
284         builder.set("parent.lemma." + parentTerm.getLemma().toLowerCase());
285         builder.set("parent.morph." + parentTerm.getMorphofeat());
286 
287         // builder.set("parent.entity", !document.getEntitiesByTerm(parentTerm).isEmpty());
288         // builder.set("parent.timex",
289         // !document.getTimeExsByWF(parentTerm.getWFs().get(0)).isEmpty());
290         // builder.set("parent.predicate", !document.getPredicatesByTerm(parentTerm).isEmpty());
291         // builder.set(features("parent.sst.", parentTerm, NAFUtils.RESOURCE_WN_SST));
292 
293         // Add child features
294         // builder.set("child.pos." + childTerm.getPos());
295         builder.set("child.pos." + childTerm.getMorphofeat().substring(0, 1));
296         builder.set("child.named", childTerm.getMorphofeat().startsWith("NNP"));
297         builder.set("child.lemma." + childTerm.getLemma().toLowerCase());
298         builder.set("child.morph." + childTerm.getMorphofeat());
299         // builder.set("child.entity", !document.getEntitiesByTerm(childTerm).isEmpty());
300         // builder.set("child.timex",
301         // !document.getTimeExsByWF(childTerm.getWFs().get(0)).isEmpty());
302         // builder.set("child.predicate", !document.getPredicatesByTerm(childTerm).isEmpty());
303         // builder.set(features("child.sst.", childTerm, NAFUtils.RESOURCE_WN_SST));
304 
305         // for (final Entity entity : document.getEntitiesByTerm(childTerm)) {
306         // if (entity.getType() != null) {
307         // builder.set("child.entity." + entity.getType().toLowerCase());
308         // }
309         // }
310         // for (final Entity entity : document.getEntitiesBySent(childTerm.getSent())) {
311         // if (entity.getType() != null
312         // && entity.isNamed()
313         // && !Sets.intersection(ImmutableSet.copyOf(entity.getTerms()), descendants)
314         // .isEmpty()) {
315         // builder.set("child.entity");
316         // builder.set("child.entity." + entity.getType().toLowerCase());
317         // }
318         // }
319 
320         for (final Dep childDep : document.getDepsFromTerm(childHead)) {
321             final Term to = childDep.getTo();
322             // final boolean before = to.getOffset() < childHead.getOffset();
323             // builder.set("depdown." + childDep.getRfunc() + "." + (before ? "before" :
324             // "after"));
325             builder.set("depdown." + childDep.getRfunc() + "."
326                     + to.getMorphofeat().substring(0, 1));
327         }
328         // builder.set("dep.path." + path);
329         // builder.set("dep.pospath." + path + "-" + childTerm.getPos());
330         // builder.set("dep.pospath." + parentTerm.getPos() + "-" + path + "-" +
331         // childTerm.getPos());
332         // builder.set("deppos." + dep.getRfunc() + "." + childTerm.getMorphofeat());
333         builder.set("dep." + pathex);
334 
335         builder.set("dep." + (depth == 0 ? "top." : "nested.") + dep.getRfunc());
336         if (dep.getRfunc().equals("COORD")) {
337             // System.out.println("*** " + parent.getStr() + " --> " + child.getStr());
338         }
339         // builder.set("dep.funcpos." + (childIndex > parentIndex ? "after." : "before.")
340         // + dep.getRfunc() + "." + childTerm.getMorphofeat());
341         if (!child.getTargets().contains(dep.getTo())) {
342             // builder.set("dep.funcposconn." + dep.getRfunc() + "." + childTerm.getMorphofeat()
343             // + "." + dep.getTo().getLemma().toLowerCase());
344             // builder.set("dep.funcposconn." + (childIndex > parentIndex ? "after." : "before.")
345             // + dep.getRfunc() + "." + childTerm.getMorphofeat() + "."
346             // + dep.getTo().getLemma().toLowerCase());
347         }
348 
349         // Emit SRL features
350         for (final Predicate predicate : document.getPredicatesByTerm(parentTerm)) {
351             for (final Role role : predicate.getRoles()) {
352                 if (role.getTerms().contains(childTerm)) {
353                     builder.set("srl." + role.getSemRole());
354                 }
355             }
356         }
357 
358         // Finalize the construction and return the resulting feature vector
359         return builder.build();
360     }
361 
362     private static Span<Term> getMinimalSpan(final KAFDocument document, final Term term) {
363 
364         // The minimal span for a term includes all the terms of a named entity covering the input
365         // term, as well as all terms reachable through IM/VC links (this allows keeping together
366         // verbal expressions such as would like, going to do...)
367         final Set<Term> terms = Sets.newHashSet(term);
368         for (final Entity entity : document.getEntitiesByTerm(term)) {
369             if (document.getTermsHead(Iterables.concat(terms, entity.getTerms())) != null) {
370                 terms.addAll(entity.getTerms());
371             }
372         }
373         final List<Term> queue = Lists.newLinkedList(terms);
374         while (!queue.isEmpty()) {
375             final Term t = queue.remove(0);
376             for (final Dep dep : document.getDepsByTerm(t)) {
377                 if (dep.getRfunc().equals("VC") || dep.getRfunc().equals("IM")) {
378                     if (terms.add(dep.getFrom())) {
379                         queue.add(dep.getFrom());
380                     }
381                     if (terms.add(dep.getTo())) {
382                         queue.add(dep.getTo());
383                     }
384                 }
385             }
386         }
387         final Term head = document.getTermsHead(terms);
388         return KAFDocument.newTermSpan(Ordering.from(Term.OFFSET_COMPARATOR).sortedCopy(terms),
389                 head);
390     }
391 
392     private static Term getMainTerm(final KAFDocument document, final Span<Term> span) {
393         @SuppressWarnings("deprecation")
394         Term term = span.getHead();
395         if (term == null) {
396             term = document.getTermsHead(span.getTargets());
397             span.setHead(term);
398         }
399         outer: while (true) {
400             for (final Dep dep : document.getDepsFromTerm(term)) {
401                 if (dep.getRfunc().equals("VC") || dep.getRfunc().equals("IM")) {
402                     term = dep.getTo();
403                     continue outer;
404                 }
405             }
406             break;
407         }
408         return term;
409     }
410 
411     private static int getSpanOffset(final Span<Term> span) {
412         int offset = Integer.MAX_VALUE;
413         for (final Term term : span.getTargets()) {
414             offset = Math.min(term.getOffset(), offset);
415         }
416         return offset;
417     }
418 
419     public void writeTo(final Path path) throws IOException {
420         this.classifier.writeTo(path);
421     }
422 
423     @Override
424     public boolean equals(final Object object) {
425         if (object == this) {
426             return true;
427         }
428         if (!(object instanceof SpanLabeller)) {
429             return false;
430         }
431         final SpanLabeller other = (SpanLabeller) object;
432         return this.classifier.equals(other.classifier);
433     }
434 
435     @Override
436     public int hashCode() {
437         return this.classifier.hashCode();
438     }
439 
440     @Override
441     public String toString() {
442         return "HeadExpander (" + this.classifier.toString() + ")";
443     }
444 
445     public static Trainer train() {
446         return new Trainer();
447     }
448 
449     public static final class Trainer {
450 
451         private final List<LabelledVector> trainingSet;
452 
453         private final SetPrecisionRecall.Evaluator evaluator;
454 
455         private Trainer() {
456             this.trainingSet = Lists.newArrayList();
457             this.evaluator = SetPrecisionRecall.evaluator();
458         }
459 
460         public void add(final KAFDocument document, final Term head,
461                 final Iterable<Term> excluded, final Span<Term> span) {
462 
463             final Set<Term> spanTerms = ImmutableSet.copyOf(span.getTargets());
464             final Set<Term> excludedTerms = ImmutableSet.copyOf(excluded);
465 
466             final Span<Term> outSpan = expand(document, new Predictor() {
467 
468                 @Override
469                 public LabelledVector predict(final Vector vector) {
470                     final Term term = document.getTerms().get((int) vector.getValue("_index"));
471                     final boolean included = spanTerms.contains(term)
472                             && !excludedTerms.contains(term);
473                     final LabelledVector result = vector.label(included ? 1 : 0);
474                     Trainer.this.trainingSet.add(result);
475                     return result;
476                 }
477 
478             }, excluded, getMinimalSpan(document, head));
479 
480             this.evaluator.add(ImmutableList.of(spanTerms),
481                     ImmutableList.of(ImmutableSet.copyOf(outSpan.getTargets())));
482         }
483 
484         public SpanLabeller end(final int gridSize, final boolean analyze,
485                 final boolean fastClassifier) throws IOException {
486 
487             // Emit feature stats if enabled
488             if (analyze && LOGGER.isInfoEnabled()) {
489                 LOGGER.info("Feature analysis (top 30 features):\n{}", FeatureStats.toString(
490                         FeatureStats.forVectors(2, this.trainingSet, null).values(), 30));
491             }
492 
493             // Log the performance penalty caused by the candidate selection algorithm
494             LOGGER.info("Maximum achievable performances on training set due to recursive "
495                     + "algorithm: " + this.evaluator.getResult());
496 
497             // Perform training considering a grid of parameters of the size specified (min 1)
498             // final List<Classifier.Parameters> grid =
499             // Classifier.Parameters.forLinearLRLossL1Reg(2,
500             // new float[] { 1f, 1f }, 1f, 1f).grid(Math.max(1, gridSize), 10.0f);
501             final List<Classifier.Parameters> grid = Lists.newArrayList();
502             for (final float weight : new float[] { 0.25f, 0.5f, 1.0f, 2.0f, 4.0f }) {
503                 // grid.addAll(Classifier.Parameters.forSVMPolyKernel(2, new float[] { 1f, weight
504                 // },
505                 // 1f, 1f, 0.0f, 3).grid(Math.max(1, gridSize), 10.0f));
506                 if (fastClassifier) {
507                     grid.addAll(Classifier.Parameters.forLinearLRLossL1Reg(2,
508                             new float[] { 1f, weight }, 1f, 1f).grid(Math.max(1, gridSize), 10.0f));
509                 } else {
510                     grid.addAll(Classifier.Parameters.forSVMLinearKernel(2,
511                             new float[] { 1f, weight }, 1f).grid(Math.max(1, gridSize), 10.0f));
512                 }
513             }
514             final Classifier classifier = Classifier.train(grid, this.trainingSet,
515                     ConfusionMatrix.labelComparator(PrecisionRecall.Measure.F1, 1, true), 100000);
516 
517             // THE SVM BELOW WAS ORIGINALLY USED
518             // final Classifier.Parameters parameters = Classifier.Parameters.forSVMRBFKernel(2,
519             // new float[] { 1f, 1f }, 1f, .1f);
520             // final Classifier classifier = Classifier.train(parameters, this.trainingSet);
521 
522             // Log parameters of the best classifier
523             LOGGER.info("Best classifier parameters: {}", classifier.getParameters());
524 
525             // Perform cross-validation and emit some performance statistics, if enabled
526             if (analyze && LOGGER.isInfoEnabled()) {
527                 final List<LabelledVector> trainingPredictions = classifier.predict(false,
528                         this.trainingSet);
529                 final ConfusionMatrix matrix = LabelledVector.evaluate(this.trainingSet,
530                         trainingPredictions, 2);
531                 LOGGER.info("Performances on training set:\n{}", matrix);
532                 final ConfusionMatrix crossMatrix = Classifier.crossValidate(
533                         classifier.getParameters(), this.trainingSet, 5, -1);
534                 LOGGER.info("5-fold cross-validation performances:\n{}", crossMatrix);
535             }
536 
537             // Build and return the created span labeller
538             return new SpanLabeller(classifier);
539         }
540 
541     }
542 
543     private interface Predictor {
544 
545         LabelledVector predict(final Vector vector);
546 
547     }
548 
549 }