This paper deals with question answering using information stored in knowledge graphs.
Terminology
Knowledge Graph is our knowledge base represented as a bunch of entities and relations between
them. In the Computer Science lingo, entity can be called vertex and relation can be called edge. Here is an
example from the paper:
Figure 1 from the paper: A small fragment of a knowledge base represented as a knowledge graph.
Question Answering in this context means taking a graph like the one above and asking for
example “Who did Malala Yousafzai share her Nobel Peace Prize with?”. The answer can be deduced
by putting together Malala Yousafzai → WonAward → Nobel Peace Prize 2014 → AwardedTo → Kailash Satyarthi.
Approach
Dumb approaches like trying all possible paths are impractical for large knowledge graphs. So the
authors frame the reasoning process as a Reinforcement Learning problem the way you would probably
do it to: Our agent starts at an initial entity (like “Malala Yousafzai”) and walks through the graph
until it decides that it reached the desired answer. Each decision is therefore a choice between all
the available edges going out of the current vertex (or a decision to stop). To make that decision,
the agent is conditioned on the history of all its previous decisions (in some fixed-length
representation), on its current position and on the query relation (like “SharesNobelPriceWith” -
note the “Malala Yousafzai” is already contained in the history, so it need not be here).
To be more precise, instead of the “stop” instruction, the authors add a “no-op” instruction that
the agent can use to stay at a node if it thinks it already arrived at the answer. Also, for every
relation, an inverse relation is added to the graph (if it’s not already there) to enable the agent
to potentially recover from a wrong decision. Each simulation than takes a predetermined number of
steps ranging from 2 to 10 based on the dataset. The agent gets a terminal reward of 1 if it ends up
on the correct entity and 0 otherwise.
All the decision-making is done using neural networks, so entities, relations, and history all have
fixed-length continuous vector embeddings. For entity and relation
embeddings authors don’t give any details. Judging from the code, the embedding size ranges from 4
(on the COUNTRIES dataset) to 100, uses Glorot initialization, and can either be trained or not
(they use different fine-tuned hyperparameters for each dataset). History is encoded using LSTM,
where in each step we “mix-in” the action we took and the state we ended up in (by which I mean the
concatenation of their embeddings). Dimensionality of the LSTM is 4-times that of entity/relation
embeddings and there is only 1 LSTM layer.
Based on history embedding and query relation embedding, the agent decides which outgoing edge
to pursue. This is done with a simple 1-hidden-layer feed-forward network with ReLU on the
hidden layer. We take its output and dot-product it with the embedding of every available edge.
Then we softmax the resulting “preference vector” and sample from the resulting categorical
distribution (our policy is therefore stochastic). This gives us our next action.
Authors give the following equations for what we just described (with biases excluded for
brevity). Note that \(\mathbf{A_t}\) is a matrix obtained by stackings embedding for all the outgoing edges in
the current vertex.
\[\mathbf{d_t} = \mathrm{softmax}\left( \mathbf{A_{t}}(\mathbf{W_{2}}\mathrm{ReLU}\left(\mathbf{W_{1}}\left[\mathbf{h_t}; \mathbf{o_t};\mathbf{r_q}\right]\right)\right))\]
\[A_t \sim \mathrm{Categorical}\left(\mathbf{d_t}\right)\]
Why use the dot-product similarity for choosing best edge? Intuitively, it’s the same as using
cosine-similarity, but without normalizing both vectors first. This means that the network can
encode some information about the prior probability of each edge. Why the output of the decision
network is also not normalized is unclear to me and it does have an effect, since softmax is not invariant under scalar
multiplication. Effectively, the network can increase/decrease temperature of the stochastic
sampling by scaling down/up its output.
Training
Training is done using the REINFORCE algorithm with baseline. For each training example (i.e.
training query), results from 20 rounds are averaged to reduce variance. Baseline (i.e. additive
control variate) is computed as exponential moving average of previous rewards. The decay parameter
is 0.95 for most experiments, but ranges from 0.93 to 1 (with the value 1 effectively turning
the baseline off, since it’s initialized to 0). Interestingly, the authors also tried to use a
learned baseline but reported similar performance.
To encourage exploration during training, the loss function contains an entropy regularization
term scaled by a parameter \(\beta\). This is an attempt to prevent the categorical distribution from
collapsing to something like one-hot too early, which would cause some options to not be
sufficiently explored. The \(\beta\) parameter decays with the rate of 0.9 every 200 batches,
starting from a value between 0.02 and 0.1 based on the dataset.
Experiments
The authors conduct experiments on 7 datasets with vastly different numbers of entities, types of
relations, numbers of edges and numbers of available queries, as shown in the following table.
Table 1 from the paper: Statistics of various datasets used in experiments.
Let’s focus on experiments regarding the classic knowledge base question answering.
When answering a question using the simulation technique described above, the authors use beam
search with a beam width of 50. At the end of the search, entities are ranked based on the
probability of that path the model took to reach them, i.e. how much the model trusts them to be
the correct answer. Position of an entity in this sorted list is it’s rank. The following metrics
are then used:
- AUC-PR - Area Under the Precision-Recall Curve. The exact details about how it was computed are
not given.
- MRR - Mean Reciprocal Rank is the inverse of rank of the correct answer.
- HITS@1,3,10 - How often does the correct answer have rank 1, rank at most 3 or rank at most 10?
Results are shown in the paper in tables 2, 3, and 4.
Other interesting things
- The WikiMovies dataset contains questions in natural language (although automatically
generated) like “Which is a film written by Herb Freed?”. Authors extract the entity (“Herb
Freed”) using simple string matching and encode the query as an average of embeddings of the
question words.
- The authors also test MINERVA on how it’s able to navigate a synthetic 16-by-16 grid world and
show that it really shines when the reasoning requires many steps (like 10).