1   package eu.fbk.dkm.pikes.resources.ecb;
2   
3   import com.google.common.collect.HashMultimap;
4   import com.google.common.io.Files;
5   import eu.fbk.utils.core.CommandLine;
6   import eu.fbk.utils.eval.PrecisionRecall;
7   import ixa.kaflib.Coref;
8   import ixa.kaflib.KAFDocument;
9   import ixa.kaflib.Span;
10  import ixa.kaflib.Term;
11  import org.apache.commons.csv.CSVFormat;
12  import org.apache.commons.csv.CSVRecord;
13  import org.slf4j.Logger;
14  import org.slf4j.LoggerFactory;
15  
16  import java.io.File;
17  import java.io.FileReader;
18  import java.io.Reader;
19  import java.util.HashMap;
20  import java.util.HashSet;
21  import java.util.Map;
22  import java.util.Set;
23  import java.util.regex.Matcher;
24  import java.util.regex.Pattern;
25  
26  /**
27   * Created by marcorospocher on 12/03/16.
28   */
29  public class ECBevaluator2 {
30  
31      private static final Logger LOGGER = LoggerFactory.getLogger(ECBevaluator2.class);
32      private static final Pattern tokenPattern = Pattern.compile("/([0-9]+)/([0-9])\\.ecb#char=([0-9]+)");
33  //    private static final Boolean removeAloneClusters = false;
34  //    private static final Pattern chainPattern = Pattern.compile("CHAIN=\"([0-9]+)\"");
35  
36      public static void main(String[] args) {
37          try {
38  
39              final CommandLine cmd = CommandLine
40                      .parser()
41                      .withName("./ecb-evaluator")
42                      .withHeader("Evaluator event extractor")
43                      .withOption("n", "input-naf", "Input NAF folder", "FOLDER",
44                              CommandLine.Type.DIRECTORY_EXISTING, true, false, true)
45                      .withOption("i", "input-csv", "Input CSV file", "FILE",
46                              CommandLine.Type.FILE_EXISTING, true, false, true)
47                      .withOption("l", "input-lemmas", "Lemmas CSV file", "FILE",
48                              CommandLine.Type.FILE_EXISTING, true, false, false)
49                      .withOption("r", "remove-alone", "Remove alone clusters")
50                      .withLogger(LoggerFactory.getLogger("eu.fbk")).parse(args);
51  
52              File inputCsv = cmd.getOptionValue("input-csv", File.class);
53              File inputNaf = cmd.getOptionValue("input-naf", File.class);
54              File inputLemmas = cmd.getOptionValue("input-lemmas", File.class);
55  
56              Boolean removeAloneClusters = cmd.hasOption("remove-alone");
57  
58              Reader in;
59              Iterable<CSVRecord> records;
60  
61              HashSet<String> lemmas = null;
62              if (inputLemmas != null) {
63                  lemmas = new HashSet<>();
64                  in = new FileReader(inputLemmas);
65                  records = CSVFormat.EXCEL.withHeader().parse(in);
66                  for (CSVRecord record : records) {
67                      String lemma = record.get(1);
68                      lemma = lemma.replaceAll("\"", "").trim();
69                      if (lemma.length() > 0) {
70                          lemmas.add(lemma);
71                      }
72                  }
73              }
74  
75              HashMultimap<String, String> goldTmpClusters = HashMultimap.create();
76              Set<String> okEvents = new HashSet<>();
77  
78              for (final File file : Files.fileTreeTraverser().preOrderTraversal(inputNaf)) {
79                  if (!file.isFile()) {
80                      continue;
81                  }
82                  if (file.getName().startsWith(".")) {
83                      continue;
84                  }
85  
86                  String path = file.getParentFile().toString();
87                  Integer folder = Integer.parseInt(path.substring(path.lastIndexOf("/")).substring(1));
88                  Integer fileNum = Integer.parseInt(file.getName().substring(0, file.getName().length() - 4));
89  
90                  LOGGER.debug(file.getAbsolutePath());
91                  KAFDocument document = KAFDocument.createFromFile(file);
92                  for (Coref coref : document.getCorefs()) {
93                      if (coref.getType() == null) {
94                          continue;
95                      }
96                      if (!coref.getType().equals("event-gold")) {
97                          continue;
98                      }
99  
100                     Integer cluster = Integer.parseInt(coref.getCluster());
101                     String idCluster = folder + "_" + cluster;
102 
103                     for (Span<Term> termSpan : coref.getSpans()) {
104                         Term term = termSpan.getTargets().get(0);
105                         String lemma = term.getLemma();
106                         if (lemmas == null || lemmas.contains(lemma)) {
107                             String text = folder + "_" + fileNum + "_" + term.getOffset();
108                             goldTmpClusters.put(idCluster, text);
109                             okEvents.add(text);
110                         }
111                     }
112                 }
113             }
114 
115             Set<Set> goldClusters = new HashSet<>();
116             for (String key : goldTmpClusters.keySet()) {
117                 Set<String> cluster = goldTmpClusters.get(key);
118                 if (cluster.size() > 1 || !removeAloneClusters) {
119                     goldClusters.add(cluster);
120                 }
121             }
122 
123             LOGGER.info("Gold clusters: {}", goldClusters.size());
124 
125             in = new FileReader(inputCsv);
126             records = CSVFormat.EXCEL.withHeader().parse(in);
127 
128             // Size must be always 4!
129             int clusterID = 0;
130             HashMap<String, Integer> clusterIndexes = new HashMap<>();
131             HashMultimap<Integer, String> tmpClusters = HashMultimap.create();
132             for (CSVRecord record : records) {
133                 Matcher matcher;
134 
135                 String id1 = null;
136                 String id2 = null;
137                 matcher = tokenPattern.matcher(record.get(1));
138                 if (matcher.find()) {
139                     id1 = matcher.group(1) + "_" + matcher.group(2) + "_" + matcher.group(3);
140                 }
141                 matcher = tokenPattern.matcher(record.get(3));
142                 if (matcher.find()) {
143                     id2 = matcher.group(1) + "_" + matcher.group(2) + "_" + matcher.group(3);
144                 }
145 
146                 Integer index1 = clusterIndexes.get(id1);
147                 Integer index2 = clusterIndexes.get(id2);
148 
149                 if (index1 == null && index2 == null) {
150                     clusterID++;
151                     if (okEvents.contains(id2)) {
152                         tmpClusters.put(clusterID, id2);
153                         clusterIndexes.put(id2, clusterID);
154                     }
155                     if (okEvents.contains(id1)) {
156                         tmpClusters.put(clusterID, id1);
157                         clusterIndexes.put(id1, clusterID);
158                     }
159                 }
160                 if (index1 == null && index2 != null) {
161                     if (okEvents.contains(id1)) {
162                         tmpClusters.put(index2, id1);
163                         clusterIndexes.put(id1, index2);
164                     }
165                 }
166                 if (index2 == null && index1 != null) {
167                     if (okEvents.contains(id2)) {
168                         tmpClusters.put(index1, id2);
169                         clusterIndexes.put(id2, index1);
170                     }
171                 }
172                 if (index2 != null && index1 != null) {
173                     if (!index1.equals(index2)) {
174                         clusterIndexes.put(id2, index1);
175                         tmpClusters.putAll(index1, tmpClusters.get(index2));
176                         tmpClusters.removeAll(index2);
177                     }
178                 }
179             }
180 
181             System.out.println(tmpClusters);
182 
183             Set<Set> clusters = new HashSet<>();
184             for (Integer key : tmpClusters.keySet()) {
185                 Set<String> cluster = tmpClusters.get(key);
186                 if (cluster.size() > 1 || !removeAloneClusters) {
187                     clusters.add(cluster);
188                 }
189             }
190             LOGGER.info("Classification clusters: {}", clusters.size());
191 
192             System.out.println(goldClusters);
193             System.out.println(clusters);
194 
195             Map<PrecisionRecall.Measure, Double> precisionRecall = ClusteringEvaluation
196                     .pairWise(clusters, goldClusters);
197 
198             System.out.println(precisionRecall);
199 
200 //            Map<PrecisionRecall.Measure, Double> measureDoubleMap = ClusteringEvaluation
201 //                    .pairWise(goldClusters, clusters);
202 //            System.out.println(measureDoubleMap);
203 
204         } catch (Exception e) {
205             CommandLine.fail(e);
206         }
207     }
208 
209 }