1 package eu.fbk.dkm.pikes.raid.pipeline;
2
3 import java.io.IOException;
4 import java.util.List;
5 import java.util.Properties;
6 import java.util.Set;
7
8 import javax.annotation.Nullable;
9
10 import com.google.common.collect.HashMultimap;
11 import com.google.common.collect.Lists;
12 import com.google.common.collect.Multimap;
13 import com.google.common.collect.Ordering;
14 import com.google.common.collect.Sets;
15
16 import org.slf4j.Logger;
17 import org.slf4j.LoggerFactory;
18
19 import ixa.kaflib.KAFDocument;
20 import ixa.kaflib.Opinion;
21 import ixa.kaflib.Span;
22 import ixa.kaflib.Term;
23
24 import eu.fbk.dkm.pikes.raid.Component;
25 import eu.fbk.dkm.pikes.raid.Opinions;
26 import eu.fbk.dkm.pikes.raid.Trainer;
27 import eu.fbk.dkm.pikes.resources.NAFFilter;
28 import eu.fbk.dkm.pikes.resources.NAFUtils;
29
30 public final class PipelineTrainer extends Trainer<PipelineExtractor> {
31
32 private static final Logger LOGGER = LoggerFactory.getLogger(PipelineTrainer.class);
33
34 @Nullable
35 private final LinkLabeller.Trainer holderLinkTrainer;
36
37 @Nullable
38 private final LinkLabeller.Trainer targetLinkTrainer;
39
40 @Nullable
41 private final SpanLabeller.Trainer holderSpanTrainer;
42
43 @Nullable
44 private final SpanLabeller.Trainer targetSpanTrainer;
45
46 private final int linkGridSize;
47
48 private final int spanGridSize;
49
50 private final boolean analyze;
51
52 private final boolean jointSpan;
53
54 private final boolean holderUnique;
55
56 private final boolean targetUnique;
57
58 private final boolean fastTrainer;
59
60 private final NAFFilter filter;
61
62 public PipelineTrainer(final Properties properties, final Component... components) {
63 super(components);
64 final boolean hasHolder = components().contains(Component.HOLDER);
65 final boolean hasTarget = components().contains(Component.TARGET);
66 this.holderUnique = Boolean.parseBoolean(properties.getProperty("holder.unique", "false"));
67 this.targetUnique = Boolean.parseBoolean(properties.getProperty("target.unique", "false"));
68 this.fastTrainer = Boolean.parseBoolean(properties.getProperty("fast", "false"));
69 this.linkGridSize = Integer.parseInt(properties.getProperty("gridsize.link", "25"));
70 this.spanGridSize = Integer.parseInt(properties.getProperty("gridsize.span", "25"));
71 this.analyze = Boolean.parseBoolean(properties.getProperty("analyze", "true"));
72 this.jointSpan = Boolean.parseBoolean(properties.getProperty("joint", "false"));
73 this.holderLinkTrainer = hasHolder ? LinkLabeller.train("NN", "PRP", "JJP", "DTP", "WP")
74 : null;
75 this.targetLinkTrainer = hasTarget ? LinkLabeller.train("NN", "PRP", "JJP", "DTP", "WP",
76 "VB") : null;
77 if (this.jointSpan) {
78 final SpanLabeller.Trainer t = hasHolder || hasTarget ? SpanLabeller.train() : null;
79 this.holderSpanTrainer = t;
80 this.targetSpanTrainer = t;
81 } else {
82 this.holderSpanTrainer = hasHolder ? SpanLabeller.train() : null;
83 this.targetSpanTrainer = hasTarget ? SpanLabeller.train() : null;
84 }
85 this.filter = NAFFilter.builder(false).withTermSenseCompletion(true)
86 .withEntityAddition(true).withEntityRemoveOverlaps(true)
87 .withEntitySpanFixing(true).withSRLPredicateAddition(true)
88 .withSRLRemoveWrongRefs(true).withSRLSelfArgFixing(true).build();
89 }
90
91 @Override
92 protected void doFilter(final KAFDocument document) {
93 this.filter.accept(document);
94 }
95
96 @Override
97 protected synchronized void doAdd(final KAFDocument document, final int sentence,
98 final Opinion[] opinions) {
99 addExpressions(document, sentence, opinions);
100 addArguments(document, sentence, opinions);
101 }
102
103 @Override
104 protected synchronized PipelineExtractor doTrain() throws IOException {
105
106
107
108
109 LinkLabeller holderLinkLabeller = null;
110 LinkLabeller targetLinkLabeller = null;
111 if (components().contains(Component.HOLDER)) {
112 LOGGER.info("====== Training holder link labeller ======");
113 holderLinkLabeller = this.holderLinkTrainer.end(this.linkGridSize, this.analyze);
114 }
115 if (components().contains(Component.TARGET)) {
116 LOGGER.info("====== Training target link labeller ======");
117 targetLinkLabeller = this.targetLinkTrainer.end(this.linkGridSize, this.analyze);
118 }
119
120
121 SpanLabeller holderSpanLabeller = null;
122 SpanLabeller targetSpanLabeller = null;
123 if (this.jointSpan) {
124 if (holderLinkLabeller != null || targetLinkLabeller != null) {
125 LOGGER.info("====== Training joint holder/target span labeller ======");
126 final SpanLabeller labeller = this.holderSpanTrainer.end(this.spanGridSize,
127 this.analyze, this.fastTrainer);
128 holderSpanLabeller = holderLinkLabeller == null ? null : labeller;
129 targetSpanLabeller = targetLinkLabeller == null ? null : labeller;
130 }
131 } else {
132 if (holderLinkLabeller != null) {
133 LOGGER.info("====== Training holder span labeller ======");
134 holderSpanLabeller = this.holderSpanTrainer.end(this.spanGridSize, this.analyze,
135 this.fastTrainer);
136 }
137 if (targetLinkLabeller != null) {
138 LOGGER.info("====== Training target span labeller ======");
139 targetSpanLabeller = this.targetSpanTrainer.end(this.spanGridSize, this.analyze,
140 this.fastTrainer);
141 }
142 }
143
144
145 return new PipelineExtractor(holderLinkLabeller, targetLinkLabeller, holderSpanLabeller,
146 targetSpanLabeller, this.holderUnique, this.targetUnique);
147 }
148
149 private void addExpressions(final KAFDocument document, final int sentence,
150 final Opinion[] opinions) {
151
152 }
153
154 private void addArguments(final KAFDocument document, final int sentence,
155 final Opinion[] opinions) {
156
157
158 final Set<Term> expressionHeads = Sets.newHashSet();
159 final Multimap<Term, Span<Term>> holderSpans = HashMultimap.create();
160 final Multimap<Term, Span<Term>> targetSpans = HashMultimap.create();
161 for (final Opinion opinion : opinions) {
162 final Set<Term> heads = Opinions.heads(document,
163 NAFUtils.normalizeSpan(document, opinion.getExpressionSpan()),
164 Component.EXPRESSION);
165 if (!heads.isEmpty()) {
166 final Term head = Ordering.from(Term.OFFSET_COMPARATOR).max(heads);
167 expressionHeads.add(head);
168 final Span<Term> holderSpan = opinion.getHolderSpan();
169 final Span<Term> targetSpan = opinion.getTargetSpan();
170 if (holderSpan != null) {
171 holderSpans.putAll(head, NAFUtils.splitSpan(document, holderSpan,
172 Opinions.heads(document, holderSpan, Component.HOLDER)));
173 }
174 if (targetSpan != null) {
175 targetSpans.putAll(head, NAFUtils.splitSpan(document, targetSpan,
176 Opinions.heads(document, targetSpan, Component.TARGET)));
177 }
178 }
179 }
180
181
182 for (final Term expressionHead : expressionHeads) {
183 if (components().contains(Component.HOLDER)) {
184 addArguments(document, sentence, expressionHead, holderSpans.get(expressionHead),
185 this.holderLinkTrainer, this.holderSpanTrainer);
186 }
187 if (components().contains(Component.TARGET)) {
188 addArguments(document, sentence, expressionHead, targetSpans.get(expressionHead),
189 this.targetLinkTrainer, this.targetSpanTrainer);
190 }
191 }
192 }
193
194 private void addArguments(final KAFDocument document, final int sentence,
195 final Term expressionHead, final Iterable<Span<Term>> argSpans,
196 final LinkLabeller.Trainer linkTrainer, final SpanLabeller.Trainer spanTrainer) {
197
198
199 final List<Term> heads = Lists.newArrayList();
200 final List<Span<Term>> spans = Lists.newArrayList();
201 for (final Span<Term> span : argSpans) {
202 final Term head = NAFUtils.extractHead(document, span);
203 if (head != null) {
204 heads.add(head);
205 spans.add(span);
206 }
207 }
208
209
210 linkTrainer.add(document, expressionHead, heads);
211
212
213 for (int i = 0; i < heads.size(); ++i) {
214 final List<Term> excludedTerms = Lists.newArrayList(heads);
215 excludedTerms.remove(heads.get(i));
216 spanTrainer.add(document, heads.get(i), excludedTerms, spans.get(i));
217 }
218 }
219
220 }