Search
  • Louis Castricato

On Neural Persistence, Improved Language Models, and Narrative Complexity

Updated: Mar 10


Introduction

There is an amazing paper I got to read last week by [Rieck19] on the subject of persistent homology. The idea is that by borrowing ideas from topological data science, we can construct a per layer complexity metric. This complexity metric can then shed light into how generalized our learned representation is.


Namely if we were to train N disjoint MLPs on classification of MNIST and compare their complexity metrics, those with lower layer-wise metrics would tend to be more generalized as the representations require less complicated rules/operations in order to accomplish the task of classification.


I will summarize the work of Rieck, and suggest future approaches to provide better distilment of attention based language models.


Finally, I conclude with a set of conjectures about how neural persistence might be of use in storytelling.


A brief discussion on persistent homology

Consider a neural network with layers L = {L1, L2, ... }. Consider connection weights W = {W[1,2], W[2,3], ...}


A filtration function is as follows


f_k(w_ij) = {{i,j}: |w_{i,j}| > k, None : otherwise}

F = Filter(W, k) = Map(f_k, Normalize(W))


Which constructs a set of edges such that {i,j} is in F if and only if w_ij > k.


For some constant k, we define the filtration on connection matrix pairs Wi,Wj as


F'_ij(k) = DisjointUnion(Filter(Wi, k), Filter(Wj, k))


Sample Unif[0,1] N times for a vector K = Sort({k1, k2, ..., kN}) and construct the set


F'_ij = {F'_ij(k1), ..., F'_ij(kN)}.


F'_ij is called our filtration on Wi,Wj. Each element of this filtration is called a simplicial complex, and stores inherent topological information about our network at Wi,Wj. Furthermore, clearly F'_ij(1) includes few edges and F'_ij(0) includes all edges by the fact that we normalize our weight matrices.


These two cases are not really that interesting to us, since both do not include interesting topological information (Considering that our MLP is already multi-partite.) Hence, these are edge cases ;P



Figure 1: Diagram of filtration on a two layer MLP. Source: [Rieck19]


Let Beta0(k) refer to the connected components of F'_ij(k). Furthermore, let Beta1(k) refer to the number of loops of some element of F'_ij(k). There are a million different libraries for computing such values of Beta, I recommend either Aleph from ETH or Giotto-TDA from giotto.ai.


For every connected component and loop that exist for some value of k, let (b,d) refer to their "birth, death" vector or namely the value of k when they appear and then the value of k when they vanish. Constructing a scatter plot of these vectors (b,d) is called a persistence diagram.


Neural Persistence

From here, as outlined in Rieck, given that the p-norm is is shown to be a stable metric on summarizing topological features, we can apply the p-norm to our set of vectors pers(b,d) via the following formula.


Figure 2: Neural Persistence formula. Also taken from the same source.


Where, in our case, if we were to apply filtration to (W[1,2],W[2,3])- G_k is L2. More details on this formula is explored in Rieck. We won't go over specifics and proofs.


Every layer has a maximum possible neural persistence. This is proven in the original paper. Thus we can normalize NP(Li) which allows us to compare the complexity layers of different sizes. Hence, from this point on let NP(Li) refer to the normalized metric on Li.


Finally, let


NP(L) = {NP(L2), ... , NP(LN)}


refer to the neural persistence vector over our entire neural network.


Structural Dropout


At this point I strongly recommend you read the paper. At least the later half, from section 4 to the conclusion. Here is the link.


A major theme of the paper is that dropout noticeably decreases neural persistence. The obvious next step is to exploit this relationship in order to improve dropout. For instance, when training ResNet, it would be optimal to drop layers that maximize the layer-wise neural persistence. Hence, layer drop might be significantly improved by applying the gumbel trick to NP(L) in order to determine which layers must be dropped. More work is needed in this direction.

Application to NLP

If we consider for a moment that attention weights are also simply connection weights, a concept that can easily be shown by recalling that transformers are just a form of graph neural network [Unpublished: Joshi20], it makes immediate sense to include them in our filtration.


Since neural persistence is normalized per layer, this metric is independent of the number of tokens we are attending over. Consider the following seq2seq model.


BiLSTM->(Attention, LSTM)


We can unfold both the encode and decode LSTMs over time, and since it is common to drop attention weights on every decode step, the obvious application is to now freeze the gradient of the attention layer on decode steps of particularly high complexity.


This same idea can be applied to training BERT, namely in determining which tokens to mask. We can construct a masking curriculum by determining which tokens are the most complex to adequately encode. Obviously the same attention structural dropout as described above still applies. Once again, more work is needed. I will post a new blog once I have completed these experiments.


Narratives

As mentioned in my first blog post here, a major issue in plot hole detection is determining entry points for where to conduct one's multihop reasoning. My original conjecture, as seen here, relied on sparse recurrent switching dynamics to construct data flow graphs between latent vectors of different plot points. The complexity of said data flow graph, plot point wise, then determines the complexity of the narrative at that point.


Rather if we tell a language model, say GPT2, to generate a short story then we do not have this latent fabula structure. While such systems can be paired, see the recent work of [Ammanabrolu20], they are still very early in development and more often than not providing strong supervision to language models has proven successful in the past. Hence, while a future approach might involve figuring out how to include neural persistence in hybrid symbolic-neural approaches, for now we will confine ourselves to the information GPT2 gives us.


Finally, I present a few potential conjectures that might prove to be useful within the realm of story generation. Let S refer to a story generated by GPT2, S = {S1, ... , SN}. We will unfold GPT2 over time, for every token Si, that it generates.


  1. The distribution of token-wise neural persistence should loosely follow a story arc curve, as specified by the designer. That is to say, the story complexity should roughly follow a story arc. Such loss can be introduced by utilizing KL-Divergence.

  2. In the event of hybrid symbol-neural language models, (1) no longer holds. Future work into forming a conjecture might be into differentiating fabula vs syuzhet story arcs. (1) might only possibly work since they are effectively combined. I am sure there is a narratologist out there who knows far more about this than me.

  3. Rapid spikes in k-window-wise neural persistence need to be significantly penalized. That is to say, when windowing over k tokens at a time there should be no significant spikes. This can be done by finding the distribution of k-window-wise neural persistence and penalizing outliers by their residual.


More on this in a few months.

574 views

Recent Posts

See All

©2019 by Louis Castricato