1   package eu.fbk.dkm.pikes.resources;
2   
3   import com.google.common.base.Charsets;
4   import com.google.common.base.Joiner;
5   import com.google.common.base.Strings;
6   import com.google.common.collect.*;
7   import com.google.common.io.Resources;
8   import eu.fbk.utils.core.CommandLine;
9   import eu.fbk.utils.core.CommandLine.Type;
10  import eu.fbk.utils.core.StaxParser;
11  import org.slf4j.LoggerFactory;
12  
13  import javax.annotation.Nullable;
14  import javax.xml.stream.XMLStreamException;
15  import java.io.*;
16  import java.util.*;
17  import java.util.regex.Matcher;
18  import java.util.regex.Pattern;
19  
20  public class NomBank {
21  
22      private static final List<Roleset> ROLESETS;
23  
24      private static final Map<String, Roleset> ID_INDEX;
25  
26      private static final ListMultimap<String, Roleset> LEMMA_INDEX;
27  
28      private static final ListMultimap<String, Roleset> PB_ID_INDEX;
29  
30      static {
31          try {
32              final Map<String, Roleset> idIndex = Maps.newLinkedHashMap();
33              final ListMultimap<String, Roleset> lemmaIndex = ArrayListMultimap.create();
34              final ListMultimap<String, Roleset> pbIdIndex = ArrayListMultimap.create();
35  
36              final BufferedReader reader = Resources.asCharSource(
37                      NomBank.class.getResource("NomBank.tsv"), Charsets.UTF_8).openBufferedStream();
38  
39              String line;
40              while ((line = reader.readLine()) != null) {
41                  final String[] tokens = line.split("\t");
42                  final String id = tokens[0];
43                  final String pbId = Strings.emptyToNull(tokens[1]);
44                  final String lemma = tokens[2];
45                  final String descr = tokens[3];
46                  final String[] argDescr = Arrays.copyOfRange(tokens, 4, 13);
47                  byte[] argPBNums = null;
48                  if (pbId != null) {
49                      argPBNums = new byte[argDescr.length];
50                      for (int i = 0; i < argPBNums.length; ++i) {
51                          argPBNums[i] = Byte.parseByte(tokens[14 + i]);
52                      }
53                  }
54                  final List<Integer> mandatoryArgs = Lists.newArrayList();
55                  final List<Integer> optionalArgs = Lists.newArrayList();
56                  if (tokens.length > 24 && !tokens[24].equals("")) {
57                      for (final String arg : Ordering.natural().sortedCopy(
58                              Arrays.asList(tokens[24].split("\\s+")))) {
59                          mandatoryArgs.add(Integer.parseInt(arg));
60                      }
61                  }
62                  if (tokens.length > 25 && !tokens[25].equals("")) {
63                      for (final String arg : Ordering.natural().sortedCopy(
64                              Arrays.asList(tokens[25].split("\\s+")))) {
65                          optionalArgs.add(Integer.parseInt(arg));
66                      }
67                  }
68                  final Roleset roleset = new Roleset(id, pbId, lemma, descr, argPBNums, argDescr,
69                          mandatoryArgs, optionalArgs);
70                  idIndex.put(id, roleset);
71                  lemmaIndex.put(lemma, roleset);
72                  if (pbId != null) {
73                      pbIdIndex.put(pbId, roleset);
74                  }
75              }
76  
77              reader.close();
78  
79              ROLESETS = ImmutableList.copyOf(idIndex.values());
80              ID_INDEX = ImmutableMap.copyOf(idIndex);
81              LEMMA_INDEX = ImmutableListMultimap.copyOf(lemmaIndex);
82              PB_ID_INDEX = ImmutableListMultimap.copyOf(pbIdIndex);
83  
84          } catch (final IOException ex) {
85              throw new Error("Cannot load eu.fbk.dkm.pikes.resources.PropBank data", ex);
86          }
87      }
88  
89      public static Set<String> getIds() {
90          return ID_INDEX.keySet();
91      }
92  
93      public static Set<String> getLemmas() {
94          return LEMMA_INDEX.keySet();
95      }
96  
97      @Nullable
98      public static Roleset getRoleset(@Nullable final String id) {
99          return ID_INDEX.get(id.toLowerCase());
100     }
101 
102     public static List<Roleset> getRolesetsForLemma(@Nullable final String lemma) {
103         if (lemma == null) {
104             return new ArrayList<>();
105         }
106         return LEMMA_INDEX.get(lemma.toLowerCase());
107     }
108 
109     public static List<Roleset> getRolesetsForPBId(final String pbId) {
110         return PB_ID_INDEX.get(pbId.toLowerCase());
111     }
112 
113     public static List<Roleset> getRolesets() {
114         return ROLESETS;
115     }
116 
117     public static void main(final String[] args) throws IOException, XMLStreamException {
118 
119         try {
120             final CommandLine cmd = CommandLine
121                     .parser()
122                     .withName("eu.fbk.dkm.pikes.resources.NomBank")
123                     .withHeader("Generate a TSV file with indexed eu.fbk.dkm.pikes.resources.NomBank data")
124                     .withOption("f", "frames", "the directory containing frame definitions",
125                             "DIR", Type.DIRECTORY_EXISTING, true, false, true)
126                     .withOption("a", "annotations",
127                             "the eu.fbk.dkm.pikes.resources.NomBank annotation file (e.g., eu.fbk.dkm.pikes.resources.NomBank.1.0)", "FILE",
128                             Type.FILE_EXISTING, true, false, true)
129                     .withOption("o", "output", "output file", "FILE", Type.FILE, true, false, true)
130                     .withLogger(LoggerFactory.getLogger("eu.fbk.nafview")).parse(args);
131 
132             final File dir = cmd.getOptionValue("f", File.class);
133             final File annotations = cmd.getOptionValue("a", File.class);
134             final File output = cmd.getOptionValue("o", File.class);
135 
136             final Writer writer = new OutputStreamWriter(new BufferedOutputStream(
137                     new FileOutputStream(output)), Charsets.UTF_8);
138 
139             final File[] files = dir.listFiles();
140             Arrays.sort(files);
141 
142             final Map<String, Multiset<Integer>> roles = getPredicateRoles(annotations);
143 
144             // Manual corrections due to lack of samples
145             addPredicateRole(roles, "1-slash-10th.01", -1);
146             addPredicateRole(roles, "bagger.01", 0);
147             addPredicateRole(roles, "bearer.01", 0);
148             addPredicateRole(roles, "being.01", 0);
149             addPredicateRole(roles, "being.01", -1);
150             addPredicateRole(roles, "caliber.01", 2);
151             addPredicateRole(roles, "calling.01", -1);
152             addPredicateRole(roles, "clogging.02", -1);
153             addPredicateRole(roles, "counting.01", -1);
154             addPredicateRole(roles, "crusher.01", 0);
155             addPredicateRole(roles, "doer.01", 0);
156             addPredicateRole(roles, "dropper.01", -1);
157             addPredicateRole(roles, "esteem.01", -1);
158             addPredicateRole(roles, "fidelity.01", -1);
159             addPredicateRole(roles, "finder.01", 0);
160             addPredicateRole(roles, "getter.01", 0);
161             addPredicateRole(roles, "goer.01", 0);
162             addPredicateRole(roles, "grinder.01", 0);
163             addPredicateRole(roles, "implant.01", -1);
164             addPredicateRole(roles, "incrimination.01", -1);
165             addPredicateRole(roles, "interdiction.01", -1);
166             addPredicateRole(roles, "kicker.03", 0);
167             addPredicateRole(roles, "purification.01", -1);
168             addPredicateRole(roles, "purity.01", -1);
169             addPredicateRole(roles, "starter.01", 0);
170             addPredicateRole(roles, "stocking.01", -1);
171             addPredicateRole(roles, "tech.01", -1);
172             addPredicateRole(roles, "tilth.01", -1);
173             addPredicateRole(roles, "trick.02", -1);
174             addPredicateRole(roles, "tuning.01", -1);
175 
176             for (final File file : files) {
177                 if (file.getName().endsWith(".xml")) {
178                     System.out.println("Processing " + file);
179                     final Reader reader = new BufferedReader(new FileReader(file));
180                     try {
181                         new Parser(reader, roles).parse(writer);
182                     } finally {
183                         reader.close();
184                     }
185                 }
186             }
187 
188         } catch (final Throwable ex) {
189             CommandLine.fail(ex);
190         }
191     }
192 
193     private static void addPredicateRole(final Map<String, Multiset<Integer>> map,
194             final String sense, final int arg) {
195         Multiset<Integer> set = map.get(sense);
196         if (set == null) {
197             set = HashMultiset.create();
198             map.put(sense, set);
199         }
200         set.add(-1);
201         if (arg != -1) {
202             set.add(arg);
203         }
204     }
205 
206     private static Map<String, Multiset<Integer>> getPredicateRoles(final File annotations)
207             throws IOException {
208         // eg: wsj/22/wsj_2278.mrg 26 18 prisoner 01 9:1-ARG0 12:0,14:0-Support 18:0-ARG1 18:0-rel
209         final Pattern rolePattern = Pattern.compile(".*-ARG(\\d).*");
210         final Map<String, Multiset<Integer>> map = Maps.newHashMap();
211         final BufferedReader reader = new BufferedReader(new FileReader(annotations));
212         try {
213             String line;
214             int count = 0;
215             while ((line = reader.readLine()) != null) {
216                 ++count;
217                 final String[] tokens = line.split("\\s+");
218                 final String sense = tokens[3] + "." + tokens[4];
219                 Multiset<Integer> set = map.get(sense);
220                 if (set == null) {
221                     set = HashMultiset.create();
222                     map.put(sense, set);
223                 }
224                 String relIndex = null;
225                 for (int i = 5; i < tokens.length; ++i) {
226                     final String t = tokens[i].toUpperCase();
227                     if (t.endsWith("-REL")) {
228                         relIndex = t.substring(0, t.length() - 4);
229                     }
230                     if (t.contains("-H")) {
231                         relIndex = null; // multi-words detected: skip sample
232                         break;
233                     }
234                 }
235                 if (relIndex != null) {
236                     set.add(-1); // used for counting total sense occurrences
237                     for (int i = 5; i < tokens.length; ++i) {
238                         final String t = tokens[i].toUpperCase();
239                         if (t.startsWith(relIndex)) {
240                             final Matcher matcher = rolePattern.matcher(t);
241                             if (matcher.matches()) {
242                                 set.add(Integer.parseInt(matcher.group(1)));
243                             }
244                         }
245                     }
246                 }
247             }
248             System.out.println(count + " annotated propositions parsed, " + map.keySet().size()
249                     + " senses found");
250         } finally {
251             reader.close();
252         }
253         return map;
254     }
255 
256     private static class Parser extends StaxParser {
257 
258         private final Map<String, Multiset<Integer>> roles;
259 
260         Parser(final Reader reader, final Map<String, Multiset<Integer>> roles) {
261             super(reader);
262             this.roles = roles;
263         }
264 
265         void parse(final Writer writer) throws IOException, XMLStreamException {
266             final Pattern rolePattern = Pattern.compile("(\\d).*");
267             enter("frameset");
268             while (tryEnter("predicate")) {
269                 final String lemma = attribute("lemma").trim().replace('_', ' ').toLowerCase();
270                 while (tryEnter("roleset")) {
271                     final String id = attribute("id").trim();
272                     Multiset<Integer> set = this.roles.get(id);
273                     if (set == null) {
274                         set = HashMultiset.create();
275                         this.roles.put(id, set);
276                     }
277                     String pbId = attribute("source");
278                     if (pbId != null && pbId.startsWith("verb-")) {
279                         pbId = pbId.substring(5);
280                     }
281                     final String name = attribute("name").trim();
282                     final String[] argDescr = new String[10];
283                     final byte[] argPBNums = new byte[10];
284                     Arrays.fill(argPBNums, (byte) -1);
285                     if (tryEnter("roles")) {
286                         while (tryEnter("role")) {
287                             try {
288                                 final int n = Integer.parseInt(attribute("n"));
289                                 argDescr[n] = attribute("descr").trim();
290                                 if (pbId != null) {
291                                     final String pbNum = attribute("source");
292                                     argPBNums[n] = pbNum == null ? (byte) n : Byte
293                                             .parseByte(pbNum);
294                                 }
295                             } catch (final NumberFormatException ex) {
296                                 // ignore
297                             }
298                             leave();
299                         }
300                         leave();
301                     }
302                     while (tryEnter("example")) {
303                         String rel = null;
304                         final List<String> args = Lists.newArrayList(Collections.<String>nCopies(
305                                 10, null));
306                         while (tryEnter(null)) {
307                             final String num = attribute("n");
308                             if (num == null) {
309                                 rel = content().trim().toLowerCase();
310                             } else if (num != null) {
311                                 final Matcher matcher = rolePattern.matcher(num);
312                                 if (matcher.matches()) {
313                                     args.set(Integer.parseInt(matcher.group(1)), content().trim()
314                                             .toLowerCase());
315                                 }
316                             }
317                             leave();
318                         }
319                         // starts with lemma = constraint to drop multi-words (e.g. spy-chaser)
320                         if (rel != null && rel.startsWith(lemma)) {
321                             set.add(-1);
322                             for (int i = 0; i < args.size(); ++i) {
323                                 if (rel.equals(args.get(i))) {
324                                     set.add(i);
325                                 }
326                             }
327                         }
328                         leave();
329                     }
330 
331                     if (set.isEmpty()) {
332                         System.out.println("WARNING: no predicate roles computed for " + id);
333                     }
334                     final List<Integer> mandatoryArgs = Lists.newArrayList();
335                     final List<Integer> optionalArgs = Lists.newArrayList();
336                     final int sampleCount = set.count(-1);
337                     for (final Integer num : set.elementSet()) {
338                         if (num != -1) {
339                             final int argCount = set.count(num);
340                             if (argCount >= sampleCount) {
341                                 mandatoryArgs.add(num);
342                             } else if (argCount > 0) {
343                                 optionalArgs.add(num);
344                             }
345                         }
346                     }
347                     Collections.sort(mandatoryArgs);
348                     Collections.sort(optionalArgs);
349 
350                     writer.write(id);
351                     writer.write('\t');
352                     if (pbId != null) {
353                         writer.write(pbId);
354                     }
355                     writer.write('\t');
356                     writer.write(lemma);
357                     writer.write('\t');
358                     writer.write(name);
359                     for (int i = 0; i < 10; ++i) {
360                         writer.write('\t');
361                         writer.write(Strings.nullToEmpty(argDescr[i]));
362                     }
363                     for (int i = 0; i < 10; ++i) {
364                         writer.write('\t');
365                         writer.write(Byte.toString(argPBNums[i]));
366                     }
367                     writer.write('\t');
368                     writer.write(Joiner.on(' ').join(mandatoryArgs));
369                     writer.write('\t');
370                     writer.write(Joiner.on(' ').join(optionalArgs));
371                     writer.write('\n');
372                     writer.flush();
373                     leave();
374                 }
375                 leave();
376             }
377             leave();
378         }
379     }
380 
381     public static final class Roleset {
382 
383         private static final Interner<List<Integer>> INTERNER = Interners.newStrongInterner();
384 
385         private final String id;
386 
387         @Nullable
388         private final String pbId;
389 
390         private final String lemma;
391 
392         private final String descr;
393 
394         @Nullable
395         private final List<Integer> argNums;
396 
397         @Nullable
398         private final byte[] argPBNums;
399 
400         private final String[] argDescr;
401 
402         private final List<Integer> predMandatoryArgNums;
403 
404         private final List<Integer> predOptionalArgNums;
405 
406         Roleset(final String id, @Nullable final String pbId, final String lemma,
407                 final String descr, @Nullable final byte[] argPBNums, final String[] argDescr,
408                 final Iterable<Integer> predMandatoryArgNums,
409                 final Iterable<Integer> predOptionalArgNums) {
410 
411             this.id = id;
412             this.pbId = pbId;
413             this.lemma = lemma;
414             this.descr = descr;
415             this.argPBNums = argPBNums;
416             this.argDescr = argDescr;
417 
418             final ImmutableList.Builder<Integer> builder = ImmutableList.builder();
419             for (int i = 0; i < this.argDescr.length; ++i) {
420                 if (this.argDescr[i] != null) {
421                     builder.add(i);
422                 }
423             }
424             this.argNums = INTERNER.intern(builder.build());
425             this.predMandatoryArgNums = INTERNER.intern(Ordering.natural().sortedCopy(
426                     predMandatoryArgNums));
427             this.predOptionalArgNums = INTERNER.intern(Ordering.natural().sortedCopy(
428                     predOptionalArgNums));
429         }
430 
431         public String getId() {
432             return this.id;
433         }
434 
435         @Nullable
436         public String getPBId() {
437             return this.pbId;
438         }
439 
440         public String getLemma() {
441             return this.lemma;
442         }
443 
444         public String getDescr() {
445             return this.descr;
446         }
447 
448         public List<Integer> getArgNums() {
449             return this.argNums;
450         }
451 
452         public String getArgDescr(final int argNum) {
453             return this.argDescr[argNum];
454         }
455 
456         @Nullable
457         public int getArgPBNum(final int argNum) {
458             if (this.argPBNums == null || this.argDescr[argNum] == null) {
459                 return -1;
460             }
461             return this.argPBNums[argNum];
462         }
463 
464         public List<Integer> getPredMandatoryArgNums() {
465             return this.predMandatoryArgNums;
466         }
467 
468         public List<Integer> getPredOptionalArgNums() {
469             return this.predOptionalArgNums;
470         }
471 
472         @Override
473         public String toString() {
474             return this.id;
475         }
476 
477     }
478 
479 }