1   package eu.fbk.dkm.pikes.raid;
2   
3   import java.lang.reflect.Constructor;
4   import java.lang.reflect.InvocationTargetException;
5   import java.nio.file.Path;
6   import java.util.EnumSet;
7   import java.util.Iterator;
8   import java.util.List;
9   import java.util.Properties;
10  import java.util.Set;
11  import java.util.concurrent.atomic.AtomicInteger;
12  import java.util.concurrent.locks.ReadWriteLock;
13  import java.util.concurrent.locks.ReentrantReadWriteLock;
14  import java.util.stream.StreamSupport;
15  
16  import javax.annotation.Nullable;
17  
18  import com.google.common.base.Preconditions;
19  import com.google.common.base.Splitter;
20  import com.google.common.base.Throwables;
21  import com.google.common.collect.ArrayListMultimap;
22  import com.google.common.collect.ImmutableList;
23  import com.google.common.collect.ImmutableSet;
24  import com.google.common.collect.Iterables;
25  import com.google.common.collect.ListMultimap;
26  import com.google.common.collect.Lists;
27  
28  import org.slf4j.Logger;
29  import org.slf4j.LoggerFactory;
30  
31  import ixa.kaflib.KAFDocument;
32  import ixa.kaflib.Opinion;
33  import ixa.kaflib.Opinion.OpinionExpression;
34  
35  import eu.fbk.dkm.pikes.naflib.Corpus;
36  import eu.fbk.dkm.pikes.resources.NAFUtils;
37  import eu.fbk.dkm.pikes.resources.WordNet;
38  import eu.fbk.utils.core.CommandLine;
39  import eu.fbk.utils.core.CommandLine.Type;
40  import eu.fbk.utils.svm.Util;
41  import eu.fbk.rdfpro.util.Tracker;
42  
43  public abstract class Trainer<T extends Extractor> {
44  
45      private static final Logger LOGGER = LoggerFactory.getLogger(Trainer.class);
46  
47      private final EnumSet<Component> components;
48  
49      private final ReadWriteLock lock;
50  
51      private final AtomicInteger numOpinions;
52  
53      private boolean trained;
54  
55      protected Trainer(final Component... components) {
56          this.components = Component.toSet(components);
57          this.lock = new ReentrantReadWriteLock(false);
58          this.numOpinions = new AtomicInteger(0);
59          this.trained = false;
60      }
61  
62      public final Set<Component> components() {
63          return this.components;
64      }
65  
66      public final Trainer<T> add(final Iterable<KAFDocument> documents,
67              @Nullable final Iterable<String> goldLabels) {
68  
69          // Process all the documents using parallelization, holding a read lock meanwhile
70          this.lock.readLock().lock();
71          try {
72              checkNotTrained();
73              StreamSupport.stream(documents.spliterator(), true).forEach(document -> {
74                  Preconditions.checkNotNull(document);
75                  doAdd(document, goldLabels);
76              });
77          } finally {
78              this.lock.readLock().unlock();
79          }
80          return this;
81      }
82  
83      public final Trainer<T> add(final KAFDocument document,
84              @Nullable final Iterable<String> goldLabels) {
85  
86          // Validate and process the document, holding a read lock meanwhile
87          Preconditions.checkNotNull(document);
88          this.lock.readLock().lock();
89          try {
90              checkNotTrained();
91              doAdd(document, goldLabels);
92          } finally {
93              this.lock.readLock().unlock();
94          }
95          return this;
96      }
97  
98      public final T train() {
99  
100         // Complete training ensuring that no add() methods are active meanwhile
101         this.lock.writeLock().lock();
102         try {
103             checkNotTrained();
104             LOGGER.info("Extracted {} opinions", this.numOpinions.get());
105             return doTrain();
106         } catch (final Throwable ex) {
107             throw Throwables.propagate(ex);
108         } finally {
109             this.trained = true;
110             this.lock.writeLock().unlock();
111         }
112     }
113 
114     private void doAdd(final KAFDocument document, @Nullable final Iterable<String> goldLabels) {
115 
116         // Filter the document
117         doFilter(document);
118 
119         // Normalize (split) input opinions and index the resulting opinions by sentence
120         final ListMultimap<Integer, Opinion> opinionsBySentence = ArrayListMultimap.create();
121         synchronized (document) {
122             List<Opinion> opinions;
123             if (goldLabels == null || Iterables.isEmpty(goldLabels)) {
124                 opinions = Lists.newArrayList(document.getOpinions());
125             } else {
126                 opinions = Lists.newArrayList();
127                 for (final String goldLabel : goldLabels) {
128                     opinions.addAll(document.getOpinions(goldLabel));
129                 }
130             }
131             // TODO: this is an hack to deal with VUA non-opinionated fake opinions
132             for (final Iterator<Opinion> i = opinions.iterator(); i.hasNext();) {
133                 final Opinion opinion = i.next();
134                 if (opinion.getPolarity() != null
135                         && opinion.getPolarity().equalsIgnoreCase("NON-OPINIONATED")) {
136                     i.remove();
137                     LOGGER.info("Skipping non-opinionated opinion {}", opinion.getId());
138                 }
139             }
140             for (final Opinion opinion : opinions) {
141                 final OpinionExpression exp = opinion.getOpinionExpression();
142                 opinionsBySentence.put(exp.getSpan().getTargets().get(0).getSent(), opinion);
143             }
144         }
145 
146         // Perform training, processing all the sentences in the document even if without opinions
147         final int numSentences = document.getNumSentences();
148         for (int sentID = 1; sentID <= numSentences; ++sentID) {
149 
150             // Extract all the opinions in the sentence
151             final Opinion opinions[] = opinionsBySentence.get(sentID).toArray(new Opinion[0]);
152 
153             // Perform training
154             try {
155                 doAdd(document, sentID, opinions);
156                 this.numOpinions.addAndGet(opinions.length);
157             } catch (final Throwable ex) {
158                 Throwables.propagate(ex);
159             }
160         }
161     }
162 
163     private void checkNotTrained() {
164         Preconditions.checkState(!this.trained, "Training already completed");
165     }
166 
167     protected void doFilter(final KAFDocument document) {
168         // can be overridden by subclasses
169     }
170 
171     protected abstract void doAdd(KAFDocument document, int sentence, Opinion[] opinions)
172             throws Throwable;
173 
174     protected abstract T doTrain() throws Throwable;
175 
176     @SuppressWarnings("unchecked")
177     public static Trainer<? extends Extractor> create(final Properties properties,
178             final Component... components) {
179 
180         // Select the implementation class and delegate to its constructor
181         String implementationName = properties.getProperty("class");
182         if (implementationName == null) {
183             implementationName = "eu.fbk.dkm.pikes.raid.pipeline.PipelineTrainer"; // default
184         }
185         try {
186             final Class<?> implementationClass = Class.forName(implementationName);
187             final Constructor<?> constructor = implementationClass.getDeclaredConstructor(
188                     Properties.class, Component[].class);
189             constructor.setAccessible(true);
190             return (Trainer<? extends Extractor>) constructor.newInstance(properties, components);
191         } catch (final InvocationTargetException ex) {
192             throw Throwables.propagate(ex);
193         } catch (final NoSuchMethodException | ClassNotFoundException | IllegalAccessException
194                 | InstantiationException ex) {
195             throw new IllegalArgumentException("Could not instantiate class " + implementationName);
196         }
197     }
198 
199     public static void main(final String... args) {
200 
201         try {
202             // Parse command line
203             final CommandLine cmd = CommandLine
204                     .parser()
205                     .withName("fssa-train")
206                     .withHeader(
207                             "Train the extractor of opinion expressions, holders and targets "
208                                     + "given a set of annotated NAF files.")
209                     .withOption("p", "properties", "a sequence of key=value properties, used to " //
210                             + "select and configure the trainer", "PROPS", Type.STRING, true,
211                             false, false)
212                     .withOption("c", "components", "the opinion components to consider: " //
213                             + "(e)xpression, (h)older, (t)arget, (p)olarity", "COMP", Type.STRING,
214                             true, false, false)
215                     .withOption("l", "labels", "the labels of gold opinions to consider, comma " //
216                             + "separated  (no spaces)", "LABELS", Type.STRING, true, false, false)
217                     .withOption("r", "recursive",
218                             "recurse into subdirectories of specified input paths")
219                     .withOption("@", "list",
220                             "interprets input as list of file names, one per line")
221                     .withOption(null, "wordnet", "wordnet dict path", "PATH",
222                             Type.DIRECTORY_EXISTING, true, false, false)
223                     .withOption("o", "output", "the output model file", "FILE", Type.FILE, true,
224                             false, false)
225                     .withOption("s", "split",
226                             "splits the supplied NAF files based on the supplied " //
227                                     + "seed:ratio spec, using only the first part for training",
228                             "RATIO", Type.STRING, true, false, false)
229                     .withFooter(
230                             "Zero or more input paths can be specified, corresponding either "
231                                     + "to NAF files or directories that are scanned for NAF "
232                                     + "files. If the list is empty, an input NAF file will be "
233                                     + "read from the standard input. If no output path is "
234                                     + "specified (-o), the model is written to standard output.")
235                     .withLogger(LoggerFactory.getLogger("eu.fbk")) //
236                     .parse(args);
237 
238             // Extract options
239             final Properties properties = Util.parseProperties(cmd.getOptionValue("p",
240                     String.class, ""));
241             final Component[] components = Component.forLetters(
242                     cmd.getOptionValue("c", String.class, "")).toArray(new Component[0]);
243             final Set<String> labels = ImmutableSet.copyOf(Splitter.on(',').omitEmptyStrings()
244                     .split(cmd.getOptionValue("l", String.class, "")));
245             final boolean recursive = cmd.hasOption("r");
246             final boolean list = cmd.hasOption("@");
247             final Path outputPath = cmd.getOptionValue("o", Path.class, null);
248             final String split = cmd.getOptionValue("s", String.class);
249             final List<Path> inputPaths = Lists.newArrayList(cmd.getArgs(Path.class));
250 
251             final String wordnetPath = cmd.getOptionValue("wordnet", String.class);
252             if (wordnetPath != null) {
253                 WordNet.setPath(wordnetPath);
254             }
255 
256             // Setup the trainer
257             final Trainer<? extends Extractor> trainer = create(properties, components);
258 
259             // Identify input
260             final List<Path> files = Util.fileMatch(inputPaths, ImmutableList.of(".naf",
261                     ".naf.gz", ".naf.bz2", ".naf.xz", ".xml", ".xml.gz", ".xml.bz2", ".xml.xz"),
262                     recursive, list);
263             Iterable<KAFDocument> documents = files != null ? Corpus.create(false, files)
264                     : ImmutableList.of(NAFUtils.readDocument(null));
265 
266             // Split training set, if required
267             if (split != null && documents instanceof Corpus) {
268                 final int index = split.indexOf(':');
269                 long seed = 0;
270                 float ratio = 1.0f;
271                 if (index >= 0) {
272                     seed = Long.parseLong(split.substring(0, index));
273                     ratio = Float.parseFloat(split.substring(index + 1));
274                 } else {
275                     ratio = Float.parseFloat(split);
276                 }
277                 final Corpus corpus = (Corpus) documents;
278                 documents = corpus.split(seed, ratio, 1.0f - ratio)[0];
279             }
280 
281             // Perform the extraction
282             final Tracker tracker = new Tracker(LOGGER, null, //
283                     "Processed %d NAF files (%d NAF/s avg)", //
284                     "Processed %d NAF files (%d NAF/s, %d NAF/s avg)");
285             tracker.start();
286             StreamSupport.stream(documents.spliterator(), false).forEach(
287                     (final KAFDocument document) -> {
288                         trainer.add(document, labels);
289                         tracker.increment();
290                     });
291             tracker.end();
292 
293             // Complete the training and save the model
294             final Extractor extractor = trainer.train();
295             extractor.writeTo(outputPath);
296 
297         } catch (final Throwable ex) {
298             CommandLine.fail(ex);
299         }
300     }
301 }