1   package eu.fbk.dkm.pikes.resources.conllAIDA;
2   
3   import java.io.BufferedReader;
4   import java.io.Writer;
5   import java.nio.file.Path;
6   import java.util.BitSet;
7   
8   import javax.annotation.Nullable;
9   
10  import com.google.common.base.Strings;
11  
12  import org.slf4j.Logger;
13  import org.slf4j.LoggerFactory;
14  
15  import eu.fbk.rdfpro.util.IO;
16  import eu.fbk.utils.core.CommandLine;
17  
18  public final class KeepOverlappingAnnotations {
19  
20      private static final Logger LOGGER = LoggerFactory.getLogger(KeepOverlappingAnnotations.class);
21  
22      public static void main(final String... args) {
23          try {
24              // Parse command line
25              final CommandLine cmd = CommandLine.parser().withName("keep-overlapping-annotations")
26                      .withHeader("Filters an input CONLL/AIDA file, keeping only annotations "
27                              + "overlapping with the ones in a supplied gold standard CONLL/AIDA file")
28                      .withOption("i", "input", "the CONLL/AIDA FILE to filter", "FILE",
29                              CommandLine.Type.FILE_EXISTING, true, false, true)
30                      .withOption("g", "gold",
31                              "the gold standard CONLL/AIDA FILE to use as reference", "FILE",
32                              CommandLine.Type.FILE_EXISTING, true, false, true)
33                      .withOption("o", "output", "the output filtered CONLL/AIDA FILE to generate",
34                              "FILE", CommandLine.Type.FILE, true, false, true)
35                      .withOption("d", "dataset", "the dataset format, either conll03 or aida",
36                              "FORMAT", CommandLine.Type.STRING, true, false, true)
37                      .withLogger(LoggerFactory.getLogger("eu.fbk")).parse(args);
38  
39              // Read options
40              final Path goldPath = cmd.getOptionValue("g", Path.class);
41              final Path inputPath = cmd.getOptionValue("i", Path.class);
42              final Path outputPath = cmd.getOptionValue("o", Path.class);
43              final boolean isAida = cmd.getOptionValue("d", String.class, "conll03")
44                      .equalsIgnoreCase("aida");
45  
46              // Read gold standard (both CONLL03 and AIDA supported)
47              int tokenIndex = 0;
48              final BitSet goldMask = new BitSet();
49              try (BufferedReader in = new BufferedReader(
50                      IO.utf8Reader(IO.buffer(IO.read(goldPath.toString()))))) {
51                  String line;
52                  while ((line = in.readLine()) != null) {
53                      if (isTokenLine(line)) {
54                          if (getAnnotation(line, isAida) != null) {
55                              goldMask.set(tokenIndex, true);
56                          }
57                          ++tokenIndex;
58                      }
59                  }
60              }
61              LOGGER.info("Read gold standard {}: {} tokens, {} tokens annotated", goldPath,
62                      tokenIndex, goldMask.cardinality());
63  
64              // Read input a first time to detect mentions (both CONLL03 and AIDA supported)
65              final int[] mentionMask = new int[tokenIndex];
66              int annotatedTokenCounter = 0;
67              int mentionIndex = 0;
68              tokenIndex = 0;
69              String previousAnnotation = null;
70              try (BufferedReader in = new BufferedReader(
71                      IO.utf8Reader(IO.buffer(IO.read(inputPath.toString()))))) {
72                  String line;
73                  while ((line = in.readLine()) != null) {
74                      String annotation = null;
75                      if (isTokenLine(line)) {
76                          annotation = getAnnotation(line, isAida);
77                          if (annotation != null) {
78                              if (tokenIndex == 0 || mentionMask[tokenIndex - 1] == 0
79                                      || isBeginToken(annotation, previousAnnotation, isAida)) {
80                                  ++mentionIndex; // start of new entity mention found
81                              }
82                              mentionMask[tokenIndex] = mentionIndex;
83                              ++annotatedTokenCounter;
84                          }
85                          ++tokenIndex;
86                      }
87                      previousAnnotation = annotation;
88                  }
89              }
90              final int numInputTokens = tokenIndex;
91              final int numInputMentions = mentionIndex;
92              LOGGER.info("Read input {}: {} tokens, {} tokens annotated, {} mentions", inputPath,
93                      numInputTokens, annotatedTokenCounter, numInputMentions);
94  
95              // Remove all mentions (in the mask) that do not overlap with a gold mention
96              tokenIndex = 0;
97              int removedMentions = 0;
98              int removedTokens = 0;
99              while (tokenIndex < mentionMask.length) {
100                 mentionIndex = mentionMask[tokenIndex];
101                 if (mentionIndex != 0) {
102                     final int start = tokenIndex;
103                     while (mentionMask[tokenIndex + 1] == mentionIndex) {
104                         ++tokenIndex;
105                     }
106                     boolean overlaps = false;
107                     for (int i = start; i <= tokenIndex; ++i) {
108                         if (goldMask.get(i)) {
109                             overlaps = true;
110                             break;
111                         }
112                     }
113                     if (!overlaps) {
114                         for (int i = start; i <= tokenIndex; ++i) {
115                             mentionMask[i] = 0;
116                             ++removedTokens;
117                         }
118                         ++removedMentions;
119                     }
120                 }
121                 ++tokenIndex;
122             }
123             LOGGER.info("Kept {} tokens annotated and {} mentions overlapping with gold standard",
124                     annotatedTokenCounter - removedTokens, numInputMentions - removedMentions);
125 
126             // Emit output
127             tokenIndex = 0;
128             removedTokens = 0;
129             try (BufferedReader in = new BufferedReader(
130                     IO.utf8Reader(IO.buffer(IO.read(inputPath.toString()))))) {
131                 try (Writer out = IO.utf8Writer(IO.buffer(IO.write(outputPath.toString())))) {
132                     String line;
133                     while ((line = in.readLine()) != null) {
134                         String modifiedLine = line;
135                         if (isTokenLine(line)) {
136                             final String annotation = getAnnotation(line, isAida);
137                             if (annotation != null && mentionMask[tokenIndex] == 0) {
138                                 modifiedLine = clearAnnotation(line, isAida);
139                                 ++removedTokens;
140                             }
141                             ++tokenIndex;
142                         }
143                         out.append(modifiedLine).append('\n');
144                     }
145                 }
146             }
147             LOGGER.info("Emitted {}: {} tokens removed", outputPath, removedTokens);
148 
149         } catch (final Throwable ex) {
150             // Handle failure
151             CommandLine.fail(ex);
152         }
153     }
154 
155     private static boolean isTokenLine(final String line) {
156         return !Strings.isNullOrEmpty(line) && !line.startsWith("-DOCSTART-");
157     }
158 
159     private static boolean isBeginToken(final String annotation,
160             @Nullable final String previousAnnotation, final boolean isAida) {
161         return previousAnnotation == null || isAida && annotation.equals("B")
162                 || !isAida && (annotation.startsWith("B-")
163                         || !annotation.substring(2).equals(previousAnnotation.substring(2)));
164     }
165 
166     private static String getAnnotation(final String line, final boolean isAida) {
167         if (isAida) {
168             final int index = line.indexOf('\t');
169             if (index > 0) {
170                 final int nextIndex = line.indexOf('\t', index + 1);
171                 if (nextIndex > index) {
172                     final String annotation = line.substring(index, nextIndex).trim();
173                     return annotation.isEmpty() ? null : annotation;
174                 }
175             }
176             return null;
177         } else {
178             final String tokens[] = line.split("\\s+");
179             return tokens[3].equals("O") ? null : tokens[3];
180         }
181     }
182 
183     private static String clearAnnotation(final String line, final boolean isAida) {
184         if (isAida) {
185             final int index = line.indexOf('\t');
186             return index < 0 ? line : line.substring(0, index);
187         } else {
188             final String tokens[] = line.split("\\s+");
189             tokens[3] = "O";
190             return String.join(" ", tokens);
191         }
192     }
193 
194 }