1   package eu.fbk.dkm.pikes.raid.pipeline;
2   
3   import java.io.IOException;
4   import java.io.Reader;
5   import java.io.Writer;
6   import java.nio.file.Files;
7   import java.util.Collections;
8   import java.util.List;
9   import java.util.Map;
10  import java.util.Properties;
11  import java.util.Set;
12  
13  import javax.annotation.Nullable;
14  
15  import com.google.common.collect.ImmutableSet;
16  import com.google.common.collect.Iterables;
17  import com.google.common.collect.Lists;
18  import com.google.common.collect.MapMaker;
19  import com.google.common.collect.Maps;
20  import com.google.common.collect.Sets;
21  
22  import org.slf4j.Logger;
23  import org.slf4j.LoggerFactory;
24  
25  import ixa.kaflib.Dep;
26  import ixa.kaflib.Dep.Path;
27  import ixa.kaflib.ExternalRef;
28  import ixa.kaflib.KAFDocument;
29  import ixa.kaflib.Predicate;
30  import ixa.kaflib.Predicate.Role;
31  import ixa.kaflib.Term;
32  
33  import eu.fbk.dkm.pikes.resources.NAFUtils;
34  import eu.fbk.dkm.pikes.resources.WordNet;
35  import eu.fbk.utils.core.Graph;
36  import eu.fbk.utils.svm.Util;
37  import eu.fbk.utils.eval.ConfusionMatrix;
38  import eu.fbk.utils.eval.PrecisionRecall;
39  import eu.fbk.utils.svm.Classifier;
40  import eu.fbk.utils.svm.FeatureStats;
41  import eu.fbk.utils.svm.LabelledVector;
42  import eu.fbk.utils.svm.Vector;
43  
44  public class LinkLabeller {
45  
46      private static final Logger LOGGER = LoggerFactory.getLogger(LinkLabeller.class);
47  
48      private static final Map<KAFDocument, Map<Integer, Graph<Term, String>>> GRAPH_CACHE = new MapMaker()
49              .weakKeys().makeMap();
50  
51      private static final int SELECTED = 1;
52  
53      private static final int UNSELECTED = 0;
54  
55      private final Classifier classifier;
56  
57      @Nullable
58      private final Set<String> posPrefixes;
59  
60      private LinkLabeller(final Classifier classifier, @Nullable final Iterable<String> posPrefixes) {
61          this.classifier = classifier;
62          this.posPrefixes = posPrefixes == null ? null : ImmutableSet.copyOf(posPrefixes);
63      }
64  
65      public static LinkLabeller readFrom(final java.nio.file.Path path) throws IOException {
66          final java.nio.file.Path p = Util.openVFS(path, false);
67          try {
68              final Classifier classifier = Classifier.readFrom(p.resolve("model"));
69              final Properties properties = new Properties();
70              try (Reader in = Files.newBufferedReader(p.resolve("properties"))) {
71                  properties.load(in);
72              }
73              final String pos = properties.getProperty("pos");
74              final Set<String> posPrefixes = pos == null ? null : ImmutableSet.copyOf(pos
75                      .split(","));
76              return new LinkLabeller(classifier, posPrefixes);
77          } finally {
78              Util.closeVFS(p);
79          }
80      }
81  
82      public Map<Term, Float> label(final KAFDocument document, final Term root) {
83  
84          // Identify candidate terms. For each of them, compute features and apply classifier to
85          // determine whether candidate term should be selected and with what probability estimate
86          // Return a map mapping selected terms to their probability estimates
87          final Map<Term, Float> map = Maps.newHashMap();
88          for (final Term term : candidates(document, root, this.posPrefixes)) {
89              final Vector vector = features(document, root, term);
90              final LabelledVector outVector = this.classifier.predict(true, vector);
91              final float p = outVector.getProbability(SELECTED);
92              if (outVector.getLabel() == SELECTED) {
93                  map.put(term, p >= 0.5f ? p : 0.5f);
94              } else {
95                  map.put(term, p < 0.5f ? p : 0.5f);
96              }
97          }
98          return map;
99      }
100 
101     public void writeTo(final java.nio.file.Path path) throws IOException {
102         final java.nio.file.Path p = Util.openVFS(path, false);
103         try {
104             this.classifier.writeTo(p.resolve("model"));
105             final Properties properties = new Properties();
106             if (this.posPrefixes != null) {
107                 properties.setProperty("pos", String.join(",", this.posPrefixes));
108             }
109             try (Writer out = Files.newBufferedWriter(p.resolve("properties"))) {
110                 properties.store(out, null);
111             }
112         } finally {
113             Util.closeVFS(p);
114         }
115     }
116 
117     @Override
118     public boolean equals(final Object object) {
119         if (object == this) {
120             return true;
121         }
122         if (!(object instanceof LinkLabeller)) {
123             return false;
124         }
125         final LinkLabeller other = (LinkLabeller) object;
126         return this.classifier.equals(other.classifier);
127     }
128 
129     @Override
130     public int hashCode() {
131         return this.classifier.hashCode();
132     }
133 
134     @Override
135     public String toString() {
136         return "LinkLabeller (" + this.classifier.toString() + ")";
137     }
138 
139     private static List<Term> candidates(final KAFDocument document, final Term root,
140             @Nullable final Set<String> posPrefixes) {
141 
142         // We build a larger set of candidate terms, always excluding modifiers (AMOD/NMOD)
143         final Set<Term> nonModifierTerms = Sets.newHashSet();
144 
145         // Extract all the terms dominated by root that are not coordinated (COORD/CONJ) with
146         // root, excluding modifiers (AMOD/NMOD) at any level
147         for (final Dep dep : document.getDepsFromTerm(root)) {
148             if (!"COORD".equals(dep.getRfunc()) && !"CONJ".equals(dep.getRfunc())) {
149                 candidatesHelper(document, dep.getTo(), nonModifierTerms);
150             }
151         }
152 
153         // Add all the ancestors of root that are not coordinated (COORD/CONJ) with root. Then,
154         // add all the descendents of those ancestors excluding modifiers (AMOD/NMOD) at any level
155         for (Dep dep = document.getDepToTerm(root); dep != null; dep = document.getDepToTerm(dep
156                 .getFrom())) {
157             if ("COORD".equals(dep.getRfunc()) || "CONJ".equals(dep.getRfunc())) {
158                 continue; // exclude terms coordinated with root
159             }
160             nonModifierTerms.add(dep.getFrom());
161             final List<Term> queue = Lists.newArrayList(dep.getFrom());
162             while (!queue.isEmpty()) {
163                 final Term t = queue.remove(0);
164                 for (final Dep dep2 : document.getDepsFromTerm(t)) {
165                     if (dep2.getTo().equals(dep.getTo())) {
166                         continue;
167                     } else if (!"COORD".equals(dep2.getRfunc()) && !"CONJ".equals(dep2.getRfunc())) {
168                         candidatesHelper(document, dep2.getTo(), nonModifierTerms);
169                     }
170                 }
171             }
172         }
173 
174         // Starting from the computed set, we build the final set of candidate terms by
175         // considering only terms that matches certain POS tags (extended to consider
176         // demonstrative pronouns) and are made only of letters; in case of verbs, we keep only
177         // the SRL predicate head
178         final List<Term> candidates = Lists.newArrayList();
179         final java.util.function.Predicate<Term> matcher = posPrefixes == null ? null //
180                 : NAFUtils.matchExtendedPos(document, posPrefixes.toArray(new String[0]));
181         for (final Term term : document.getTermsBySent(root.getSent())) {
182             if (!term.equals(root)
183                     && nonModifierTerms.contains(term)
184                     && (matcher == null || matcher.test(term))
185                     && (!"V".equals(term.getPos()) || term.equals(NAFUtils.syntacticToSRLHead(
186                             document, term)))) {
187                 final String s = term.getStr();
188                 for (int i = 0; i < s.length(); ++i) {
189                     if (Character.isLetter(s.charAt(i))) {
190                         candidates.add(term);
191                         break;
192                     }
193                 }
194             }
195         }
196         return candidates;
197     }
198 
199     private static void candidatesHelper(final KAFDocument document, final Term term,
200             final Set<Term> nonModifierTerms) {
201 
202         // Recursively add term and all its descendants, stopping when a NMOD/AMOD link is found
203         nonModifierTerms.add(term);
204         for (final Dep dep : document.getDepsFromTerm(term)) {
205             final String func = dep.getRfunc();
206             if (!"NMOD".equals(func) && !"AMOD".equals(func)) {
207                 candidatesHelper(document, dep.getTo(), nonModifierTerms);
208             }
209         }
210     }
211 
212     private static Vector features(final KAFDocument document, final Term root, final Term node) {
213 
214         // Allocate a builder for constructing the feature vector
215         final Vector.Builder builder = Vector.builder();
216 
217         // Add document ID (not used for training, just for proper CV splitting)
218         builder.set("_cluster." + document.getPublic().uri);
219 
220         final String rootSST = getReference(root, NAFUtils.RESOURCE_WN_SST, "none") //
221                 .replaceFirst("[^.]*\\.", "");
222         final String nodeSST = getReference(node, NAFUtils.RESOURCE_WN_SST, "none") //
223                 .replaceFirst("[^.]*\\.", "");
224         final Dep rootDep = document.getDepToTerm(root);
225         final Dep nodeDep = document.getDepToTerm(node);
226         final Boolean rootActive = NAFUtils.isActiveForm(document, root);
227 
228         // Add root features
229         builder.set("root.pos." + root.getMorphofeat()); // JM
230         builder.set("root.dep." + (rootDep == null ? "none" : rootDep.getRfunc())); // JM
231         for (final ExternalRef ref : NAFUtils.getRefs(root, NAFUtils.RESOURCE_WN_SYNSET, null)) {
232             // builder.set("root.wn." + ref.getReference());
233             for (final String synsetID : WordNet.getHypernyms(ref.getReference(), true)) {
234                 builder.set("root.wn." + synsetID);
235             }
236         }
237         // builder.set("root.word." + root.getStr().toLowerCase()); // JM
238         builder.set("root.lemma." + root.getLemma().toLowerCase()); // JM
239         builder.set("root.form."
240                 + (rootActive == null ? "none" : rootActive ? "active" : "passive"));
241         builder.set("root.sst." + rootSST);
242 
243         // Experimental, disabled root features: SUMO and YAGO types
244         // for (final ExternalRef ref : NAFUtils.getRefs(root, NAFUtils.RESOURCE_SUMO, null)) {
245         // builder.set("root.sumo." + ref.getReference());
246         // for (final URI uri : Sumo.getSuperClasses(new URIImpl(Sumo.NAMESPACE
247         // + ref.getReference()))) {
248         // builder.set("root.sumo." + uri.getLocalName());
249         // }
250         // }
251         // for (final ExternalRef ref : NAFUtils.getRefs(root, NAFUtils.RESOURCE_YAGO, null)) {
252         // builder.set("root.yago." + ref.getReference());
253         // for (final URI uri : YagoTaxonomy.getSuperClasses(new URIImpl(YagoTaxonomy.NAMESPACE
254         // + ref.getReference()), true)) {
255         // builder.set("root.yago." + uri.getLocalName());
256         // }
257         // }
258 
259         // Add node features
260         builder.set("node.pos." + node.getMorphofeat()); // JM
261         builder.set("node.dep." + (nodeDep == null ? "none" : nodeDep.getRfunc())); // JM
262         for (final ExternalRef ref : NAFUtils.getRefs(node, NAFUtils.RESOURCE_WN_SYNSET, null)) {
263             // builder.set("node.wn." + ref.getReference());
264             for (final String synsetID : WordNet.getHypernyms(ref.getReference(), true)) {
265                 builder.set("node.wn." + synsetID);
266             }
267         }
268         // builder.set("node.word." + node.getStr().toLowerCase()); // JM
269         builder.set("node.lemma." + node.getLemma().toLowerCase());
270         builder.set("node.named", node.getMorphofeat().startsWith("NNP"));
271         builder.set("node.bbn.", getReferences(node, NAFUtils.RESOURCE_BBN));
272         builder.set("node.sst." + nodeSST);
273 
274         // Experimental, disabled node features: SUMO and YAGO types
275         // for (final ExternalRef ref : NAFUtils.getRefs(node, NAFUtils.RESOURCE_SUMO, null)) {
276         // builder.set("node.sumo." + ref.getReference());
277         // for (final URI uri : Sumo.getSuperClasses(new URIImpl(Sumo.NAMESPACE
278         // + ref.getReference()))) {
279         // builder.set("node.sumo." + uri.getLocalName());
280         // }
281         // }
282         // for (final ExternalRef ref : NAFUtils.getRefs(node, NAFUtils.RESOURCE_YAGO, null)) {
283         // builder.set("node.yago." + ref.getReference());
284         // for (final URI uri : YagoTaxonomy.getSuperClasses(new URIImpl(YagoTaxonomy.NAMESPACE
285         // + ref.getReference()), true)) {
286         // builder.set("node.yago." + uri.getLocalName());
287         // }
288         // }
289 
290         // Add context features (their impact is minimal, hence we disabled them)
291         // final int index = document.getTerms().indexOf(node);
292         // final Term left = index == 0 ? null : document.getTerms().get(index - 1);
293         // final Term right = index == document.getTerms().size() - 1 ? null : document.getTerms()
294         // .get(index + 1);
295         // builder.set("left.pos." + (left == null ? "none" : left.getMorphofeat())); // JM
296         // builder.set("left.word." + (left == null ? "none" : left.getStr().toLowerCase())); //
297         // JM
298         // builder.set("right.pos." + (right == null ? "none" : right.getMorphofeat())); // JM
299         // builder.set("right.word." + (right == null ? "none" : right.getStr().toLowerCase()));
300         // // JM
301 
302         // Add path features
303         final Dep.Path path = Dep.Path.create(root, node, document);
304         builder.set("path." + (path == null ? "none" : getSimplifiedPathLabel(path)));
305         for (int i = 0; i < 10; ++i) {
306             if (path.length() <= i) {
307                 builder.set("path.lenless." + i);
308             }
309         }
310 
311         // Add SRL features
312         final Graph<Term, String> graph = getSRLGraph(document, root.getSent());
313         for (final Graph.Path<Term, String> srlPath : graph.getPaths(root, node, false, 2)) {
314             final StringBuilder b = new StringBuilder("srl.path.");
315             for (int i = 0; i < srlPath.length(); ++i) {
316                 final Term term = srlPath.getVertices().get(i);
317                 final Graph.Edge<Term, String> edge = srlPath.getEdges().get(i);
318                 if (i > 0 && term.getMorphofeat().startsWith("VB")
319                         && !document.getPredicatesByTerm(term).isEmpty()) {
320                     b.append("_").append(term.getLemma().toLowerCase());
321                 }
322                 b.append(edge.getSource().equals(term) ? "_F_" : "_B_").append(edge.getLabel());
323             }
324             builder.set(b.toString());
325         }
326 
327         // Finalize the construction and return the resulting feature vector
328         return builder.build();
329     }
330 
331     private static String getSimplifiedPathLabel(final Path path) {
332 
333         // Filter out coordination edges (COORD|CONJ) and keep track of remaining edge direction
334         // (U=up, D=down in the tree)
335         final StringBuilder builder = new StringBuilder();
336         Term term = path.getTerms().get(0);
337         for (int i = 0; i < path.getDeps().size(); ++i) {
338             final Dep dep = path.getDeps().get(i);
339             final String func = dep.getRfunc().toLowerCase();
340             if ("coord".equals(func) || "conj".equals(func)) {
341                 continue;
342             }
343             builder.append(func);
344             if (term.equals(dep.getFrom())) {
345                 builder.append('D');
346                 term = dep.getTo();
347             } else {
348                 builder.append('U');
349                 term = dep.getFrom();
350             }
351         }
352         return builder.toString();
353     }
354 
355     @Nullable
356     private static String getReference(final Object annotation, final String resource,
357             @Nullable final String defaultValue) {
358         final List<String> refs = getReferences(annotation, resource);
359         if (refs.isEmpty()) {
360             return defaultValue;
361         } else if (refs.size() == 1) {
362             return refs.get(0);
363         } else {
364             Collections.sort(refs);
365             return refs.get(0);
366         }
367     }
368 
369     private static List<String> getReferences(final Object annotation, final String resource) {
370         final List<String> result = Lists.newArrayList();
371         for (final ExternalRef ref : NAFUtils.getRefs(annotation, resource, null)) {
372             result.add(ref.getReference());
373         }
374         return result;
375     }
376 
377     private static Graph<Term, String> getSRLGraph(final KAFDocument document, final int sentence) {
378 
379         // Lookup an existing graph in the cache
380         synchronized (GRAPH_CACHE) {
381             final Map<Integer, Graph<Term, String>> map = GRAPH_CACHE.get(document);
382             if (map != null) {
383                 final Graph<Term, String> graph = map.get(sentence);
384                 if (graph != null) {
385                     return graph;
386                 }
387             }
388         }
389 
390         // Build graph
391         final Graph.Builder<Term, String> builder = Graph.builder();
392         for (final Predicate predicate : document.getPredicates()) {
393             final Term predicateHead = NAFUtils.extractHead(document, predicate.getSpan());
394             if (predicateHead != null) {
395                 for (final Role role : predicate.getRoles()) {
396                     for (final Term argHead : NAFUtils.extractHeads(document, null, role
397                             .getTerms(), NAFUtils.matchExtendedPos(document, "NN", "PRP", "JJP",
398                             "DTP", "WP", "VB"))) {
399                         builder.addEdges(predicateHead, argHead, role.getSemRole());
400                     }
401                 }
402             }
403         }
404 
405         // Finalize graph construction
406         final Graph<Term, String> graph = builder.build();
407         synchronized (GRAPH_CACHE) {
408             Map<Integer, Graph<Term, String>> map = GRAPH_CACHE.get(document);
409             if (map == null) {
410                 map = Maps.newHashMap();
411                 GRAPH_CACHE.put(document, map);
412             }
413             map.put(sentence, graph);
414         }
415         return graph;
416     }
417 
418     public static Trainer train(final String... posPrefixes) {
419         final Set<String> posPrefixesSet = posPrefixes == null || posPrefixes.length == 0 ? null
420                 : ImmutableSet.copyOf(posPrefixes);
421         return new Trainer(posPrefixesSet);
422     }
423 
424     public static final class Trainer {
425 
426         private final List<LabelledVector> trainingSet;
427 
428         private final PrecisionRecall.Evaluator coverageEvaluator;
429 
430         @Nullable
431         private final Set<String> posPrefixes;
432 
433         private Trainer(@Nullable final Set<String> posPrefixes) {
434             this.trainingSet = Lists.newArrayList();
435             this.posPrefixes = posPrefixes;
436             this.coverageEvaluator = PrecisionRecall.evaluator();
437         }
438 
439         public void add(final KAFDocument document, final Term root, final Iterable<Term> selected) {
440 
441             int numSelected = 0;
442             final List<Term> candidates = candidates(document, root, this.posPrefixes);
443             for (final Term term : candidates) {
444                 final int label = Iterables.contains(selected, term) ? SELECTED : UNSELECTED;
445                 numSelected += label == SELECTED ? 1 : 0;
446                 final LabelledVector vector = features(document, root, term).label(label);
447                 this.trainingSet.add(vector);
448             }
449 
450             final int goldSelected = Iterables.size(selected);
451             if (goldSelected > 0) {
452                 this.coverageEvaluator.add(numSelected, 0, goldSelected - numSelected);
453 
454                 if (numSelected < goldSelected && LOGGER.isDebugEnabled()) {
455                     final List<Term> unselected = Lists.newArrayList();
456                     for (final Term term : selected) {
457                         if (!candidates.contains(term)) {
458                             unselected.add(term);
459                         }
460                     }
461                     LOGGER.debug("Missing candidates: "
462                             + unselected
463                             + " (root: "
464                             + root.toString()
465                             + ", sentence: "
466                             + KAFDocument.newTermSpan(document.getTermsBySent(root.getSent()))
467                                     .getStr() + ")");
468                 }
469             }
470         }
471 
472         public LinkLabeller end(final int gridSize, final boolean analyze) throws IOException {
473 
474             // Emit feature stats if enabled
475             if (analyze && LOGGER.isInfoEnabled()) {
476                 LOGGER.info("Feature analysis (top 30 features):\n{}", FeatureStats.toString(
477                         FeatureStats.forVectors(2, this.trainingSet, null).values(), 30));
478             }
479 
480             // Log the performance penalty caused by the candidate selection algorithm
481             LOGGER.info("Maximum achievable performances on training set due to candidate "
482                     + "selection criteria: " + this.coverageEvaluator.getResult());
483 
484             // Perform training considering a grid of parameters of the size specified (min 1)
485             final List<Classifier.Parameters> grid = Lists.newArrayList();
486             for (final float weight : new float[] { 0.25f, 0.5f, 1.0f, 2.0f, 4.0f }) {
487                 // SVM classifiers produce better CV results on training, but overfits on test
488                 // grid.addAll(Classifier.Parameters.forSVMPolyKernel(2, new float[] { 1f, weight
489                 // }, 1f, 1f, 0.0f, 3).grid(Math.max(1, gridSize), 10.0f));
490                 // grid.addAll(Classifier.Parameters.forSVMLinearKernel(2,
491                 // new float[] { 1f, weight }, 1f).grid(Math.max(1, gridSize), 10.0f));
492                 grid.addAll(Classifier.Parameters.forLinearLRLossL1Reg(2,
493                         new float[] { 1f, weight }, 1f, 1f).grid(Math.max(1, gridSize), 10.0f));
494             }
495             final Classifier classifier = Classifier.train(grid, this.trainingSet,
496                     ConfusionMatrix.labelComparator(PrecisionRecall.Measure.F1, 1, true), 100000);
497 
498             // Log parameters of the best classifier
499             LOGGER.info("Best classifier parameters: {}", classifier.getParameters());
500 
501             // Perform cross-validation and emit some performance statistics, if enabled
502             if (analyze && LOGGER.isInfoEnabled()) {
503                 final List<LabelledVector> trainingPredictions = classifier.predict(false,
504                         this.trainingSet);
505                 final ConfusionMatrix matrix = LabelledVector.evaluate(this.trainingSet,
506                         trainingPredictions, 2);
507                 LOGGER.info("Performances on training set:\n{}", matrix);
508                 final ConfusionMatrix crossMatrix = Classifier.crossValidate(
509                         classifier.getParameters(), this.trainingSet, 5, -1);
510                 LOGGER.info("5-fold cross-validation performances:\n{}", crossMatrix);
511             }
512 
513             // Build and return the created link labeller
514             return new LinkLabeller(classifier, this.posPrefixes);
515         }
516 
517     }
518 
519 }