1   package eu.fbk.dkm.pikes.raid.pipeline;
2   
3   import java.io.IOException;
4   import java.util.List;
5   import java.util.Properties;
6   import java.util.Set;
7   
8   import javax.annotation.Nullable;
9   
10  import com.google.common.collect.HashMultimap;
11  import com.google.common.collect.Lists;
12  import com.google.common.collect.Multimap;
13  import com.google.common.collect.Ordering;
14  import com.google.common.collect.Sets;
15  
16  import org.slf4j.Logger;
17  import org.slf4j.LoggerFactory;
18  
19  import ixa.kaflib.KAFDocument;
20  import ixa.kaflib.Opinion;
21  import ixa.kaflib.Span;
22  import ixa.kaflib.Term;
23  
24  import eu.fbk.dkm.pikes.raid.Component;
25  import eu.fbk.dkm.pikes.raid.Opinions;
26  import eu.fbk.dkm.pikes.raid.Trainer;
27  import eu.fbk.dkm.pikes.resources.NAFFilter;
28  import eu.fbk.dkm.pikes.resources.NAFUtils;
29  
30  public final class PipelineTrainer extends Trainer<PipelineExtractor> {
31  
32      private static final Logger LOGGER = LoggerFactory.getLogger(PipelineTrainer.class);
33  
34      @Nullable
35      private final LinkLabeller.Trainer holderLinkTrainer;
36  
37      @Nullable
38      private final LinkLabeller.Trainer targetLinkTrainer;
39  
40      @Nullable
41      private final SpanLabeller.Trainer holderSpanTrainer;
42  
43      @Nullable
44      private final SpanLabeller.Trainer targetSpanTrainer;
45  
46      private final int linkGridSize;
47  
48      private final int spanGridSize;
49  
50      private final boolean analyze;
51  
52      private final boolean jointSpan;
53  
54      private final boolean holderUnique;
55  
56      private final boolean targetUnique;
57  
58      private final boolean fastTrainer;
59  
60      private final NAFFilter filter;
61  
62      public PipelineTrainer(final Properties properties, final Component... components) {
63          super(components);
64          final boolean hasHolder = components().contains(Component.HOLDER);
65          final boolean hasTarget = components().contains(Component.TARGET);
66          this.holderUnique = Boolean.parseBoolean(properties.getProperty("holder.unique", "false"));
67          this.targetUnique = Boolean.parseBoolean(properties.getProperty("target.unique", "false"));
68          this.fastTrainer = Boolean.parseBoolean(properties.getProperty("fast", "false"));
69          this.linkGridSize = Integer.parseInt(properties.getProperty("gridsize.link", "25"));
70          this.spanGridSize = Integer.parseInt(properties.getProperty("gridsize.span", "25"));
71          this.analyze = Boolean.parseBoolean(properties.getProperty("analyze", "true"));
72          this.jointSpan = Boolean.parseBoolean(properties.getProperty("joint", "false"));
73          this.holderLinkTrainer = hasHolder ? LinkLabeller.train("NN", "PRP", "JJP", "DTP", "WP")
74                  : null;
75          this.targetLinkTrainer = hasTarget ? LinkLabeller.train("NN", "PRP", "JJP", "DTP", "WP",
76                  "VB") : null;
77          if (this.jointSpan) {
78              final SpanLabeller.Trainer t = hasHolder || hasTarget ? SpanLabeller.train() : null;
79              this.holderSpanTrainer = t;
80              this.targetSpanTrainer = t;
81          } else {
82              this.holderSpanTrainer = hasHolder ? SpanLabeller.train() : null;
83              this.targetSpanTrainer = hasTarget ? SpanLabeller.train() : null;
84          }
85          this.filter = NAFFilter.builder(false).withTermSenseCompletion(true)
86                  .withEntityAddition(true).withEntityRemoveOverlaps(true)
87                  .withEntitySpanFixing(true).withSRLPredicateAddition(true)
88                  .withSRLRemoveWrongRefs(true).withSRLSelfArgFixing(true).build();
89      }
90  
91      @Override
92      protected void doFilter(final KAFDocument document) {
93          this.filter.accept(document);
94      }
95  
96      @Override
97      protected synchronized void doAdd(final KAFDocument document, final int sentence,
98              final Opinion[] opinions) {
99          addExpressions(document, sentence, opinions);
100         addArguments(document, sentence, opinions);
101     }
102 
103     @Override
104     protected synchronized PipelineExtractor doTrain() throws IOException {
105 
106         // TODO: Alessio
107 
108         // Train link labellers, if enabled
109         LinkLabeller holderLinkLabeller = null;
110         LinkLabeller targetLinkLabeller = null;
111         if (components().contains(Component.HOLDER)) {
112             LOGGER.info("====== Training holder link labeller ======");
113             holderLinkLabeller = this.holderLinkTrainer.end(this.linkGridSize, this.analyze);
114         }
115         if (components().contains(Component.TARGET)) {
116             LOGGER.info("====== Training target link labeller ======");
117             targetLinkLabeller = this.targetLinkTrainer.end(this.linkGridSize, this.analyze);
118         }
119 
120         // Train span labellers, if enabled
121         SpanLabeller holderSpanLabeller = null;
122         SpanLabeller targetSpanLabeller = null;
123         if (this.jointSpan) {
124             if (holderLinkLabeller != null || targetLinkLabeller != null) {
125                 LOGGER.info("====== Training joint holder/target span labeller ======");
126                 final SpanLabeller labeller = this.holderSpanTrainer.end(this.spanGridSize,
127                         this.analyze, this.fastTrainer);
128                 holderSpanLabeller = holderLinkLabeller == null ? null : labeller;
129                 targetSpanLabeller = targetLinkLabeller == null ? null : labeller;
130             }
131         } else {
132             if (holderLinkLabeller != null) {
133                 LOGGER.info("====== Training holder span labeller ======");
134                 holderSpanLabeller = this.holderSpanTrainer.end(this.spanGridSize, this.analyze,
135                         this.fastTrainer);
136             }
137             if (targetLinkLabeller != null) {
138                 LOGGER.info("====== Training target span labeller ======");
139                 targetSpanLabeller = this.targetSpanTrainer.end(this.spanGridSize, this.analyze,
140                         this.fastTrainer);
141             }
142         }
143 
144         // Build and return the resulting opinion extractor
145         return new PipelineExtractor(holderLinkLabeller, targetLinkLabeller, holderSpanLabeller,
146                 targetSpanLabeller, this.holderUnique, this.targetUnique);
147     }
148 
149     private void addExpressions(final KAFDocument document, final int sentence,
150             final Opinion[] opinions) {
151         // TODO: Alessio
152     }
153 
154     private void addArguments(final KAFDocument document, final int sentence,
155             final Opinion[] opinions) {
156 
157         // Index holder and target spans by expression head, keeping track of all exp. heads
158         final Set<Term> expressionHeads = Sets.newHashSet();
159         final Multimap<Term, Span<Term>> holderSpans = HashMultimap.create();
160         final Multimap<Term, Span<Term>> targetSpans = HashMultimap.create();
161         for (final Opinion opinion : opinions) {
162             final Set<Term> heads = Opinions.heads(document,
163                     NAFUtils.normalizeSpan(document, opinion.getExpressionSpan()),
164                     Component.EXPRESSION);
165             if (!heads.isEmpty()) {
166                 final Term head = Ordering.from(Term.OFFSET_COMPARATOR).max(heads);
167                 expressionHeads.add(head);
168                 final Span<Term> holderSpan = opinion.getHolderSpan();
169                 final Span<Term> targetSpan = opinion.getTargetSpan();
170                 if (holderSpan != null) {
171                     holderSpans.putAll(head, NAFUtils.splitSpan(document, holderSpan, //
172                             Opinions.heads(document, holderSpan, Component.HOLDER)));
173                 }
174                 if (targetSpan != null) {
175                     targetSpans.putAll(head, NAFUtils.splitSpan(document, targetSpan, //
176                             Opinions.heads(document, targetSpan, Component.TARGET)));
177                 }
178             }
179         }
180 
181         // Add training samples for holder and target extraction, separately (if enabled)
182         for (final Term expressionHead : expressionHeads) {
183             if (components().contains(Component.HOLDER)) {
184                 addArguments(document, sentence, expressionHead, holderSpans.get(expressionHead),
185                         this.holderLinkTrainer, this.holderSpanTrainer);
186             }
187             if (components().contains(Component.TARGET)) {
188                 addArguments(document, sentence, expressionHead, targetSpans.get(expressionHead),
189                         this.targetLinkTrainer, this.targetSpanTrainer);
190             }
191         }
192     }
193 
194     private void addArguments(final KAFDocument document, final int sentence,
195             final Term expressionHead, final Iterable<Span<Term>> argSpans,
196             final LinkLabeller.Trainer linkTrainer, final SpanLabeller.Trainer spanTrainer) {
197 
198         // Extract heads and spans of the arguments (only where defined)
199         final List<Term> heads = Lists.newArrayList();
200         final List<Span<Term>> spans = Lists.newArrayList();
201         for (final Span<Term> span : argSpans) {
202             final Term head = NAFUtils.extractHead(document, span);
203             if (head != null) {
204                 heads.add(head);
205                 spans.add(span);
206             }
207         }
208 
209         // Add a sample for node labelling
210         linkTrainer.add(document, expressionHead, heads);
211 
212         // Add samples for span labelling (one for each argument)
213         for (int i = 0; i < heads.size(); ++i) {
214             final List<Term> excludedTerms = Lists.newArrayList(heads);
215             excludedTerms.remove(heads.get(i));
216             spanTrainer.add(document, heads.get(i), excludedTerms, spans.get(i));
217         }
218     }
219 
220 }