1 package eu.fbk.dkm.pikes.resources.conllAIDA;
2
3 import java.io.BufferedReader;
4 import java.io.Writer;
5 import java.nio.file.Path;
6 import java.util.BitSet;
7
8 import javax.annotation.Nullable;
9
10 import com.google.common.base.Strings;
11
12 import org.slf4j.Logger;
13 import org.slf4j.LoggerFactory;
14
15 import eu.fbk.rdfpro.util.IO;
16 import eu.fbk.utils.core.CommandLine;
17
18 public final class KeepOverlappingAnnotations {
19
20 private static final Logger LOGGER = LoggerFactory.getLogger(KeepOverlappingAnnotations.class);
21
22 public static void main(final String... args) {
23 try {
24
25 final CommandLine cmd = CommandLine.parser().withName("keep-overlapping-annotations")
26 .withHeader("Filters an input CONLL/AIDA file, keeping only annotations "
27 + "overlapping with the ones in a supplied gold standard CONLL/AIDA file")
28 .withOption("i", "input", "the CONLL/AIDA FILE to filter", "FILE",
29 CommandLine.Type.FILE_EXISTING, true, false, true)
30 .withOption("g", "gold",
31 "the gold standard CONLL/AIDA FILE to use as reference", "FILE",
32 CommandLine.Type.FILE_EXISTING, true, false, true)
33 .withOption("o", "output", "the output filtered CONLL/AIDA FILE to generate",
34 "FILE", CommandLine.Type.FILE, true, false, true)
35 .withOption("d", "dataset", "the dataset format, either conll03 or aida",
36 "FORMAT", CommandLine.Type.STRING, true, false, true)
37 .withLogger(LoggerFactory.getLogger("eu.fbk")).parse(args);
38
39
40 final Path goldPath = cmd.getOptionValue("g", Path.class);
41 final Path inputPath = cmd.getOptionValue("i", Path.class);
42 final Path outputPath = cmd.getOptionValue("o", Path.class);
43 final boolean isAida = cmd.getOptionValue("d", String.class, "conll03")
44 .equalsIgnoreCase("aida");
45
46
47 int tokenIndex = 0;
48 final BitSet goldMask = new BitSet();
49 try (BufferedReader in = new BufferedReader(
50 IO.utf8Reader(IO.buffer(IO.read(goldPath.toString()))))) {
51 String line;
52 while ((line = in.readLine()) != null) {
53 if (isTokenLine(line)) {
54 if (getAnnotation(line, isAida) != null) {
55 goldMask.set(tokenIndex, true);
56 }
57 ++tokenIndex;
58 }
59 }
60 }
61 LOGGER.info("Read gold standard {}: {} tokens, {} tokens annotated", goldPath,
62 tokenIndex, goldMask.cardinality());
63
64
65 final int[] mentionMask = new int[tokenIndex];
66 int annotatedTokenCounter = 0;
67 int mentionIndex = 0;
68 tokenIndex = 0;
69 String previousAnnotation = null;
70 try (BufferedReader in = new BufferedReader(
71 IO.utf8Reader(IO.buffer(IO.read(inputPath.toString()))))) {
72 String line;
73 while ((line = in.readLine()) != null) {
74 String annotation = null;
75 if (isTokenLine(line)) {
76 annotation = getAnnotation(line, isAida);
77 if (annotation != null) {
78 if (tokenIndex == 0 || mentionMask[tokenIndex - 1] == 0
79 || isBeginToken(annotation, previousAnnotation, isAida)) {
80 ++mentionIndex;
81 }
82 mentionMask[tokenIndex] = mentionIndex;
83 ++annotatedTokenCounter;
84 }
85 ++tokenIndex;
86 }
87 previousAnnotation = annotation;
88 }
89 }
90 final int numInputTokens = tokenIndex;
91 final int numInputMentions = mentionIndex;
92 LOGGER.info("Read input {}: {} tokens, {} tokens annotated, {} mentions", inputPath,
93 numInputTokens, annotatedTokenCounter, numInputMentions);
94
95
96 tokenIndex = 0;
97 int removedMentions = 0;
98 int removedTokens = 0;
99 while (tokenIndex < mentionMask.length) {
100 mentionIndex = mentionMask[tokenIndex];
101 if (mentionIndex != 0) {
102 final int start = tokenIndex;
103 while (mentionMask[tokenIndex + 1] == mentionIndex) {
104 ++tokenIndex;
105 }
106 boolean overlaps = false;
107 for (int i = start; i <= tokenIndex; ++i) {
108 if (goldMask.get(i)) {
109 overlaps = true;
110 break;
111 }
112 }
113 if (!overlaps) {
114 for (int i = start; i <= tokenIndex; ++i) {
115 mentionMask[i] = 0;
116 ++removedTokens;
117 }
118 ++removedMentions;
119 }
120 }
121 ++tokenIndex;
122 }
123 LOGGER.info("Kept {} tokens annotated and {} mentions overlapping with gold standard",
124 annotatedTokenCounter - removedTokens, numInputMentions - removedMentions);
125
126
127 tokenIndex = 0;
128 removedTokens = 0;
129 try (BufferedReader in = new BufferedReader(
130 IO.utf8Reader(IO.buffer(IO.read(inputPath.toString()))))) {
131 try (Writer out = IO.utf8Writer(IO.buffer(IO.write(outputPath.toString())))) {
132 String line;
133 while ((line = in.readLine()) != null) {
134 String modifiedLine = line;
135 if (isTokenLine(line)) {
136 final String annotation = getAnnotation(line, isAida);
137 if (annotation != null && mentionMask[tokenIndex] == 0) {
138 modifiedLine = clearAnnotation(line, isAida);
139 ++removedTokens;
140 }
141 ++tokenIndex;
142 }
143 out.append(modifiedLine).append('\n');
144 }
145 }
146 }
147 LOGGER.info("Emitted {}: {} tokens removed", outputPath, removedTokens);
148
149 } catch (final Throwable ex) {
150
151 CommandLine.fail(ex);
152 }
153 }
154
155 private static boolean isTokenLine(final String line) {
156 return !Strings.isNullOrEmpty(line) && !line.startsWith("-DOCSTART-");
157 }
158
159 private static boolean isBeginToken(final String annotation,
160 @Nullable final String previousAnnotation, final boolean isAida) {
161 return previousAnnotation == null || isAida && annotation.equals("B")
162 || !isAida && (annotation.startsWith("B-")
163 || !annotation.substring(2).equals(previousAnnotation.substring(2)));
164 }
165
166 private static String getAnnotation(final String line, final boolean isAida) {
167 if (isAida) {
168 final int index = line.indexOf('\t');
169 if (index > 0) {
170 final int nextIndex = line.indexOf('\t', index + 1);
171 if (nextIndex > index) {
172 final String annotation = line.substring(index, nextIndex).trim();
173 return annotation.isEmpty() ? null : annotation;
174 }
175 }
176 return null;
177 } else {
178 final String tokens[] = line.split("\\s+");
179 return tokens[3].equals("O") ? null : tokens[3];
180 }
181 }
182
183 private static String clearAnnotation(final String line, final boolean isAida) {
184 if (isAida) {
185 final int index = line.indexOf('\t');
186 return index < 0 ? line : line.substring(0, index);
187 } else {
188 final String tokens[] = line.split("\\s+");
189 tokens[3] = "O";
190 return String.join(" ", tokens);
191 }
192 }
193
194 }