antirez 45 days ago. 4778 views.
From a theoretical standpoint, the best reply provided by an LLM is obtained by always picking the token associated with the highest probability. This approach makes the LLM output deterministic, which is not a good property for a number of applications. For this reason, in order to balance LLMs creativity while preserving adherence to the context, different sampling algorithms have been proposed in recent years.

Today one of the most used ones, more or less the default, is called top-p: it is a form of nucleus sampling where top-scoring tokens are collected up to a total probability sum of “p”, then random weighted sampling is performed.

In this blog post I’ll examine why I believe nucleus sampling may not be the best approach, and will show a simple and understandable alternative in order to avoid the issues of nucleus sampling. The algorithm is yet a work in progress, but by publishing it now I hope to stimulate some discussion / hacking.

## There is some gold in the logits

Despite the fact that LLM logits are one of the few completely understandable parts of the LLM inner working, I generally see very little interest in studying their features, investigating more advanced sampling methods, detecting and signaling users uncertainty and likely hallucination. Visualizing the probabilities distribution for successive tokens is a simple and practical exercise in order to gain some insights:

In the image we can see the top 32 candidate tokens colored by probability (white = 0, blue = 1), the selected token and the rank of the selected token (highest probability = 0, the previous one = 1, and so forth).

In the above example, the Mistral base model knows the birth and death dates of Umberto Eco, so it confidently signals the most likely token with most of the total probability. Other times the model is more perplexed because either there are multiple ways to express the continuation of the text, or because it is not certain about certain facts. Asking the date of a name which birthday was not learned during training produces a different distribution of token probabilities.

However in general what we want to avoid is to select suboptimal tokens, putting the LLM generation outside the path of minimum perplexity and hallucination. Nucleus sampling fails at doing this because accumulating tokens up to “p”, depending on the distribution may include tokens that are extremely weaker than the first choice. Consider, for instance, the “writer” token in the image above. Umberto Eco was a writer, indeed. The token associated with writer, while it is not scoring with a very high value, it is still a lot more above the second choice. Yet a p of 0.3 may accumulate the second token with a low value like 0.01, and yield it with some single-digit probability, with the risk of putting the generation in the wrong path.

Instead, in the case of the following token, “who”, there are multiple choices with a relatively similar score, that could be used for alternative generations.

## Making use of the avalanche effect

A key observation here is that to select too weak tokens for the sake of variability is not a good deal: even if we only exploit generations moments where there are a few good candidates in order to diversify the output, the avalanche effect will help us: the input context will change, thus the output of the LLM will be perturbed, with the effect of making it more likely to produce some alternative version of the text.

## First Token Cutoff (FTC) algorithm

DISCLAIMER: The algorithm described here hasn’t undergone any scientific scrutiny. I experimented a couple of days with different sampling algorithms having “bound worst token” properties, and this one looks like the best balance between applicability, results and understandability.

The algorithm described here can be informally stated as follows:

- When the LLM is strongly biased towards a given candidate, select it.
- When there are multiple viable candidates, produce alternatives.
- The selection of the worst possible token should be bounded to a given amount.

Past work, like Tail Free Sampling, also noted that a selection should be made across a small set of high-quality tokens emitted by the LLM. However in TFS such set is identified by performing the derivative to select a cluster corresponding to the tokens that don’t see a steep curve after which the token quality decreases strongly.

In the algorithm proposed here, instead, we want the selection to follow a more bounded and understandable cut-off relative to the highest scoring token T0, attributing to T0 a special meaning compared to all the other tokens: the level of certainty the LLM has during the emission of such token (a proxy of perplexity, basically). Thus the algorithm refuses every token that is worse than a given percentage if compared to T0.

The cut off percentage, that can have a value from 0 to 1, is called “co”.
An example and viable “co” could be 0.5.

This is how the algorithm works:

1. Compute softmax() of logits.
2. Sort tokens by probability.
3. Given T0, the probability of the best token, compute the ratio of all the other tokens as:

    r = 1 - (T[i] / T0)

4. Select only tokens for which r <= co
5. Perform weighted random pick among the selected tokens.

Note that in this way, regardless of the fact that tokens may have a smooth monotonically decreasing value, there is a hard limit to the tokens we can include in the set of possibilities. Instead with other methods that try to identify high-score clusters, this is not the case.

## Practical examples

One reason why nucleus top-p sampling does not look to fail catastrophically in the practice, is that often times the first token probability is very high, thus when the perplexity is low, and also all the times casually we don’t collect and then pick low-quality tokens, the generation continues along a sensible path.

Things are more problematic when there are successive tokens with probabilities like:

0.25, 0.14, 0.01

With a p=0.4, we could collect the third low quality token and yield it ~3% probability.

Now consider First Token Cutoff with a co value of 0.5 (token can be up to 50% worse than first one):

The second token r value is:

r[t1] = 1-(0.14/0.25) = 0.44 # 0.44 <= 0.5, this token is accepted
r[t2] = 1-(0.01/0.25) = 0.96 # 0.96 > 0.5, this token is refused

## Example output

Outputs of Mistral base model (no instruct) with co=0.7 for the prompt “Sorted sets are”.
The outputs are three successful outputs not cherry picked for quality.

1. Sorted sets are a powerful data structure in Redis. They can be used to store sorted lists, to store unique values, to store scores for ranking, and to store a sorted list of sorted sets.

2. Sorted sets are a very powerful data structure. They allow you to store data in a way that makes it easy to find the highest or lowest values in the set, and they also allow you to sort the data. This can be useful for many different tasks, such as ranking users by their score, or finding the most popular items in a database.

3. Sorted sets are a very powerful data structure that can be used to solve many different problems. The most common use case is to store a list of unique elements, each of which has an associated value. For example, you could use a sorted set to store the names of all the people in your family, with their ages as the associated values.

## Why understandability matters

Sampling parameters are among the few things that the end user of LLMs, or the API user, must tune. Trial and error is often needed, however to have a single tunable parameter for which there is an immediate real-world description and intuition helps a lot. Moreover "co" is a linear parameter, so it's particularly simple to reason about VS parameters like temperature, or even "p" of top_p that while linear strongly depends on the distribution shape of the logits.

## Future work

I’m at the start of my investigations, so I’ll study and evaluate better this algorithm. More than anything else, I would love to see more interest in sampling algorithms and more interest in moving forward from top-p alike approaches.

There are perhaps interesting information to collect from the logits distribution. For example it is likely that a linear probe could be able to learn when the hidden layers of an LLM are dealing with some factual information. This, with the token perplexity, could be used in order to show the user of LLMs that some part of the output is likely wrong. In general visualizing tokens probabilities distribution is very informative and in some way allows to touch with bare hands how LLMs work and what are the candidates at each step.

## Reference implementation

    logits = mx.softmax(logits)
    np_logits = np.array(logits) # MX -> NumPy
    np_logits = np_logits.flatten()
    sorted_indices = np.argsort(np_logits)
    sorted_indices = sorted_indices[::-1]

    co = 0.7
    j = 1
    t0 = np_logits[sorted_indices[0]]
    while 1 - (np_logits[sorted_indices[j]] / t0) < co and j < len(np_logits):
        j += 1
    accepted_logits = []
    for i in range(0,j):
    accepted_logits = mx.array(accepted_logits)

    idx = mx.random.categorical(accepted_logits)
    idx = int(np.array(idx)) # Convert zero-dim array to scalar
    token_id = sorted_indices[idx]

## Credits

These experiments were really simple to perform thanks to the MLX library from Apple, and the cool and smart developers working incessantily at it. MLX is extremely accessible, like it should be: after all LLMs are imprescrutable, but the inference itself is a simple process.
🚀 Dear reader, the first six chapters of my AI sci-fi novel, WOHPE, are now available as a free eBook. Click here to get it.
blog comments powered by Disqus