1   package eu.fbk.dkm.pikes.raid.sbrs;
2   
3   import com.google.common.collect.*;
4   import eu.fbk.dkm.pikes.raid.Component;
5   import eu.fbk.dkm.pikes.raid.Opinions;
6   import eu.fbk.dkm.pikes.raid.Trainer;
7   import eu.fbk.dkm.pikes.raid.pipeline.LinkLabeller;
8   import eu.fbk.dkm.pikes.raid.pipeline.PipelineExtractor;
9   import eu.fbk.dkm.pikes.raid.pipeline.SpanLabeller;
10  import eu.fbk.dkm.pikes.resources.NAFFilter;
11  import eu.fbk.dkm.pikes.resources.NAFUtils;
12  import ixa.kaflib.KAFDocument;
13  import ixa.kaflib.Opinion;
14  import ixa.kaflib.Span;
15  import ixa.kaflib.Term;
16  
17  import javax.annotation.Nullable;
18  import java.util.List;
19  import java.util.Properties;
20  import java.util.Set;
21  
22  /**
23   * Created by alessio on 20/08/15.
24   */
25  
26  public class SBRSTrainer extends Trainer<PipelineExtractor> {
27  
28  	@Nullable
29  	private final LinkLabeller.Trainer holderLinkTrainer;
30  
31  	@Nullable
32  	private final LinkLabeller.Trainer targetLinkTrainer;
33  
34  	@Nullable
35  	private final SpanLabeller.Trainer holderSpanTrainer;
36  
37  	@Nullable
38  	private final SpanLabeller.Trainer targetSpanTrainer;
39  
40  	private final boolean jointSpan;
41  	private final NAFFilter filter;
42  
43  	public SBRSTrainer(final Properties properties, final Component... components) {
44  		super(components);
45  		final boolean hasHolder = components().contains(Component.HOLDER);
46  		final boolean hasTarget = components().contains(Component.TARGET);
47  		this.jointSpan = Boolean.parseBoolean(properties.getProperty("joint", "true"));
48  		this.holderLinkTrainer = hasHolder ? LinkLabeller.train("NN", "PRP", "JJP", "DTP", "WP") : null;
49  		this.targetLinkTrainer = hasTarget ? LinkLabeller.train("NN", "PRP", "JJP", "DTP", "WP", "VB") : null;
50  
51  		if (this.jointSpan) {
52  			final SpanLabeller.Trainer t = hasHolder || hasTarget ? SpanLabeller.train() : null;
53  			this.holderSpanTrainer = t;
54  			this.targetSpanTrainer = t;
55  		}
56  		else {
57  			this.holderSpanTrainer = hasHolder ? SpanLabeller.train() : null;
58  			this.targetSpanTrainer = hasTarget ? SpanLabeller.train() : null;
59  		}
60  		this.filter = NAFFilter.builder(false).withTermSenseCompletion(true)
61  				.withEntityAddition(true).withEntityRemoveOverlaps(true)
62  				.withEntitySpanFixing(true).withSRLPredicateAddition(true)
63  				.withSRLRemoveWrongRefs(true).withSRLSelfArgFixing(true).build();
64  
65  	}
66  
67  	@Override
68  	protected void doAdd(KAFDocument document, int sentence, Opinion[] opinions) throws Throwable {
69  		addArguments(document, sentence, opinions);
70  
71  	}
72  
73  	@Override
74  	protected PipelineExtractor doTrain() throws Throwable {
75  		return null;
76  	}
77  
78  	private void addArguments(final KAFDocument document, final int sentence,
79  							  final Opinion[] opinions) {
80  
81  		// Index holder and target spans by expression head, keeping track of all exp. heads
82  		final Set<Term> expressionHeads = Sets.newHashSet();
83  		final Multimap<Term, Span<Term>> holderSpans = HashMultimap.create();
84  		final Multimap<Term, Span<Term>> targetSpans = HashMultimap.create();
85  		for (final Opinion opinion : opinions) {
86  			final Set<Term> heads = Opinions.heads(document,
87  					NAFUtils.normalizeSpan(document, opinion.getExpressionSpan()),
88  					Component.EXPRESSION);
89  			if (!heads.isEmpty()) {
90  				final Term head = Ordering.from(Term.OFFSET_COMPARATOR).max(heads);
91  				expressionHeads.add(head);
92  				final Span<Term> holderSpan = opinion.getHolderSpan();
93  				final Span<Term> targetSpan = opinion.getTargetSpan();
94  				if (holderSpan != null) {
95  					holderSpans.putAll(head, NAFUtils.splitSpan(document, holderSpan, //
96  							Opinions.heads(document, holderSpan, Component.HOLDER)));
97  				}
98  				if (targetSpan != null) {
99  					targetSpans.putAll(head, NAFUtils.splitSpan(document, targetSpan, //
100 							Opinions.heads(document, targetSpan, Component.TARGET)));
101 				}
102 			}
103 		}
104 
105 		// Add training samples for holder and target extraction, separately (if enabled)
106 		for (final Term expressionHead : expressionHeads) {
107 			if (components().contains(Component.HOLDER)) {
108 				addArguments(document, sentence, expressionHead, holderSpans.get(expressionHead),
109 						this.holderLinkTrainer, this.holderSpanTrainer);
110 			}
111 			if (components().contains(Component.TARGET)) {
112 				addArguments(document, sentence, expressionHead, targetSpans.get(expressionHead),
113 						this.targetLinkTrainer, this.targetSpanTrainer);
114 			}
115 		}
116 	}
117 
118 	private void addArguments(final KAFDocument document, final int sentence,
119 							  final Term expressionHead, final Iterable<Span<Term>> argSpans,
120 							  final LinkLabeller.Trainer linkTrainer, final SpanLabeller.Trainer spanTrainer) {
121 
122 		// Extract heads and spans of the arguments (only where defined)
123 		final List<Term> heads = Lists.newArrayList();
124 		final List<Span<Term>> spans = Lists.newArrayList();
125 		for (final Span<Term> span : argSpans) {
126 			final Term head = NAFUtils.extractHead(document, span);
127 			if (head != null) {
128 				heads.add(head);
129 				spans.add(span);
130 			}
131 		}
132 
133 		// Add a sample for node labelling
134 		linkTrainer.add(document, expressionHead, heads);
135 
136 		// Add samples for span labelling (one for each argument)
137 		for (int i = 0; i < heads.size(); ++i) {
138 			final List<Term> excludedTerms = Lists.newArrayList(heads);
139 			excludedTerms.remove(heads.get(i));
140 			spanTrainer.add(document, heads.get(i), excludedTerms, spans.get(i));
141 		}
142 	}
143 }