1   package eu.fbk.dkm.pikes.raid.pipeline;
2   
3   import java.io.IOException;
4   import java.io.Reader;
5   import java.io.Writer;
6   import java.nio.file.Files;
7   import java.util.Collections;
8   import java.util.List;
9   import java.util.Map;
10  import java.util.Properties;
11  import java.util.Set;
12  
13  import javax.annotation.Nullable;
14  
15  import com.google.common.collect.ImmutableSet;
16  import com.google.common.collect.Iterables;
17  import com.google.common.collect.Lists;
18  import com.google.common.collect.MapMaker;
19  import com.google.common.collect.Maps;
20  import com.google.common.collect.Sets;
21  
22  import org.slf4j.Logger;
23  import org.slf4j.LoggerFactory;
24  
25  import ixa.kaflib.Dep;
26  import ixa.kaflib.Dep.Path;
27  import ixa.kaflib.ExternalRef;
28  import ixa.kaflib.KAFDocument;
29  import ixa.kaflib.Predicate;
30  import ixa.kaflib.Predicate.Role;
31  import ixa.kaflib.Term;
32  
33  import eu.fbk.dkm.pikes.resources.NAFUtils;
34  import eu.fbk.dkm.pikes.resources.WordNet;
35  import eu.fbk.utils.core.Graph;
36  import eu.fbk.utils.svm.Util;
37  import eu.fbk.utils.eval.ConfusionMatrix;
38  import eu.fbk.utils.eval.PrecisionRecall;
39  import eu.fbk.utils.svm.Classifier;
40  import eu.fbk.utils.svm.FeatureStats;
41  import eu.fbk.utils.svm.LabelledVector;
42  import eu.fbk.utils.svm.Vector;
43  
44  public class LinkLabeller {
45  
46      private static final Logger LOGGER = LoggerFactory.getLogger(LinkLabeller.class);
47  
48      private static final Map<KAFDocument, Map<Integer, Graph<Term, String>>> GRAPH_CACHE = new MapMaker()
49              .weakKeys().makeMap();
50  
51      private static final int SELECTED = 1;
52  
53      private static final int UNSELECTED = 0;
54  
55      private final Classifier classifier;
56  
57      @Nullable
58      private final Set<String> posPrefixes;
59  
60      private LinkLabeller(final Classifier classifier, @Nullable final Iterable<String> posPrefixes) {
61          this.classifier = classifier;
62          this.posPrefixes = posPrefixes == null ? null : ImmutableSet.copyOf(posPrefixes);
63      }
64  
65      public static LinkLabeller readFrom(final java.nio.file.Path path) throws IOException {
66          final java.nio.file.Path p = Util.openVFS(path, false);
67          try {
68              final Classifier classifier = Classifier.readFrom(p.resolve("model"));
69              final Properties properties = new Properties();
70              try (Reader in = Files.newBufferedReader(p.resolve("properties"))) {
71                  properties.load(in);
72              }
73              final String pos = properties.getProperty("pos");
74              final Set<String> posPrefixes = pos == null ? null : ImmutableSet.copyOf(pos
75                      .split(","));
76              return new LinkLabeller(classifier, posPrefixes);
77          } finally {
78              Util.closeVFS(p);
79          }
80      }
81  
82      public Map<Term, Float> label(final KAFDocument document, final Term root) {
83  
84          
85          
86          
87          final Map<Term, Float> map = Maps.newHashMap();
88          for (final Term term : candidates(document, root, this.posPrefixes)) {
89              final Vector vector = features(document, root, term);
90              final LabelledVector outVector = this.classifier.predict(true, vector);
91              final float p = outVector.getProbability(SELECTED);
92              if (outVector.getLabel() == SELECTED) {
93                  map.put(term, p >= 0.5f ? p : 0.5f);
94              } else {
95                  map.put(term, p < 0.5f ? p : 0.5f);
96              }
97          }
98          return map;
99      }
100 
101     public void writeTo(final java.nio.file.Path path) throws IOException {
102         final java.nio.file.Path p = Util.openVFS(path, false);
103         try {
104             this.classifier.writeTo(p.resolve("model"));
105             final Properties properties = new Properties();
106             if (this.posPrefixes != null) {
107                 properties.setProperty("pos", String.join(",", this.posPrefixes));
108             }
109             try (Writer out = Files.newBufferedWriter(p.resolve("properties"))) {
110                 properties.store(out, null);
111             }
112         } finally {
113             Util.closeVFS(p);
114         }
115     }
116 
117     @Override
118     public boolean equals(final Object object) {
119         if (object == this) {
120             return true;
121         }
122         if (!(object instanceof LinkLabeller)) {
123             return false;
124         }
125         final LinkLabeller other = (LinkLabeller) object;
126         return this.classifier.equals(other.classifier);
127     }
128 
129     @Override
130     public int hashCode() {
131         return this.classifier.hashCode();
132     }
133 
134     @Override
135     public String toString() {
136         return "LinkLabeller (" + this.classifier.toString() + ")";
137     }
138 
139     private static List<Term> candidates(final KAFDocument document, final Term root,
140             @Nullable final Set<String> posPrefixes) {
141 
142         
143         final Set<Term> nonModifierTerms = Sets.newHashSet();
144 
145         
146         
147         for (final Dep dep : document.getDepsFromTerm(root)) {
148             if (!"COORD".equals(dep.getRfunc()) && !"CONJ".equals(dep.getRfunc())) {
149                 candidatesHelper(document, dep.getTo(), nonModifierTerms);
150             }
151         }
152 
153         
154         
155         for (Dep dep = document.getDepToTerm(root); dep != null; dep = document.getDepToTerm(dep
156                 .getFrom())) {
157             if ("COORD".equals(dep.getRfunc()) || "CONJ".equals(dep.getRfunc())) {
158                 continue; 
159             }
160             nonModifierTerms.add(dep.getFrom());
161             final List<Term> queue = Lists.newArrayList(dep.getFrom());
162             while (!queue.isEmpty()) {
163                 final Term t = queue.remove(0);
164                 for (final Dep dep2 : document.getDepsFromTerm(t)) {
165                     if (dep2.getTo().equals(dep.getTo())) {
166                         continue;
167                     } else if (!"COORD".equals(dep2.getRfunc()) && !"CONJ".equals(dep2.getRfunc())) {
168                         candidatesHelper(document, dep2.getTo(), nonModifierTerms);
169                     }
170                 }
171             }
172         }
173 
174         
175         
176         
177         
178         final List<Term> candidates = Lists.newArrayList();
179         final java.util.function.Predicate<Term> matcher = posPrefixes == null ? null 
180                 : NAFUtils.matchExtendedPos(document, posPrefixes.toArray(new String[0]));
181         for (final Term term : document.getTermsBySent(root.getSent())) {
182             if (!term.equals(root)
183                     && nonModifierTerms.contains(term)
184                     && (matcher == null || matcher.test(term))
185                     && (!"V".equals(term.getPos()) || term.equals(NAFUtils.syntacticToSRLHead(
186                             document, term)))) {
187                 final String s = term.getStr();
188                 for (int i = 0; i < s.length(); ++i) {
189                     if (Character.isLetter(s.charAt(i))) {
190                         candidates.add(term);
191                         break;
192                     }
193                 }
194             }
195         }
196         return candidates;
197     }
198 
199     private static void candidatesHelper(final KAFDocument document, final Term term,
200             final Set<Term> nonModifierTerms) {
201 
202         
203         nonModifierTerms.add(term);
204         for (final Dep dep : document.getDepsFromTerm(term)) {
205             final String func = dep.getRfunc();
206             if (!"NMOD".equals(func) && !"AMOD".equals(func)) {
207                 candidatesHelper(document, dep.getTo(), nonModifierTerms);
208             }
209         }
210     }
211 
212     private static Vector features(final KAFDocument document, final Term root, final Term node) {
213 
214         
215         final Vector.Builder builder = Vector.builder();
216 
217         
218         builder.set("_cluster." + document.getPublic().uri);
219 
220         final String rootSST = getReference(root, NAFUtils.RESOURCE_WN_SST, "none") 
221                 .replaceFirst("[^.]*\\.", "");
222         final String nodeSST = getReference(node, NAFUtils.RESOURCE_WN_SST, "none") 
223                 .replaceFirst("[^.]*\\.", "");
224         final Dep rootDep = document.getDepToTerm(root);
225         final Dep nodeDep = document.getDepToTerm(node);
226         final Boolean rootActive = NAFUtils.isActiveForm(document, root);
227 
228         
229         builder.set("root.pos." + root.getMorphofeat()); 
230         builder.set("root.dep." + (rootDep == null ? "none" : rootDep.getRfunc())); 
231         for (final ExternalRef ref : NAFUtils.getRefs(root, NAFUtils.RESOURCE_WN_SYNSET, null)) {
232             
233             for (final String synsetID : WordNet.getHypernyms(ref.getReference(), true)) {
234                 builder.set("root.wn." + synsetID);
235             }
236         }
237         
238         builder.set("root.lemma." + root.getLemma().toLowerCase()); 
239         builder.set("root.form."
240                 + (rootActive == null ? "none" : rootActive ? "active" : "passive"));
241         builder.set("root.sst." + rootSST);
242 
243         
244         
245         
246         
247         
248         
249         
250         
251         
252         
253         
254         
255         
256         
257         
258 
259         
260         builder.set("node.pos." + node.getMorphofeat()); 
261         builder.set("node.dep." + (nodeDep == null ? "none" : nodeDep.getRfunc())); 
262         for (final ExternalRef ref : NAFUtils.getRefs(node, NAFUtils.RESOURCE_WN_SYNSET, null)) {
263             
264             for (final String synsetID : WordNet.getHypernyms(ref.getReference(), true)) {
265                 builder.set("node.wn." + synsetID);
266             }
267         }
268         
269         builder.set("node.lemma." + node.getLemma().toLowerCase());
270         builder.set("node.named", node.getMorphofeat().startsWith("NNP"));
271         builder.set("node.bbn.", getReferences(node, NAFUtils.RESOURCE_BBN));
272         builder.set("node.sst." + nodeSST);
273 
274         
275         
276         
277         
278         
279         
280         
281         
282         
283         
284         
285         
286         
287         
288         
289 
290         
291         
292         
293         
294         
295         
296         
297         
298         
299         
300         
301 
302         
303         final Dep.Path path = Dep.Path.create(root, node, document);
304         builder.set("path." + (path == null ? "none" : getSimplifiedPathLabel(path)));
305         for (int i = 0; i < 10; ++i) {
306             if (path.length() <= i) {
307                 builder.set("path.lenless." + i);
308             }
309         }
310 
311         
312         final Graph<Term, String> graph = getSRLGraph(document, root.getSent());
313         for (final Graph.Path<Term, String> srlPath : graph.getPaths(root, node, false, 2)) {
314             final StringBuilder b = new StringBuilder("srl.path.");
315             for (int i = 0; i < srlPath.length(); ++i) {
316                 final Term term = srlPath.getVertices().get(i);
317                 final Graph.Edge<Term, String> edge = srlPath.getEdges().get(i);
318                 if (i > 0 && term.getMorphofeat().startsWith("VB")
319                         && !document.getPredicatesByTerm(term).isEmpty()) {
320                     b.append("_").append(term.getLemma().toLowerCase());
321                 }
322                 b.append(edge.getSource().equals(term) ? "_F_" : "_B_").append(edge.getLabel());
323             }
324             builder.set(b.toString());
325         }
326 
327         
328         return builder.build();
329     }
330 
331     private static String getSimplifiedPathLabel(final Path path) {
332 
333         
334         
335         final StringBuilder builder = new StringBuilder();
336         Term term = path.getTerms().get(0);
337         for (int i = 0; i < path.getDeps().size(); ++i) {
338             final Dep dep = path.getDeps().get(i);
339             final String func = dep.getRfunc().toLowerCase();
340             if ("coord".equals(func) || "conj".equals(func)) {
341                 continue;
342             }
343             builder.append(func);
344             if (term.equals(dep.getFrom())) {
345                 builder.append('D');
346                 term = dep.getTo();
347             } else {
348                 builder.append('U');
349                 term = dep.getFrom();
350             }
351         }
352         return builder.toString();
353     }
354 
355     @Nullable
356     private static String getReference(final Object annotation, final String resource,
357             @Nullable final String defaultValue) {
358         final List<String> refs = getReferences(annotation, resource);
359         if (refs.isEmpty()) {
360             return defaultValue;
361         } else if (refs.size() == 1) {
362             return refs.get(0);
363         } else {
364             Collections.sort(refs);
365             return refs.get(0);
366         }
367     }
368 
369     private static List<String> getReferences(final Object annotation, final String resource) {
370         final List<String> result = Lists.newArrayList();
371         for (final ExternalRef ref : NAFUtils.getRefs(annotation, resource, null)) {
372             result.add(ref.getReference());
373         }
374         return result;
375     }
376 
377     private static Graph<Term, String> getSRLGraph(final KAFDocument document, final int sentence) {
378 
379         
380         synchronized (GRAPH_CACHE) {
381             final Map<Integer, Graph<Term, String>> map = GRAPH_CACHE.get(document);
382             if (map != null) {
383                 final Graph<Term, String> graph = map.get(sentence);
384                 if (graph != null) {
385                     return graph;
386                 }
387             }
388         }
389 
390         
391         final Graph.Builder<Term, String> builder = Graph.builder();
392         for (final Predicate predicate : document.getPredicates()) {
393             final Term predicateHead = NAFUtils.extractHead(document, predicate.getSpan());
394             if (predicateHead != null) {
395                 for (final Role role : predicate.getRoles()) {
396                     for (final Term argHead : NAFUtils.extractHeads(document, null, role
397                             .getTerms(), NAFUtils.matchExtendedPos(document, "NN", "PRP", "JJP",
398                             "DTP", "WP", "VB"))) {
399                         builder.addEdges(predicateHead, argHead, role.getSemRole());
400                     }
401                 }
402             }
403         }
404 
405         
406         final Graph<Term, String> graph = builder.build();
407         synchronized (GRAPH_CACHE) {
408             Map<Integer, Graph<Term, String>> map = GRAPH_CACHE.get(document);
409             if (map == null) {
410                 map = Maps.newHashMap();
411                 GRAPH_CACHE.put(document, map);
412             }
413             map.put(sentence, graph);
414         }
415         return graph;
416     }
417 
418     public static Trainer train(final String... posPrefixes) {
419         final Set<String> posPrefixesSet = posPrefixes == null || posPrefixes.length == 0 ? null
420                 : ImmutableSet.copyOf(posPrefixes);
421         return new Trainer(posPrefixesSet);
422     }
423 
424     public static final class Trainer {
425 
426         private final List<LabelledVector> trainingSet;
427 
428         private final PrecisionRecall.Evaluator coverageEvaluator;
429 
430         @Nullable
431         private final Set<String> posPrefixes;
432 
433         private Trainer(@Nullable final Set<String> posPrefixes) {
434             this.trainingSet = Lists.newArrayList();
435             this.posPrefixes = posPrefixes;
436             this.coverageEvaluator = PrecisionRecall.evaluator();
437         }
438 
439         public void add(final KAFDocument document, final Term root, final Iterable<Term> selected) {
440 
441             int numSelected = 0;
442             final List<Term> candidates = candidates(document, root, this.posPrefixes);
443             for (final Term term : candidates) {
444                 final int label = Iterables.contains(selected, term) ? SELECTED : UNSELECTED;
445                 numSelected += label == SELECTED ? 1 : 0;
446                 final LabelledVector vector = features(document, root, term).label(label);
447                 this.trainingSet.add(vector);
448             }
449 
450             final int goldSelected = Iterables.size(selected);
451             if (goldSelected > 0) {
452                 this.coverageEvaluator.add(numSelected, 0, goldSelected - numSelected);
453 
454                 if (numSelected < goldSelected && LOGGER.isDebugEnabled()) {
455                     final List<Term> unselected = Lists.newArrayList();
456                     for (final Term term : selected) {
457                         if (!candidates.contains(term)) {
458                             unselected.add(term);
459                         }
460                     }
461                     LOGGER.debug("Missing candidates: "
462                             + unselected
463                             + " (root: "
464                             + root.toString()
465                             + ", sentence: "
466                             + KAFDocument.newTermSpan(document.getTermsBySent(root.getSent()))
467                                     .getStr() + ")");
468                 }
469             }
470         }
471 
472         public LinkLabeller end(final int gridSize, final boolean analyze) throws IOException {
473 
474             
475             if (analyze && LOGGER.isInfoEnabled()) {
476                 LOGGER.info("Feature analysis (top 30 features):\n{}", FeatureStats.toString(
477                         FeatureStats.forVectors(2, this.trainingSet, null).values(), 30));
478             }
479 
480             
481             LOGGER.info("Maximum achievable performances on training set due to candidate "
482                     + "selection criteria: " + this.coverageEvaluator.getResult());
483 
484             
485             final List<Classifier.Parameters> grid = Lists.newArrayList();
486             for (final float weight : new float[] { 0.25f, 0.5f, 1.0f, 2.0f, 4.0f }) {
487                 
488                 
489                 
490                 
491                 
492                 grid.addAll(Classifier.Parameters.forLinearLRLossL1Reg(2,
493                         new float[] { 1f, weight }, 1f, 1f).grid(Math.max(1, gridSize), 10.0f));
494             }
495             final Classifier classifier = Classifier.train(grid, this.trainingSet,
496                     ConfusionMatrix.labelComparator(PrecisionRecall.Measure.F1, 1, true), 100000);
497 
498             
499             LOGGER.info("Best classifier parameters: {}", classifier.getParameters());
500 
501             
502             if (analyze && LOGGER.isInfoEnabled()) {
503                 final List<LabelledVector> trainingPredictions = classifier.predict(false,
504                         this.trainingSet);
505                 final ConfusionMatrix matrix = LabelledVector.evaluate(this.trainingSet,
506                         trainingPredictions, 2);
507                 LOGGER.info("Performances on training set:\n{}", matrix);
508                 final ConfusionMatrix crossMatrix = Classifier.crossValidate(
509                         classifier.getParameters(), this.trainingSet, 5, -1);
510                 LOGGER.info("5-fold cross-validation performances:\n{}", crossMatrix);
511             }
512 
513             
514             return new LinkLabeller(classifier, this.posPrefixes);
515         }
516 
517     }
518 
519 }