1 package eu.fbk.dkm.pikes.raid.pipeline;
2
3 import java.io.IOException;
4 import java.io.Reader;
5 import java.io.Writer;
6 import java.nio.file.Files;
7 import java.util.Collections;
8 import java.util.List;
9 import java.util.Map;
10 import java.util.Properties;
11 import java.util.Set;
12
13 import javax.annotation.Nullable;
14
15 import com.google.common.collect.ImmutableSet;
16 import com.google.common.collect.Iterables;
17 import com.google.common.collect.Lists;
18 import com.google.common.collect.MapMaker;
19 import com.google.common.collect.Maps;
20 import com.google.common.collect.Sets;
21
22 import org.slf4j.Logger;
23 import org.slf4j.LoggerFactory;
24
25 import ixa.kaflib.Dep;
26 import ixa.kaflib.Dep.Path;
27 import ixa.kaflib.ExternalRef;
28 import ixa.kaflib.KAFDocument;
29 import ixa.kaflib.Predicate;
30 import ixa.kaflib.Predicate.Role;
31 import ixa.kaflib.Term;
32
33 import eu.fbk.dkm.pikes.resources.NAFUtils;
34 import eu.fbk.dkm.pikes.resources.WordNet;
35 import eu.fbk.utils.core.Graph;
36 import eu.fbk.utils.svm.Util;
37 import eu.fbk.utils.eval.ConfusionMatrix;
38 import eu.fbk.utils.eval.PrecisionRecall;
39 import eu.fbk.utils.svm.Classifier;
40 import eu.fbk.utils.svm.FeatureStats;
41 import eu.fbk.utils.svm.LabelledVector;
42 import eu.fbk.utils.svm.Vector;
43
44 public class LinkLabeller {
45
46 private static final Logger LOGGER = LoggerFactory.getLogger(LinkLabeller.class);
47
48 private static final Map<KAFDocument, Map<Integer, Graph<Term, String>>> GRAPH_CACHE = new MapMaker()
49 .weakKeys().makeMap();
50
51 private static final int SELECTED = 1;
52
53 private static final int UNSELECTED = 0;
54
55 private final Classifier classifier;
56
57 @Nullable
58 private final Set<String> posPrefixes;
59
60 private LinkLabeller(final Classifier classifier, @Nullable final Iterable<String> posPrefixes) {
61 this.classifier = classifier;
62 this.posPrefixes = posPrefixes == null ? null : ImmutableSet.copyOf(posPrefixes);
63 }
64
65 public static LinkLabeller readFrom(final java.nio.file.Path path) throws IOException {
66 final java.nio.file.Path p = Util.openVFS(path, false);
67 try {
68 final Classifier classifier = Classifier.readFrom(p.resolve("model"));
69 final Properties properties = new Properties();
70 try (Reader in = Files.newBufferedReader(p.resolve("properties"))) {
71 properties.load(in);
72 }
73 final String pos = properties.getProperty("pos");
74 final Set<String> posPrefixes = pos == null ? null : ImmutableSet.copyOf(pos
75 .split(","));
76 return new LinkLabeller(classifier, posPrefixes);
77 } finally {
78 Util.closeVFS(p);
79 }
80 }
81
82 public Map<Term, Float> label(final KAFDocument document, final Term root) {
83
84
85
86
87 final Map<Term, Float> map = Maps.newHashMap();
88 for (final Term term : candidates(document, root, this.posPrefixes)) {
89 final Vector vector = features(document, root, term);
90 final LabelledVector outVector = this.classifier.predict(true, vector);
91 final float p = outVector.getProbability(SELECTED);
92 if (outVector.getLabel() == SELECTED) {
93 map.put(term, p >= 0.5f ? p : 0.5f);
94 } else {
95 map.put(term, p < 0.5f ? p : 0.5f);
96 }
97 }
98 return map;
99 }
100
101 public void writeTo(final java.nio.file.Path path) throws IOException {
102 final java.nio.file.Path p = Util.openVFS(path, false);
103 try {
104 this.classifier.writeTo(p.resolve("model"));
105 final Properties properties = new Properties();
106 if (this.posPrefixes != null) {
107 properties.setProperty("pos", String.join(",", this.posPrefixes));
108 }
109 try (Writer out = Files.newBufferedWriter(p.resolve("properties"))) {
110 properties.store(out, null);
111 }
112 } finally {
113 Util.closeVFS(p);
114 }
115 }
116
117 @Override
118 public boolean equals(final Object object) {
119 if (object == this) {
120 return true;
121 }
122 if (!(object instanceof LinkLabeller)) {
123 return false;
124 }
125 final LinkLabeller other = (LinkLabeller) object;
126 return this.classifier.equals(other.classifier);
127 }
128
129 @Override
130 public int hashCode() {
131 return this.classifier.hashCode();
132 }
133
134 @Override
135 public String toString() {
136 return "LinkLabeller (" + this.classifier.toString() + ")";
137 }
138
139 private static List<Term> candidates(final KAFDocument document, final Term root,
140 @Nullable final Set<String> posPrefixes) {
141
142
143 final Set<Term> nonModifierTerms = Sets.newHashSet();
144
145
146
147 for (final Dep dep : document.getDepsFromTerm(root)) {
148 if (!"COORD".equals(dep.getRfunc()) && !"CONJ".equals(dep.getRfunc())) {
149 candidatesHelper(document, dep.getTo(), nonModifierTerms);
150 }
151 }
152
153
154
155 for (Dep dep = document.getDepToTerm(root); dep != null; dep = document.getDepToTerm(dep
156 .getFrom())) {
157 if ("COORD".equals(dep.getRfunc()) || "CONJ".equals(dep.getRfunc())) {
158 continue;
159 }
160 nonModifierTerms.add(dep.getFrom());
161 final List<Term> queue = Lists.newArrayList(dep.getFrom());
162 while (!queue.isEmpty()) {
163 final Term t = queue.remove(0);
164 for (final Dep dep2 : document.getDepsFromTerm(t)) {
165 if (dep2.getTo().equals(dep.getTo())) {
166 continue;
167 } else if (!"COORD".equals(dep2.getRfunc()) && !"CONJ".equals(dep2.getRfunc())) {
168 candidatesHelper(document, dep2.getTo(), nonModifierTerms);
169 }
170 }
171 }
172 }
173
174
175
176
177
178 final List<Term> candidates = Lists.newArrayList();
179 final java.util.function.Predicate<Term> matcher = posPrefixes == null ? null
180 : NAFUtils.matchExtendedPos(document, posPrefixes.toArray(new String[0]));
181 for (final Term term : document.getTermsBySent(root.getSent())) {
182 if (!term.equals(root)
183 && nonModifierTerms.contains(term)
184 && (matcher == null || matcher.test(term))
185 && (!"V".equals(term.getPos()) || term.equals(NAFUtils.syntacticToSRLHead(
186 document, term)))) {
187 final String s = term.getStr();
188 for (int i = 0; i < s.length(); ++i) {
189 if (Character.isLetter(s.charAt(i))) {
190 candidates.add(term);
191 break;
192 }
193 }
194 }
195 }
196 return candidates;
197 }
198
199 private static void candidatesHelper(final KAFDocument document, final Term term,
200 final Set<Term> nonModifierTerms) {
201
202
203 nonModifierTerms.add(term);
204 for (final Dep dep : document.getDepsFromTerm(term)) {
205 final String func = dep.getRfunc();
206 if (!"NMOD".equals(func) && !"AMOD".equals(func)) {
207 candidatesHelper(document, dep.getTo(), nonModifierTerms);
208 }
209 }
210 }
211
212 private static Vector features(final KAFDocument document, final Term root, final Term node) {
213
214
215 final Vector.Builder builder = Vector.builder();
216
217
218 builder.set("_cluster." + document.getPublic().uri);
219
220 final String rootSST = getReference(root, NAFUtils.RESOURCE_WN_SST, "none")
221 .replaceFirst("[^.]*\\.", "");
222 final String nodeSST = getReference(node, NAFUtils.RESOURCE_WN_SST, "none")
223 .replaceFirst("[^.]*\\.", "");
224 final Dep rootDep = document.getDepToTerm(root);
225 final Dep nodeDep = document.getDepToTerm(node);
226 final Boolean rootActive = NAFUtils.isActiveForm(document, root);
227
228
229 builder.set("root.pos." + root.getMorphofeat());
230 builder.set("root.dep." + (rootDep == null ? "none" : rootDep.getRfunc()));
231 for (final ExternalRef ref : NAFUtils.getRefs(root, NAFUtils.RESOURCE_WN_SYNSET, null)) {
232
233 for (final String synsetID : WordNet.getHypernyms(ref.getReference(), true)) {
234 builder.set("root.wn." + synsetID);
235 }
236 }
237
238 builder.set("root.lemma." + root.getLemma().toLowerCase());
239 builder.set("root.form."
240 + (rootActive == null ? "none" : rootActive ? "active" : "passive"));
241 builder.set("root.sst." + rootSST);
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260 builder.set("node.pos." + node.getMorphofeat());
261 builder.set("node.dep." + (nodeDep == null ? "none" : nodeDep.getRfunc()));
262 for (final ExternalRef ref : NAFUtils.getRefs(node, NAFUtils.RESOURCE_WN_SYNSET, null)) {
263
264 for (final String synsetID : WordNet.getHypernyms(ref.getReference(), true)) {
265 builder.set("node.wn." + synsetID);
266 }
267 }
268
269 builder.set("node.lemma." + node.getLemma().toLowerCase());
270 builder.set("node.named", node.getMorphofeat().startsWith("NNP"));
271 builder.set("node.bbn.", getReferences(node, NAFUtils.RESOURCE_BBN));
272 builder.set("node.sst." + nodeSST);
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303 final Dep.Path path = Dep.Path.create(root, node, document);
304 builder.set("path." + (path == null ? "none" : getSimplifiedPathLabel(path)));
305 for (int i = 0; i < 10; ++i) {
306 if (path.length() <= i) {
307 builder.set("path.lenless." + i);
308 }
309 }
310
311
312 final Graph<Term, String> graph = getSRLGraph(document, root.getSent());
313 for (final Graph.Path<Term, String> srlPath : graph.getPaths(root, node, false, 2)) {
314 final StringBuilder b = new StringBuilder("srl.path.");
315 for (int i = 0; i < srlPath.length(); ++i) {
316 final Term term = srlPath.getVertices().get(i);
317 final Graph.Edge<Term, String> edge = srlPath.getEdges().get(i);
318 if (i > 0 && term.getMorphofeat().startsWith("VB")
319 && !document.getPredicatesByTerm(term).isEmpty()) {
320 b.append("_").append(term.getLemma().toLowerCase());
321 }
322 b.append(edge.getSource().equals(term) ? "_F_" : "_B_").append(edge.getLabel());
323 }
324 builder.set(b.toString());
325 }
326
327
328 return builder.build();
329 }
330
331 private static String getSimplifiedPathLabel(final Path path) {
332
333
334
335 final StringBuilder builder = new StringBuilder();
336 Term term = path.getTerms().get(0);
337 for (int i = 0; i < path.getDeps().size(); ++i) {
338 final Dep dep = path.getDeps().get(i);
339 final String func = dep.getRfunc().toLowerCase();
340 if ("coord".equals(func) || "conj".equals(func)) {
341 continue;
342 }
343 builder.append(func);
344 if (term.equals(dep.getFrom())) {
345 builder.append('D');
346 term = dep.getTo();
347 } else {
348 builder.append('U');
349 term = dep.getFrom();
350 }
351 }
352 return builder.toString();
353 }
354
355 @Nullable
356 private static String getReference(final Object annotation, final String resource,
357 @Nullable final String defaultValue) {
358 final List<String> refs = getReferences(annotation, resource);
359 if (refs.isEmpty()) {
360 return defaultValue;
361 } else if (refs.size() == 1) {
362 return refs.get(0);
363 } else {
364 Collections.sort(refs);
365 return refs.get(0);
366 }
367 }
368
369 private static List<String> getReferences(final Object annotation, final String resource) {
370 final List<String> result = Lists.newArrayList();
371 for (final ExternalRef ref : NAFUtils.getRefs(annotation, resource, null)) {
372 result.add(ref.getReference());
373 }
374 return result;
375 }
376
377 private static Graph<Term, String> getSRLGraph(final KAFDocument document, final int sentence) {
378
379
380 synchronized (GRAPH_CACHE) {
381 final Map<Integer, Graph<Term, String>> map = GRAPH_CACHE.get(document);
382 if (map != null) {
383 final Graph<Term, String> graph = map.get(sentence);
384 if (graph != null) {
385 return graph;
386 }
387 }
388 }
389
390
391 final Graph.Builder<Term, String> builder = Graph.builder();
392 for (final Predicate predicate : document.getPredicates()) {
393 final Term predicateHead = NAFUtils.extractHead(document, predicate.getSpan());
394 if (predicateHead != null) {
395 for (final Role role : predicate.getRoles()) {
396 for (final Term argHead : NAFUtils.extractHeads(document, null, role
397 .getTerms(), NAFUtils.matchExtendedPos(document, "NN", "PRP", "JJP",
398 "DTP", "WP", "VB"))) {
399 builder.addEdges(predicateHead, argHead, role.getSemRole());
400 }
401 }
402 }
403 }
404
405
406 final Graph<Term, String> graph = builder.build();
407 synchronized (GRAPH_CACHE) {
408 Map<Integer, Graph<Term, String>> map = GRAPH_CACHE.get(document);
409 if (map == null) {
410 map = Maps.newHashMap();
411 GRAPH_CACHE.put(document, map);
412 }
413 map.put(sentence, graph);
414 }
415 return graph;
416 }
417
418 public static Trainer train(final String... posPrefixes) {
419 final Set<String> posPrefixesSet = posPrefixes == null || posPrefixes.length == 0 ? null
420 : ImmutableSet.copyOf(posPrefixes);
421 return new Trainer(posPrefixesSet);
422 }
423
424 public static final class Trainer {
425
426 private final List<LabelledVector> trainingSet;
427
428 private final PrecisionRecall.Evaluator coverageEvaluator;
429
430 @Nullable
431 private final Set<String> posPrefixes;
432
433 private Trainer(@Nullable final Set<String> posPrefixes) {
434 this.trainingSet = Lists.newArrayList();
435 this.posPrefixes = posPrefixes;
436 this.coverageEvaluator = PrecisionRecall.evaluator();
437 }
438
439 public void add(final KAFDocument document, final Term root, final Iterable<Term> selected) {
440
441 int numSelected = 0;
442 final List<Term> candidates = candidates(document, root, this.posPrefixes);
443 for (final Term term : candidates) {
444 final int label = Iterables.contains(selected, term) ? SELECTED : UNSELECTED;
445 numSelected += label == SELECTED ? 1 : 0;
446 final LabelledVector vector = features(document, root, term).label(label);
447 this.trainingSet.add(vector);
448 }
449
450 final int goldSelected = Iterables.size(selected);
451 if (goldSelected > 0) {
452 this.coverageEvaluator.add(numSelected, 0, goldSelected - numSelected);
453
454 if (numSelected < goldSelected && LOGGER.isDebugEnabled()) {
455 final List<Term> unselected = Lists.newArrayList();
456 for (final Term term : selected) {
457 if (!candidates.contains(term)) {
458 unselected.add(term);
459 }
460 }
461 LOGGER.debug("Missing candidates: "
462 + unselected
463 + " (root: "
464 + root.toString()
465 + ", sentence: "
466 + KAFDocument.newTermSpan(document.getTermsBySent(root.getSent()))
467 .getStr() + ")");
468 }
469 }
470 }
471
472 public LinkLabeller end(final int gridSize, final boolean analyze) throws IOException {
473
474
475 if (analyze && LOGGER.isInfoEnabled()) {
476 LOGGER.info("Feature analysis (top 30 features):\n{}", FeatureStats.toString(
477 FeatureStats.forVectors(2, this.trainingSet, null).values(), 30));
478 }
479
480
481 LOGGER.info("Maximum achievable performances on training set due to candidate "
482 + "selection criteria: " + this.coverageEvaluator.getResult());
483
484
485 final List<Classifier.Parameters> grid = Lists.newArrayList();
486 for (final float weight : new float[] { 0.25f, 0.5f, 1.0f, 2.0f, 4.0f }) {
487
488
489
490
491
492 grid.addAll(Classifier.Parameters.forLinearLRLossL1Reg(2,
493 new float[] { 1f, weight }, 1f, 1f).grid(Math.max(1, gridSize), 10.0f));
494 }
495 final Classifier classifier = Classifier.train(grid, this.trainingSet,
496 ConfusionMatrix.labelComparator(PrecisionRecall.Measure.F1, 1, true), 100000);
497
498
499 LOGGER.info("Best classifier parameters: {}", classifier.getParameters());
500
501
502 if (analyze && LOGGER.isInfoEnabled()) {
503 final List<LabelledVector> trainingPredictions = classifier.predict(false,
504 this.trainingSet);
505 final ConfusionMatrix matrix = LabelledVector.evaluate(this.trainingSet,
506 trainingPredictions, 2);
507 LOGGER.info("Performances on training set:\n{}", matrix);
508 final ConfusionMatrix crossMatrix = Classifier.crossValidate(
509 classifier.getParameters(), this.trainingSet, 5, -1);
510 LOGGER.info("5-fold cross-validation performances:\n{}", crossMatrix);
511 }
512
513
514 return new LinkLabeller(classifier, this.posPrefixes);
515 }
516
517 }
518
519 }