1 package eu.fbk.dkm.pikes.raid;
2
3 import java.lang.reflect.Constructor;
4 import java.lang.reflect.InvocationTargetException;
5 import java.nio.file.Path;
6 import java.util.EnumSet;
7 import java.util.Iterator;
8 import java.util.List;
9 import java.util.Properties;
10 import java.util.Set;
11 import java.util.concurrent.atomic.AtomicInteger;
12 import java.util.concurrent.locks.ReadWriteLock;
13 import java.util.concurrent.locks.ReentrantReadWriteLock;
14 import java.util.stream.StreamSupport;
15
16 import javax.annotation.Nullable;
17
18 import com.google.common.base.Preconditions;
19 import com.google.common.base.Splitter;
20 import com.google.common.base.Throwables;
21 import com.google.common.collect.ArrayListMultimap;
22 import com.google.common.collect.ImmutableList;
23 import com.google.common.collect.ImmutableSet;
24 import com.google.common.collect.Iterables;
25 import com.google.common.collect.ListMultimap;
26 import com.google.common.collect.Lists;
27
28 import org.slf4j.Logger;
29 import org.slf4j.LoggerFactory;
30
31 import ixa.kaflib.KAFDocument;
32 import ixa.kaflib.Opinion;
33 import ixa.kaflib.Opinion.OpinionExpression;
34
35 import eu.fbk.dkm.pikes.naflib.Corpus;
36 import eu.fbk.dkm.pikes.resources.NAFUtils;
37 import eu.fbk.dkm.pikes.resources.WordNet;
38 import eu.fbk.utils.core.CommandLine;
39 import eu.fbk.utils.core.CommandLine.Type;
40 import eu.fbk.utils.svm.Util;
41 import eu.fbk.rdfpro.util.Tracker;
42
43 public abstract class Trainer<T extends Extractor> {
44
45 private static final Logger LOGGER = LoggerFactory.getLogger(Trainer.class);
46
47 private final EnumSet<Component> components;
48
49 private final ReadWriteLock lock;
50
51 private final AtomicInteger numOpinions;
52
53 private boolean trained;
54
55 protected Trainer(final Component... components) {
56 this.components = Component.toSet(components);
57 this.lock = new ReentrantReadWriteLock(false);
58 this.numOpinions = new AtomicInteger(0);
59 this.trained = false;
60 }
61
62 public final Set<Component> components() {
63 return this.components;
64 }
65
66 public final Trainer<T> add(final Iterable<KAFDocument> documents,
67 @Nullable final Iterable<String> goldLabels) {
68
69
70 this.lock.readLock().lock();
71 try {
72 checkNotTrained();
73 StreamSupport.stream(documents.spliterator(), true).forEach(document -> {
74 Preconditions.checkNotNull(document);
75 doAdd(document, goldLabels);
76 });
77 } finally {
78 this.lock.readLock().unlock();
79 }
80 return this;
81 }
82
83 public final Trainer<T> add(final KAFDocument document,
84 @Nullable final Iterable<String> goldLabels) {
85
86
87 Preconditions.checkNotNull(document);
88 this.lock.readLock().lock();
89 try {
90 checkNotTrained();
91 doAdd(document, goldLabels);
92 } finally {
93 this.lock.readLock().unlock();
94 }
95 return this;
96 }
97
98 public final T train() {
99
100
101 this.lock.writeLock().lock();
102 try {
103 checkNotTrained();
104 LOGGER.info("Extracted {} opinions", this.numOpinions.get());
105 return doTrain();
106 } catch (final Throwable ex) {
107 throw Throwables.propagate(ex);
108 } finally {
109 this.trained = true;
110 this.lock.writeLock().unlock();
111 }
112 }
113
114 private void doAdd(final KAFDocument document, @Nullable final Iterable<String> goldLabels) {
115
116
117 doFilter(document);
118
119
120 final ListMultimap<Integer, Opinion> opinionsBySentence = ArrayListMultimap.create();
121 synchronized (document) {
122 List<Opinion> opinions;
123 if (goldLabels == null || Iterables.isEmpty(goldLabels)) {
124 opinions = Lists.newArrayList(document.getOpinions());
125 } else {
126 opinions = Lists.newArrayList();
127 for (final String goldLabel : goldLabels) {
128 opinions.addAll(document.getOpinions(goldLabel));
129 }
130 }
131
132 for (final Iterator<Opinion> i = opinions.iterator(); i.hasNext();) {
133 final Opinion opinion = i.next();
134 if (opinion.getPolarity() != null
135 && opinion.getPolarity().equalsIgnoreCase("NON-OPINIONATED")) {
136 i.remove();
137 LOGGER.info("Skipping non-opinionated opinion {}", opinion.getId());
138 }
139 }
140 for (final Opinion opinion : opinions) {
141 final OpinionExpression exp = opinion.getOpinionExpression();
142 opinionsBySentence.put(exp.getSpan().getTargets().get(0).getSent(), opinion);
143 }
144 }
145
146
147 final int numSentences = document.getNumSentences();
148 for (int sentID = 1; sentID <= numSentences; ++sentID) {
149
150
151 final Opinion opinions[] = opinionsBySentence.get(sentID).toArray(new Opinion[0]);
152
153
154 try {
155 doAdd(document, sentID, opinions);
156 this.numOpinions.addAndGet(opinions.length);
157 } catch (final Throwable ex) {
158 Throwables.propagate(ex);
159 }
160 }
161 }
162
163 private void checkNotTrained() {
164 Preconditions.checkState(!this.trained, "Training already completed");
165 }
166
167 protected void doFilter(final KAFDocument document) {
168
169 }
170
171 protected abstract void doAdd(KAFDocument document, int sentence, Opinion[] opinions)
172 throws Throwable;
173
174 protected abstract T doTrain() throws Throwable;
175
176 @SuppressWarnings("unchecked")
177 public static Trainer<? extends Extractor> create(final Properties properties,
178 final Component... components) {
179
180
181 String implementationName = properties.getProperty("class");
182 if (implementationName == null) {
183 implementationName = "eu.fbk.dkm.pikes.raid.pipeline.PipelineTrainer";
184 }
185 try {
186 final Class<?> implementationClass = Class.forName(implementationName);
187 final Constructor<?> constructor = implementationClass.getDeclaredConstructor(
188 Properties.class, Component[].class);
189 constructor.setAccessible(true);
190 return (Trainer<? extends Extractor>) constructor.newInstance(properties, components);
191 } catch (final InvocationTargetException ex) {
192 throw Throwables.propagate(ex);
193 } catch (final NoSuchMethodException | ClassNotFoundException | IllegalAccessException
194 | InstantiationException ex) {
195 throw new IllegalArgumentException("Could not instantiate class " + implementationName);
196 }
197 }
198
199 public static void main(final String... args) {
200
201 try {
202
203 final CommandLine cmd = CommandLine
204 .parser()
205 .withName("fssa-train")
206 .withHeader(
207 "Train the extractor of opinion expressions, holders and targets "
208 + "given a set of annotated NAF files.")
209 .withOption("p", "properties", "a sequence of key=value properties, used to "
210 + "select and configure the trainer", "PROPS", Type.STRING, true,
211 false, false)
212 .withOption("c", "components", "the opinion components to consider: "
213 + "(e)xpression, (h)older, (t)arget, (p)olarity", "COMP", Type.STRING,
214 true, false, false)
215 .withOption("l", "labels", "the labels of gold opinions to consider, comma "
216 + "separated (no spaces)", "LABELS", Type.STRING, true, false, false)
217 .withOption("r", "recursive",
218 "recurse into subdirectories of specified input paths")
219 .withOption("@", "list",
220 "interprets input as list of file names, one per line")
221 .withOption(null, "wordnet", "wordnet dict path", "PATH",
222 Type.DIRECTORY_EXISTING, true, false, false)
223 .withOption("o", "output", "the output model file", "FILE", Type.FILE, true,
224 false, false)
225 .withOption("s", "split",
226 "splits the supplied NAF files based on the supplied "
227 + "seed:ratio spec, using only the first part for training",
228 "RATIO", Type.STRING, true, false, false)
229 .withFooter(
230 "Zero or more input paths can be specified, corresponding either "
231 + "to NAF files or directories that are scanned for NAF "
232 + "files. If the list is empty, an input NAF file will be "
233 + "read from the standard input. If no output path is "
234 + "specified (-o), the model is written to standard output.")
235 .withLogger(LoggerFactory.getLogger("eu.fbk"))
236 .parse(args);
237
238
239 final Properties properties = Util.parseProperties(cmd.getOptionValue("p",
240 String.class, ""));
241 final Component[] components = Component.forLetters(
242 cmd.getOptionValue("c", String.class, "")).toArray(new Component[0]);
243 final Set<String> labels = ImmutableSet.copyOf(Splitter.on(',').omitEmptyStrings()
244 .split(cmd.getOptionValue("l", String.class, "")));
245 final boolean recursive = cmd.hasOption("r");
246 final boolean list = cmd.hasOption("@");
247 final Path outputPath = cmd.getOptionValue("o", Path.class, null);
248 final String split = cmd.getOptionValue("s", String.class);
249 final List<Path> inputPaths = Lists.newArrayList(cmd.getArgs(Path.class));
250
251 final String wordnetPath = cmd.getOptionValue("wordnet", String.class);
252 if (wordnetPath != null) {
253 WordNet.setPath(wordnetPath);
254 }
255
256
257 final Trainer<? extends Extractor> trainer = create(properties, components);
258
259
260 final List<Path> files = Util.fileMatch(inputPaths, ImmutableList.of(".naf",
261 ".naf.gz", ".naf.bz2", ".naf.xz", ".xml", ".xml.gz", ".xml.bz2", ".xml.xz"),
262 recursive, list);
263 Iterable<KAFDocument> documents = files != null ? Corpus.create(false, files)
264 : ImmutableList.of(NAFUtils.readDocument(null));
265
266
267 if (split != null && documents instanceof Corpus) {
268 final int index = split.indexOf(':');
269 long seed = 0;
270 float ratio = 1.0f;
271 if (index >= 0) {
272 seed = Long.parseLong(split.substring(0, index));
273 ratio = Float.parseFloat(split.substring(index + 1));
274 } else {
275 ratio = Float.parseFloat(split);
276 }
277 final Corpus corpus = (Corpus) documents;
278 documents = corpus.split(seed, ratio, 1.0f - ratio)[0];
279 }
280
281
282 final Tracker tracker = new Tracker(LOGGER, null,
283 "Processed %d NAF files (%d NAF/s avg)",
284 "Processed %d NAF files (%d NAF/s, %d NAF/s avg)");
285 tracker.start();
286 StreamSupport.stream(documents.spliterator(), false).forEach(
287 (final KAFDocument document) -> {
288 trainer.add(document, labels);
289 tracker.increment();
290 });
291 tracker.end();
292
293
294 final Extractor extractor = trainer.train();
295 extractor.writeTo(outputPath);
296
297 } catch (final Throwable ex) {
298 CommandLine.fail(ex);
299 }
300 }
301 }