1 package eu.fbk.dkm.pikes.eval;
2
3 import java.util.Collection;
4 import java.util.HashSet;
5 import java.util.List;
6 import java.util.Map;
7 import java.util.Set;
8
9 import com.google.common.base.Preconditions;
10 import com.google.common.collect.HashMultimap;
11 import com.google.common.collect.ImmutableSet;
12 import com.google.common.collect.Lists;
13 import com.google.common.collect.Maps;
14 import com.google.common.collect.Multimap;
15 import com.google.common.collect.Ordering;
16 import com.google.common.collect.Sets;
17
18 import org.eclipse.rdf4j.model.Resource;
19 import org.eclipse.rdf4j.model.Statement;
20 import org.eclipse.rdf4j.model.IRI;
21 import org.eclipse.rdf4j.model.Value;
22 import org.eclipse.rdf4j.model.vocabulary.DCTERMS;
23 import org.eclipse.rdf4j.model.vocabulary.RDF;
24 import org.eclipse.rdf4j.rio.RDFHandler;
25 import org.slf4j.Logger;
26 import org.slf4j.LoggerFactory;
27
28 import eu.fbk.utils.core.CommandLine;
29 import eu.fbk.utils.core.CommandLine.Type;
30 import eu.fbk.utils.eval.PrecisionRecall;
31 import eu.fbk.rdfpro.RDFHandlers;
32 import eu.fbk.rdfpro.RDFSources;
33 import eu.fbk.rdfpro.util.QuadModel;
34 import eu.fbk.rdfpro.util.Statements;
35
36 public class Aligner {
37
38 private static final Logger LOGGER = LoggerFactory.getLogger(Aligner.class);
39
40 public static List<Statement> align(final Collection<Statement> stmts) {
41
42 final QuadModel model = stmts instanceof QuadModel ? (QuadModel) stmts : QuadModel
43 .create(stmts);
44
45 final List<Statement> mappingStmts = Lists.newArrayList();
46
47 for (final Resource sentenceID : model
48 .filter(null, RDF.TYPE, EVAL.SENTENCE, EVAL.METADATA).subjects()) {
49 final Map<String, IRI> graphs = Maps.newHashMap();
50 for (final Resource graphID : model.filter(null, DCTERMS.SOURCE, sentenceID,
51 EVAL.METADATA).subjects()) {
52 if (model.contains(graphID, RDF.TYPE, EVAL.KNOWLEDGE_GRAPH, EVAL.METADATA)) {
53 final String creator = model
54 .filter(graphID, DCTERMS.CREATOR, null, EVAL.METADATA).objectLiteral()
55 .stringValue();
56 graphs.put(creator, (IRI) graphID);
57 }
58 }
59 final IRI goldGraphID = graphs.get("gold");
60 Preconditions.checkNotNull(goldGraphID);
61 final Collection<Statement> goldGraph = model.filter(null, null, null, goldGraphID);
62 LOGGER.info("Processing sentence {}, {} gold statements, {} test graphs", sentenceID,
63 goldGraph.size(), graphs.size() - 1);
64 for (final String creator : graphs.keySet()) {
65 if (!creator.equals("gold")) {
66 final IRI testGraphID = graphs.get(creator);
67 final Collection<Statement> testGraph = model.filter(null, null, null,
68 testGraphID);
69 final Map<IRI, IRI> sentenceMapping = align(goldGraph, testGraph);
70 for (final Map.Entry<IRI, IRI> entry : sentenceMapping.entrySet()) {
71 mappingStmts.add(Statements.VALUE_FACTORY.createStatement(entry.getKey(),
72 EVAL.MAPPED_TO, entry.getValue(), testGraphID));
73 }
74 }
75 }
76 }
77
78 return mappingStmts;
79 }
80
81 public static Map<IRI, IRI> align(final Iterable<Statement> goldStmts,
82 final Iterable<Statement> testStmts) {
83
84 int goldNodesCount = 0;
85 final Multimap<IRI, IRI> goldMap = HashMultimap.create();
86 for (final Statement stmt : goldStmts) {
87 if (stmt.getPredicate().equals(EVAL.DENOTED_BY)) {
88 ++goldNodesCount;
89 goldMap.put((IRI) stmt.getObject(), (IRI) stmt.getSubject());
90 }
91 }
92
93 int testNodesCount = 0;
94 final Multimap<IRI, IRI> testMap = HashMultimap.create();
95 for (final Statement stmt : testStmts) {
96 if (stmt.getPredicate().equals(EVAL.DENOTED_BY)) {
97 ++testNodesCount;
98 if (goldMap.containsKey(stmt.getObject())) {
99 testMap.put((IRI) stmt.getObject(), (IRI) stmt.getSubject());
100 }
101 }
102 }
103
104 final Map<IRI, IRI> baseMapping = Maps.newHashMap();
105 final List<IRI> alternativesTestNodes = Lists.newArrayList();
106 final List<IRI[]> alternativesGoldNodes = Lists.newArrayList();
107 int alternativesCount = 1;
108 for (final IRI term : testMap.keySet()) {
109 final Collection<IRI> testNodes = testMap.get(term);
110 final Collection<IRI> goldNodes = goldMap.get(term);
111 for (final IRI testNode : testNodes) {
112 if (goldNodes.size() == 1) {
113 baseMapping.put(testNode, goldNodes.iterator().next());
114 } else {
115 alternativesTestNodes.add(testNode);
116 alternativesGoldNodes.add(goldNodes.toArray(new IRI[goldNodes.size()]));
117 alternativesCount *= goldNodes.size();
118 }
119 }
120 }
121
122 final Set<Relation> goldRelations = relationsFor(goldStmts);
123 final Set<Relation> testRelations = relationsFor(testStmts);
124
125 Map<IRI, IRI> bestMapping = baseMapping;
126 PrecisionRecall bestPR = null;
127 int bestCount = 0;
128
129 final int[] tps = new int[alternativesCount];
130 if (alternativesCount == 1) {
131 bestPR = evaluate(goldRelations, testRelations, baseMapping);
132
133 } else {
134 for (int i = 0; i < alternativesCount; ++i) {
135 final Map<IRI, IRI> mapping = Maps.newHashMap(baseMapping);
136 int n = i;
137 for (int j = 0; j < alternativesTestNodes.size(); ++j) {
138 final IRI testNode = alternativesTestNodes.get(j);
139 final IRI[] goldNodes = alternativesGoldNodes.get(j);
140 final IRI goldNode = goldNodes[n % goldNodes.length];
141 n = n / goldNodes.length;
142 mapping.put(testNode, goldNode);
143 }
144 final PrecisionRecall pr = evaluate(goldRelations, testRelations, mapping);
145 final int count = ImmutableSet.copyOf(mapping.values()).size();
146 if (bestPR == null || pr.getTP() > bestPR.getTP() || pr.getTP() == bestPR.getTP()
147 && count > bestCount) {
148 bestPR = pr;
149 bestCount = count;
150 bestMapping = mapping;
151 }
152 tps[i] = (int) pr.getTP();
153 }
154 }
155
156 int numOptimalSolutions = 0;
157 for (int i = 0; i < alternativesCount; ++i) {
158 if (tps[i] == (int) bestPR.getTP()) {
159 ++numOptimalSolutions;
160 }
161 }
162 numOptimalSolutions = Math.max(1, numOptimalSolutions);
163
164 if (LOGGER.isInfoEnabled()) {
165 final String creator = ((IRI) testStmts.iterator().next().getContext()).getLocalName();
166 LOGGER.info(
167 "{} - {} gold nodes, {} test nodes, {} mapped nodes, {} alternatives, best PR ({}): {}",
168 creator, goldNodesCount, testNodesCount, bestMapping.size(),
169 alternativesCount, numOptimalSolutions, bestPR);
170 }
171
172 return bestMapping;
173 }
174
175 private static PrecisionRecall evaluate(final Set<Relation> goldRelations,
176 final Set<Relation> testRelations, final Map<IRI, IRI> mapping) {
177
178 final Set<Relation> rewrittenTestRelations = new HashSet<>();
179 for (final Relation relation : testRelations) {
180 final Relation rewrittenRelation = rewrite(relation, mapping);
181 if (!rewrittenRelation.getFirst().equals(rewrittenRelation.getSecond())) {
182 rewrittenTestRelations.add(rewrittenRelation);
183 }
184 }
185
186 final int tp = Sets.intersection(goldRelations, rewrittenTestRelations).size();
187 final int fp = rewrittenTestRelations.size() - tp;
188 final int fn = goldRelations.size() - tp;
189
190 return PrecisionRecall.forCounts(tp, fp, fn);
191 }
192
193 private static Relation rewrite(final Relation relation, final Map<IRI, IRI> mapping) {
194 final IRI first = (IRI) rewrite(relation.getFirst(), mapping);
195 final IRI second = (IRI) rewrite(relation.getSecond(), mapping);
196 return first == relation.getFirst() && second == relation.getSecond() ? relation
197 : new Relation(first, second, relation.isExtra());
198 }
199
200 private static Value rewrite(final Value value, final Map<IRI, IRI> mapping) {
201 if (value instanceof IRI) {
202 final IRI mappedValue = mapping.get(value);
203 if (mappedValue != null) {
204 return mappedValue;
205 }
206 }
207 return value;
208 }
209
210 private static Set<Relation> relationsFor(final Iterable<Statement> stmts) {
211 final Set<IRI> nodes = Sets.newHashSet();
212 for (final Statement stmt : stmts) {
213 if (stmt.getPredicate().equals(EVAL.DENOTED_BY)) {
214 nodes.add((IRI) stmt.getSubject());
215 }
216 }
217 final Set<Relation> relations = Sets.newHashSet();
218 for (final Statement stmt : stmts) {
219 if (!stmt.getPredicate().equals(EVAL.CLASSIFIABLE_AS)
220 && !stmt.getPredicate().equals(EVAL.ASSOCIABLE_TO)
221 && !stmt.getPredicate().equals(EVAL.NOT_ASSOCIABLE_TO)
222 && !stmt.getSubject().equals(stmt.getObject())
223 && nodes.contains(stmt.getSubject())
224 && (nodes.contains(stmt.getObject()) || stmt.getPredicate().equals(RDF.TYPE))) {
225 relations
226 .add(new Relation((IRI) stmt.getSubject(), (IRI) stmt.getObject(), false));
227 }
228 }
229 return relations;
230 }
231
232 public static void main(final String[] args) {
233
234 try {
235
236 final CommandLine cmd = CommandLine
237 .parser()
238 .withName("eval-aligner")
239 .withHeader("Alignes the knowledge graphs produced by different tools "
240 + "againsts a gold graph")
241 .withOption("o", "output", "the output file", "FILE", Type.STRING, true,
242 false, true)
243 .withLogger(LoggerFactory.getLogger("eu.fbk"))
244 .parse(args);
245
246
247 final String outputFile = cmd.getOptionValue("o", String.class);
248 final List<String> inputFiles = cmd.getArgs(String.class);
249
250
251 final Map<String, String> namespaces = Maps.newHashMap();
252 final QuadModel input = QuadModel.create();
253 RDFSources.read(false, false, null, null, null,true,
254 inputFiles.toArray(new String[inputFiles.size()])).emit(
255 RDFHandlers.wrap(input, namespaces), 1);
256
257
258 final List<Statement> mappingStmts = align(input);
259 input.addAll(mappingStmts);
260
261
262 final RDFHandler out = RDFHandlers.write(null, 1000, outputFile);
263 out.startRDF();
264 namespaces.put(DCTERMS.PREFIX, DCTERMS.NAMESPACE);
265 for (final Map.Entry<String, String> entry : namespaces.entrySet()) {
266 if (!entry.getKey().isEmpty()) {
267 out.handleNamespace(entry.getKey(), entry.getValue());
268 }
269 }
270 for (final Statement stmt : Ordering.from(
271 Statements.statementComparator("cspo",
272 Statements.valueComparator(RDF.NAMESPACE))).sortedCopy(input)) {
273 out.handleStatement(stmt);
274 }
275 out.endRDF();
276
277 } catch (final Throwable ex) {
278
279 CommandLine.fail(ex);
280 }
281 }
282
283 }