1   package eu.fbk.dkm.pikes.eval;
2   
3   import java.util.Collection;
4   import java.util.HashSet;
5   import java.util.List;
6   import java.util.Map;
7   import java.util.Set;
8   
9   import com.google.common.base.Preconditions;
10  import com.google.common.collect.HashMultimap;
11  import com.google.common.collect.ImmutableSet;
12  import com.google.common.collect.Lists;
13  import com.google.common.collect.Maps;
14  import com.google.common.collect.Multimap;
15  import com.google.common.collect.Ordering;
16  import com.google.common.collect.Sets;
17  
18  import org.eclipse.rdf4j.model.Resource;
19  import org.eclipse.rdf4j.model.Statement;
20  import org.eclipse.rdf4j.model.IRI;
21  import org.eclipse.rdf4j.model.Value;
22  import org.eclipse.rdf4j.model.vocabulary.DCTERMS;
23  import org.eclipse.rdf4j.model.vocabulary.RDF;
24  import org.eclipse.rdf4j.rio.RDFHandler;
25  import org.slf4j.Logger;
26  import org.slf4j.LoggerFactory;
27  
28  import eu.fbk.utils.core.CommandLine;
29  import eu.fbk.utils.core.CommandLine.Type;
30  import eu.fbk.utils.eval.PrecisionRecall;
31  import eu.fbk.rdfpro.RDFHandlers;
32  import eu.fbk.rdfpro.RDFSources;
33  import eu.fbk.rdfpro.util.QuadModel;
34  import eu.fbk.rdfpro.util.Statements;
35  
36  public class Aligner {
37  
38      private static final Logger LOGGER = LoggerFactory.getLogger(Aligner.class);
39  
40      public static List<Statement> align(final Collection<Statement> stmts) {
41  
42          final QuadModel model = stmts instanceof QuadModel ? (QuadModel) stmts : QuadModel
43                  .create(stmts);
44  
45          final List<Statement> mappingStmts = Lists.newArrayList();
46  
47          for (final Resource sentenceID : model
48                  .filter(null, RDF.TYPE, EVAL.SENTENCE, EVAL.METADATA).subjects()) {
49              final Map<String, IRI> graphs = Maps.newHashMap();
50              for (final Resource graphID : model.filter(null, DCTERMS.SOURCE, sentenceID,
51                      EVAL.METADATA).subjects()) {
52                  if (model.contains(graphID, RDF.TYPE, EVAL.KNOWLEDGE_GRAPH, EVAL.METADATA)) {
53                      final String creator = model
54                              .filter(graphID, DCTERMS.CREATOR, null, EVAL.METADATA).objectLiteral()
55                              .stringValue();
56                      graphs.put(creator, (IRI) graphID);
57                  }
58              }
59              final IRI goldGraphID = graphs.get("gold");
60              Preconditions.checkNotNull(goldGraphID);
61              final Collection<Statement> goldGraph = model.filter(null, null, null, goldGraphID);
62              LOGGER.info("Processing sentence {}, {} gold statements, {} test graphs", sentenceID,
63                      goldGraph.size(), graphs.size() - 1);
64              for (final String creator : graphs.keySet()) {
65                  if (!creator.equals("gold")) {
66                      final IRI testGraphID = graphs.get(creator);
67                      final Collection<Statement> testGraph = model.filter(null, null, null,
68                              testGraphID);
69                      final Map<IRI, IRI> sentenceMapping = align(goldGraph, testGraph);
70                      for (final Map.Entry<IRI, IRI> entry : sentenceMapping.entrySet()) {
71                          mappingStmts.add(Statements.VALUE_FACTORY.createStatement(entry.getKey(),
72                                  EVAL.MAPPED_TO, entry.getValue(), testGraphID));
73                      }
74                  }
75              }
76          }
77  
78          return mappingStmts;
79      }
80  
81      public static Map<IRI, IRI> align(final Iterable<Statement> goldStmts,
82              final Iterable<Statement> testStmts) {
83  
84          int goldNodesCount = 0;
85          final Multimap<IRI, IRI> goldMap = HashMultimap.create();
86          for (final Statement stmt : goldStmts) {
87              if (stmt.getPredicate().equals(EVAL.DENOTED_BY)) {
88                  ++goldNodesCount;
89                  goldMap.put((IRI) stmt.getObject(), (IRI) stmt.getSubject());
90              }
91          }
92  
93          int testNodesCount = 0;
94          final Multimap<IRI, IRI> testMap = HashMultimap.create();
95          for (final Statement stmt : testStmts) {
96              if (stmt.getPredicate().equals(EVAL.DENOTED_BY)) {
97                  ++testNodesCount;
98                  if (goldMap.containsKey(stmt.getObject())) {
99                      testMap.put((IRI) stmt.getObject(), (IRI) stmt.getSubject());
100                 }
101             }
102         }
103 
104         final Map<IRI, IRI> baseMapping = Maps.newHashMap();
105         final List<IRI> alternativesTestNodes = Lists.newArrayList();
106         final List<IRI[]> alternativesGoldNodes = Lists.newArrayList();
107         int alternativesCount = 1;
108         for (final IRI term : testMap.keySet()) {
109             final Collection<IRI> testNodes = testMap.get(term);
110             final Collection<IRI> goldNodes = goldMap.get(term);
111             for (final IRI testNode : testNodes) {
112                 if (goldNodes.size() == 1) {
113                     baseMapping.put(testNode, goldNodes.iterator().next());
114                 } else {
115                     alternativesTestNodes.add(testNode);
116                     alternativesGoldNodes.add(goldNodes.toArray(new IRI[goldNodes.size()]));
117                     alternativesCount *= goldNodes.size();
118                 }
119             }
120         }
121 
122         final Set<Relation> goldRelations = relationsFor(goldStmts);
123         final Set<Relation> testRelations = relationsFor(testStmts);
124 
125         Map<IRI, IRI> bestMapping = baseMapping;
126         PrecisionRecall bestPR = null;
127         int bestCount = 0;
128 
129         final int[] tps = new int[alternativesCount];
130         if (alternativesCount == 1) {
131             bestPR = evaluate(goldRelations, testRelations, baseMapping);
132 
133         } else {
134             for (int i = 0; i < alternativesCount; ++i) {
135                 final Map<IRI, IRI> mapping = Maps.newHashMap(baseMapping);
136                 int n = i;
137                 for (int j = 0; j < alternativesTestNodes.size(); ++j) {
138                     final IRI testNode = alternativesTestNodes.get(j);
139                     final IRI[] goldNodes = alternativesGoldNodes.get(j);
140                     final IRI goldNode = goldNodes[n % goldNodes.length];
141                     n = n / goldNodes.length;
142                     mapping.put(testNode, goldNode);
143                 }
144                 final PrecisionRecall pr = evaluate(goldRelations, testRelations, mapping);
145                 final int count = ImmutableSet.copyOf(mapping.values()).size();
146                 if (bestPR == null || pr.getTP() > bestPR.getTP() || pr.getTP() == bestPR.getTP()
147                         && count > bestCount) {
148                     bestPR = pr;
149                     bestCount = count;
150                     bestMapping = mapping;
151                 }
152                 tps[i] = (int) pr.getTP();
153             }
154         }
155 
156         int numOptimalSolutions = 0;
157         for (int i = 0; i < alternativesCount; ++i) {
158             if (tps[i] == (int) bestPR.getTP()) {
159                 ++numOptimalSolutions;
160             }
161         }
162         numOptimalSolutions = Math.max(1, numOptimalSolutions);
163 
164         if (LOGGER.isInfoEnabled()) {
165             final String creator = ((IRI) testStmts.iterator().next().getContext()).getLocalName();
166             LOGGER.info(
167                     "{} - {} gold nodes, {} test nodes, {} mapped nodes, {} alternatives, best PR ({}): {}",
168                     creator, goldNodesCount, testNodesCount, bestMapping.size(),
169                     alternativesCount, numOptimalSolutions, bestPR);
170         }
171 
172         return bestMapping;
173     }
174 
175     private static PrecisionRecall evaluate(final Set<Relation> goldRelations,
176             final Set<Relation> testRelations, final Map<IRI, IRI> mapping) {
177 
178         final Set<Relation> rewrittenTestRelations = new HashSet<>();
179         for (final Relation relation : testRelations) {
180             final Relation rewrittenRelation = rewrite(relation, mapping);
181             if (!rewrittenRelation.getFirst().equals(rewrittenRelation.getSecond())) {
182                 rewrittenTestRelations.add(rewrittenRelation);
183             }
184         }
185 
186         final int tp = Sets.intersection(goldRelations, rewrittenTestRelations).size();
187         final int fp = rewrittenTestRelations.size() - tp;
188         final int fn = goldRelations.size() - tp;
189 
190         return PrecisionRecall.forCounts(tp, fp, fn);
191     }
192 
193     private static Relation rewrite(final Relation relation, final Map<IRI, IRI> mapping) {
194         final IRI first = (IRI) rewrite(relation.getFirst(), mapping);
195         final IRI second = (IRI) rewrite(relation.getSecond(), mapping);
196         return first == relation.getFirst() && second == relation.getSecond() ? relation
197                 : new Relation(first, second, relation.isExtra());
198     }
199 
200     private static Value rewrite(final Value value, final Map<IRI, IRI> mapping) {
201         if (value instanceof IRI) {
202             final IRI mappedValue = mapping.get(value);
203             if (mappedValue != null) {
204                 return mappedValue;
205             }
206         }
207         return value;
208     }
209 
210     private static Set<Relation> relationsFor(final Iterable<Statement> stmts) {
211         final Set<IRI> nodes = Sets.newHashSet();
212         for (final Statement stmt : stmts) {
213             if (stmt.getPredicate().equals(EVAL.DENOTED_BY)) {
214                 nodes.add((IRI) stmt.getSubject());
215             }
216         }
217         final Set<Relation> relations = Sets.newHashSet();
218         for (final Statement stmt : stmts) {
219             if (!stmt.getPredicate().equals(EVAL.CLASSIFIABLE_AS)
220                     && !stmt.getPredicate().equals(EVAL.ASSOCIABLE_TO)
221                     && !stmt.getPredicate().equals(EVAL.NOT_ASSOCIABLE_TO)
222                     && !stmt.getSubject().equals(stmt.getObject()) //
223                     && nodes.contains(stmt.getSubject()) //
224                     && (nodes.contains(stmt.getObject()) || stmt.getPredicate().equals(RDF.TYPE))) {
225                 relations
226                         .add(new Relation((IRI) stmt.getSubject(), (IRI) stmt.getObject(), false));
227             }
228         }
229         return relations;
230     }
231 
232     public static void main(final String[] args) {
233 
234         try {
235             // Parse command line
236             final CommandLine cmd = CommandLine
237                     .parser()
238                     .withName("eval-aligner")
239                     .withHeader("Alignes the knowledge graphs produced by different tools " //
240                             + "againsts a gold graph")
241                     .withOption("o", "output", "the output file", "FILE", Type.STRING, true,
242                             false, true) //
243                     .withLogger(LoggerFactory.getLogger("eu.fbk")) //
244                     .parse(args);
245 
246             // Extract options
247             final String outputFile = cmd.getOptionValue("o", String.class);
248             final List<String> inputFiles = cmd.getArgs(String.class);
249 
250             // Read the input
251             final Map<String, String> namespaces = Maps.newHashMap();
252             final QuadModel input = QuadModel.create();
253             RDFSources.read(false, false, null, null, null,true,
254                     inputFiles.toArray(new String[inputFiles.size()])).emit(
255                     RDFHandlers.wrap(input, namespaces), 1);
256 
257             // Perform the alignment
258             final List<Statement> mappingStmts = align(input);
259             input.addAll(mappingStmts);
260 
261             // Write the output
262             final RDFHandler out = RDFHandlers.write(null, 1000, outputFile);
263             out.startRDF();
264             namespaces.put(DCTERMS.PREFIX, DCTERMS.NAMESPACE);
265             for (final Map.Entry<String, String> entry : namespaces.entrySet()) {
266                 if (!entry.getKey().isEmpty()) {
267                     out.handleNamespace(entry.getKey(), entry.getValue());
268                 }
269             }
270             for (final Statement stmt : Ordering.from(
271                     Statements.statementComparator("cspo",
272                             Statements.valueComparator(RDF.NAMESPACE))).sortedCopy(input)) {
273                 out.handleStatement(stmt);
274             }
275             out.endRDF();
276 
277         } catch (final Throwable ex) {
278             // Display error information and terminate
279             CommandLine.fail(ex);
280         }
281     }
282 
283 }