<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://karlosmid.com/feed.xml" rel="self" type="application/atom+xml" /><link href="https://karlosmid.com/" rel="alternate" type="text/html" /><updated>2026-03-29T17:32:51+00:00</updated><id>https://karlosmid.com/feed.xml</id><title type="html">Karlo Šmid Blog</title><subtitle>karlosmid.github.io</subtitle><entry><title type="html">An unexpected display of art</title><link href="https://karlosmid.com/2026/03/an-unexpected-display-of-art/" rel="alternate" type="text/html" title="An unexpected display of art" /><published>2026-03-29T08:00:00+00:00</published><updated>2026-03-29T08:00:00+00:00</updated><id>https://karlosmid.com/2026/03/an-unexpected-display-of-art</id><content type="html" xml:base="https://karlosmid.com/2026/03/an-unexpected-display-of-art/"><![CDATA[<h2 id="tldr">TL;DR</h2>

<p>How we had the privilege of attending a Testival meetup and seeing an unexpected display of art.</p>

<h2 id="the-meetup">The meetup</h2>

<p>Zeljko organized this Testival #79 at short notice. It was our second discussion about the book Taking Testing Seriously. There were three of us, and Bojan was the third attendee. Zeljko was the only one who finished the book. He had 25 pages of Kindle notes. He noticed one interesting thing: all of his highlights were from the first two chapters.</p>

<p>The book also contains interview chapters, and with all respect to the interviewees, Zeljko noticed that this was the weakest part of the book. It seems like James and Michael were missing some pages and then decided to add interviews with people who helped shape the Rapid Software Testing methodology.</p>

<p>Bojan was asking for advice on which MacBook he should buy. For him, the important feature is display size, because he uses a laptop instead of a TV set. The problem is that each extra inch increases the price significantly. My only advice was not to buy a Mac with a Touch Bar, which is no longer in production.</p>

<p>I also learned about a free app <a href="https://1001albumsgenerator.com/">One Album A Day</a> that regularly suggests an essential music album that you should listen to before you die. A very interesting idea presented by Bojan.</p>

<p>The summer concert season in Zagreb is coming on pretty strong. I’ll be attending shows by Sepultura, The Offspring, A Perfect Circle, Anthrax, Voivod, Biohazard, Slaughter to Prevail, and Fear Factory.</p>

<p>And a shout-out to Zeljko, who moved our <a href="https://testival.eu">testival.eu</a> WordPress export to Jekyll on GitHub Pages, without images, which was my mistake. This is a much better option because it can be easily moved to any provider, and hosting on GitHub is free.</p>

<h2 id="the-art">The art</h2>

<p>And now the most interesting part. Bojan had to go home because his wife sent him a message saying that he had a surprise waiting for Father’s Day the next day. Here is the art created by two six-year-olds, in their original font, with some help from their mother.</p>

<p><img src="/assets/fathersday_01.jpeg" alt="Father's Day art 1" /></p>

<p><img src="/assets/fathersday02.jpeg" alt="Father's Day art 2" /></p>

<p><img src="/assets/fathersday03.jpeg" alt="Father's Day art 3" /></p>

<p><img src="/assets/fathersday04.jpeg" alt="Father's Day art 4" /></p>

<h2 id="next-meetup">Next meetup</h2>

<p>The next meetup, <a href="https://www.meetup.com/testival/events/313923239">Testival #80</a>, is scheduled for April 15, 2026, at Super, where they will talk about their QA process and how it is smartly supported by LLMs.</p>]]></content><author><name></name></author><category term="meetup" /><category term="testival" /></entry><entry><title type="html">The Heart of an LLM: Attention Mechanism in Elixir</title><link href="https://karlosmid.com/2026/03/coding-attention-mechanism-in-elixir/" rel="alternate" type="text/html" title="The Heart of an LLM: Attention Mechanism in Elixir" /><published>2026-03-28T09:00:00+00:00</published><updated>2026-03-28T09:00:00+00:00</updated><id>https://karlosmid.com/2026/03/coding-attention-mechanism-in-elixir</id><content type="html" xml:base="https://karlosmid.com/2026/03/coding-attention-mechanism-in-elixir/"><![CDATA[<h2 id="tldr">TL;DR</h2>

<p>This post is based on Chapter 3 of <em>Build a LLM from Scratch</em> by Sebastian Raschka, with one twist: all Python examples are rewritten in Elixir. We are building LLMs attention mechanism. Attention mechanism is second part of stage 1. We have already prepared input text data. Attention mechanism help LLM to predict next token. We will implement four attention mechanism:</p>
<ul>
  <li>simplified self-attention</li>
  <li>self-attention</li>
  <li>casual attention</li>
  <li>multi-head attention</li>
</ul>

<h2 id="the-problem-with-modeling-long-sequences">The problem with modeling long sequences</h2>

<p>The heart of an LLM is its attention mechanism. As an introduction, we should first go back in time, when there was no attention mechanism. Let’s try to translate Croatian to English word by word:</p>

<table>
  <tbody>
    <tr>
      <td>Karlo</td>
      <td>u</td>
      <td>potupnosti</td>
      <td>razumije</td>
      <td>kako</td>
      <td>napraviti</td>
      <td>LLM</td>
      <td>model</td>
      <td>iz</td>
      <td>nule</td>
    </tr>
    <tr>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
    </tr>
    <tr>
      <td>Karlo</td>
      <td>in</td>
      <td>completely</td>
      <td>understands</td>
      <td>how</td>
      <td>make</td>
      <td>LLM</td>
      <td>model</td>
      <td>from</td>
      <td>scratch</td>
    </tr>
  </tbody>
</table>

<p>The words are in the same order in Croatian and English, so that is not the issue. The issue is sentence context, where <code class="language-plaintext highlighter-rouge">u potpunosti</code> translates to one word, <code class="language-plaintext highlighter-rouge">completely</code>:</p>

<table>
  <tbody>
    <tr>
      <td>Karlo</td>
      <td>u potupnosti</td>
      <td>razumije</td>
      <td>kako</td>
      <td>napraviti</td>
      <td>LLM</td>
      <td>model</td>
      <td>iz</td>
      <td>nule</td>
      <td> </td>
    </tr>
    <tr>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
    </tr>
    <tr>
      <td>Karlo</td>
      <td>completely</td>
      <td>understands</td>
      <td>how</td>
      <td>to make</td>
      <td>LLM</td>
      <td>model</td>
      <td>from</td>
      <td>scratch</td>
      <td> </td>
    </tr>
  </tbody>
</table>

<p>Now let’s bring back neural networks with an encoder/decoder architecture, called an RNN (recurrent neural network). The encoder first processes the input text, and that output becomes the input for the decoder module. The encoder processes text token by token, and at each step it updates its hidden state. The hidden state holds input context. When it is done, the input and hidden state go to the decoder, which translates token by token using the encoder hidden state, and on each token it updates its own hidden state. This is the core idea; we do not need full RNN details here. We can think of the hidden state as encoding vectors from Chapter 2.<br />
Then the encoder/decoder architecture hit a problem that eventually gave birth to LLM transformers. The decoder has access only to the hidden state of the current input, and it does not know anything about previous inputs (sentences). The result is a loss of context during translation, especially in sentences where dependencies span long distances.</p>

<h2 id="capturing-data-dependencies-with-attention-mechanisms">Capturing data dependencies with attention mechanisms</h2>

<p>The shortcoming of the encoder/decoder architecture was that we had to store all encoded input in the hidden state. To overcome that, the Bahdanau attention mechanism was developed. The idea was to give the decoder selective access to the encoded input during the decoding step. Selective means that some encoded input tokens are more important for the current output token. This inspired the self-attention mechanism in transformer architecture. So, the self-attention mechanism helps an LLM determine which part of the input sequence is more relevant for a particular input position.</p>

<h2 id="attending-to-different-parts-of-the-input-with-self-attention">Attending to different parts of the input with self-attention</h2>

<p>This is one of the toughest parts of the book.</p>

<blockquote>
  <p>What does “self” in self-attention mean? It means that a transformer creates attention weights using only the input sequence; the input sequence attends to itself. In encoder-decoder architectures, the output depends on information from the input.</p>
</blockquote>

<p>Let’s break this down into chunks!</p>

<h3 id="a-simple-self-attention-mechanism">A simple self-attention mechanism</h3>

<p>First, we will use no trainable weights; this is our first simplification. Let’s work through this test:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="n">test</span> <span class="s2">"simple self-attention mechanism"</span> <span class="k">do</span>

    <span class="c1"># our input has 6 tokens Your journey starts with one step. Each token has 3 dimensions.</span>

    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="c1"># Your (x^1)</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="c1"># journey (x^2)</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="c1"># starts (x^3)</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="c1"># with (x^4)</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="c1"># one (x^5)</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="c1"># step (x^6)</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">}</span>

    <span class="c1"># In Nx we read this like this: slice along axis 0 (rows), starting from index 1 with length of 1</span>
    <span class="c1"># which gives us query x2: [0.55, 0.87, 0.66] </span>

    <span class="n">query</span> <span class="o">=</span>
      <span class="n">inputs</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">slice_along_axis</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="ss">axis:</span> <span class="mi">0</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">query</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">3</span><span class="p">}</span>

    <span class="c1"># we are doting {6, 3} dot {3}, condition is that we can only dot over axis that are same, so 3 with 3</span>
    <span class="c1"># so each row from inputs is multiplied with query. That means to multiply elements of corresponding columns and then sum those values</span>
    <span class="c1"># for first row: score_0 = 0.43*0.55 + 0.15*0.87 + 0.89*0.66</span>

    <span class="n">attn_scores_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">query</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>

    <span class="c1"># result is vector with 6 values</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_scores_2</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">}</span>

    <span class="n">expected_attn_scores_2</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="mf">0.9544000029563904</span><span class="p">,</span>
          <span class="mf">1.4950000047683716</span><span class="p">,</span>
          <span class="mf">1.4754000902175903</span><span class="p">,</span>
          <span class="mf">0.8434000015258789</span><span class="p">,</span>
          <span class="mf">0.7070000171661377</span><span class="p">,</span>
          <span class="mf">1.0865000486373901</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="ss">:f32</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_scores_2</span><span class="p">,</span> <span class="n">expected_attn_scores_2</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span>
             <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"Attention scores should match expected values exactly"</span>

    <span class="c1"># simple normalization is to divide each column element with sum of all row elements</span>
    
    <span class="n">attn_scores_2_normalized</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="n">attn_scores_2</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">attn_scores_2</span><span class="p">,</span> <span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">]))</span>

    <span class="n">expected_attn_scores_2_normalized</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="mf">0.14545010030269623</span><span class="p">,</span>
          <span class="mf">0.22783729434013367</span><span class="p">,</span>
          <span class="mf">0.22485026717185974</span><span class="p">,</span>
          <span class="mf">0.1285337507724762</span><span class="p">,</span>
          <span class="mf">0.10774646699428558</span><span class="p">,</span>
          <span class="mf">0.1655820906162262</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_scores_2_normalized</span><span class="p">,</span> <span class="n">expected_attn_scores_2_normalized</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"Normalized attention scores should match expected values exactly"</span>

    <span class="n">attn_scores_2_normalized_sum</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">attn_scores_2_normalized</span><span class="p">,</span> <span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>

    <span class="c1"># so when we sum normalized column elements, value is 1.0 (apparently LLM likes this normalization)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_scores_2_normalized_sum</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.0</span><span class="p">],</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}),</span>
             <span class="ss">atol:</span> <span class="mf">1.0e-6</span>
           <span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"Sum of normalized attention scores should be 1.0"</span>

    <span class="c1"># softmax naive (there is a real one), is advanced way of normalization for really big numbers</span>

    <span class="n">attn_scores_2_softmax_naive</span> <span class="o">=</span> <span class="n">softmax_naive</span><span class="p">(</span><span class="n">attn_scores_2</span><span class="p">)</span>

    <span class="n">expected_attn_scores_2_softmax</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="mf">0.13854756951332092</span><span class="p">,</span>
          <span class="mf">0.237891286611557</span><span class="p">,</span>
          <span class="mf">0.23327402770519257</span><span class="p">,</span>
          <span class="mf">0.12399158626794815</span><span class="p">,</span>
          <span class="mf">0.10818187147378922</span><span class="p">,</span>
          <span class="mf">0.15811361372470856</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span>
             <span class="n">attn_scores_2_softmax_naive</span><span class="p">,</span>
             <span class="n">expected_attn_scores_2_softmax</span><span class="p">,</span>
             <span class="ss">atol:</span> <span class="mf">1.0e-6</span>
           <span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"Softmax of attention scores should match expected values exactly"</span>

    <span class="c1"># this is Axon softmax with all optimizations</span>

    <span class="n">attn_scores_2_softmax_axon</span> <span class="o">=</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn_scores_2</span><span class="p">)</span>

    <span class="n">expected_attn_scores_2_softmax_axon</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="mf">0.13854756951332092</span><span class="p">,</span>
          <span class="mf">0.237891286611557</span><span class="p">,</span>
          <span class="mf">0.23327402770519257</span><span class="p">,</span>
          <span class="mf">0.12399158626794815</span><span class="p">,</span>
          <span class="mf">0.10818187147378922</span><span class="p">,</span>
          <span class="mf">0.15811361372470856</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_scores_2_softmax_axon</span><span class="p">,</span> <span class="n">expected_attn_scores_2_softmax_axon</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"Axon Softmax of attention scores should match expected values exactly"</span>

    <span class="c1"># {6} dot {6, 3} = {3}</span>
    <span class="c1"># here we multiply first row (note that there is only one row) with first column. And we repeat that for all columns (there are three)</span>
    <span class="c1"># remember that multiplication of two one dimensional vectors is multiplication of corresponding elements and then summing them</span>

    <span class="n">context_vec_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">attn_scores_2_softmax_axon</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">inputs</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>

    <span class="n">expected_context_vec_2</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span><span class="mf">0.4418657422065735</span><span class="p">,</span> <span class="mf">0.6514819860458374</span><span class="p">,</span> <span class="mf">0.5683088898658752</span><span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">context_vec_2</span><span class="p">,</span> <span class="n">expected_context_vec_2</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span>
             <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"Context vector should match expected values exactly"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">softmax_naive</span><span class="p">(%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{}</span> <span class="o">=</span> <span class="n">x</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">exp_x</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="n">exp_x</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">exp_x</span><span class="p">,</span> <span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
  <span class="k">end</span>
</code></pre></div></div>

<p><code class="language-plaintext highlighter-rouge">inputs</code> is fictional token embeddings vector, last what we did in post <a href="https://karlosmid.com/2026/01/build-llm-from-scratch-chapter-2-working-with-text-data/">Working with test data</a>. It contains vector for six tokens (six rows in inputs Tensor), with dimension of 3 (small enough for demonstraion purpose, but big number in practice). These are vectors x1 to x6. Each vector is embedding for a word, so we have six words with dimension of 3.
<code class="language-plaintext highlighter-rouge">query</code> is second row in <code class="language-plaintext highlighter-rouge">inputs</code>. It could be any row, but Sebastian pick second row for which we will calculate simple self attention.<br />
Then we calculate <code class="language-plaintext highlighter-rouge">dot</code> product of two vectors, <code class="language-plaintext highlighter-rouge">inputs</code> with dimension <code class="language-plaintext highlighter-rouge">{6,3}</code> and <code class="language-plaintext highlighter-rouge">query</code> with dimension <code class="language-plaintext highlighter-rouge">{3}</code>, result is tadda, simple attention Tensor with shape <code class="language-plaintext highlighter-rouge">{6}</code>. Each value is attention weight. But what is <code class="language-plaintext highlighter-rouge">dot</code> product?</p>

<p><img src="/assets/2026/03/matrix-dot-product-6x3-by-3x1.jpg" alt="ASCII diagram of 6x3 matrix dot product with 3x1 vector" />
<em>Matrix dot product: <code class="language-plaintext highlighter-rouge">A(6x3) * b(3x1) = c(6x1)</code>.</em><br />
Important condition for <code class="language-plaintext highlighter-rouge">dot</code> product is that number of columns from first Tensor must match to number of rows in second Tensor.</p>

<p>How to apply this to <code class="language-plaintext highlighter-rouge">Nx.dot</code> function:<br />
<code class="language-plaintext highlighter-rouge">attn_scores_2 = Nx.dot(inputs, [1], query, [0])</code><br />
Parameters are, first tensor, axis of first tensor, second tensor, axis of second tensor. So what is this axis? We need to direct <code class="language-plaintext highlighter-rouge">Nx.dot</code> over which axis to do the sum. In <code class="language-plaintext highlighter-rouge">{6,3}</code> axis are 0 and 1, in <code class="language-plaintext highlighter-rouge">{3}</code>, axis is just 0. So we tell that dot condition to be satisfied (same axis value), we are using 3 from first shape, and 3 from second shape.</p>

<p>As you will soon learn, <code class="language-plaintext highlighter-rouge">dot</code> product is the heart of attention mechanism! So all this GPUs are doing over and over, dot, dot, dot, …</p>

<p>But why <code class="language-plaintext highlighter-rouge">dot</code> product? Because result is the similarity of vectors that we <code class="language-plaintext highlighter-rouge">dot</code>, higher the number, gratter similarity. In words (as each vector is embedding of a word), that means how often those words come together in text.</p>

<p>We are preparing those Tensors for LLM, so we must to please him, It turns out to achieve LLM training stability, LLM likes weights to be &gt; 0 and &lt; 1, so the next step is weights normalizations. <code class="language-plaintext highlighter-rouge">attn_score_2_normalize</code> is that step in the code. In normalization we first sum all weights, and then each weight is divided with that sum. So when we sum weights, sum is 1.</p>

<p>But not so fast! Instead of this simple normalization we need to use <code class="language-plaintext highlighter-rouge">softmax</code> function for normalization. Why? Because it helps to taim big values. <code class="language-plaintext highlighter-rouge">softmax_naive</code> is just idea that could easily overflow or underflow, so <code class="language-plaintext highlighter-rouge">Axon.Activations.softmax</code> is well optimized <code class="language-plaintext highlighter-rouge">softmax</code> to handle overflows and underflows.</p>

<p>Last step is to calculate context vector for query token form the input. For that query we calculated attention weights. Now comes the hard part, to explain this <code class="language-plaintext highlighter-rouge">Nx.dot</code> calculation. We must do sum of multiplications over same axis. As first Tensor shape is <code class="language-plaintext highlighter-rouge">{6}</code> and second is <code class="language-plaintext highlighter-rouge">{6, 3}</code>, they match over first axis (index is 0).</p>

<h3 id="computing-attention-weights-for-all-input-tokens">Computing attention weights for all input tokens</h3>

<p>Now we will calculate the context vector for all input words, not just the second one. The code is very similar to the previous one, with a few Nx-specific twists:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"attention weights for all tokens"</span> <span class="k">do</span>
    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="c1"># Your (x^1)</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="c1"># journey (x^2)</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="c1"># starts (x^3)</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="c1"># with (x^4)</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="c1"># one (x^5)</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="c1"># step (x^6)</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="c1"># Scores for every query against every key: Q @ K^T</span>
    <span class="c1"># {6, 3} dot {6, 3} = {6, 3} dot {3, 6} (this is transponded second Tensor) = {6, 6}</span>
    <span class="c1"># result in Tensor {6 x 6} where first row is attention of first token with other tokens, and so on for second, third, ...</span>
    <span class="c1"># element {1, 1} is specific because this is attention of token with himself (I appologize for presonelizing token, but I am so excited that I started to understand those calculations!)</span>

    <span class="n">attn_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">inputs</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_scores</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">}</span>

    <span class="n">expected_attn_scores</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="mf">0.9994999766349792</span><span class="p">,</span>
            <span class="mf">0.9544000029563904</span><span class="p">,</span>
            <span class="mf">0.9422000050544739</span><span class="p">,</span>
            <span class="mf">0.47530001401901245</span><span class="p">,</span>
            <span class="mf">0.4575999975204468</span><span class="p">,</span>
            <span class="mf">0.6309999823570251</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.9544000029563904</span><span class="p">,</span>
            <span class="mf">1.4950000047683716</span><span class="p">,</span>
            <span class="mf">1.4754000902175903</span><span class="p">,</span>
            <span class="mf">0.8434000015258789</span><span class="p">,</span>
            <span class="mf">0.7070000171661377</span><span class="p">,</span>
            <span class="mf">1.0865000486373901</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.9422000050544739</span><span class="p">,</span>
            <span class="mf">1.4754000902175903</span><span class="p">,</span>
            <span class="mf">1.4570000171661377</span><span class="p">,</span>
            <span class="mf">0.8295999765396118</span><span class="p">,</span>
            <span class="mf">0.715399980545044</span><span class="p">,</span>
            <span class="mf">1.0605000257492065</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.47530001401901245</span><span class="p">,</span>
            <span class="mf">0.8434000015258789</span><span class="p">,</span>
            <span class="mf">0.8295999765396118</span><span class="p">,</span>
            <span class="mf">0.4936999976634979</span><span class="p">,</span>
            <span class="mf">0.3473999798297882</span><span class="p">,</span>
            <span class="mf">0.656499981880188</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.4575999975204468</span><span class="p">,</span>
            <span class="mf">0.7070000171661377</span><span class="p">,</span>
            <span class="mf">0.715399980545044</span><span class="p">,</span>
            <span class="mf">0.3473999798297882</span><span class="p">,</span>
            <span class="mf">0.665399968624115</span><span class="p">,</span>
            <span class="mf">0.29350000619888306</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.6309999823570251</span><span class="p">,</span>
            <span class="mf">1.0865000486373901</span><span class="p">,</span>
            <span class="mf">1.0605000257492065</span><span class="p">,</span>
            <span class="mf">0.656499981880188</span><span class="p">,</span>
            <span class="mf">0.29350000619888306</span><span class="p">,</span>
            <span class="mf">0.9450000524520874</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_scores</span><span class="p">,</span> <span class="n">expected_attn_scores</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"Attention scores should match expected values exactly"</span>

    <span class="c1"># Row-wise softmax gives attention weights per query token</span>
    <span class="c1"># do softmax over columns (row by row), which means use last (second) axis, and that is column</span>

    <span class="n">attn_weights_softmax</span> <span class="o">=</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn_scores</span><span class="p">,</span> <span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">expected_attn_weights_softmax</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="mf">0.2098347693681717</span><span class="p">,</span>
            <span class="mf">0.20058146119117737</span><span class="p">,</span>
            <span class="mf">0.19814923405647278</span><span class="p">,</span>
            <span class="mf">0.12422822415828705</span><span class="p">,</span>
            <span class="mf">0.12204873561859131</span><span class="p">,</span>
            <span class="mf">0.14515765011310577</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.13854758441448212</span><span class="p">,</span>
            <span class="mf">0.237891286611557</span><span class="p">,</span>
            <span class="mf">0.23327404260635376</span><span class="p">,</span>
            <span class="mf">0.12399159371852875</span><span class="p">,</span>
            <span class="mf">0.10818187147378922</span><span class="p">,</span>
            <span class="mf">0.15811361372470856</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.1390075981616974</span><span class="p">,</span>
            <span class="mf">0.23692145943641663</span><span class="p">,</span>
            <span class="mf">0.23260195553302765</span><span class="p">,</span>
            <span class="mf">0.12420440465211868</span><span class="p">,</span>
            <span class="mf">0.11080020666122437</span><span class="p">,</span>
            <span class="mf">0.15646442770957947</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.14352688193321228</span><span class="p">,</span>
            <span class="mf">0.20739442110061646</span><span class="p">,</span>
            <span class="mf">0.20455202460289001</span><span class="p">,</span>
            <span class="mf">0.14619223773479462</span><span class="p">,</span>
            <span class="mf">0.12629525363445282</span><span class="p">,</span>
            <span class="mf">0.172039195895195</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.15261085331439972</span><span class="p">,</span>
            <span class="mf">0.19583867490291595</span><span class="p">,</span>
            <span class="mf">0.1974906474351883</span><span class="p">,</span>
            <span class="mf">0.13668666779994965</span><span class="p">,</span>
            <span class="mf">0.18785890936851501</span><span class="p">,</span>
            <span class="mf">0.12951429188251495</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.13847115635871887</span><span class="p">,</span>
            <span class="mf">0.2183637171983719</span><span class="p">,</span>
            <span class="mf">0.21275943517684937</span><span class="p">,</span>
            <span class="mf">0.14204756915569305</span><span class="p">,</span>
            <span class="mf">0.09880637377500534</span><span class="p">,</span>
            <span class="mf">0.18955175578594208</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_weights_softmax</span><span class="p">,</span> <span class="n">expected_attn_weights_softmax</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span>
             <span class="mi">1</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_weights_softmax</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">}</span>

    <span class="c1"># [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]</span>

    <span class="n">row_sums</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">attn_weights_softmax</span><span class="p">,</span> <span class="ss">axes:</span> <span class="p">[</span><span class="mi">1</span><span class="p">])</span>

    <span class="n">ones</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}),</span> <span class="p">{</span><span class="mi">6</span><span class="p">})</span>

    <span class="c1">#belive it or not, but sum of each row (over columns) is 1!</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">row_sums</span><span class="p">,</span> <span class="n">ones</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span>

    <span class="c1"># {6, 6} dot {6, 3} = {6, 3}</span>
    <span class="c1"># result is context vector for all tokens!</span>

    <span class="n">context_vecs</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">attn_weights_softmax</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">inputs</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">expected_context_vecs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.4420594274997711</span><span class="p">,</span> <span class="mf">0.5930986404418945</span><span class="p">,</span> <span class="mf">0.5789890885353088</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.4418657422065735</span><span class="p">,</span> <span class="mf">0.6514819860458374</span><span class="p">,</span> <span class="mf">0.5683088898658752</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.4431275427341461</span><span class="p">,</span> <span class="mf">0.6495946049690247</span><span class="p">,</span> <span class="mf">0.5670731067657471</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.43038973212242126</span><span class="p">,</span> <span class="mf">0.6298280954360962</span><span class="p">,</span> <span class="mf">0.5510270595550537</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.4671017527580261</span><span class="p">,</span> <span class="mf">0.5909927487373352</span><span class="p">,</span> <span class="mf">0.5265965461730957</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.41772449016571045</span><span class="p">,</span> <span class="mf">0.650323212146759</span><span class="p">,</span> <span class="mf">0.5645352005958557</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">context_vecs</span><span class="p">,</span> <span class="n">expected_context_vecs</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span>
             <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"Context vectors should match expected values exactly"</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">context_vecs</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">}</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>We start with the same token embeddings for six words in three dimensions. <code class="language-plaintext highlighter-rouge">Nx.dot</code> is a powerful function, so there is no need for any <code class="language-plaintext highlighter-rouge">Enum</code> looping. With this:</p>

<p><code class="language-plaintext highlighter-rouge">attn_scores = Nx.dot(inputs, [1], inputs, [1])</code><br />
<code class="language-plaintext highlighter-rouge">assert Nx.shape(attn_scores) == {6, 6}</code></p>

<p>We multiply <code class="language-plaintext highlighter-rouge">inputs</code> by itself by connecting both tensors over columns. Remember that the number of columns in the first tensor must be equal to the number of rows in the second one.</p>

<p>Then we calculate the <code class="language-plaintext highlighter-rouge">softmax</code>:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>`attn_weights_softmax = Axon.Activations.softmax(attn_scores, axis: -1)`
</code></pre></div></div>

<p>Here we tell softmax to sum over the last axis. For shape <code class="language-plaintext highlighter-rouge">{6, 6}</code>, that means each row is normalized across columns.</p>

<p>To calculate the context vector:</p>

<p><code class="language-plaintext highlighter-rouge">context_vecs = Nx.dot(attn_weights_softmax, [1], inputs, [0])</code></p>

<p>We compute the dot product of weights with shape <code class="language-plaintext highlighter-rouge">{6, 6}</code> and inputs with shape <code class="language-plaintext highlighter-rouge">{6, 3}</code> over the second and first axes (columns and rows). You can check that row 2 in <code class="language-plaintext highlighter-rouge">expected_context_vecs</code> is the same as <code class="language-plaintext highlighter-rouge">context_vecs</code> from the previous section.</p>

<h2 id="implementing-self-attention-with-trainable-weights">Implementing self-attention with trainable weights</h2>

<p>OK, let’s complicate things. We will implement scaled dot-product attention (this sounds scary, but we are building on top of a simple self-attention mechanism). How? We will put trainable weights into the mix.</p>

<h3 id="computing-the-attention-weights-in-a-new-kids-on-the-block-way-step-by-step">Computing the attention weights in a New Kids on the Block way, <a href="https://www.youtube.com/watch?v=VbxJv8MGCbc&amp;list=RDVbxJv8MGCbc&amp;start_radio=1">step by step</a></h3>

<p>As New Kids on the Block got their girls step by step, we will use the same approach for the self-attention mechanism.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"self-attention mechanism with trainable weights"</span> <span class="k">do</span>
    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="c1"># Your (x^1)</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="c1"># journey (x^2)</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="c1"># starts (x^3)</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="c1"># with (x^4)</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="c1"># one (x^5)</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="c1"># step (x^6)</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="c1"># we already know this Nx trick how to get second row of Tensor with dimension {6, 3}</span>

    <span class="n">x_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">slice_along_axis</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="ss">axis:</span> <span class="mi">0</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>

    <span class="c1"># this is first dimension of x_2, = 3</span>
    <span class="n">d_in</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x_2</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="n">elem</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

    <span class="c1"># this is Sebastian's simplificaiton, so we could easier follow the computation.</span>
    <span class="c1"># in real life, d_out = d_in</span>

    <span class="n">d_out</span> <span class="o">=</span> <span class="mi">2</span>
    <span class="c1"># random key for linear gneration of weights</span>

    <span class="n">key</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">123</span><span class="p">)</span>
    <span class="n">query_weights</span> <span class="o">=</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Initializers</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="ss">scale:</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="p">({</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">},</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">},</span> <span class="n">key</span><span class="p">)</span>
    <span class="n">key_weights</span> <span class="o">=</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Initializers</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="ss">scale:</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="p">({</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">},</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">},</span> <span class="n">key</span><span class="p">)</span>
    <span class="n">value_weights</span> <span class="o">=</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Initializers</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="ss">scale:</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="p">({</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">},</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">},</span> <span class="n">key</span><span class="p">)</span>

    <span class="c1"># 1x3 dot 3x2 = 1x2</span>
    <span class="c1"># when dimension rule is satisfied, no of columns = no of rows, no need to state this in Nx.dot as Nx.dot(x_2, [1], query_weights, [0])</span>
    <span class="c1"># we can say that we are projecting token from dimension 3 =&gt; 2</span>

    <span class="n">query_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">x_2</span><span class="p">,</span> <span class="n">query_weights</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">query_2</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="n">d_out</span><span class="p">}</span>
    <span class="n">expected_query_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">0.20726783573627472</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.3094936013221741</span><span class="p">],</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>
    <span class="n">expected_query_2_book</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.4306</span><span class="p">,</span> <span class="mf">1.4551</span><span class="p">],</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">query_2</span><span class="p">,</span> <span class="n">expected_query_2</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"query_2 should match expected values exactly"</span>

    <span class="c1"># if you are also following examples from Sebastian Python examples, numbers are different</span>
    <span class="c1"># reson is different implementation of Pythorch and Nx random generator</span>

    <span class="n">refute</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">query_2</span><span class="p">,</span> <span class="n">expected_query_2_book</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"query_2 should not match expected values exactly due to different random number generators in PyTorch and Nx"</span>

    <span class="n">key_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">x_2</span><span class="p">,</span> <span class="n">key_weights</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">key_2</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="n">d_out</span><span class="p">}</span>

    <span class="n">value_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">x_2</span><span class="p">,</span> <span class="n">value_weights</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">value_2</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="n">d_out</span><span class="p">}</span>

    <span class="c1"># lets calculate all keys and values</span>
    <span class="c1"># 6x3 dot 3x2 = 6x2</span>
    <span class="n">keys</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">key_weights</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">keys</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="n">d_out</span><span class="p">}</span>

    <span class="c1"># 6x3 dot 3x2 = 6x2</span>
    <span class="n">values</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">value_weights</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">values</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="n">d_out</span><span class="p">}</span>

    <span class="n">keys_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">slice_along_axis</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="ss">axis:</span> <span class="mi">0</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="c1"># 1x2 dot 2x1 = 1x1</span>
    <span class="n">attn_scores_22</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">query_2</span><span class="p">,</span> <span class="n">keys_2</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_scores_22</span><span class="p">)</span> <span class="o">==</span> <span class="p">{}</span>
    <span class="n">expected_attn_scores_22</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.1387462466955185</span><span class="p">],</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_scores_22</span><span class="p">,</span> <span class="n">expected_attn_scores_22</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span>
             <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"attn_scores_22 should match expected values exactly"</span>

    <span class="n">expected_attn_scores_22_book</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.8524</span><span class="p">],</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="n">refute</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_scores_22</span><span class="p">,</span> <span class="n">expected_attn_scores_22_book</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"attn_scores_22 should not match expected values exactly due to different random number generators in PyTorch and Nx"</span>

    <span class="c1"># {2} dot {6,2} over feature dim -&gt; {6}</span>
    <span class="n">attn_scores_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">query_2</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">keys</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_scores_2</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">}</span>

    <span class="n">expected_attn_scores_2</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="o">-</span><span class="mf">0.17058327794075012</span><span class="p">,</span>
          <span class="mf">0.1387462466955185</span><span class="p">,</span>
          <span class="mf">0.14079777896404266</span><span class="p">,</span>
          <span class="mf">0.10855000466108322</span><span class="p">,</span>
          <span class="mf">0.13786746561527252</span><span class="p">,</span>
          <span class="mf">0.09178745746612549</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_scores_2</span><span class="p">,</span> <span class="n">expected_attn_scores_2</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"attn_scores_2 should match expected values"</span>

    <span class="c1"># for normalization we are using last (second) axis of keys</span>

    <span class="n">d_k</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">axis_size</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="n">attn_weights_2</span> <span class="o">=</span>
      <span class="n">attn_scores_2</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d_k</span><span class="p">))</span>
      <span class="o">|&gt;</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_weights_2</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">}</span>

    <span class="n">expected_attn_weights_2</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="mf">0.1397317796945572</span><span class="p">,</span>
          <span class="mf">0.17389537394046783</span><span class="p">,</span>
          <span class="mf">0.1741478145122528</span><span class="p">,</span>
          <span class="mf">0.1702217161655426</span><span class="p">,</span>
          <span class="mf">0.17378734052181244</span><span class="p">,</span>
          <span class="mf">0.16821600496768951</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_weights_2</span><span class="p">,</span> <span class="n">expected_attn_weights_2</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"attn_weights_2 should match expected values"</span>

    <span class="c1"># {6} dot {6, 2} = {2}</span>
    <span class="n">context_vec_2</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">attn_weights_2</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">values</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">context_vec_2</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="n">d_out</span><span class="p">}</span>

    <span class="n">expected_context_vec_2</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span><span class="o">-</span><span class="mf">0.11537063866853714</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.18990936875343323</span><span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">context_vec_2</span><span class="p">,</span> <span class="n">expected_context_vec_2</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"context_vec_2 should match expected values"</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>We start with the same <code class="language-plaintext highlighter-rouge">inputs</code> of six words with shape <code class="language-plaintext highlighter-rouge">{6,3}</code>. Then we introduce the three kings of LLMs: query, key, and value trainable matrices. Let’s use some scary math notation: Wq, Wk, and Wv. Again, Sebastian picks x2 from inputs, but that could be any of x1..x6. We set the dimensions of the Wq, Wk, and Wv matrices to <code class="language-plaintext highlighter-rouge">{3, 2}</code>. Three is the second axis of <code class="language-plaintext highlighter-rouge">inputs</code>, and 2 is used to make the computation easier to follow. In the real world, we would likely have <code class="language-plaintext highlighter-rouge">{3,3}</code>.</p>

<p>The next step is to fill Wq, Wk, and Wv with random values with the help of the <code class="language-plaintext highlighter-rouge">Axon.Initializers.uniform</code> function, and to compute <code class="language-plaintext highlighter-rouge">query_2</code>, <code class="language-plaintext highlighter-rouge">key_2</code>, and <code class="language-plaintext highlighter-rouge">value_2</code> for <code class="language-plaintext highlighter-rouge">x_2</code> by taking the <code class="language-plaintext highlighter-rouge">dot</code> of <code class="language-plaintext highlighter-rouge">x_2</code> and the corresponding weight matrices. The results are vectors of dimension <code class="language-plaintext highlighter-rouge">{2}</code>.</p>

<p>Sebastian explains the difference between weights and attention weights. Weight matrices are static values learned during model training, while attention weights are dynamic, context-specific values.</p>

<p>Using the <code class="language-plaintext highlighter-rouge">dot</code> product, we can also calculate <code class="language-plaintext highlighter-rouge">keys</code> and <code class="language-plaintext highlighter-rouge">values</code> by dotting <code class="language-plaintext highlighter-rouge">inputs</code> with Wk and Wv.</p>

<p>Now it’s time for <code class="language-plaintext highlighter-rouge">attention_score_22</code>, where we dot <code class="language-plaintext highlighter-rouge">query_2</code> and <code class="language-plaintext highlighter-rouge">key_2</code>, and for <code class="language-plaintext highlighter-rouge">attention_score_2</code>, where we dot <code class="language-plaintext highlighter-rouge">query_2</code> and <code class="language-plaintext highlighter-rouge">keys</code>. We have prepared everything to calculate <code class="language-plaintext highlighter-rouge">attn_weights_2</code>. We first need to scale <code class="language-plaintext highlighter-rouge">attention_score_2</code> by dividing each vector element by <code class="language-plaintext highlighter-rouge">Nx.sqrt</code> of the last dimension of <code class="language-plaintext highlighter-rouge">keys</code> <code class="language-plaintext highlighter-rouge">{3,2}</code>, which is 2. Then we use the <code class="language-plaintext highlighter-rouge">softmax</code> function over the last dimension. Why are we doing this? To avoid small gradients, and we will learn that small gradients are the worst enemies of LLMs. This is called a <code class="language-plaintext highlighter-rouge">scaled-dot product</code>.</p>

<p>Finally, we calculate <code class="language-plaintext highlighter-rouge">context_vector_2</code> (for the x2 token) with <code class="language-plaintext highlighter-rouge">Nx.dot(attn_weights_2, [0], values, [0])</code>, which dots <code class="language-plaintext highlighter-rouge">attn_weights_2</code> and <code class="language-plaintext highlighter-rouge">values</code>. Note that we need to set the axes explicitly because we are dotting dimensions <code class="language-plaintext highlighter-rouge">{6}</code> and <code class="language-plaintext highlighter-rouge">{6,2}</code>. We use <code class="language-plaintext highlighter-rouge">[0]</code> and <code class="language-plaintext highlighter-rouge">[0]</code> because those are compatible axes, 6 and 6.</p>

<p>Why do we use the terms query, key, and value? The analogy comes from data retrieval: query is the database search, key is the database index, and value is the actual stored value. So we search for a query token from the input; in LLMs, token values are stored with their corresponding keys (remember that this is an index in the token vocabulary).</p>

<h3 id="implementing-compact-self-attention-modules">Implementing compact self-attention modules</h3>

<p>In this section, we will consolidate code from the previous test, as we will obviously need the self-attention mechanism later in this book. The book uses a class; we will do it via Elixir modules. Version v1 is the same as what we did in the test “self-attention mechanism with trainable weights.” The only difference is that now we calculate the full context vector for all input tokens, so we get a tensor with shape <code class="language-plaintext highlighter-rouge">{6,2}</code>:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV1</span> <span class="k">do</span>
  <span class="nv">@moduledoc</span> <span class="sd">"""
  Simple self-attention module with trainable projection weights.

  API:

    * `new/2` and `new/3` - initialize `w_q`, `w_k`, and `w_v`
    * `forward/2` - compute full context vectors for all tokens
  """</span>

  <span class="k">defstruct</span> <span class="p">[</span><span class="ss">:w_q</span><span class="p">,</span> <span class="ss">:w_k</span><span class="p">,</span> <span class="ss">:w_v</span><span class="p">,</span> <span class="ss">:d_in</span><span class="p">,</span> <span class="ss">:d_out</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">]</span>

  <span class="nv">@type</span> <span class="n">t</span> <span class="p">::</span> <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
          <span class="ss">w_q:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span>
          <span class="ss">w_k:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span>
          <span class="ss">w_v:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span>
          <span class="ss">d_in:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">d_out:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">seed:</span> <span class="n">integer</span><span class="p">()</span> <span class="o">|</span> <span class="no">nil</span>
        <span class="p">}</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a self-attention module with randomly initialized projection weights.

  ## Arguments

    * `d_in` - input feature size.
    * `d_out` - projection/output feature size.

  ## Returns

    * `%LlmScratch.SelfAttentionV1{}` with initialized `w_q`, `w_k`, and `w_v`.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="p">[])</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">keyword</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a self-attention module with optional `:seed` and optional manual
  projection weights (`:w_q`, `:w_k`, `:w_v`).

  ## Arguments

    * `d_in` - input feature size.
    * `d_out` - projection/output feature size.
    * `opts` - keyword options:
      `:seed`, `:w_q`, `:w_k`, `:w_v`.

  ## Returns

    * `%LlmScratch.SelfAttentionV1{}` with initialized or provided projections.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">opts</span><span class="p">)</span>
      <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_in</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_in</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_out</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_out</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">do</span>
    <span class="n">seed</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">)</span>

    <span class="n">key</span> <span class="o">=</span>
      <span class="k">case</span> <span class="n">seed</span> <span class="k">do</span>
        <span class="no">nil</span> <span class="o">-&gt;</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="no">System</span><span class="o">.</span><span class="n">unique_integer</span><span class="p">([</span><span class="ss">:positive</span><span class="p">]))</span>
        <span class="n">int</span> <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">int</span><span class="p">)</span>
        <span class="n">other</span> <span class="o">-&gt;</span> <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"seed must be an integer or nil, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">other</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
      <span class="k">end</span>

    <span class="c1"># in v1 we uniformly initialize weight tensor</span>

    <span class="n">w_q</span> <span class="o">=</span> <span class="n">init_weight</span><span class="p">({</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">},</span> <span class="n">key</span><span class="p">,</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:w_q</span><span class="p">))</span>
    <span class="n">w_k</span> <span class="o">=</span> <span class="n">init_weight</span><span class="p">({</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">},</span> <span class="n">key</span><span class="p">,</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:w_k</span><span class="p">))</span>
    <span class="n">w_v</span> <span class="o">=</span> <span class="n">init_weight</span><span class="p">({</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">},</span> <span class="n">key</span><span class="p">,</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:w_v</span><span class="p">))</span>

    <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span><span class="ss">w_q:</span> <span class="n">w_q</span><span class="p">,</span> <span class="ss">w_k:</span> <span class="n">w_k</span><span class="p">,</span> <span class="ss">w_v:</span> <span class="n">w_v</span><span class="p">,</span> <span class="ss">d_in:</span> <span class="n">d_in</span><span class="p">,</span> <span class="ss">d_out:</span> <span class="n">d_out</span><span class="p">,</span> <span class="ss">seed:</span> <span class="n">seed</span><span class="p">}</span>
  <span class="k">end</span>

  <span class="nv">@spec</span> <span class="n">forward</span><span class="p">(</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">())</span> <span class="p">::</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Computes context vectors for all tokens in `inputs`.

  ## Arguments

    * `sa` - `%LlmScratch.SelfAttentionV1{}` module state.
    * `inputs` - tensor of shape `{num_tokens, d_in}`.

  ## Returns

    * context tensor of shape `{num_tokens, d_out}`.
  """</span>
  <span class="k">def</span> <span class="n">forward</span><span class="p">(%</span><span class="bp">__MODULE__</span><span class="p">{}</span> <span class="o">=</span> <span class="n">sa</span><span class="p">,</span> <span class="p">%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{}</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">validate_input_shape!</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">d_in</span><span class="p">)</span>

    <span class="c1"># calculate query, key and values based on inputs and weight Tensors</span>
    <span class="c1"># {no_of_tokens x token_dimension} dot {d_in, d_out} = {no_of_tokens, d_out}</span>
    <span class="c1"># now you know why we validate_input_shape that token_dimension == d_in</span>

    <span class="n">q</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">w_q</span><span class="p">)</span>
    <span class="n">k</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">w_k</span><span class="p">)</span>
    <span class="n">v</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">w_v</span><span class="p">)</span>

    <span class="c1"># result is context vector {num_of_tokens x d_out}</span>

    <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionCore</span><span class="o">.</span><span class="n">context_from_qkv</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">d_out</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">init_weight</span><span class="p">(</span><span class="n">expected_shape</span><span class="p">,</span> <span class="n">_key</span><span class="p">,</span> <span class="p">%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{}</span> <span class="o">=</span> <span class="n">provided_weight</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">if</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">provided_weight</span><span class="p">)</span> <span class="o">!=</span> <span class="n">expected_shape</span> <span class="k">do</span>
      <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
            <span class="s2">"expected weight shape </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">expected_shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">provided_weight</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
    <span class="k">end</span>

    <span class="no">Nx</span><span class="o">.</span><span class="n">as_type</span><span class="p">(</span><span class="n">provided_weight</span><span class="p">,</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">init_weight</span><span class="p">(</span><span class="n">expected_shape</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="no">nil</span><span class="p">)</span> <span class="k">do</span>
    <span class="no">Axon</span><span class="o">.</span><span class="no">Initializers</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="ss">scale:</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="p">(</span><span class="n">expected_shape</span><span class="p">,</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">},</span> <span class="n">key</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">init_weight</span><span class="p">(</span><span class="n">_expected_shape</span><span class="p">,</span> <span class="n">_key</span><span class="p">,</span> <span class="n">other</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"weight must be an Nx.Tensor or nil, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">other</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">validate_input_shape!</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">expected_d_in</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">case</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">do</span>
      <span class="p">{</span><span class="n">_</span><span class="p">,</span> <span class="o">^</span><span class="n">expected_d_in</span><span class="p">}</span> <span class="o">-&gt;</span>
        <span class="ss">:ok</span>

      <span class="n">shape</span> <span class="o">-&gt;</span>
        <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
              <span class="s2">"expected inputs shape {num_tokens, </span><span class="si">#{</span><span class="n">expected_d_in</span><span class="si">}</span><span class="s2">}, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
    <span class="k">end</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionCore</span> <span class="k">do</span>
  <span class="nv">@moduledoc</span> <span class="sd">"""
  has function that is shared in LlmScratch.SelfAttentionV1 and LlmScratch.SelfAttentionV2

  """</span>

  <span class="nv">@spec</span> <span class="n">context_from_qkv</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">())</span> <span class="p">::</span>
          <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Computes self-attention context vectors from projected query, key, and value
  tensors.

  ## Arguments

    * `q` - query tensor of shape `{num_tokens, d_out}`.
    * `k` - key tensor of shape `{num_tokens, d_out}`.
    * `v` - value tensor of shape `{num_tokens, d_out}`.
    * `d_out` - feature size used for score scaling by `sqrt(d_out)`.

  ## Returns

    * context tensor of shape `{num_tokens, d_out}`.
  """</span>
  <span class="k">def</span> <span class="n">context_from_qkv</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">d_out</span><span class="p">)</span> <span class="k">do</span>

    <span class="c1"># {no_of_tokens, d_out} dot {no_of_tokens, d_out} = {no_of_tokens, no_of_tokens}</span>
    <span class="c1"># dot over second dimension d_out</span>
    
    <span class="n">attn_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">k</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    <span class="n">d_k</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">d_out</span><span class="p">,</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="n">attn_weights</span> <span class="o">=</span>
      <span class="n">attn_scores</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d_k</span><span class="p">))</span>
      <span class="o">|&gt;</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># {no_of_tokens, no_of_tokens} dot {no_of_tokens, d_out} = {no_of_tokens, d_out}</span>
    <span class="c1"># because we are dotting row (over second dimension) with column (over first dimension)</span>

    <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">v</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<p>We have <code class="language-plaintext highlighter-rouge">SelfAttentionCore</code> because the <code class="language-plaintext highlighter-rouge">context_from_qkv</code> function is the same for <code class="language-plaintext highlighter-rouge">SelfAttention V1 and V2</code>. V2 differs by using <code class="language-plaintext highlighter-rouge">torch.nn.Linear</code>-style tensor initialization for weights. Why? Because it is optimized and leads to more stable and effective model training:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span> <span class="k">do</span>
  <span class="nv">@moduledoc</span> <span class="sd">"""
  Self-attention module initialized from Axon dense layers.

  API:

    * `new/2` and `new/3` - initialize `w_q`, `w_k`, and `w_v`
      using `Axon.input/2 |&gt; Axon.dense/3`
    * `forward/2` - compute full context vectors for all tokens
  """</span>

  <span class="k">defstruct</span> <span class="p">[</span><span class="ss">:w_q</span><span class="p">,</span> <span class="ss">:w_k</span><span class="p">,</span> <span class="ss">:w_v</span><span class="p">,</span> <span class="ss">:d_in</span><span class="p">,</span> <span class="ss">:d_out</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">]</span>

  <span class="nv">@type</span> <span class="n">dense_weights</span> <span class="p">::</span> <span class="p">%{</span><span class="ss">kernel:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="ss">bias:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()}</span>

  <span class="nv">@type</span> <span class="n">t</span> <span class="p">::</span> <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
          <span class="ss">w_q:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">w_k:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">w_v:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">d_in:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">d_out:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">seed:</span> <span class="n">integer</span><span class="p">()</span>
        <span class="p">}</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a self-attention module whose projection weights are initialized via
  Axon dense layers.

  ## Arguments

    * `d_in` - input feature size.
    * `d_out` - projection/output feature size.

  ## Returns

    * `%LlmScratch.SelfAttentionV2{}` with dense-initialized `w_q`, `w_k`, and `w_v`.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="p">[])</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">keyword</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a self-attention module with Axon dense-initialized projections.

  ## Arguments

    * `d_in` - input feature size.
    * `d_out` - projection/output feature size.
    * `opts` - keyword options:
      `:seed` (optional, deterministic initialization).

  ## Returns

    * `%LlmScratch.SelfAttentionV2{}` with dense-initialized projections.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">opts</span><span class="p">)</span>
      <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_in</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_in</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_out</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_out</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">do</span>
    <span class="n">seed</span> <span class="o">=</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">))</span>

    <span class="n">w_q</span> <span class="o">=</span> <span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="s2">"q_proj"</span><span class="p">)</span>
    <span class="n">w_k</span> <span class="o">=</span> <span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="s2">"k_proj"</span><span class="p">)</span>
    <span class="n">w_v</span> <span class="o">=</span> <span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="s2">"v_proj"</span><span class="p">)</span>

    <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span><span class="ss">w_q:</span> <span class="n">w_q</span><span class="p">,</span> <span class="ss">w_k:</span> <span class="n">w_k</span><span class="p">,</span> <span class="ss">w_v:</span> <span class="n">w_v</span><span class="p">,</span> <span class="ss">d_in:</span> <span class="n">d_in</span><span class="p">,</span> <span class="ss">d_out:</span> <span class="n">d_out</span><span class="p">,</span> <span class="ss">seed:</span> <span class="n">seed</span><span class="p">}</span>
  <span class="k">end</span>

  <span class="nv">@spec</span> <span class="n">forward</span><span class="p">(</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">())</span> <span class="p">::</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Computes context vectors for all tokens in `inputs`.

  ## Arguments

    * `sa` - `%LlmScratch.SelfAttentionV2{}` module state.
    * `inputs` - tensor of shape `{num_tokens, d_in}`.

  ## Returns

    * context tensor of shape `{num_tokens, d_out}`.
  """</span>
  <span class="k">def</span> <span class="n">forward</span><span class="p">(%</span><span class="bp">__MODULE__</span><span class="p">{}</span> <span class="o">=</span> <span class="n">sa</span><span class="p">,</span> <span class="p">%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{}</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">validate_input_shape!</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">d_in</span><span class="p">)</span>

    <span class="n">q</span> <span class="o">=</span> <span class="n">dense_project</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">w_q</span><span class="p">)</span>
    <span class="n">k</span> <span class="o">=</span> <span class="n">dense_project</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">w_k</span><span class="p">)</span>
    <span class="n">v</span> <span class="o">=</span> <span class="n">dense_project</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">w_v</span><span class="p">)</span>
    <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionCore</span><span class="o">.</span><span class="n">context_from_qkv</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">sa</span><span class="o">.</span><span class="n">d_out</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">layer_name</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">model</span> <span class="o">=</span>
      <span class="no">Axon</span><span class="o">.</span><span class="n">input</span><span class="p">(</span><span class="s2">"input"</span><span class="p">,</span> <span class="ss">shape:</span> <span class="p">{</span><span class="no">nil</span><span class="p">,</span> <span class="n">d_in</span><span class="p">})</span>
      <span class="o">|&gt;</span> <span class="no">Axon</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="n">d_out</span><span class="p">,</span> <span class="ss">use_bias:</span> <span class="no">true</span><span class="p">,</span> <span class="ss">name:</span> <span class="n">layer_name</span><span class="p">)</span>

    <span class="p">{</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">_predict_fn</span><span class="p">}</span> <span class="o">=</span> <span class="no">Axon</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="ss">seed:</span> <span class="n">seed</span><span class="p">)</span>
    <span class="n">params</span> <span class="o">=</span> <span class="n">init_fn</span><span class="o">.</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">template</span><span class="p">({</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_in</span><span class="p">},</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}),</span> <span class="no">Axon</span><span class="o">.</span><span class="no">ModelState</span><span class="o">.</span><span class="n">empty</span><span class="p">())</span>
    <span class="n">extract_dense_weights!</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">layer_name</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">extract_dense_weights!</span><span class="p">(%</span><span class="no">Axon</span><span class="o">.</span><span class="no">ModelState</span><span class="p">{}</span> <span class="o">=</span> <span class="n">params</span><span class="p">,</span> <span class="n">layer_name</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">params</span>
    <span class="o">|&gt;</span> <span class="no">Axon</span><span class="o">.</span><span class="no">ModelState</span><span class="o">.</span><span class="n">trainable_parameters</span><span class="p">()</span>
    <span class="o">|&gt;</span> <span class="n">extract_dense_weights!</span><span class="p">(</span><span class="n">layer_name</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">extract_dense_weights!</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">layer_name</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_map</span><span class="p">(</span><span class="n">params</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">layer_params</span> <span class="o">=</span>
      <span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">layer_name</span><span class="p">)</span> <span class="o">||</span>
        <span class="no">Enum</span><span class="o">.</span><span class="n">find_value</span><span class="p">(</span><span class="no">Map</span><span class="o">.</span><span class="n">values</span><span class="p">(</span><span class="n">params</span><span class="p">),</span> <span class="k">fn</span>
          <span class="n">layer</span> <span class="ow">when</span> <span class="n">is_map</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span> <span class="o">-&gt;</span>
            <span class="n">kernel</span> <span class="o">=</span> <span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="s2">"kernel"</span><span class="p">)</span> <span class="o">||</span> <span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="ss">:kernel</span><span class="p">)</span>
            <span class="n">bias</span> <span class="o">=</span> <span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="s2">"bias"</span><span class="p">)</span> <span class="o">||</span> <span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="ss">:bias</span><span class="p">)</span>

            <span class="k">if</span> <span class="n">match?</span><span class="p">(%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{},</span> <span class="n">kernel</span><span class="p">)</span> <span class="ow">and</span> <span class="n">match?</span><span class="p">(%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{},</span> <span class="n">bias</span><span class="p">)</span> <span class="k">do</span>
              <span class="n">layer</span>
            <span class="k">else</span>
              <span class="no">nil</span>
            <span class="k">end</span>

          <span class="n">_</span> <span class="o">-&gt;</span>
            <span class="no">nil</span>
        <span class="k">end</span><span class="p">)</span>

    <span class="n">kernel</span> <span class="o">=</span> <span class="n">layer_params</span> <span class="o">&amp;&amp;</span> <span class="p">(</span><span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">layer_params</span><span class="p">,</span> <span class="s2">"kernel"</span><span class="p">)</span> <span class="o">||</span> <span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">layer_params</span><span class="p">,</span> <span class="ss">:kernel</span><span class="p">))</span>
    <span class="n">bias</span> <span class="o">=</span> <span class="n">layer_params</span> <span class="o">&amp;&amp;</span> <span class="p">(</span><span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">layer_params</span><span class="p">,</span> <span class="s2">"bias"</span><span class="p">)</span> <span class="o">||</span> <span class="no">Map</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">layer_params</span><span class="p">,</span> <span class="ss">:bias</span><span class="p">))</span>

    <span class="k">if</span> <span class="n">match?</span><span class="p">(%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{},</span> <span class="n">kernel</span><span class="p">)</span> <span class="ow">and</span> <span class="n">match?</span><span class="p">(%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{},</span> <span class="n">bias</span><span class="p">)</span> <span class="k">do</span>
      <span class="p">%{</span><span class="ss">kernel:</span> <span class="no">Nx</span><span class="o">.</span><span class="n">as_type</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}),</span> <span class="ss">bias:</span> <span class="no">Nx</span><span class="o">.</span><span class="n">as_type</span><span class="p">(</span><span class="n">bias</span><span class="p">,</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})}</span>
    <span class="k">else</span>
      <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
            <span class="s2">"could not extract dense kernel/bias params for layer </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">layer_name</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
    <span class="k">end</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">dense_project</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="p">%{</span><span class="ss">kernel:</span> <span class="n">kernel</span><span class="p">,</span> <span class="ss">bias:</span> <span class="n">bias</span><span class="p">})</span> <span class="k">do</span>
    <span class="no">Nx</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">kernel</span><span class="p">),</span> <span class="n">bias</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="no">nil</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="no">System</span><span class="o">.</span><span class="n">unique_integer</span><span class="p">([</span><span class="ss">:positive</span><span class="p">])</span>
  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="n">seed</span>

  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"seed must be an integer or nil, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">validate_input_shape!</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">expected_d_in</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">case</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">do</span>
      <span class="p">{</span><span class="n">_</span><span class="p">,</span> <span class="o">^</span><span class="n">expected_d_in</span><span class="p">}</span> <span class="o">-&gt;</span>
        <span class="ss">:ok</span>

      <span class="n">shape</span> <span class="o">-&gt;</span>
        <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
              <span class="s2">"expected inputs shape {num_tokens, </span><span class="si">#{</span><span class="n">expected_d_in</span><span class="si">}</span><span class="s2">}, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
    <span class="k">end</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<p>This initialization is more complex than in v1, and it is based on Axon (LLM implements it based on Python <code class="language-plaintext highlighter-rouge">torch.nn.Linear</code>). For now, we will not dive deeper into this implementation, as this would require whole new blog post.</p>

<p>Here are some tests where we use the v1 and v2 modules, so you can get a feel for how much less code we need to write. context vectors values are different because we use different initializations:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test</span> <span class="s2">"Implementing a compact self-attention module"</span> <span class="k">do</span>
    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="c1"># Your (x^1)</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="c1"># journey (x^2)</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="c1"># starts (x^3)</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="c1"># with (x^4)</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="c1"># one (x^5)</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="c1"># step (x^6)</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">sa</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV1</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="n">context_vecs</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV1</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">sa</span><span class="p">,</span> <span class="n">inputs</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">context_vecs</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">}</span>

    <span class="n">expected_context_vecs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">0.07548463344573975</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.15017275512218475</span><span class="p">],</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">0.11537063866853714</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.18990936875343323</span><span class="p">],</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">0.11561498790979385</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.19015151262283325</span><span class="p">],</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">0.11222726106643677</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1867683380842209</span><span class="p">],</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">0.11577533185482025</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1902845799922943</span><span class="p">],</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">0.11009891331195831</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.18466056883335114</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">context_vecs</span><span class="p">,</span> <span class="n">expected_context_vecs</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"context_vecs should match expected values exactly"</span>
  <span class="k">end</span>
</code></pre></div></div>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
  <span class="n">test</span> <span class="s2">"compact self-attention module v2 uses Axon dense initialization"</span> <span class="k">do</span>
    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">sa</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">789</span><span class="p">)</span>

    <span class="n">context_vecs</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">sa</span><span class="p">,</span> <span class="n">inputs</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">context_vecs</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">}</span>

    <span class="n">expected_context_vecs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.20869487524032593</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.11512904614210129</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.1995905637741089</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.10041604191064835</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.197800412774086</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.09748103469610214</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.20753224194049835</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.11311019212007523</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.16690319776535034</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.04650232568383217</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.22278699278831482</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1379932463169098</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">context_vecs</span><span class="p">,</span> <span class="n">expected_context_vecs</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"context_vecs should match expected values exactly"</span>
  <span class="k">end</span>
</code></pre></div></div>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
  <span class="n">test</span> <span class="s2">"exercise 3.1 - transfer v2 weights into v1 and match outputs"</span> <span class="k">do</span>
    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">sa_v2</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="c1"># Axon dense kernels in this project are already shaped {d_in, d_out},</span>
    <span class="c1"># matching V1's expected projection weight layout.</span>
    <span class="c1"># goal of this exercise is to show that v1 and v2 have same calculations, only difference are weight Tensors</span>

    <span class="n">sa_v1</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV1</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
        <span class="ss">w_q:</span> <span class="n">sa_v2</span><span class="o">.</span><span class="n">w_q</span><span class="o">.</span><span class="n">kernel</span><span class="p">,</span>
        <span class="ss">w_k:</span> <span class="n">sa_v2</span><span class="o">.</span><span class="n">w_k</span><span class="o">.</span><span class="n">kernel</span><span class="p">,</span>
        <span class="ss">w_v:</span> <span class="n">sa_v2</span><span class="o">.</span><span class="n">w_v</span><span class="o">.</span><span class="n">kernel</span>
      <span class="p">)</span>

    <span class="n">context_v2</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">sa_v2</span><span class="p">,</span> <span class="n">inputs</span><span class="p">)</span>
    <span class="n">context_v1</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV1</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">sa_v1</span><span class="p">,</span> <span class="n">inputs</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">context_v1</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">}</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">context_v2</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">}</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">context_v1</span><span class="p">,</span> <span class="n">context_v2</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"after copying weights, v1 and v2 should produce the same outputs"</span>
  <span class="k">end</span>
</code></pre></div></div>

<h2 id="hiding-future-words-with-causal-attention">Hiding future words with causal attention</h2>

<p>We are moving from simple attention to causal attention mechanism. Simple attention considered all tokens, before and after the current token position for calculating attention scores. Causal attention is considering only tokens before current postion. Why? Because we want to train LLM to predict token only using previous tokens, not to be able to “cheat” and look for the answers on the right side.</p>

<p>For example, take the six-word sentence <code class="language-plaintext highlighter-rouge">I love learning with tiny models</code>. Before masking, a query token can assign attention scores to every token in the sentence:</p>

<table>
  <thead>
    <tr>
      <th>Query \ Key</th>
      <th style="text-align: right">I</th>
      <th style="text-align: right">love</th>
      <th style="text-align: right">learning</th>
      <th style="text-align: right">with</th>
      <th style="text-align: right">tiny</th>
      <th style="text-align: right">models</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>I</td>
      <td style="text-align: right">0.25</td>
      <td style="text-align: right">0.20</td>
      <td style="text-align: right">0.18</td>
      <td style="text-align: right">0.12</td>
      <td style="text-align: right">0.13</td>
      <td style="text-align: right">0.12</td>
    </tr>
    <tr>
      <td>love</td>
      <td style="text-align: right">0.14</td>
      <td style="text-align: right">0.28</td>
      <td style="text-align: right">0.19</td>
      <td style="text-align: right">0.16</td>
      <td style="text-align: right">0.11</td>
      <td style="text-align: right">0.12</td>
    </tr>
    <tr>
      <td>learning</td>
      <td style="text-align: right">0.10</td>
      <td style="text-align: right">0.16</td>
      <td style="text-align: right">0.29</td>
      <td style="text-align: right">0.18</td>
      <td style="text-align: right">0.15</td>
      <td style="text-align: right">0.12</td>
    </tr>
    <tr>
      <td>with</td>
      <td style="text-align: right">0.09</td>
      <td style="text-align: right">0.13</td>
      <td style="text-align: right">0.19</td>
      <td style="text-align: right">0.28</td>
      <td style="text-align: right">0.17</td>
      <td style="text-align: right">0.14</td>
    </tr>
    <tr>
      <td>tiny</td>
      <td style="text-align: right">0.08</td>
      <td style="text-align: right">0.11</td>
      <td style="text-align: right">0.13</td>
      <td style="text-align: right">0.20</td>
      <td style="text-align: right">0.31</td>
      <td style="text-align: right">0.17</td>
    </tr>
    <tr>
      <td>models</td>
      <td style="text-align: right">0.07</td>
      <td style="text-align: right">0.09</td>
      <td style="text-align: right">0.11</td>
      <td style="text-align: right">0.14</td>
      <td style="text-align: right">0.21</td>
      <td style="text-align: right">0.38</td>
    </tr>
  </tbody>
</table>

<p>With causal masking, we block all attention weights above the diagonal, so each token can only attend to itself and earlier tokens:</p>

<table>
  <thead>
    <tr>
      <th>Query \ Key</th>
      <th style="text-align: right">I</th>
      <th style="text-align: right">love</th>
      <th style="text-align: right">learning</th>
      <th style="text-align: right">with</th>
      <th style="text-align: right">tiny</th>
      <th style="text-align: right">models</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>I</td>
      <td style="text-align: right">1.00</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
    </tr>
    <tr>
      <td>love</td>
      <td style="text-align: right">0.22</td>
      <td style="text-align: right">0.78</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
    </tr>
    <tr>
      <td>learning</td>
      <td style="text-align: right">0.18</td>
      <td style="text-align: right">0.27</td>
      <td style="text-align: right">0.55</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
    </tr>
    <tr>
      <td>with</td>
      <td style="text-align: right">0.12</td>
      <td style="text-align: right">0.17</td>
      <td style="text-align: right">0.26</td>
      <td style="text-align: right">0.45</td>
      <td style="text-align: right">X</td>
      <td style="text-align: right">X</td>
    </tr>
    <tr>
      <td>tiny</td>
      <td style="text-align: right">0.10</td>
      <td style="text-align: right">0.13</td>
      <td style="text-align: right">0.16</td>
      <td style="text-align: right">0.24</td>
      <td style="text-align: right">0.37</td>
      <td style="text-align: right">X</td>
    </tr>
    <tr>
      <td>models</td>
      <td style="text-align: right">0.07</td>
      <td style="text-align: right">0.09</td>
      <td style="text-align: right">0.11</td>
      <td style="text-align: right">0.14</td>
      <td style="text-align: right">0.21</td>
      <td style="text-align: right">0.38</td>
    </tr>
  </tbody>
</table>

<p>Here <code class="language-plaintext highlighter-rouge">X</code> means “masked out”. In practice those entries are set to a very large negative value before softmax, so their final attention probability becomes zero.</p>

<p>The same flow can be visualized like this:</p>

<h3 id="applying-a-causal-attention-mask">Applying a causal attention mask</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>+---------------------------+   1) apply softmax   +-----------------------------+
| Attention scores          | -------------------&gt; | Attention weights           |
| (unnormalized)            |                      | (normalized)                |
+---------------------------+                      +-----------------------------+
                                                         |
                                                         | 2) mask 0's above diagonal
                                                         v
                                            +-----------------------------+
                                            | Masked attention scores     |
                                            | (unnormalized)              |
                                            +-----------------------------+
                                                         |
                                                         | 3) normalize rows
                                                         v
                                            +-----------------------------+
                                            | Masked attention weights    |
                                            | (normalized)                |
                                            +-----------------------------+

"Normalized" means that the values in each row sum to 1.
</code></pre></div></div>

<p>The ASCII diagram above shows what we are going to implement in our codebase. Follow the comments in the code below.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"casual attention mask"</span> <span class="k">do</span>
    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="c1"># first step</span>
    <span class="c1"># start with self attention module v2 from our previous chapter</span>

    <span class="n">sa_v2</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="c1">#torch like linear projection</span>

    <span class="n">queries</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa_v2</span><span class="o">.</span><span class="n">w_q</span><span class="p">)</span>
    <span class="n">keys</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">sa_v2</span><span class="o">.</span><span class="n">w_k</span><span class="p">)</span>

    <span class="c1"># {6,2} dot {6,2} = {6,6} where we connect on second axis of first and second Tensor</span>
    <span class="c1"># remember the condition for Tensor dot product!</span>
    <span class="c1"># we are actually transponding second Tensor</span>

    <span class="n">attn_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">queries</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">keys</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">])</span>

    <span class="c1"># normalize attention scores with softmax to get attention weights, sum of each row is 1</span>

    <span class="n">attn_weights</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="n">attn_scores</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">axis_size</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)))</span>
      <span class="o">|&gt;</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">expected_attn_weights</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="mf">0.1531660407781601</span><span class="p">,</span>
            <span class="mf">0.1543799340724945</span><span class="p">,</span>
            <span class="mf">0.1536969095468521</span><span class="p">,</span>
            <span class="mf">0.1883421689271927</span><span class="p">,</span>
            <span class="mf">0.1566835194826126</span><span class="p">,</span>
            <span class="mf">0.1937314122915268</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.14633320271968842</span><span class="p">,</span>
            <span class="mf">0.14933642745018005</span><span class="p">,</span>
            <span class="mf">0.14835353195667267</span><span class="p">,</span>
            <span class="mf">0.19792792201042175</span><span class="p">,</span>
            <span class="mf">0.1513950079679489</span><span class="p">,</span>
            <span class="mf">0.2066539078950882</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.14623598754405975</span><span class="p">,</span>
            <span class="mf">0.14969584345817566</span><span class="p">,</span>
            <span class="mf">0.1487070620059967</span><span class="p">,</span>
            <span class="mf">0.19760626554489136</span><span class="p">,</span>
            <span class="mf">0.15132319927215576</span><span class="p">,</span>
            <span class="mf">0.20643165707588196</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.15648190677165985</span><span class="p">,</span>
            <span class="mf">0.15691247582435608</span><span class="p">,</span>
            <span class="mf">0.15638375282287598</span><span class="p">,</span>
            <span class="mf">0.1835598349571228</span><span class="p">,</span>
            <span class="mf">0.15919820964336395</span><span class="p">,</span>
            <span class="mf">0.18746380507946014</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.15005655586719513</span><span class="p">,</span>
            <span class="mf">0.1612723171710968</span><span class="p">,</span>
            <span class="mf">0.1604372262954712</span><span class="p">,</span>
            <span class="mf">0.18309712409973145</span><span class="p">,</span>
            <span class="mf">0.15431715548038483</span><span class="p">,</span>
            <span class="mf">0.1908196359872818</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.15678176283836365</span><span class="p">,</span>
            <span class="mf">0.1527770310640335</span><span class="p">,</span>
            <span class="mf">0.15225917100906372</span><span class="p">,</span>
            <span class="mf">0.1877351701259613</span><span class="p">,</span>
            <span class="mf">0.15941700339317322</span><span class="p">,</span>
            <span class="mf">0.1910298764705658</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="n">expected_attn_weights</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"attn_weights should match expected values"</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">}</span>

    <span class="c1"># second step, zero out values above the diagonal in attention weights</span>

    <span class="n">context_length</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">axis_size</span><span class="p">(</span><span class="n">attn_scores</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

    <span class="n">mask_simple</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="p">{</span><span class="n">context_length</span><span class="p">,</span> <span class="n">context_length</span><span class="p">})</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tril</span><span class="p">()</span>

    <span class="n">mask_simple_expected</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">mask_simple</span><span class="p">,</span> <span class="n">mask_simple_expected</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"mask_simple should match expected values"</span>

    <span class="c1"># to apply mask we are multiplying attention weights with mask_simple</span>
    <span class="c1"># {6 x 6} x {6, 6} = {6, 6}</span>
    <span class="c1"># multiply first row element in first tensor with first row element in second Tensor.</span>
    <span class="c1"># then continue with second elements in row</span>
    <span class="c1"># effect is that elements multiplied with zero are zeroed out</span>

    <span class="n">masked_attn_weights</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="n">mask_simple</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">masked_attn_weights</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">}</span>

    <span class="n">masked_attn_weights_expected</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="p">[</span><span class="mf">0.1531660407781601</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.14633320271968842</span><span class="p">,</span> <span class="mf">0.14933642745018005</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.14623598754405975</span><span class="p">,</span> <span class="mf">0.14969584345817566</span><span class="p">,</span> <span class="mf">0.1487070620059967</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
            <span class="p">[</span>
              <span class="mf">0.15648190677165985</span><span class="p">,</span>
              <span class="mf">0.15691247582435608</span><span class="p">,</span>
              <span class="mf">0.15638375282287598</span><span class="p">,</span>
              <span class="mf">0.1835598349571228</span><span class="p">,</span>
              <span class="mf">0.0</span><span class="p">,</span>
              <span class="mf">0.0</span>
            <span class="p">],</span>
            <span class="p">[</span>
              <span class="mf">0.15005655586719513</span><span class="p">,</span>
              <span class="mf">0.1612723171710968</span><span class="p">,</span>
              <span class="mf">0.1604372262954712</span><span class="p">,</span>
              <span class="mf">0.18309712409973145</span><span class="p">,</span>
              <span class="mf">0.15431715548038483</span><span class="p">,</span>
              <span class="mf">0.0</span>
            <span class="p">],</span>
            <span class="p">[</span>
              <span class="mf">0.15678176283836365</span><span class="p">,</span>
              <span class="mf">0.1527770310640335</span><span class="p">,</span>
              <span class="mf">0.15225917100906372</span><span class="p">,</span>
              <span class="mf">0.1877351701259613</span><span class="p">,</span>
              <span class="mf">0.15941700339317322</span><span class="p">,</span>
              <span class="mf">0.1910298764705658</span>
            <span class="p">]</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">masked_attn_weights</span><span class="p">,</span> <span class="n">masked_attn_weights_expected</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"masked_attn_weights should match expected values"</span>

    <span class="c1"># Third step is row normalization for masked attention weights, sum of each row must be 1</span>
    <span class="c1"># sum of each row</span>

    <span class="n">row_sums</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">masked_attn_weights</span><span class="p">,</span> <span class="ss">axes:</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="ss">keep_axes:</span> <span class="no">true</span><span class="p">)</span>

    <span class="c1"># each row element is divided with row sum</span>

    <span class="n">masked_attn_weights_norm</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="n">masked_attn_weights</span><span class="p">,</span> <span class="n">row_sums</span><span class="p">)</span>

    <span class="n">expected_masked_attn_weights_norm</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.49492135643959045</span><span class="p">,</span> <span class="mf">0.5050787329673767</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.32888707518577576</span><span class="p">,</span> <span class="mf">0.33666834235191345</span><span class="p">,</span> <span class="mf">0.3344445526599884</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.2395114302635193</span><span class="p">,</span>
            <span class="mf">0.2401704639196396</span><span class="p">,</span>
            <span class="mf">0.23936119675636292</span><span class="p">,</span>
            <span class="mf">0.2809569537639618</span><span class="p">,</span>
            <span class="mf">0.0</span><span class="p">,</span>
            <span class="mf">0.0</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.18544265627861023</span><span class="p">,</span>
            <span class="mf">0.1993032991886139</span><span class="p">,</span>
            <span class="mf">0.1982712745666504</span><span class="p">,</span>
            <span class="mf">0.2262747883796692</span><span class="p">,</span>
            <span class="mf">0.1907079815864563</span><span class="p">,</span>
            <span class="mf">0.0</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.15678176283836365</span><span class="p">,</span>
            <span class="mf">0.1527770310640335</span><span class="p">,</span>
            <span class="mf">0.15225917100906372</span><span class="p">,</span>
            <span class="mf">0.1877351701259613</span><span class="p">,</span>
            <span class="mf">0.15941700339317322</span><span class="p">,</span>
            <span class="mf">0.1910298764705658</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">masked_attn_weights_norm</span><span class="p">,</span> <span class="n">expected_masked_attn_weights_norm</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"masked_attn_weights_norm should match expected values"</span>

    <span class="c1"># improvement, masking witn negative infinity instead of 1</span>
    <span class="c1"># this is again step 2, we start from attn_scores</span>

    <span class="n">mask</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="p">{</span><span class="n">context_length</span><span class="p">,</span> <span class="n">context_length</span><span class="p">})</span>
      <span class="c1"># upper triangle above diagonal</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="ss">k:</span> <span class="mi">1</span><span class="p">)</span>

    <span class="n">mask_bool</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">greater</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span>
    <span class="n">neg_inf</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="ss">:neg_infinity</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_scores</span><span class="p">))</span>
    <span class="n">masked_neg_inf_attn_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="n">mask_bool</span><span class="p">,</span> <span class="n">neg_inf</span><span class="p">,</span> <span class="n">attn_scores</span><span class="p">)</span>

    <span class="n">expected_masked_neg_inf_att_scores</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="o">-</span><span class="mf">0.5037099719047546</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="o">-</span><span class="mf">0.7201937437057495</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.69146329164505</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="o">-</span><span class="mf">0.7123136520385742</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.6792438626289368</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.6886162161827087</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="o">-</span><span class="mf">0.3948274254798889</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.3909415006637573</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.39571473002433777</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.16911853849887848</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="o">-</span><span class="mf">0.3698960840702057</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.2679566442966461</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.2752986252307892</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.08846122026443481</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.3303014636039734</span><span class="p">,</span>
            <span class="ss">:neg_infinity</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="o">-</span><span class="mf">0.4973980784416199</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.533991277217865</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.53879314661026</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.24258869886398315</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.473825067281723</span><span class="p">,</span>
            <span class="o">-</span><span class="mf">0.21798484027385712</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span><span class="n">masked_neg_inf_attn_scores</span><span class="p">,</span> <span class="n">expected_masked_neg_inf_att_scores</span><span class="p">))</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"masked should match expected values"</span>

    <span class="c1"># softmax normalization, note that rows are summing to one out of the box, thanks to negative infinity trick!</span>

    <span class="n">masked_neg_inf_attn_weights_causal</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="n">masked_neg_inf_attn_scores</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">axis_size</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)))</span>
      <span class="o">|&gt;</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">expected_masked_neg_inf_attn_weights_causal</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.4949212968349457</span><span class="p">,</span> <span class="mf">0.5050786733627319</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.32888710498809814</span><span class="p">,</span> <span class="mf">0.33666837215423584</span><span class="p">,</span> <span class="mf">0.3344445526599884</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.23951144516468048</span><span class="p">,</span>
            <span class="mf">0.2401704639196396</span><span class="p">,</span>
            <span class="mf">0.2393612116575241</span><span class="p">,</span>
            <span class="mf">0.2809569537639618</span><span class="p">,</span>
            <span class="mf">0.0</span><span class="p">,</span>
            <span class="mf">0.0</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.18544265627861023</span><span class="p">,</span>
            <span class="mf">0.1993032991886139</span><span class="p">,</span>
            <span class="mf">0.1982712745666504</span><span class="p">,</span>
            <span class="mf">0.22627480328083038</span><span class="p">,</span>
            <span class="mf">0.1907079666852951</span><span class="p">,</span>
            <span class="mf">0.0</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.15678176283836365</span><span class="p">,</span>
            <span class="mf">0.1527770310640335</span><span class="p">,</span>
            <span class="mf">0.15225917100906372</span><span class="p">,</span>
            <span class="mf">0.1877351701259613</span><span class="p">,</span>
            <span class="mf">0.15941700339317322</span><span class="p">,</span>
            <span class="mf">0.1910298764705658</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">masked_neg_inf_attn_weights_causal</span><span class="p">,</span> <span class="n">expected_masked_neg_inf_attn_weights_causal</span><span class="p">,</span>
             <span class="ss">atol:</span> <span class="mf">1.0e-6</span>
           <span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"masked_attn_weights_causal should match expected values"</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>I hope that code comments were desctiptive enough, so you understood what is going on in casual attention.</p>

<h3 id="masking-additional-attention-weights-with-dropout">Masking additional attention weights with dropout</h3>

<p>In dropout, we randomly “drop” selected values from attention-weight tensors, or from the final context-vector tensor. Dropping a value means setting it to zero. In this example, we will drop values from the attention-weights tensor. Dropout is used only during the training phase. Why do we need to drop values? To prevent model overfitting, where the model could become overly reliant on a specific set of hidden-layer units.</p>

<p>Assume the dropout probability is <code class="language-plaintext highlighter-rouge">p = 0.5</code>, which means:</p>

<ul>
  <li>each visible attention weight has a 50% chance to be dropped</li>
  <li>kept values are scaled by <code class="language-plaintext highlighter-rouge">1 / (1 - p) = 2.0</code></li>
  <li>masked values stay zero</li>
</ul>

<p>Below is one possible dropout-mask example. <code class="language-plaintext highlighter-rouge">keep</code> means the value survives dropout, and <code class="language-plaintext highlighter-rouge">drop</code> means it becomes <code class="language-plaintext highlighter-rouge">0.0</code>.</p>

<table>
  <thead>
    <tr>
      <th>Query \ Key</th>
      <th>I</th>
      <th>love</th>
      <th>learning</th>
      <th>with</th>
      <th>tiny</th>
      <th>models</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>I</td>
      <td>keep</td>
      <td>X</td>
      <td>X</td>
      <td>X</td>
      <td>X</td>
      <td>X</td>
    </tr>
    <tr>
      <td>love</td>
      <td>drop</td>
      <td>keep</td>
      <td>X</td>
      <td>X</td>
      <td>X</td>
      <td>X</td>
    </tr>
    <tr>
      <td>learning</td>
      <td>keep</td>
      <td>drop</td>
      <td>keep</td>
      <td>X</td>
      <td>X</td>
      <td>X</td>
    </tr>
    <tr>
      <td>with</td>
      <td>keep</td>
      <td>keep</td>
      <td>drop</td>
      <td>keep</td>
      <td>X</td>
      <td>X</td>
    </tr>
    <tr>
      <td>tiny</td>
      <td>drop</td>
      <td>keep</td>
      <td>keep</td>
      <td>drop</td>
      <td>keep</td>
      <td>X</td>
    </tr>
    <tr>
      <td>models</td>
      <td>keep</td>
      <td>drop</td>
      <td>keep</td>
      <td>keep</td>
      <td>drop</td>
      <td>keep</td>
    </tr>
  </tbody>
</table>

<p>Applying that mask to the causal attention weights and scaling the kept values by <code class="language-plaintext highlighter-rouge">2.0</code> gives this final dropout result:</p>

<table>
  <thead>
    <tr>
      <th>Query \ Key</th>
      <th style="text-align: right">I</th>
      <th style="text-align: right">love</th>
      <th style="text-align: right">learning</th>
      <th style="text-align: right">with</th>
      <th style="text-align: right">tiny</th>
      <th style="text-align: right">models</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>I</td>
      <td style="text-align: right">2.00</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
    </tr>
    <tr>
      <td>love</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">1.56</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
    </tr>
    <tr>
      <td>learning</td>
      <td style="text-align: right">0.36</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">1.10</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
    </tr>
    <tr>
      <td>with</td>
      <td style="text-align: right">0.24</td>
      <td style="text-align: right">0.34</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.90</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.00</td>
    </tr>
    <tr>
      <td>tiny</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.26</td>
      <td style="text-align: right">0.32</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.74</td>
      <td style="text-align: right">0.00</td>
    </tr>
    <tr>
      <td>models</td>
      <td style="text-align: right">0.14</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.22</td>
      <td style="text-align: right">0.28</td>
      <td style="text-align: right">0.00</td>
      <td style="text-align: right">0.76</td>
    </tr>
  </tbody>
</table>

<p>Here is the test that applies dropout:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"dropout"</span> <span class="k">do</span>
    <span class="n">key</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">123</span><span class="p">)</span>

    <span class="c1"># Set a {6 x 6} tensor filled with ones.</span>

    <span class="n">example</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="p">{</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">})</span>
    <span class="c1"># Axon provides a dropout function.</span>

    <span class="p">%</span><span class="no">Axon</span><span class="o">.</span><span class="no">StatefulOutput</span><span class="p">{</span><span class="ss">output:</span> <span class="n">dropped</span><span class="p">,</span> <span class="ss">state:</span> <span class="p">%{</span><span class="s2">"key"</span> <span class="o">=&gt;</span> <span class="n">_new_key</span><span class="p">}}</span> <span class="o">=</span>
      <span class="no">Axon</span><span class="o">.</span><span class="no">Layers</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">example</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="ss">rate:</span> <span class="mf">0.5</span><span class="p">,</span> <span class="ss">mode:</span> <span class="ss">:train</span><span class="p">)</span>

    <span class="n">expected_dropped</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">2.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">dropped</span><span class="p">,</span> <span class="n">expected_dropped</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"dropped should match expected values"</span>

    <span class="c1"># Causal attention weights from the previous test.</span>

    <span class="n">masked_neg_inf_attn_weights_causal</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.4949212968349457</span><span class="p">,</span> <span class="mf">0.5050786733627319</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.32888710498809814</span><span class="p">,</span> <span class="mf">0.33666837215423584</span><span class="p">,</span> <span class="mf">0.3344445526599884</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.23951144516468048</span><span class="p">,</span>
            <span class="mf">0.2401704639196396</span><span class="p">,</span>
            <span class="mf">0.2393612116575241</span><span class="p">,</span>
            <span class="mf">0.2809569537639618</span><span class="p">,</span>
            <span class="mf">0.0</span><span class="p">,</span>
            <span class="mf">0.0</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.18544265627861023</span><span class="p">,</span>
            <span class="mf">0.1993032991886139</span><span class="p">,</span>
            <span class="mf">0.1982712745666504</span><span class="p">,</span>
            <span class="mf">0.22627480328083038</span><span class="p">,</span>
            <span class="mf">0.1907079666852951</span><span class="p">,</span>
            <span class="mf">0.0</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="mf">0.15678176283836365</span><span class="p">,</span>
            <span class="mf">0.1527770310640335</span><span class="p">,</span>
            <span class="mf">0.15225917100906372</span><span class="p">,</span>
            <span class="mf">0.1877351701259613</span><span class="p">,</span>
            <span class="mf">0.15941700339317322</span><span class="p">,</span>
            <span class="mf">0.1910298764705658</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="p">%</span><span class="no">Axon</span><span class="o">.</span><span class="no">StatefulOutput</span><span class="p">{</span><span class="ss">output:</span> <span class="n">masked_attn_weights_causal_dropped</span><span class="p">,</span> <span class="ss">state:</span> <span class="p">%{</span><span class="s2">"key"</span> <span class="o">=&gt;</span> <span class="n">_new_key</span><span class="p">}}</span> <span class="o">=</span>
      <span class="no">Axon</span><span class="o">.</span><span class="no">Layers</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">masked_neg_inf_attn_weights_causal</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="ss">rate:</span> <span class="mf">0.5</span><span class="p">,</span> <span class="ss">mode:</span> <span class="ss">:train</span><span class="p">)</span>

    <span class="n">expected_masked_neg_inf_attn_weights_causal_dropped</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.47902289032936096</span><span class="p">,</span> <span class="mf">0.4803409278392792</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.37088531255722046</span><span class="p">,</span> <span class="mf">0.3986065983772278</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.3814159333705902</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.30451834201812744</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.31883400678634644</span><span class="p">,</span> <span class="mf">0.3820597529411316</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span>
             <span class="n">masked_neg_inf_attn_weights_causal_dropped</span><span class="p">,</span>
             <span class="n">expected_masked_neg_inf_attn_weights_causal_dropped</span><span class="p">,</span>
             <span class="ss">atol:</span> <span class="mf">1.0e-6</span>
           <span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span>
           <span class="s2">"masked_attn_weights_causal_dropped should match expected values"</span>

  <span class="k">end</span>
</code></pre></div></div>

<h3 id="implementing-a-compact-causal-attention-elixir-module">Implementing a compact causal attention Elixir module</h3>

<p>We are implementing an Elixir module for the causal-attention and dropout steps, and then using this module in a test. For the test, we prepare a batch that consists of two identical inputs.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
  <span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">CausalAttention</span> <span class="k">do</span>
  <span class="nv">@moduledoc</span> <span class="sd">"""
  Causal self-attention over batched token sequences.

  Mirrors the PyTorch `CausalAttention` module:

    * query/key/value dense projections
    * causal upper-triangular mask stored in the module state
    * attention dropout
  """</span>
  <span class="n">alias</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span>

  <span class="k">defstruct</span> <span class="p">[</span>
    <span class="ss">:w_q</span><span class="p">,</span>
    <span class="ss">:w_k</span><span class="p">,</span>
    <span class="ss">:w_v</span><span class="p">,</span>
    <span class="ss">:mask</span><span class="p">,</span>
    <span class="ss">:d_in</span><span class="p">,</span>
    <span class="ss">:d_out</span><span class="p">,</span>
    <span class="ss">:context_length</span><span class="p">,</span>
    <span class="ss">:dropout</span><span class="p">,</span>
    <span class="ss">:qkv_bias</span><span class="p">,</span>
    <span class="ss">:seed</span>
  <span class="p">]</span>

  <span class="nv">@type</span> <span class="n">dense_weights</span> <span class="p">::</span> <span class="p">%{</span><span class="ss">kernel:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="ss">bias:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()}</span>

  <span class="nv">@type</span> <span class="n">t</span> <span class="p">::</span> <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
          <span class="ss">w_q:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">w_k:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">w_v:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">mask:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span>
          <span class="ss">d_in:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">d_out:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">context_length:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">dropout:</span> <span class="n">float</span><span class="p">(),</span>
          <span class="ss">qkv_bias:</span> <span class="n">boolean</span><span class="p">(),</span>
          <span class="ss">seed:</span> <span class="n">integer</span><span class="p">()</span>
        <span class="p">}</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">number</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a causal attention module with default `qkv_bias: false`.

  Equivalent to `new(d_in, d_out, context_length, dropout, false, [])`.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">),</span>
    <span class="k">do</span><span class="p">:</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="p">[])</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">number</span><span class="p">(),</span> <span class="n">boolean</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a causal attention module with explicit `qkv_bias`.

  Equivalent to `new(d_in, d_out, context_length, dropout, qkv_bias, [])`.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">),</span>
    <span class="k">do</span><span class="p">:</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="p">[])</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">number</span><span class="p">(),</span> <span class="n">boolean</span><span class="p">(),</span> <span class="n">keyword</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a causal attention module.

  ## Arguments

    * `d_in` - input feature size.
    * `d_out` - projection/output feature size.
    * `context_length` - maximum sequence length for the causal mask.
    * `dropout` - dropout rate in `[0, 1)`, applied to attention weights.
    * `qkv_bias` - whether query/key/value dense layers use bias.
    * `opts` - keyword options:
      `:seed` (optional, deterministic initialization).

  ## Notes

  `mask` in the struct is the Elixir/Nx equivalent of PyTorch `register_buffer`
  for the upper-triangular causal mask.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="n">opts</span><span class="p">)</span>
      <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_in</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_in</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_out</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_out</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span>
             <span class="n">is_integer</span><span class="p">(</span><span class="n">context_length</span><span class="p">)</span> <span class="ow">and</span> <span class="n">context_length</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">do</span>
    <span class="n">seed</span> <span class="o">=</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">))</span>
    <span class="n">qkv_bias</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">normalize_qkv_bias</span><span class="p">(</span><span class="n">qkv_bias</span><span class="p">)</span>
    <span class="n">dropout</span> <span class="o">=</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>

    <span class="n">w_q</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="s2">"q_proj"</span><span class="p">)</span>
    <span class="n">w_k</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="s2">"k_proj"</span><span class="p">)</span>
    <span class="n">w_v</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="s2">"v_proj"</span><span class="p">)</span>

    <span class="n">mask</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="p">{</span><span class="n">context_length</span><span class="p">,</span> <span class="n">context_length</span><span class="p">})</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="ss">k:</span> <span class="mi">1</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">as_type</span><span class="p">({</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
      <span class="ss">w_q:</span> <span class="n">w_q</span><span class="p">,</span>
      <span class="ss">w_k:</span> <span class="n">w_k</span><span class="p">,</span>
      <span class="ss">w_v:</span> <span class="n">w_v</span><span class="p">,</span>
      <span class="ss">mask:</span> <span class="n">mask</span><span class="p">,</span>
      <span class="ss">d_in:</span> <span class="n">d_in</span><span class="p">,</span>
      <span class="ss">d_out:</span> <span class="n">d_out</span><span class="p">,</span>
      <span class="ss">context_length:</span> <span class="n">context_length</span><span class="p">,</span>
      <span class="ss">dropout:</span> <span class="n">dropout</span><span class="p">,</span>
      <span class="ss">qkv_bias:</span> <span class="n">qkv_bias</span><span class="p">,</span>
      <span class="ss">seed:</span> <span class="n">seed</span>
    <span class="p">}</span>
  <span class="k">end</span>

  <span class="nv">@spec</span> <span class="n">forward</span><span class="p">(</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="n">keyword</span><span class="p">())</span> <span class="p">::</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Computes causal self-attention context vectors for a batched input.

  ## Arguments

    * `ca` - `%LlmScratch.CausalAttention{}` module state.
    * `x` - input tensor of shape `{batch_size, num_tokens, d_in}`.
    * `opts` - keyword options:
      `:mode` (`:train` or `:inference`, default `:train`),
      `:key` (optional Nx random key for dropout when in train mode).

  ## Returns

    * context tensor of shape `{batch_size, num_tokens, d_out}`.
  """</span>
  <span class="k">def</span> <span class="n">forward</span><span class="p">(%</span><span class="bp">__MODULE__</span><span class="p">{}</span> <span class="o">=</span> <span class="n">ca</span><span class="p">,</span> <span class="p">%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{}</span> <span class="o">=</span> <span class="n">x</span><span class="p">,</span> <span class="n">opts</span> <span class="p">\\</span> <span class="p">[])</span> <span class="k">do</span>
    <span class="p">{</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">_</span><span class="p">}</span> <span class="o">=</span> <span class="n">validate_input_shape!</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">ca</span><span class="o">.</span><span class="n">d_in</span><span class="p">)</span>
    <span class="n">validate_context_length!</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">ca</span><span class="o">.</span><span class="n">context_length</span><span class="p">)</span>

    <span class="n">keys</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">ca</span><span class="o">.</span><span class="n">w_k</span><span class="p">)</span>
    <span class="n">queries</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">ca</span><span class="o">.</span><span class="n">w_q</span><span class="p">)</span>
    <span class="n">values</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">ca</span><span class="o">.</span><span class="n">w_v</span><span class="p">)</span>

    <span class="c1"># {no_of_batches, num_tokens, d_out} dot {no_of_batches, num_tokens, d_out}</span>
    <span class="c1"># we are dotting tensors over axis 2 (last one), [0] means that batch is on axis 0</span>
    <span class="c1"># so for each batch we are dotting {num_tokens, d_out} dot {num_tokens, d_out}</span>

    <span class="n">attn_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">queries</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">keys</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>

    <span class="n">mask</span> <span class="o">=</span>
      <span class="n">ca</span><span class="o">.</span><span class="n">mask</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">slice</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">])</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">greater</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">({</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">})</span>

    <span class="n">neg_inf</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="ss">:neg_infinity</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_scores</span><span class="p">))</span>
    <span class="n">masked_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">neg_inf</span><span class="p">,</span> <span class="n">attn_scores</span><span class="p">)</span>

    <span class="n">attn_weights</span> <span class="o">=</span>
      <span class="n">masked_scores</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">axis_size</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)))</span>
      <span class="o">|&gt;</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="n">maybe_dropout</span><span class="p">(</span><span class="n">ca</span><span class="p">,</span> <span class="n">opts</span><span class="p">)</span>

    <span class="c1"># {no_of_batches, no_of_tokens, no_of_tokens} dot {no_of_batches, no_of_tokens, d_out}</span>
    <span class="c1"># {no_of_batches, no_of_tokens, d_out}</span>

    <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">values</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">maybe_dropout</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="p">%{</span><span class="ss">dropout:</span> <span class="n">dropout</span><span class="p">},</span> <span class="n">_opts</span><span class="p">)</span> <span class="ow">when</span> <span class="n">dropout</span> <span class="o">&lt;=</span> <span class="mf">0.0</span><span class="p">,</span>
    <span class="k">do</span><span class="p">:</span> <span class="n">attn_weights</span>

  <span class="k">defp</span> <span class="n">maybe_dropout</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="p">%{</span><span class="ss">dropout:</span> <span class="n">dropout</span><span class="p">,</span> <span class="ss">seed:</span> <span class="n">seed</span><span class="p">},</span> <span class="n">opts</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">mode</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:mode</span><span class="p">,</span> <span class="ss">:train</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="ss">:train</span> <span class="k">do</span>
      <span class="n">key</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:key</span><span class="p">)</span> <span class="o">||</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>

      <span class="p">%</span><span class="no">Axon</span><span class="o">.</span><span class="no">StatefulOutput</span><span class="p">{</span><span class="ss">output:</span> <span class="n">dropped</span><span class="p">,</span> <span class="ss">state:</span> <span class="p">%{</span><span class="s2">"key"</span> <span class="o">=&gt;</span> <span class="n">_new_key</span><span class="p">}}</span> <span class="o">=</span>
        <span class="no">Axon</span><span class="o">.</span><span class="no">Layers</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="ss">rate:</span> <span class="n">dropout</span><span class="p">,</span> <span class="ss">mode:</span> <span class="ss">:train</span><span class="p">)</span>

      <span class="n">dropped</span>
    <span class="k">else</span>
      <span class="n">attn_weights</span>
    <span class="k">end</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="no">nil</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="no">System</span><span class="o">.</span><span class="n">unique_integer</span><span class="p">([</span><span class="ss">:positive</span><span class="p">])</span>
  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="n">seed</span>

  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"seed must be an integer or nil, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_number</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="ow">and</span> <span class="n">dropout</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">dropout</span> <span class="o">&lt;</span> <span class="mi">1</span> <span class="k">do</span>
    <span class="n">dropout</span> <span class="o">*</span> <span class="mf">1.0</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"dropout must be a number in [0, 1), got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">validate_input_shape!</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">expected_d_in</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">case</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">do</span>
      <span class="p">{</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="o">^</span><span class="n">expected_d_in</span><span class="p">}</span> <span class="o">-&gt;</span>
        <span class="p">{</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">expected_d_in</span><span class="p">}</span>

      <span class="n">shape</span> <span class="o">-&gt;</span>
        <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
              <span class="s2">"expected inputs shape {batch_size, num_tokens, </span><span class="si">#{</span><span class="n">expected_d_in</span><span class="si">}</span><span class="s2">}, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
    <span class="k">end</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">validate_context_length!</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">context_length</span><span class="p">)</span> <span class="ow">when</span> <span class="n">num_tokens</span> <span class="o">&lt;=</span> <span class="n">context_length</span><span class="p">,</span>
    <span class="k">do</span><span class="p">:</span> <span class="ss">:ok</span>

  <span class="k">defp</span> <span class="n">validate_context_length!</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">context_length</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
          <span class="s2">"num_tokens (</span><span class="si">#{</span><span class="n">num_tokens</span><span class="si">}</span><span class="s2">) exceeds context_length (</span><span class="si">#{</span><span class="n">context_length</span><span class="si">}</span><span class="s2">)"</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"causal attention matches stacked batch example"</span> <span class="k">do</span>
    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">d_in</span> <span class="o">=</span> <span class="mi">3</span>
    <span class="n">d_out</span> <span class="o">=</span> <span class="mi">2</span>

    <span class="c1"># batch dimension is {2,6,3}</span>

    <span class="n">batch</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">inputs</span><span class="p">,</span> <span class="n">inputs</span><span class="p">],</span> <span class="ss">axis:</span> <span class="mi">0</span><span class="p">)</span>

    <span class="c1">#context_length is 6</span>

    <span class="n">context_length</span> <span class="o">=</span> <span class="n">elem</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">batch</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">ca</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">CausalAttention</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="n">context_vecs</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">CausalAttention</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">ca</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="ss">mode:</span> <span class="ss">:inference</span><span class="p">)</span>

    <span class="c1"># we projected dimension third axis 3 =&gt; 2</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">context_vecs</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">}</span>

    <span class="c1"># we batched same input twice, which means that shape on axis: 0 is 2, as first outermost list has two elements</span>

    <span class="n">expected_context_vecs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.49523380398750305</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.17632800340652466</span><span class="p">],</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.07537277787923813</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.13790269196033478</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.06633053719997406</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.12039512395858765</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.11786159127950668</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.10831516981124878</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.1877504140138626</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.04864511638879776</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.1768769919872284</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.08047633618116379</span><span class="p">]</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.49523380398750305</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.17632800340652466</span><span class="p">],</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.07537277787923813</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.13790269196033478</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.06633053719997406</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.12039512395858765</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.11786159127950668</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.10831516981124878</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.1877504140138626</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.04864511638879776</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.1768769919872284</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.08047633618116379</span><span class="p">]</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">context_vecs</span><span class="p">,</span> <span class="n">expected_context_vecs</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span>
  <span class="k">end</span>
</code></pre></div></div>

<h2 id="extending-single-head-attention-to-multi-head-attention">Extending single-head attention to multi-head attention</h2>

<blockquote>
  <p><strong>Warning:</strong> This part is really hard!<br />
This section covers multi-head attention, which is one of the most complex and subtle parts in understanding transformers and LLMs. Take your time and revisit the section as needed—the payoff is worth it!</p>
</blockquote>

<p>So what are those heads? One head is one causal attention module, <code class="language-plaintext highlighter-rouge">LlmScratch.CausalAttention.new</code>, that processes one input at a time. Using a New Kids on the Block approach, step by step, we will first implement a multi-head causal attention module as a <code class="language-plaintext highlighter-rouge">stack</code> of causal attention modules.</p>

<h3 id="stacking-multiple-single-head-attention-layers">Stacking multiple single-head attention layers</h3>

<p>Let’s stack some causal attention modules (additional info is in the comments):</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttentionWrapper</span> <span class="k">do</span>
  <span class="nv">@moduledoc</span> <span class="sd">"""
  Thin wrapper that builds multiple independent causal-attention heads and
  concatenates their outputs on the last axis.

  Mirrors the PyTorch module:

      class MultiHeadAttentionWrapper(nn.Module):
          def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
              super().__init__()
              self.heads = nn.ModuleList(
                  [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
                   for _ in range(num_heads)]
              )

          def forward(self, x):
              return torch.cat([head(x) for head in self.heads], dim=-1)
  """</span>
  <span class="n">alias</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">CausalAttention</span>

  <span class="k">defstruct</span> <span class="p">[</span>
    <span class="ss">:heads</span><span class="p">,</span>
    <span class="ss">:d_in</span><span class="p">,</span>
    <span class="ss">:d_out</span><span class="p">,</span>
    <span class="ss">:context_length</span><span class="p">,</span>
    <span class="ss">:dropout</span><span class="p">,</span>
    <span class="ss">:num_heads</span><span class="p">,</span>
    <span class="ss">:qkv_bias</span><span class="p">,</span>
    <span class="ss">:seed</span>
  <span class="p">]</span>

  <span class="nv">@type</span> <span class="n">t</span> <span class="p">::</span> <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
          <span class="ss">heads:</span> <span class="p">[</span><span class="no">CausalAttention</span><span class="o">.</span><span class="n">t</span><span class="p">()],</span>
          <span class="ss">d_in:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">d_out:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">context_length:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">dropout:</span> <span class="n">float</span><span class="p">(),</span>
          <span class="ss">num_heads:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">qkv_bias:</span> <span class="n">boolean</span><span class="p">(),</span>
          <span class="ss">seed:</span> <span class="n">integer</span><span class="p">()</span>
        <span class="p">}</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">number</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a wrapper with default `qkv_bias: false`.

  Equivalent to `new(d_in, d_out, context_length, dropout, num_heads, false, [])`.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">),</span>
    <span class="k">do</span><span class="p">:</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="p">[])</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">number</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">boolean</span><span class="p">())</span> <span class="p">::</span>
          <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a wrapper with explicit `qkv_bias`.

  Equivalent to `new(d_in, d_out, context_length, dropout, num_heads, qkv_bias, [])`.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">),</span>
    <span class="k">do</span><span class="p">:</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="p">[])</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span>
          <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="n">number</span><span class="p">(),</span>
          <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="n">boolean</span><span class="p">(),</span>
          <span class="n">keyword</span><span class="p">()</span>
        <span class="p">)</span> <span class="p">::</span>
          <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a wrapper containing `num_heads` independent causal-attention heads.

  Unlike `LlmScratch.MultiheadAttention`, this module does not share Q/K/V
  projections across heads. Each head is a full `LlmScratch.CausalAttention`
  instance with output size `d_out`, so the concatenated output width is
  `num_heads * d_out`.

  ## Arguments

    * `d_in` - input feature size for each token
    * `d_out` - output feature size produced by each head
    * `context_length` - maximum sequence length supported by each head
    * `dropout` - dropout rate in `[0, 1)`, applied inside each head
    * `num_heads` - number of independent causal-attention heads
    * `qkv_bias` - whether each head's query/key/value projections use bias
    * `opts` - keyword options for initialization

  ## Options

    * `:seed` - deterministic base seed. Each head uses `seed + head_index`.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="n">opts</span><span class="p">)</span>
      <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_in</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_in</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_out</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_out</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span>
             <span class="n">is_integer</span><span class="p">(</span><span class="n">context_length</span><span class="p">)</span> <span class="ow">and</span> <span class="n">context_length</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">num_heads</span><span class="p">)</span> <span class="ow">and</span>
             <span class="n">num_heads</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">do</span>
    <span class="n">seed</span> <span class="o">=</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">))</span>
    <span class="n">dropout</span> <span class="o">=</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
    <span class="n">qkv_bias</span> <span class="o">=</span> <span class="n">normalize_qkv_bias</span><span class="p">(</span><span class="n">qkv_bias</span><span class="p">)</span>

    <span class="n">heads</span> <span class="o">=</span>
      <span class="n">for</span> <span class="n">index</span> <span class="o">&lt;-</span> <span class="mi">0</span><span class="o">..</span><span class="p">(</span><span class="n">num_heads</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">do</span>
        <span class="no">CausalAttention</span><span class="o">.</span><span class="n">new</span><span class="p">(</span>
          <span class="n">d_in</span><span class="p">,</span>
          <span class="n">d_out</span><span class="p">,</span>
          <span class="n">context_length</span><span class="p">,</span>
          <span class="n">dropout</span><span class="p">,</span>
          <span class="n">qkv_bias</span><span class="p">,</span>
          <span class="ss">seed:</span> <span class="n">seed</span> <span class="o">+</span> <span class="n">index</span>
        <span class="p">)</span>
      <span class="k">end</span>

    <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
      <span class="ss">heads:</span> <span class="n">heads</span><span class="p">,</span>
      <span class="ss">d_in:</span> <span class="n">d_in</span><span class="p">,</span>
      <span class="ss">d_out:</span> <span class="n">d_out</span><span class="p">,</span>
      <span class="ss">context_length:</span> <span class="n">context_length</span><span class="p">,</span>
      <span class="ss">dropout:</span> <span class="n">dropout</span><span class="p">,</span>
      <span class="ss">num_heads:</span> <span class="n">num_heads</span><span class="p">,</span>
      <span class="ss">qkv_bias:</span> <span class="n">qkv_bias</span><span class="p">,</span>
      <span class="ss">seed:</span> <span class="n">seed</span>
    <span class="p">}</span>
  <span class="k">end</span>

  <span class="nv">@spec</span> <span class="n">forward</span><span class="p">(</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="n">keyword</span><span class="p">())</span> <span class="p">::</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Runs all heads on the same input and concatenates their context vectors on the
  last axis.

  ## Arguments

    * `wrapper` - `%LlmScratch.MultiheadAttentionWrapper{}`
    * `x` - input tensor of shape `{batch_size, num_tokens, d_in}`
    * `opts` - same forward options accepted by
      `LlmScratch.CausalAttention.forward/3`

  ## Returns

    * tensor of shape `{batch_size, num_tokens, num_heads * d_out}`
  """</span>
  <span class="k">def</span> <span class="n">forward</span><span class="p">(%</span><span class="bp">__MODULE__</span><span class="p">{</span><span class="ss">heads:</span> <span class="n">heads</span><span class="p">},</span> <span class="p">%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{}</span> <span class="o">=</span> <span class="n">x</span><span class="p">,</span> <span class="n">opts</span> <span class="p">\\</span> <span class="p">[])</span> <span class="k">do</span>
    <span class="n">heads</span>
    <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="o">&amp;</span><span class="no">CausalAttention</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="nv">&amp;1</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">opts</span><span class="p">))</span>
    <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="no">nil</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="no">System</span><span class="o">.</span><span class="n">unique_integer</span><span class="p">([</span><span class="ss">:positive</span><span class="p">])</span>
  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="n">seed</span>

  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"seed must be an integer or nil, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_number</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="ow">and</span> <span class="n">dropout</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">dropout</span> <span class="o">&lt;</span> <span class="mi">1</span> <span class="k">do</span>
    <span class="n">dropout</span> <span class="o">*</span> <span class="mf">1.0</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"dropout must be a number in [0, 1), got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_qkv_bias</span><span class="p">(</span><span class="n">qkv_bias</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_boolean</span><span class="p">(</span><span class="n">qkv_bias</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="n">qkv_bias</span>

  <span class="k">defp</span> <span class="n">normalize_qkv_bias</span><span class="p">(</span><span class="n">qkv_bias</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"qkv_bias must be a boolean, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">qkv_bias</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>
<span class="k">end</span>

</code></pre></div></div>

<p>Nothing fancy here. We have a new input attribute, <code class="language-plaintext highlighter-rouge">num_heads</code>. In <code class="language-plaintext highlighter-rouge">forward</code>, we sequentially calculate causal attention outputs and concatenate the results over the last dimension:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">a</span> <span class="o">=</span>
  <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span>
    <span class="p">[[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]],</span>
    <span class="p">[[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">],</span> <span class="p">[</span><span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">]]</span>
  <span class="p">])</span>

<span class="n">b</span> <span class="o">=</span>
  <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span>
    <span class="p">[[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">],</span> <span class="p">[</span><span class="mi">30</span><span class="p">,</span> <span class="mi">40</span><span class="p">]],</span>
    <span class="p">[[</span><span class="mi">50</span><span class="p">,</span> <span class="mi">60</span><span class="p">],</span> <span class="p">[</span><span class="mi">70</span><span class="p">,</span> <span class="mi">80</span><span class="p">]]</span>
  <span class="p">])</span>

<span class="n">result</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">],</span> <span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

<span class="c1"># Shapes:</span>

<span class="c1"># a: {2, 2, 2}</span>
<span class="c1"># b: {2, 2, 2}</span>
<span class="c1"># result: {2, 2, 4}</span>

<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s64</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="mi">2</span><span class="p">][</span><span class="mi">4</span><span class="p">]</span>
  <span class="p">[</span>
    <span class="p">[</span>
      <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">],</span>
      <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">30</span><span class="p">,</span> <span class="mi">40</span><span class="p">]</span>
    <span class="p">],</span>
    <span class="p">[</span>
      <span class="p">[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">60</span><span class="p">],</span>
      <span class="p">[</span><span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">70</span><span class="p">,</span> <span class="mi">80</span><span class="p">]</span>
    <span class="p">]</span>
  <span class="p">]</span>
<span class="o">&gt;</span>

</code></pre></div></div>

<p>There is room for improvement in that sequential part. Here is a test with this wrapper in action:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"multihead attention wrapper concatenates head outputs on the last axis"</span> <span class="k">do</span>
    <span class="c1"># six tokens with dimension 3</span>
    
    <span class="n">input</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="c1"># make a batch of two identical token sequences: {2x6x3}</span>

    <span class="n">inputs</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">input</span><span class="p">,</span> <span class="n">input</span><span class="p">],</span> <span class="ss">axis:</span> <span class="mi">0</span><span class="p">)</span>
    <span class="n">context_length</span> <span class="o">=</span> <span class="n">elem</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
    <span class="c1"># number of heads is 2</span>

    <span class="n">mha</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttentionWrapper</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="c1"># {2, 6, 4}</span>
    <span class="n">result</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttentionWrapper</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">mha</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="ss">mode:</span> <span class="ss">:inference</span><span class="p">)</span>

    <span class="n">expected_result</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.49523380398750305</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.17632800340652466</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.2037302404642105</span><span class="p">,</span> <span class="mf">0.2859067916870117</span><span class="p">],</span>
            <span class="p">[</span>
              <span class="o">-</span><span class="mf">0.07537277787923813</span><span class="p">,</span>
              <span class="o">-</span><span class="mf">0.13790269196033478</span><span class="p">,</span>
              <span class="o">-</span><span class="mf">0.04826965555548668</span><span class="p">,</span>
              <span class="mf">0.23433709144592285</span>
            <span class="p">],</span>
            <span class="p">[</span>
              <span class="mf">0.06633053719997406</span><span class="p">,</span>
              <span class="o">-</span><span class="mf">0.12039512395858765</span><span class="p">,</span>
              <span class="mf">0.029345838353037834</span><span class="p">,</span>
              <span class="mf">0.21299715340137482</span>
            <span class="p">],</span>
            <span class="p">[</span><span class="mf">0.11786159127950668</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.10831516981124878</span><span class="p">,</span> <span class="mf">0.08328337967395782</span><span class="p">,</span> <span class="mf">0.15773256123065948</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.1877504140138626</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.04864511638879776</span><span class="p">,</span> <span class="mf">0.12333470582962036</span><span class="p">,</span> <span class="mf">0.1985599249601364</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.1768769919872284</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.08047633618116379</span><span class="p">,</span> <span class="mf">0.12170650064945221</span><span class="p">,</span> <span class="mf">0.146833136677742</span><span class="p">]</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.49523380398750305</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.17632800340652466</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.2037302404642105</span><span class="p">,</span> <span class="mf">0.2859067916870117</span><span class="p">],</span>
            <span class="p">[</span>
              <span class="o">-</span><span class="mf">0.07537277787923813</span><span class="p">,</span>
              <span class="o">-</span><span class="mf">0.13790269196033478</span><span class="p">,</span>
              <span class="o">-</span><span class="mf">0.04826965555548668</span><span class="p">,</span>
              <span class="mf">0.23433709144592285</span>
            <span class="p">],</span>
            <span class="p">[</span>
              <span class="mf">0.06633053719997406</span><span class="p">,</span>
              <span class="o">-</span><span class="mf">0.12039512395858765</span><span class="p">,</span>
              <span class="mf">0.029345838353037834</span><span class="p">,</span>
              <span class="mf">0.21299715340137482</span>
            <span class="p">],</span>
            <span class="p">[</span><span class="mf">0.11786159127950668</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.10831516981124878</span><span class="p">,</span> <span class="mf">0.08328337967395782</span><span class="p">,</span> <span class="mf">0.15773256123065948</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.1877504140138626</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.04864511638879776</span><span class="p">,</span> <span class="mf">0.12333470582962036</span><span class="p">,</span> <span class="mf">0.1985599249601364</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.1768769919872284</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.08047633618116379</span><span class="p">,</span> <span class="mf">0.12170650064945221</span><span class="p">,</span> <span class="mf">0.146833136677742</span><span class="p">]</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="c1"># We have two inputs, six tokens, and four dimensions for each token.</span>
    <span class="c1"># The input has three dimensions because we project inputs to out_dim = 2,</span>
    <span class="c1"># and we have two heads that we concatenate over the last dimension: 2 + 2 = 4.</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">4</span><span class="p">}</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">expected</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>Then, in the book, we get an exercise:</p>

<p><em>Change the input arguments for the MultiHeadAttentionWrapper(…, num_heads=2) call so that the output context vectors are two-dimensional instead of
four dimensional while keeping the setting num_heads=2. Hint: You don’t have to
modify the class implementation; you just have to change one of the other input
arguments.</em></p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"exercise 3.2 returns two-dimensional embedding vectors with two heads"</span> <span class="k">do</span>
    <span class="n">input</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">inputs</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">input</span><span class="p">,</span> <span class="n">input</span><span class="p">],</span> <span class="ss">axis:</span> <span class="mi">0</span><span class="p">)</span>
    <span class="n">context_length</span> <span class="o">=</span> <span class="n">elem</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>

    <span class="c1"># We set d_out = 1.</span>

    <span class="n">mha</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttentionWrapper</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="n">result</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttentionWrapper</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">mha</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="ss">mode:</span> <span class="ss">:inference</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">}</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>As we mentioned, the improvements will be in how we handle sequential heads. But we will not use Elixir concurrency; there is a smarter way to do it. Let’s dig into the hardest part of the book so far.</p>

<h3 id="implementing-multi-head-attention-with-weight-splits">Implementing multi-head attention with weight splits</h3>

<p>The first step is to combine the two modules we have: <code class="language-plaintext highlighter-rouge">LlmScratch.MultiheadAttentionWrapper</code> and <code class="language-plaintext highlighter-rouge">LlmScratch.CausalAttention</code>:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttention</span> <span class="k">do</span>
  <span class="nv">@moduledoc</span> <span class="sd">"""
  Multi-head causal self-attention with shared Q/K/V projections and an output
  projection.

  Mirrors the PyTorch `MultiHeadAttention` module from chapter 3:

    * one query/key/value projection each with total size `d_out`
    * head splitting into `num_heads` chunks of size `head_dim`
    * causal masking over token positions
    * attention dropout
    * output projection that mixes the concatenated head outputs

  This module is different from
  `LlmScratch.MultiheadAttentionWrapper`:

    * `MultiheadAttentionWrapper` builds `num_heads` separate
      `LlmScratch.CausalAttention` modules and concatenates their outputs
    * `MultiheadAttention` uses one set of Q/K/V projections of size `d_out`,
      splits those projections into heads, runs attention per head, then applies
      one final output projection

  The expected input shape is:

      {batch_size, num_tokens, d_in}

  and the returned output shape is:

      {batch_size, num_tokens, d_out}

  Internally, projected queries, keys, and values move through these shapes:

      {batch_size, num_tokens, d_out}
      -&gt; {batch_size, num_tokens, num_heads, head_dim}
      -&gt; {batch_size, num_heads, num_tokens, head_dim}

  where `head_dim = div(d_out, num_heads)`.
  """</span>

  <span class="n">alias</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SelfAttentionV2</span>

  <span class="k">defstruct</span> <span class="p">[</span>
    <span class="ss">:w_q</span><span class="p">,</span>
    <span class="ss">:w_k</span><span class="p">,</span>
    <span class="ss">:w_v</span><span class="p">,</span>
    <span class="ss">:out_proj</span><span class="p">,</span>
    <span class="ss">:mask</span><span class="p">,</span>
    <span class="ss">:d_in</span><span class="p">,</span>
    <span class="ss">:d_out</span><span class="p">,</span>
    <span class="ss">:context_length</span><span class="p">,</span>
    <span class="ss">:dropout</span><span class="p">,</span>
    <span class="ss">:num_heads</span><span class="p">,</span>
    <span class="ss">:head_dim</span><span class="p">,</span>
    <span class="ss">:qkv_bias</span><span class="p">,</span>
    <span class="ss">:seed</span>
  <span class="p">]</span>

  <span class="nv">@type</span> <span class="n">dense_weights</span> <span class="p">::</span> <span class="p">%{</span><span class="ss">kernel:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="ss">bias:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()}</span>

  <span class="nv">@type</span> <span class="n">t</span> <span class="p">::</span> <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
          <span class="ss">w_q:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">w_k:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">w_v:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">out_proj:</span> <span class="n">dense_weights</span><span class="p">(),</span>
          <span class="ss">mask:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span>
          <span class="ss">d_in:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">d_out:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">context_length:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">dropout:</span> <span class="n">float</span><span class="p">(),</span>
          <span class="ss">num_heads:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">head_dim:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">qkv_bias:</span> <span class="n">boolean</span><span class="p">(),</span>
          <span class="ss">seed:</span> <span class="n">integer</span><span class="p">()</span>
        <span class="p">}</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">number</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">),</span>
    <span class="k">do</span><span class="p">:</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="p">[])</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">number</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">boolean</span><span class="p">())</span> <span class="p">::</span>
          <span class="n">t</span><span class="p">()</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">),</span>
    <span class="k">do</span><span class="p">:</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="p">[])</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span>
          <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="n">number</span><span class="p">(),</span>
          <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="n">boolean</span><span class="p">(),</span>
          <span class="n">keyword</span><span class="p">()</span>
        <span class="p">)</span> <span class="p">::</span>
          <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a multi-head causal attention module.

  `d_out` must be divisible by `num_heads`, because each head receives an equal
  slice of the projected feature dimension.

  The Q/K/V projections are initialized with output size `d_out`, the
  per-head size is stored in `head_dim`, and `out_proj` maps the concatenated
  head outputs back into `d_out`.

  ## Arguments

    * `d_in` - input embedding dimension for each token
    * `d_out` - total projected output dimension across all heads
    * `context_length` - maximum sequence length supported by the causal mask
    * `dropout` - dropout rate applied to attention weights during training
    * `num_heads` - number of attention heads
    * `qkv_bias` - whether the query, key, and value projections use bias
    * `opts` - keyword options for initialization

  ## Options

    * `:seed` - deterministic initialization seed
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">context_length</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="n">opts</span><span class="p">)</span>
      <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_in</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_in</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">d_out</span><span class="p">)</span> <span class="ow">and</span> <span class="n">d_out</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span>
             <span class="n">is_integer</span><span class="p">(</span><span class="n">context_length</span><span class="p">)</span> <span class="ow">and</span> <span class="n">context_length</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">num_heads</span><span class="p">)</span> <span class="ow">and</span>
             <span class="n">num_heads</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">do</span>
    <span class="k">if</span> <span class="n">rem</span><span class="p">(</span><span class="n">d_out</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span> <span class="k">do</span>
      <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"d_out must be divisible by num_heads"</span>
    <span class="k">end</span>

    <span class="n">seed</span> <span class="o">=</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">))</span>
    <span class="n">qkv_bias</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">normalize_qkv_bias</span><span class="p">(</span><span class="n">qkv_bias</span><span class="p">)</span>
    <span class="n">dropout</span> <span class="o">=</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
    <span class="n">head_dim</span> <span class="o">=</span> <span class="n">div</span><span class="p">(</span><span class="n">d_out</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span>

    <span class="n">w_q</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="s2">"q_proj"</span><span class="p">)</span>
    <span class="n">w_k</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="s2">"k_proj"</span><span class="p">)</span>
    <span class="n">w_v</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="p">,</span> <span class="s2">"v_proj"</span><span class="p">)</span>
    <span class="n">out_proj</span> <span class="o">=</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">init_dense_weights</span><span class="p">(</span><span class="n">d_out</span><span class="p">,</span> <span class="n">d_out</span><span class="p">,</span> <span class="n">seed</span><span class="p">,</span> <span class="no">true</span><span class="p">,</span> <span class="s2">"out_proj"</span><span class="p">)</span>

    <span class="n">mask</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="p">{</span><span class="n">context_length</span><span class="p">,</span> <span class="n">context_length</span><span class="p">})</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="ss">k:</span> <span class="mi">1</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">as_type</span><span class="p">({</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
      <span class="ss">w_q:</span> <span class="n">w_q</span><span class="p">,</span>
      <span class="ss">w_k:</span> <span class="n">w_k</span><span class="p">,</span>
      <span class="ss">w_v:</span> <span class="n">w_v</span><span class="p">,</span>
      <span class="ss">out_proj:</span> <span class="n">out_proj</span><span class="p">,</span>
      <span class="ss">mask:</span> <span class="n">mask</span><span class="p">,</span>
      <span class="ss">d_in:</span> <span class="n">d_in</span><span class="p">,</span>
      <span class="ss">d_out:</span> <span class="n">d_out</span><span class="p">,</span>
      <span class="ss">context_length:</span> <span class="n">context_length</span><span class="p">,</span>
      <span class="ss">dropout:</span> <span class="n">dropout</span><span class="p">,</span>
      <span class="ss">num_heads:</span> <span class="n">num_heads</span><span class="p">,</span>
      <span class="ss">head_dim:</span> <span class="n">head_dim</span><span class="p">,</span>
      <span class="ss">qkv_bias:</span> <span class="n">qkv_bias</span><span class="p">,</span>
      <span class="ss">seed:</span> <span class="n">seed</span>
    <span class="p">}</span>
  <span class="k">end</span>

  <span class="nv">@spec</span> <span class="n">forward</span><span class="p">(</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="n">keyword</span><span class="p">())</span> <span class="p">::</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Computes causal multi-head self-attention for batched token sequences.

  ## Arguments

    * `mha` - `%LlmScratch.MultiheadAttention{}`
    * `x` - input tensor of shape `{batch_size, num_tokens, d_in}`
    * `opts` - forward options

  ## Forward Options

    * `:mode` - `:train` or `:inference`; dropout is only applied in train mode
    * `:key` - optional `Nx.Random` key used by dropout in train mode

  ## Steps

    * project `x` into queries, keys, and values of shape
      `{batch_size, num_tokens, d_out}`
    * split the last dimension into `num_heads * head_dim`
    * compute attention scores independently for each head
    * apply the causal mask so tokens cannot attend to future tokens
    * softmax and optionally apply dropout to the attention weights
    * combine attended values across heads and reshape back to
      `{batch_size, num_tokens, d_out}`
    * apply `out_proj`

  ## Returns

    * tensor of shape `{batch_size, num_tokens, d_out}`
  """</span>
  <span class="k">def</span> <span class="n">forward</span><span class="p">(%</span><span class="bp">__MODULE__</span><span class="p">{}</span> <span class="o">=</span> <span class="n">mha</span><span class="p">,</span> <span class="p">%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{}</span> <span class="o">=</span> <span class="n">x</span><span class="p">,</span> <span class="n">opts</span> <span class="p">\\</span> <span class="p">[])</span> <span class="k">do</span>
    <span class="p">{</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">_</span><span class="p">}</span> <span class="o">=</span> <span class="n">validate_input_shape!</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">d_in</span><span class="p">)</span>
    <span class="n">validate_context_length!</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">context_length</span><span class="p">)</span>

    <span class="n">keys</span> <span class="o">=</span>
      <span class="n">x</span>
      <span class="o">|&gt;</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">w_k</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="n">split_heads</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">head_dim</span><span class="p">)</span>

    <span class="n">queries</span> <span class="o">=</span>
      <span class="n">x</span>
      <span class="o">|&gt;</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">w_q</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="n">split_heads</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">head_dim</span><span class="p">)</span>

    <span class="n">values</span> <span class="o">=</span>
      <span class="n">x</span>
      <span class="o">|&gt;</span> <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">w_v</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="n">split_heads</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">head_dim</span><span class="p">)</span>

    <span class="n">attn_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">queries</span><span class="p">,</span> <span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">keys</span><span class="p">,</span> <span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>

    <span class="n">mask</span> <span class="o">=</span>
      <span class="n">mha</span><span class="o">.</span><span class="n">mask</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">slice</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">])</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">greater</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">new_axis</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">new_axis</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">({</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">})</span>

    <span class="n">neg_inf</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="ss">:neg_infinity</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">attn_scores</span><span class="p">))</span>
    <span class="n">masked_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">neg_inf</span><span class="p">,</span> <span class="n">attn_scores</span><span class="p">)</span>

    <span class="n">context</span> <span class="o">=</span>
      <span class="n">masked_scores</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">head_dim</span><span class="p">))</span>
      <span class="o">|&gt;</span> <span class="no">Axon</span><span class="o">.</span><span class="no">Activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="ss">axis:</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="n">maybe_dropout</span><span class="p">(</span><span class="n">mha</span><span class="p">,</span> <span class="n">opts</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">([</span><span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">values</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">reshape</span><span class="p">({</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">d_out</span><span class="p">})</span>

    <span class="no">SelfAttentionV2</span><span class="o">.</span><span class="n">dense_project</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">mha</span><span class="o">.</span><span class="n">out_proj</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="nv">@doc</span> <span class="no">false</span>
  <span class="k">defp</span> <span class="n">split_heads</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">tensor</span>
    <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">reshape</span><span class="p">({</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">})</span>
    <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">maybe_dropout</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="p">%{</span><span class="ss">dropout:</span> <span class="n">dropout</span><span class="p">},</span> <span class="n">_opts</span><span class="p">)</span> <span class="ow">when</span> <span class="n">dropout</span> <span class="o">&lt;=</span> <span class="mf">0.0</span><span class="p">,</span>
    <span class="k">do</span><span class="p">:</span> <span class="n">attn_weights</span>

  <span class="k">defp</span> <span class="n">maybe_dropout</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="p">%{</span><span class="ss">dropout:</span> <span class="n">dropout</span><span class="p">,</span> <span class="ss">seed:</span> <span class="n">seed</span><span class="p">},</span> <span class="n">opts</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">mode</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:mode</span><span class="p">,</span> <span class="ss">:train</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="ss">:train</span> <span class="k">do</span>
      <span class="n">key</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:key</span><span class="p">)</span> <span class="o">||</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>

      <span class="p">%</span><span class="no">Axon</span><span class="o">.</span><span class="no">StatefulOutput</span><span class="p">{</span><span class="ss">output:</span> <span class="n">dropped</span><span class="p">,</span> <span class="ss">state:</span> <span class="p">%{</span><span class="s2">"key"</span> <span class="o">=&gt;</span> <span class="n">_new_key</span><span class="p">}}</span> <span class="o">=</span>
        <span class="no">Axon</span><span class="o">.</span><span class="no">Layers</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="ss">rate:</span> <span class="n">dropout</span><span class="p">,</span> <span class="ss">mode:</span> <span class="ss">:train</span><span class="p">)</span>

      <span class="n">dropped</span>
    <span class="k">else</span>
      <span class="n">attn_weights</span>
    <span class="k">end</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="no">nil</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="no">System</span><span class="o">.</span><span class="n">unique_integer</span><span class="p">([</span><span class="ss">:positive</span><span class="p">])</span>
  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="n">seed</span>

  <span class="k">defp</span> <span class="n">normalize_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"seed must be an integer or nil, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_number</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="ow">and</span> <span class="n">dropout</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">dropout</span> <span class="o">&lt;</span> <span class="mi">1</span> <span class="k">do</span>
    <span class="n">dropout</span> <span class="o">*</span> <span class="mf">1.0</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">normalize_dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"dropout must be a number in [0, 1), got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">validate_input_shape!</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">expected_d_in</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">case</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">do</span>
      <span class="p">{</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="o">^</span><span class="n">expected_d_in</span><span class="p">}</span> <span class="o">-&gt;</span>
        <span class="p">{</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">expected_d_in</span><span class="p">}</span>

      <span class="n">shape</span> <span class="o">-&gt;</span>
        <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
              <span class="s2">"expected inputs shape {batch_size, num_tokens, </span><span class="si">#{</span><span class="n">expected_d_in</span><span class="si">}</span><span class="s2">}, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
    <span class="k">end</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">validate_context_length!</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">context_length</span><span class="p">)</span> <span class="ow">when</span> <span class="n">num_tokens</span> <span class="o">&lt;=</span> <span class="n">context_length</span><span class="p">,</span>
    <span class="k">do</span><span class="p">:</span> <span class="ss">:ok</span>

  <span class="k">defp</span> <span class="n">validate_context_length!</span><span class="p">(</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">context_length</span><span class="p">)</span> <span class="k">do</span>
    <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
          <span class="s2">"num_tokens (</span><span class="si">#{</span><span class="n">num_tokens</span><span class="si">}</span><span class="s2">) exceeds context_length (</span><span class="si">#{</span><span class="n">context_length</span><span class="si">}</span><span class="s2">)"</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<p>In the <code class="language-plaintext highlighter-rouge">forward</code> function, the first step is familiar from <code class="language-plaintext highlighter-rouge">SelfAttentionV2</code>; there is nothing new here. Step 2 is new, so what does it mean to split the last dimension? Here is an explanation with an example.<br />
<code class="language-plaintext highlighter-rouge">split_heads/5</code> does two operations:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tensor</span>
<span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">reshape</span><span class="p">({</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">})</span>
<span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
</code></pre></div></div>

<p>Its job is to take the last dimension, which is <code class="language-plaintext highlighter-rouge">d_out</code>, and split it into:</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>d_out = num_heads * head_dim
</code></pre></div></div>

<h2 id="example">Example</h2>

<p>Assume:</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">batch_size = 1</code></li>
  <li><code class="language-plaintext highlighter-rouge">num_tokens = 2</code></li>
  <li><code class="language-plaintext highlighter-rouge">num_heads = 2</code></li>
  <li><code class="language-plaintext highlighter-rouge">head_dim = 2</code></li>
</ul>

<p>So input tensor shape is:</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>{1, 2, 4}
</code></pre></div></div>

<p>Example tensor:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tensor</span> <span class="o">=</span>
  <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span>
    <span class="p">[</span>
      <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">],</span>
      <span class="p">[</span><span class="mf">5.0</span><span class="p">,</span> <span class="mf">6.0</span><span class="p">,</span> <span class="mf">7.0</span><span class="p">,</span> <span class="mf">8.0</span><span class="p">]</span>
    <span class="p">]</span>
  <span class="p">])</span>
</code></pre></div></div>

<p>This means:</p>
<ul>
  <li>batch 0</li>
  <li>token 0 has features <code class="language-plaintext highlighter-rouge">[1, 2, 3, 4]</code></li>
  <li>token 1 has features <code class="language-plaintext highlighter-rouge">[5, 6, 7, 8]</code></li>
</ul>

<h2 id="step-1-reshape">Step 1: reshape</h2>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="no">Nx</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">{</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">})</span>
</code></pre></div></div>

<p>Result:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">[</span>
  <span class="p">[</span>
    <span class="p">[[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">3.0</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">]],</span>
    <span class="p">[[</span><span class="mf">5.0</span><span class="p">,</span> <span class="mf">6.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">7.0</span><span class="p">,</span> <span class="mf">8.0</span><span class="p">]]</span>
  <span class="p">]</span>
<span class="p">]</span>
</code></pre></div></div>

<p>So we have four nested lists. The first list has one element, the second list has two elements, the third list has two elements, and the final list has two elements. This gives us the shape <code class="language-plaintext highlighter-rouge">{1, 2, 2, 2}</code>.</p>

<p>Shape is now:</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>{batch_size, num_tokens, num_heads, head_dim}
= {1, 2, 2, 2}
</code></pre></div></div>

<p>Interpretation:</p>
<ul>
  <li>token 0:
    <ul>
      <li>head 0 gets <code class="language-plaintext highlighter-rouge">[1, 2]</code></li>
      <li>head 1 gets <code class="language-plaintext highlighter-rouge">[3, 4]</code></li>
    </ul>
  </li>
  <li>token 1:
    <ul>
      <li>head 0 gets <code class="language-plaintext highlighter-rouge">[5, 6]</code></li>
      <li>head 1 gets <code class="language-plaintext highlighter-rouge">[7, 8]</code></li>
    </ul>
  </li>
</ul>

<h2 id="step-2-transpose">Step 2: transpose</h2>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="no">Nx</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="ss">axes:</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
</code></pre></div></div>

<p>This changes dimension order from:</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>{batch, tokens, heads, head_dim}
</code></pre></div></div>

<p>to:</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>{batch, heads, tokens, head_dim}
</code></pre></div></div>

<p>Result:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">[</span>
  <span class="p">[</span>
    <span class="p">[[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">5.0</span><span class="p">,</span> <span class="mf">6.0</span><span class="p">]],</span>
    <span class="p">[[</span><span class="mf">3.0</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">7.0</span><span class="p">,</span> <span class="mf">8.0</span><span class="p">]]</span>
  <span class="p">]</span>
<span class="p">]</span>
</code></pre></div></div>

<p>The shape numbers are still <code class="language-plaintext highlighter-rouge">{1, 2, 2, 2}</code>, but the meaning is different:</p>
<ul>
  <li>first head now contains all tokens:
    <ul>
      <li><code class="language-plaintext highlighter-rouge">[[1, 2], [5, 6]]</code></li>
    </ul>
  </li>
  <li>second head now contains all tokens:
    <ul>
      <li><code class="language-plaintext highlighter-rouge">[[3, 4], [7, 8]]</code></li>
    </ul>
  </li>
</ul>

<h2 id="why-transpose">Why transpose?</h2>

<p>Because attention is computed per head. After transpose, each head has its own sequence of token vectors grouped together, which is the layout needed for the later <code class="language-plaintext highlighter-rouge">Nx.dot</code> call.</p>

<p>So <code class="language-plaintext highlighter-rouge">split_heads</code> does this:</p>

<ol>
  <li>Break each token’s full feature vector into head-sized chunks</li>
  <li>Rearrange so data is grouped by head instead of by token</li>
</ol>

<p>Then we calculate:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">attn_scores</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">queries</span><span class="p">,</span> <span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">keys</span><span class="p">,</span> <span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
</code></pre></div></div>

<p>This is the dot product of <code class="language-plaintext highlighter-rouge">{1, 2, 2, 2}</code> and <code class="language-plaintext highlighter-rouge">{1, 2, 2, 2}</code>. We do it over the last axis, <code class="language-plaintext highlighter-rouge">3</code>, for the tensor that starts at axes <code class="language-plaintext highlighter-rouge">0</code> and <code class="language-plaintext highlighter-rouge">1</code>:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="p">[[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">5.0</span><span class="p">,</span> <span class="mf">6.0</span><span class="p">]],</span>
    <span class="p">[[</span><span class="mf">3.0</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">7.0</span><span class="p">,</span> <span class="mf">8.0</span><span class="p">]]</span>
</code></pre></div></div>

<p>Then we compute <code class="language-plaintext highlighter-rouge">masked_scores</code>, just as we did in <code class="language-plaintext highlighter-rouge">CausalAttention</code>.</p>

<p>Suppose one head’s attn_scores is:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">[</span>
  <span class="p">[</span><span class="mf">1.2</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">],</span>
  <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">],</span>
  <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">],</span>
  <span class="p">[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]</span>
<span class="p">]</span>
</code></pre></div></div>

<p>Then we get:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="p">[</span>
  <span class="p">[</span><span class="mf">1.2</span><span class="p">,</span> <span class="o">-</span><span class="n">inf</span><span class="p">,</span> <span class="o">-</span><span class="n">inf</span><span class="p">,</span> <span class="o">-</span><span class="n">inf</span><span class="p">],</span>
  <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">,</span> <span class="o">-</span><span class="n">inf</span><span class="p">,</span> <span class="o">-</span><span class="n">inf</span><span class="p">],</span>
  <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">,</span> <span class="o">-</span><span class="n">inf</span><span class="p">],</span>
  <span class="p">[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]</span>
<span class="p">]</span>
</code></pre></div></div>

<p>That is the input for the normalization, softmax, and dropout steps from <code class="language-plaintext highlighter-rouge">CausalAttention</code>, which we then apply in a dot product with <code class="language-plaintext highlighter-rouge">values</code> to get <code class="language-plaintext highlighter-rouge">att_weights</code>. The final step is the opposite of what we did in <code class="language-plaintext highlighter-rouge">split_heads</code>: transpose and reshape to combine the heads back into the original three-axis shape. After that, we apply a linear projection.</p>

<p>How did we enhance computational efficiency in <code class="language-plaintext highlighter-rouge">LlmScratch.MultiheadAttention</code> compared with <code class="language-plaintext highlighter-rouge">LlmScratch.MultiheadAttentionWrapper</code>? Because we calculate <code class="language-plaintext highlighter-rouge">keys</code>, <code class="language-plaintext highlighter-rouge">queries</code>, and <code class="language-plaintext highlighter-rouge">values</code> only once. And this is the most computationally intensive part of multi-head attention.</p>

<p>Let’s use <code class="language-plaintext highlighter-rouge">LlmScratch.MultiheadAttention</code> in a test:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"multihead attention combines split heads and output projection in one module"</span> <span class="k">do</span>
    <span class="n">inputs</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.43</span><span class="p">,</span> <span class="mf">0.15</span><span class="p">,</span> <span class="mf">0.89</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.55</span><span class="p">,</span> <span class="mf">0.87</span><span class="p">,</span> <span class="mf">0.66</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.57</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.64</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.22</span><span class="p">,</span> <span class="mf">0.58</span><span class="p">,</span> <span class="mf">0.33</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.77</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.10</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.80</span><span class="p">,</span> <span class="mf">0.55</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">inputs</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">inputs</span><span class="p">,</span> <span class="n">inputs</span><span class="p">],</span> <span class="ss">axis:</span> <span class="mi">0</span><span class="p">)</span>

    <span class="n">mha</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttention</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="n">result</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttention</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">mha</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="ss">mode:</span> <span class="ss">:inference</span><span class="p">)</span>

    <span class="n">expected</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.47695064544677734</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.23235172033309937</span><span class="p">],</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.10124677419662476</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.023370809853076935</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.031040165573358536</span><span class="p">,</span> <span class="mf">0.04926960915327072</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.07956627011299133</span><span class="p">,</span> <span class="mf">0.07443806529045105</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.15488147735595703</span><span class="p">,</span> <span class="mf">0.10296830534934998</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.14108411967754364</span><span class="p">,</span> <span class="mf">0.10342075675725937</span><span class="p">]</span>
          <span class="p">],</span>
          <span class="p">[</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.47695064544677734</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.23235172033309937</span><span class="p">],</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.10124677419662476</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.023370809853076935</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.031040165573358536</span><span class="p">,</span> <span class="mf">0.04926960915327072</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.07956627011299133</span><span class="p">,</span> <span class="mf">0.07443806529045105</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.15488147735595703</span><span class="p">,</span> <span class="mf">0.10296830534934998</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.14108411967754364</span><span class="p">,</span> <span class="mf">0.10342075675725937</span><span class="p">]</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">}</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">expected</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_number</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>And here it is with real-world GPT-2 parameter values, where we call EXLA for help:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
  <span class="n">test</span> <span class="s2">"exercise 3.3 initializes a GPT-2 small attention module"</span> <span class="k">do</span>
    <span class="n">previous_backend</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">default_backend</span><span class="p">()</span>
    <span class="no">Nx</span><span class="o">.</span><span class="n">default_backend</span><span class="p">(</span><span class="no">EXLA</span><span class="o">.</span><span class="no">Backend</span><span class="p">)</span>
    <span class="n">on_exit</span><span class="p">(</span><span class="k">fn</span> <span class="o">-&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">default_backend</span><span class="p">(</span><span class="n">previous_backend</span><span class="p">)</span> <span class="k">end</span><span class="p">)</span>

    <span class="n">mha</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">MultiheadAttention</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">768</span><span class="p">,</span> <span class="mi">768</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">mha</span><span class="o">.</span><span class="n">d_in</span> <span class="o">==</span> <span class="mi">768</span>
    <span class="n">assert</span> <span class="n">mha</span><span class="o">.</span><span class="n">d_out</span> <span class="o">==</span> <span class="mi">768</span>
    <span class="n">assert</span> <span class="n">mha</span><span class="o">.</span><span class="n">context_length</span> <span class="o">==</span> <span class="mi">1024</span>
    <span class="n">assert</span> <span class="n">mha</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">==</span> <span class="mi">12</span>
    <span class="n">assert</span> <span class="n">mha</span><span class="o">.</span><span class="n">head_dim</span> <span class="o">==</span> <span class="mi">64</span>
    <span class="n">assert</span> <span class="n">mha</span><span class="o">.</span><span class="n">qkv_bias</span> <span class="o">==</span> <span class="no">false</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">mask</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">1024</span><span class="p">}</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">w_q</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">768</span><span class="p">,</span> <span class="mi">768</span><span class="p">}</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">w_k</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">768</span><span class="p">,</span> <span class="mi">768</span><span class="p">}</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">w_v</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">768</span><span class="p">,</span> <span class="mi">768</span><span class="p">}</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">mha</span><span class="o">.</span><span class="n">out_proj</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">768</span><span class="p">,</span> <span class="mi">768</span><span class="p">}</span>
  <span class="k">end</span>
</code></pre></div></div>

<h2 id="conclusion">Conclusion</h2>

<p>We started with SimpleAttention and finished with MultiheadAttention, which we will use later in the book. For me, this section was really hard, and I will still need several rereads before I can explain it to someone else.</p>]]></content><author><name></name></author><category term="llm-from-scratch" /><category term="ai" /><category term="tutorials" /><category term="llm" /><category term="attention" /><category term="elixir" /><summary type="html"><![CDATA[A practical walkthrough of the attention mechanism in Elixir, from simple self-attention to causal and multi-head attention, based on Chapter 3 of Build a LLM from Scratch.]]></summary></entry><entry><title type="html">Workbook Answers on Chapter 2, Build a LLM from Scratch, Working with Text Data</title><link href="https://karlosmid.com/2026/03/workbook-answers-chapter-2-build-llm-from-scratch-working-with-text-data/" rel="alternate" type="text/html" title="Workbook Answers on Chapter 2, Build a LLM from Scratch, Working with Text Data" /><published>2026-03-01T23:00:00+00:00</published><updated>2026-03-01T23:00:00+00:00</updated><id>https://karlosmid.com/2026/03/workbook-answers-chapter-2-build-llm-from-scratch-working-with-text-data</id><content type="html" xml:base="https://karlosmid.com/2026/03/workbook-answers-chapter-2-build-llm-from-scratch-working-with-text-data/"><![CDATA[<h2 id="tldr">TL;DR</h2>

<p>Workbook answers and notes for Chapter 2 of <em>Build a LLM from Scratch</em> by Sebastian Raschka, focused on working with text data (tokenization, token IDs, embeddings, and context windows). But first, let’s reread Giles’s blog posts about Chapter 2 (<a href="https://www.gilesthomas.com/2024/12/llm-from-scratch-2">part 2</a>, <a href="https://www.gilesthomas.com/2024/12/llm-from-scratch-3">part 3</a>).</p>

<h2 id="giless-blog-posts">Giles’s blog posts</h2>

<p>In part 2, Giles briefly explains byte-pair encoding with <code class="language-plaintext highlighter-rouge">tiktoken</code>. <code class="language-plaintext highlighter-rouge">tiktoken</code> starts with letters, numbers, and punctuation, then learns word patterns from the input and creates smarter tokens (in my blog post, I mentioned the pattern we can find in words like <code class="language-plaintext highlighter-rouge">define</code> and <code class="language-plaintext highlighter-rouge">delfine</code>, …). It seems that Giles read the byte-pair encoding paper that I skipped.<br />
The second thing is Python <code class="language-plaintext highlighter-rouge">tiktoken</code>’s <code class="language-plaintext highlighter-rouge">allowed_special</code> option, which we use when passing special tokens like <code class="language-plaintext highlighter-rouge">|endoftext|</code>. If it encounters special tokens that are not in that list, it raises an error. Elixir <code class="language-plaintext highlighter-rouge">tiktoken</code> also accepts special tokens as input, but it does not raise an error. Here is a ChatGPT <a href="https://github.com/karlosmid/llm-from-scratch/blob/main/TIKTOKEN_SPECIAL_TOKENS_EXPLANATION.md">explanation</a> of the problem in the Elixir <code class="language-plaintext highlighter-rouge">tiktoken</code> wrapper.<br />
In part 3, Giles nicely explains the concept of batches and why we need them. Batches are just how we stack <code class="language-plaintext highlighter-rouge">{input, target}</code> pairs. One batch is a list of <code class="language-plaintext highlighter-rouge">{input, target}</code> pairs, where the number of elements in that list is <code class="language-plaintext highlighter-rouge">batch_size</code>. The second batch is the next such list. Doing LLM-related calculations on batches is faster than running calculations on each pair individually.<br />
One-hot embeddings are also explained, and this is actually a pre-step to the embedding data explained in the book.<br />
I learned something new from Giles’s understanding of Chapter 2. Let’s do the workbook questions!</p>

<h2 id="workbook-multiple-choice-questions">Workbook multiple choice questions</h2>

<h3 id="what-is-the-primary-purpose-of-tokenization-in-the-context-of-llms">What is the primary purpose of tokenization in the context of LLMs?</h3>

<p>A. Tokenization is used to convert text into lowercase.<br />
B. Tokenization splits text into individual words or special characters<br />
C. Tokenization is used to identify the parts of speech in a sentence.<br />
D. Tokenization is used to remove stop words from text.</p>

<p>My answer: B</p>

<h3 id="what-is-the-purpose-of-the-unk-token-in-a-vocabulary-used-for-llms">What is the purpose of the &lt;|unk|&gt; token in a vocabulary used for LLMs?</h3>

<p>A. The &lt;|unk|&gt; token is used to represent punctuation marks.<br />
B. The &lt;|unk|&gt; token is used to mark the beginning of a sentence.<br />
C. The &lt;|unk|&gt; token represents unknown words that were not present in the training data. <br />
D. The &lt;|unk|&gt; token is used to mark the end of a sentence.</p>

<p>My answer: C</p>

<h3 id="what-is-the-primary-task-of-an-llm-during-training">What is the primary task of an LLM during training?</h3>

<p>A. LLMs are trained to translate text from one language to another.<br />
B. LLMs are trained to summarize text.<br />
C. LLMs are trained to answer questions based on a given text.<br />
D. LLMs are trained to predict the next word in a sequence, given a preceding context.</p>

<p>My answer: D (all are LLM tasks, but D is the primary task)</p>

<h3 id="what-is-the-difference-between-absolute-positional-embeddings-and-relative-positional-embeddings">What is the difference between absolute positional embeddings and relative positional embeddings?</h3>

<p>A. Absolute positional embeddings encode the exact position of a token in a sequence, while relative positional embeddings encode the relative distance between tokens.<br />
B. Absolute positional embeddings are only used for short sequences, while relative positional embeddings are used for longer sequences.<br />
C. Absolute positional embeddings are more efficient than relative positional embeddings.<br />
D. Relative positional embeddings are more accurate than absolute positional embeddings.</p>

<p>My answer is A (or, as it is stated in the book for relative positions, how far away a token is from other tokens).</p>

<h3 id="the-purpose-of-____-in-the-context-of-llms-is-to-provide-information-about-the-order-and-location-of-tokens-within-a-sequence-helping-the-llm-understand-the-relationships-between-words">The purpose of <strong>__</strong><strong>__</strong> in the context of LLMs is to provide information about the order and location of tokens within a sequence, helping the LLM understand the relationships between words.</h3>

<p>A. attention mechanism<br />
B. positional embeddings<br />
C. tokenization</p>

<p>My answer: B</p>

<h3 id="what-is-the-final-output-of-the-input-processing-pipeline-for-an-llm-before-it-is-fed-into-the-main-llm-layers">What is the final output of the input processing pipeline for an LLM, before it is fed into the main LLM layers?</h3>

<p>A. The final output is a tensor of probabilities for each word in the vocabulary.<br />
B. The final output is a tensor of text tokens.<br />
C. The final output is a tensor of input embeddings, created by combining token embeddings and positional embeddings.<br />
D. The final output is a tensor of token IDs.</p>

<p>My answer: C</p>

<p>HoHoHo! My answers are correct!</p>

<h2 id="understanding-word-embeddings">Understanding word embeddings</h2>

<h3 id="why-are-word-embeddings-necessary-for-processing-text-data-in-deep-learning-models">Why are word embeddings necessary for processing text data in deep learning models?</h3>

<p>My answer:</p>

<p>Deep learning models are deep neural networks that work with tensors that have float numbers. We use float numbers because we need precision that we can not achive with integers. These data represent probabilites that particular words are related.</p>

<p>I failed on this one to mention that text is categorical and embeding vectors are compatible with mathematical operations used in deep neural networks.</p>

<p>Workbook answer:
<em>Deep learning models operate on numerical data, while text is
categorical. Word embeddings convert words into continuous-valued
vectors, making them compatible with the mathematical operations used
in neural networks.</em></p>

<h3 id="what-is-the-main-idea-behind-the-word2vec-approach-to-generating-word-embeddings">What is the main idea behind the Word2Vec approach to generating word embeddings?</h3>

<p>My answer:</p>

<p>In Word2Vec approach, we translate words to embeding vectors that tell us how close of far away are words based on their meaning.</p>

<p>This was also incomplete answer.</p>

<p>Workbook answer:</p>

<p><em>Word2Vec trains a neural network to predict the context of a word given
the target word or vice versa. This approach assumes that words
appearing in similar contexts tend to have similar meanings, resulting in
clustered representations of related words in the embedding space.</em></p>

<p>As I do not understand what is context/target relationship, I asked for help ChatGPT 5.3. to explain workbook answer.</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>It trains a small neural network on a simple language task:

- `target -&gt; context` (predict nearby words from a center word) = `Skip-gram`
- `context -&gt; target` (predict the center word from surrounding words) = `CBOW`

Example:
- In “the cat sat on the mat”, if the target word is `cat`, context words might be `the`, `sat`.

During training, the model adjusts each word’s vector so it gets better at these predictions.

Why this creates meaning:
- Words used in similar situations (like `cat` and `dog`) tend to have similar neighboring words (`pet`, `food`, `runs`, `cute`).
- So the model learns similar vectors for them.

Result:
- Related words become clustered in embedding space.
- The vectors capture semantic relationships (similarity/relatedness) even though the model was only trained on prediction, not dictionary definitions.
</code></pre></div></div>

<h3 id="explain-the-trade-off-involved-in-choosing-the-dimensionality-of-word-embeddings">Explain the trade-off involved in choosing the dimensionality of word embeddings.</h3>

<p>My answer:</p>

<p>With bigger dimensionality we will have better word prediction, but we will need more process power to train such model, and more memory to use that LLM.</p>

<p>Workbook answer:</p>

<p><em>Higher dimensionality in word embeddings can capture more nuanced
relationships between words but comes at the cost of computational
efficiency. Lower dimensionality offers faster processing but may sacrifice
some semantic detail.</em></p>

<p>I was close enough with my answer.</p>

<h3 id="how-do-llms-typically-handle-word-embeddings-compared-to-using-pretrained-models-like-word2vec">How do LLMs typically handle word embeddings compared to using pretrained models like Word2Vec?</h3>

<p>My answer:</p>

<p>I do not know how to compare word embeddings and Word2Vec.</p>

<p>Workbook answer:</p>

<p><em>LLMs often generate their own embeddings as part of the input layer and
optimize them during training. This allows for embeddings tailored to the
specific task and data, potentially leading to better performance than
using pre-trained embeddings.</em></p>

<p>Aha, so Word2Vec is pretraining phase that is not needed with embedding vectors as those are created during training phase.</p>

<h3 id="what-is-the-primary-challenge-associated-with-visualizing-high-dimensional-word-embeddings">What is the primary challenge associated with visualizing high-dimensional word embeddings?</h3>

<p>My answer:</p>

<p>Primary challange is human imagination. Humans can visualize up to four dimensions, three space dimensions and time.</p>

<p>Close enough with workbook answer:</p>

<p><em>Our visual perception and common graphical representations are limited
to three dimensions or fewer. Visualizing high-dimensional embeddings
requires specialized techniques or dimensionality reduction methods.</em></p>

<h2 id="tokenizing-text">Tokenizing text</h2>

<h3 id="what-is-the-purpose-of-tokenizing-text-in-the-context-of-building-a-large-language-model">What is the purpose of tokenizing text in the context of building a large language model?</h3>

<p>My answer:
As deep neural networks does not understand human text, we must represent it with tensors with float numbers. To do that, we first need to break text into words (or subwords) and special characters. Then we create vocabulary where each token is represented with ID (integer). And in the end we map those IDs to embeding vectors that are ready for mathematical operations in deep neural networks.</p>

<p>workbook answer:
<em>Tokenization is a crucial preprocessing step for creating embeddings for
an LLM. It involves splitting input text into individual tokens, which are
either words or special characters, to prepare the text for further
processing and embedding creation.</em></p>

<p>Partialy correct because I put under tokenization embeding step.</p>

<h3 id="describe-the-process-of-tokenizing-text-using-elixir-regular-expression-module-regex">Describe the process of tokenizing text using Elixir regular expression module Regex</h3>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="no">Regex</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sr">~r/([,.:;?_!"()\']|--|\s)/</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="ss">include_captures:</span> <span class="no">true</span><span class="p">,</span> <span class="ss">trim:</span> <span class="no">true</span><span class="p">)</span>
</code></pre></div></div>

<p>My answer:</p>

<p>Ok, when I see Regex, I always get chills at my neck. Here we go (with LLM help):</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">~r/regex_expression_goes_here/</code></li>
  <li><code class="language-plaintext highlighter-rouge">(capture_goes_here)</code> we tell Regex to remember what was captured in regex expression</li>
  <li><code class="language-plaintext highlighter-rouge">|</code> this separates there groups of regex expressions</li>
  <li><code class="language-plaintext highlighter-rouge">[list_of_equal_characters]</code> so dear Regex, first try to split on characters between <code class="language-plaintext highlighter-rouge">[</code> and <code class="language-plaintext highlighter-rouge">]</code>, and those are special characters from The Verdict novel.</li>
  <li><code class="language-plaintext highlighter-rouge">--</code> yes, we have in Verdict two hyphens in a row</li>
  <li><code class="language-plaintext highlighter-rouge">\s</code> this means any space, as we know that space could be encoded in various ways.</li>
  <li><code class="language-plaintext highlighter-rouge">include_captures</code> keep all those capture strings, which means everything that goes between <code class="language-plaintext highlighter-rouge">(</code> and <code class="language-plaintext highlighter-rouge">)</code></li>
  <li><code class="language-plaintext highlighter-rouge">trim</code> do not return empty string, aka <code class="language-plaintext highlighter-rouge">""</code></li>
</ul>

<p>Allegedly, somebody once solved one advent of code only using regex.</p>

<p>Workbook answer:
<em>The ‘re.split’ function can be used to split text based on specific patterns.
By defining a regular expression that matches whitespace characters,
punctuation marks, and other special characters, we can separate the
text into individual tokens. The resulting list can then be further
processed to remove redundant whitespace characters.</em></p>

<p>I was correct, but I cheated, I asked llm about options <code class="language-plaintext highlighter-rouge">:include_captures</code> and <code class="language-plaintext highlighter-rouge">trim</code></p>

<h3 id="why-is-it-important-to-consider-capitalization-when-tokenizing-text-for-llm-training">Why is it important to consider capitalization when tokenizing text for LLM training?</h3>

<p>Didn’t know what to answer.</p>

<p>Workbook answer:<br />
<em>Capitalization helps LLMs distinguish between proper nouns and common
nouns, understand sentence structure, and learn to generate text with
proper capitalization. Therefore, preserving capitalization during
tokenization is beneficial for training effective language models.</em></p>

<p>Yes, now I remember, to distinguish noun types.</p>

<h3 id="explain-the-trade-off-between-removing-whitespaces-during-tokenization-and-keeping-them">Explain the trade-off between removing whitespaces during tokenization and keeping them.</h3>

<p>Didn’t know what to answer.</p>

<p>Workbook answer:<br />
<em>Removing whitespaces reduces memory and computing requirements.
However, keeping whitespaces can be useful for training models that are
sensitive to the exact structure of the text, such as Python code, which
relies on indentation and spacing.</em></p>

<p>Ah, famous Python idents! We are spoiled with Elixir syntax and <code class="language-plaintext highlighter-rouge">mix format</code>.</p>

<h3 id="match-the-term-on-the-left-to-its-description-on-the-right">Match the term on the left to its description on the right:</h3>

<table>
  <thead>
    <tr>
      <th>Term</th>
      <th> </th>
      <th>Definition</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Word Embeddings</td>
      <td> </td>
      <td>The process of converting various data types, such as text, audio, or video, into a dense vector representation that deep learning models can understand.</td>
    </tr>
    <tr>
      <td>Embedding</td>
      <td> </td>
      <td>The dimensionality of a word embedding, which determines the number of dimensions used to represent each word, influencing the complexity and computational efficiency of the model.</td>
    </tr>
    <tr>
      <td>Word2Vec</td>
      <td> </td>
      <td>An algorithm that generates word embeddings by predicting the context of a word given the target word or vice versa, based on the idea that words appearing in similar contexts tend to have similar meanings.</td>
    </tr>
    <tr>
      <td>Embedding Size</td>
      <td> </td>
      <td>A method of representing words as continuous-valued vectors, allowing deep learning models to process text data.</td>
    </tr>
  </tbody>
</table>

<p>My answer:</p>

<p>Word embeddings =&gt; row 4.<br />
Embedding =&gt; row 1.<br />
Word2Vec =&gt; row 3.<br />
Embedding size =&gt; row 2.</p>

<p>Correct!</p>

<h2 id="converting-tokens-into-token-ids">Converting tokens into token IDs</h2>

<h3 id="what-is-the-purpose-of-converting-tokens-into-token-ids">What is the purpose of converting tokens into token IDs?</h3>

<p>My answer:</p>

<p>The purpose of converting tokens to IDs (integers) is to prepare them in a format that is suitable for mathematical operations in a deep neural network.</p>

<p>Workbook answer:</p>

<p><em>Converting tokens into token IDs is an intermediate step before
converting them into embedding vectors. This process allows for efficient
representation and processing of text data within a language model.</em></p>

<p>I would say my answer is correct and maybe even more concrete because it explains what this processing is in an LLM.</p>

<h3 id="how-is-a-vocabulary-created-for-tokenization">How is a vocabulary created for tokenization?</h3>

<p>My answer:</p>

<p>We sort tokens, then remove duplicates, and then create a dictionary where each token gets a unique ID.</p>

<p>Workbook answer:</p>

<p><em>A vocabulary is created by tokenizing the entire training dataset, sorting
the unique tokens alphabetically, and assigning a unique integer to each
token. This mapping allows for efficient conversion between tokens and
their corresponding integer representations.</em></p>

<p>I did not mention that tokens are sorted alphabetically, but the workbook also misses that tokens are unique.</p>

<h3 id="what-is-the-purpose-of-the-encode-function-in-the-simpletokenizerv1-module">What is the purpose of the encode function in the SimpleTokenizerV1 module?</h3>

<p>My answer:</p>

<p>The <code class="language-plaintext highlighter-rouge">encode</code> function converts tokens to vocabulary ID values.</p>

<p>Workbook answer:</p>

<p><em>The encode method takes text as input, splits it into tokens, and uses the
vocabulary to convert these tokens into their corresponding integer IDs.
This process allows for representing text data as a sequence of integers,
which can be processed by the language model.</em></p>

<p>I missed the tokenization part. Note that the workbook uses <code class="language-plaintext highlighter-rouge">method</code>, and I use <code class="language-plaintext highlighter-rouge">function</code>. If you use <code class="language-plaintext highlighter-rouge">method</code> in a conversation with Sasa Juric, he would definitely correct you!</p>

<h3 id="what-is-the-purpose-of-the-decode-function-in-the-simpletokenizerv1-module">What is the purpose of the decode function in the SimpleTokenizerV1 module?</h3>

<p>My answer:</p>

<p>The <code class="language-plaintext highlighter-rouge">decode</code> function converts vocabulary IDs to token values.</p>

<p>Workbook answer:</p>

<p><em>The decode method takes a sequence of token IDs as input and uses the
inverse vocabulary to convert these IDs back into their corresponding
text tokens. This process allows for converting the output of the language
model, which is a sequence of integers, back into human-readable text.</em></p>

<p>Workbook answer is more extensive.</p>

<h3 id="what-is-the-limitation-of-using-a-vocabulary-built-from-a-small-training-set">What is the limitation of using a vocabulary built from a small training set?</h3>

<p>My answer:</p>

<p>The limitation is that we would have a lot of <code class="language-plaintext highlighter-rouge">\|unk\|</code> tokens, and our LLM would not be able to predict the next token reliably.</p>

<p>Workbook answer:</p>

<p><em>Using a vocabulary built from a small training set can lead to issues when
encountering new words or phrases not present in the training data. This
can result in errors during tokenization and decoding, highlighting the
importance of using large and diverse training sets for building robust
language models.</em></p>

<p>I did not mention that we hit <code class="language-plaintext highlighter-rouge">\|unk\|</code> because the input word is not part of the training dataset.</p>

<h3 id="match-the-term-on-the-left-to-its-description-on-the-right-1">Match the term on the left to its description on the right:</h3>

<table>
  <thead>
    <tr>
      <th>Term</th>
      <th> </th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Tokenization</td>
      <td> </td>
      <td>individual units of text that result from tokenization, representing words, punctuation, or other special characters.</td>
    </tr>
    <tr>
      <td>Tokens</td>
      <td> </td>
      <td>used to define patterns in text, allowing for flexible and precise text manipulation, including tokenization.</td>
    </tr>
    <tr>
      <td>Regular Expressions</td>
      <td> </td>
      <td>initial steps taken to prepare text data for further processing, such as tokenization, which makes the text suitable for use in language models.</td>
    </tr>
    <tr>
      <td>Preprocessing</td>
      <td> </td>
      <td>splitting text into individual units, called tokens, which can be words, punctuation marks, or other special characters.</td>
    </tr>
  </tbody>
</table>

<p>My answer:</p>

<table>
  <thead>
    <tr>
      <th>Left Hand Column</th>
      <th>1</th>
      <th>2</th>
      <th>3</th>
      <th>4</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Right Hand Column</td>
      <td>4</td>
      <td>1</td>
      <td>2</td>
      <td>3</td>
    </tr>
  </tbody>
</table>

<p>Correct answer.</p>

<h2 id="adding-special-context-tokens">Adding special context tokens</h2>

<h3 id="what-are-the-two-special-tokens-added-to-the-vocabulary-and-what-are-their-purposes">What are the two special tokens added to the vocabulary and what are their purposes?</h3>

<p>My answer:</p>

<p>The purpose of <code class="language-plaintext highlighter-rouge">\|unk\|</code> is to mark a token that is not in the vocabulary.<br />
The purpose of <code class="language-plaintext highlighter-rouge">\|could_not_remember\|</code> is to mark different input sources.</p>

<p>Workbook answer:</p>

<p><em>The two special tokens added are &lt;|unk|&gt; and &lt;|endoftext|&gt;. &lt;|unk|&gt;
represents unknown words not in the training data, while &lt;|endoftext|&gt;
separates unrelated text sources, helping the LLM understand their
distinct nature.</em></p>

<p>My answer is correct, aside from the fact that I could not remember <code class="language-plaintext highlighter-rouge">endoftext</code>.</p>

<h3 id="how-does-the-modified-simpletokenizerv2-handle-unknown-words">How does the modified SimpleTokenizerV2 handle unknown words?</h3>

<p>My answer:</p>

<p>It adds <code class="language-plaintext highlighter-rouge">\|unk\|</code> as an optional <code class="language-plaintext highlighter-rouge">special_tokens</code> attribute, so this token is used in place of all unknown tokens.</p>

<p>Workbook answer:</p>

<p><em>When encountering a word not in the vocabulary, SimpleTokenizerV2
replaces it with the &lt;|unk|&gt; token, ensuring that all words are
represented in the encoded text.</em></p>

<p>Correct answer.</p>

<h3 id="explain-the-purpose-of-the-endoftext-token-when-training-on-multiple-independent-documents">Explain the purpose of the &lt;|endoftext|&gt; token when training on multiple independent documents.</h3>

<p>My answer:</p>

<p>This special token separates different text inputs so we know when the input context changes.</p>

<p>Workbook answer:</p>

<p><em>The &lt;|endoftext|&gt; token acts as a marker between unrelated text
sources, signaling the start or end of a particular segment. This helps the
LLM understand that these texts, though concatenated for training, are
distinct entities.</em></p>

<p><code class="language-plaintext highlighter-rouge">Different segments</code> is better phrasing than <code class="language-plaintext highlighter-rouge">different text inputs</code>.</p>

<h3 id="a-piece-of-the-code-has-been-removed-from-this-listing-which-of-these-terms-has-been-removed-from-position_1">A piece of the code has been removed from this listing. Which of these terms has been removed from position_1?</h3>

<p>A unk<br />
B \n<br />
C &lt;|unk|&gt;<br />
D |unk|</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
 <span class="k">def</span> <span class="n">encode</span><span class="p">(</span><span class="n">text</span><span class="p">,</span> <span class="n">vocab</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">tokenize</span><span class="p">(</span><span class="n">text</span><span class="p">)</span>
    <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">fn</span> <span class="n">token</span> <span class="o">-&gt;</span>
      <span class="no">Enum</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="k">fn</span> <span class="p">{</span><span class="n">vocab_token</span><span class="p">,</span> <span class="n">_</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">token</span> <span class="o">==</span> <span class="n">vocab_token</span> <span class="k">end</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="k">case</span> <span class="k">do</span>
        <span class="p">{</span><span class="n">_</span><span class="p">,</span> <span class="n">id</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">id</span>
        <span class="no">nil</span> <span class="o">-&gt;</span> <span class="n">find_unknown_token</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="n">token</span><span class="p">)</span>
      <span class="k">end</span>
    <span class="k">end</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">find_unknown_token</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="n">token</span><span class="p">)</span> <span class="k">do</span>
    <span class="no">Enum</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="k">fn</span> <span class="p">{</span><span class="n">vocab_token</span><span class="p">,</span> <span class="n">_</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">vocab_token</span> <span class="o">==</span> <span class="s2">" position_1 "</span> <span class="k">end</span><span class="p">)</span>
    <span class="o">|&gt;</span> <span class="k">case</span> <span class="k">do</span>
      <span class="p">{</span><span class="n">_</span><span class="p">,</span> <span class="n">id</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">id</span>
      <span class="no">nil</span> <span class="o">-&gt;</span> <span class="k">raise</span> <span class="s2">"Token not found in vocab: </span><span class="si">#{</span><span class="n">token</span><span class="si">}</span><span class="s2">"</span>
    <span class="k">end</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>My answer: C</p>

<p>Note: I used Elixir instead of the original Python.</p>

<h3 id="what-are-the-additional-special-tokens-commonly-used-in-llms-and-what-are-their-functions">What are the additional special tokens commonly used in LLMs, and what are their functions?</h3>

<p>My answer:</p>

<p>We could use a token that aligns all input texts, separated with <code class="language-plaintext highlighter-rouge">\|endoftext\|</code>, to the same length.</p>

<p><em>Other common special tokens include [BOS] (beginning of sequence),
[EOS] (end of sequence), and [PAD] (padding). [BOS] marks the start of
a text, [EOS] indicates the end, and [PAD] is used to extend shorter texts
to match the length of the longest text in a batch for training.</em></p>

<p><code class="language-plaintext highlighter-rouge">Padding</code> is what I could not remember.</p>

<h3 id="match-the-term-on-the-left-to-its-description-on-the-right-2">Match the term on the left to its description on the right:</h3>

<table>
  <thead>
    <tr>
      <th>Term</th>
      <th> </th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Vocabulary</td>
      <td> </td>
      <td>integer representations of tokens, used as an intermediate step before converting tokens into embedding vectors.</td>
    </tr>
    <tr>
      <td>Token IDs</td>
      <td> </td>
      <td>the dataset used to build the vocabulary and train the language model.</td>
    </tr>
    <tr>
      <td>Tokenizer</td>
      <td> </td>
      <td>a mapping from unique tokens to unique integer values, created by tokenizing the entire training dataset and sorting the tokens alphabetically.</td>
    </tr>
    <tr>
      <td>Training Set</td>
      <td> </td>
      <td>a class that implements methods for encoding text into token IDs and decoding token IDs back into text.</td>
    </tr>
  </tbody>
</table>

<p>My answer:</p>

<table>
  <thead>
    <tr>
      <th>Left Hand Column</th>
      <th>1</th>
      <th>2</th>
      <th>3</th>
      <th>4</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Right Hand Column</td>
      <td>3</td>
      <td>1</td>
      <td>4</td>
      <td>2</td>
    </tr>
  </tbody>
</table>

<p>Correct.</p>

<h2 id="byte-pair-encoding">Byte pair encoding</h2>

<h3 id="what-are-the-two-stages-in-this-diagram">What are the two stages in this diagram?</h3>

<p><img src="/assets/byte_pair_encoding_stages.png" alt="Byte pair encoding stages" /></p>

<p>My answer:</p>

<p>Stage 1 is where the byte pair encoder splits input text into tokens.<br />
Stage 2 is mapping tokens to their vocabulary ID values.</p>

<p>Correct answer.</p>

<h3 id="what-is-the-primary-advantage-of-using-byte-pair-encoding-bpe-for-tokenization-especially-when-dealing-with-unknown-words">What is the primary advantage of using Byte Pair Encoding (BPE) for tokenization, especially when dealing with unknown words?</h3>

<p>My answer:</p>

<p>The byte pair encoder splits tokens into subwords, so it can map any word into tokens.</p>

<p>Workbook answer:</p>

<p><em>BPE tokenizers break down unknown words into smaller subword units or
even individual characters. This allows them to handle any word without
needing a special &lt;|unk|&gt; token, ensuring that the tokenizer and the LLM
can process any text, even if it contains words not present in the training
data.</em></p>

<p>I did not mention that a token could be a single character.</p>

<h3 id="what-is-the-total-vocabulary-size-of-the-bpe-tokenizer-used-in-models-like-gpt-2-gpt-3-and-the-original-chatgpt">What is the total vocabulary size of the BPE tokenizer used in models like GPT-2, GPT-3, and the original ChatGPT?</h3>

<p>My answer:</p>

<p>52345</p>

<p>Workbook answer:</p>

<p><em>The BPE tokenizer used in these models has a vocabulary size of 50,257,
with the &lt;|endoftext|&gt; token assigned the largest token ID.</em></p>

<p>Not correct.</p>

<h3 id="how-does-the-bpe-tokenizer-handle-unknown-words-such-as-someunknownplace-without-using-unk-tokens">How does the BPE tokenizer handle unknown words, such as someunknownPlace, without using &lt;|unk|&gt; tokens?</h3>

<p>My answer:</p>

<p>Since the byte pair encoder uses subwords as tokens, such as <code class="language-plaintext highlighter-rouge">so</code> and <code class="language-plaintext highlighter-rouge">me</code>, it can split any word into byte tokens.</p>

<p>Workbook answer:</p>

<p><em>The BPE tokenizer breaks down unknown words into smaller subword
units or individual characters. This allows it to represent any word as a
sequence of known subword tokens or characters, enabling it to process
any text without needing a special token for unknown words.</em></p>

<p>I forgot to mention single characters.</p>

<h3 id="what-python-library-is-used-to-implement-the-bpe-tokenizer-in-the-provided-code-example">What Python library is used to implement the BPE tokenizer in the provided code example?</h3>

<p>My answer: tiktoken</p>

<p>Correct.</p>

<h2 id="data-sampling-with-a-sliding-window">Data sampling with a sliding window</h2>

<h3 id="explain-the-purpose-of-creating-input-target-pairs-in-the-context-of-training-a-large-language-model-llm">Explain the purpose of creating input-target pairs in the context of training a large language model (LLM).</h3>

<p>My answer:</p>

<p>With input-target pairs, we encode in the training dataset what the next token is for a particular input.</p>

<p>Workbook answer:</p>

<p><em>the model with examples of text sequences and their corresponding next
words. This allows the LLM to learn the relationships between words and
predict the most likely next word in a given context.</em></p>

<p>Now I know the purpose of the sliding window.</p>

<h3 id="describe-the-sliding-window-approach-used-for-generating-input-target-pairs-and-how-it-works">Describe the sliding window approach used for generating input-target pairs and how it works.</h3>

<p>My answer:</p>

<p>In the sliding window approach, we generate input-target pairs with a stride parameter that defines how many tokens we skip in the input to define the next input-target pair.<br />
For example, with <code class="language-plaintext highlighter-rouge">My name is Karlo Smid</code>, stride <code class="language-plaintext highlighter-rouge">2</code>, and input size <code class="language-plaintext highlighter-rouge">3</code>, we have:<br />
My name is =&gt; Karlo<br />
is Karlo Smid =&gt; “”</p>

<p>Workbook answer:</p>

<p><em>The sliding window approach involves iterating through a text sequence
and extracting overlapping chunks of text as inputs. Each input chunk is
paired with the corresponding next word as the target. The window slides
across the text, creating multiple input-target pairs for training.</em></p>

<p>I correctly described the sliding window process.</p>

<h3 id="pieces-of-code-have-been-removed-from-three-places-in-this-listing-which-of-these-terms-have-been-removed-and-where-should-they-go">Pieces of code have been removed from three places in this listing. Which of these terms have been removed, and where should they go?</h3>

<p>A Nx.vector<br />
B Tiktoken<br />
C Tokenizer<br />
D Nx.tensor</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">def</span> <span class="n">chunk_dataset</span><span class="p">(</span><span class="n">txt</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">max_length</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span> <span class="k">do</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">token_ids</span><span class="p">}</span> <span class="o">=</span> <span class="n">position_1</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">txt</span><span class="p">,</span> <span class="p">[</span><span class="s2">"&lt;|endoftext|&gt;"</span><span class="p">])</span>

    <span class="k">if</span> <span class="n">length</span><span class="p">(</span><span class="n">token_ids</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">max_length</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
      <span class="k">do</span><span class="p">:</span> <span class="k">raise</span><span class="p">(</span><span class="s2">"Number of tokenized inputs must be at least max_length + 1"</span><span class="p">)</span>

    <span class="n">num_chunks</span> <span class="o">=</span> <span class="n">length</span><span class="p">(</span><span class="n">token_ids</span><span class="p">)</span> <span class="o">-</span> <span class="n">max_length</span>

    <span class="n">chunks</span> <span class="o">=</span>
      <span class="mi">0</span><span class="o">..</span><span class="n">num_chunks</span><span class="o">//</span><span class="n">stride</span>
      <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reduce</span><span class="p">([</span><span class="ss">input_chunks:</span> <span class="p">[],</span> <span class="ss">target_chunks:</span> <span class="p">[]],</span> <span class="k">fn</span> <span class="n">i</span><span class="p">,</span> <span class="n">acc</span> <span class="o">-&gt;</span>
        <span class="n">input_chunk</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">token_ids</span><span class="p">,</span> <span class="n">i</span><span class="o">..</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="n">max_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
        <span class="n">target_chunk</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">token_ids</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">..</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="n">max_length</span><span class="p">))</span>

        <span class="p">[</span>
          <span class="ss">input_chunks:</span> <span class="p">[</span><span class="n">position_2</span><span class="p">(</span><span class="n">input_chunk</span><span class="p">)</span> <span class="o">|</span> <span class="n">acc</span><span class="p">[</span><span class="ss">:input_chunks</span><span class="p">]],</span>
          <span class="ss">target_chunks:</span> <span class="p">[</span><span class="n">position_2</span><span class="p">(</span><span class="n">target_chunk</span><span class="p">)</span> <span class="o">|</span> <span class="n">acc</span><span class="p">[</span><span class="ss">:target_chunks</span><span class="p">]]</span>
        <span class="p">]</span>
      <span class="k">end</span><span class="p">)</span>

    <span class="p">[</span>
      <span class="ss">input_chunks:</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reverse</span><span class="p">(</span><span class="n">chunks</span><span class="p">[</span><span class="ss">:input_chunks</span><span class="p">]),</span>
      <span class="ss">target_chunks:</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reverse</span><span class="p">(</span><span class="n">chunks</span><span class="p">[</span><span class="ss">:target_chunks</span><span class="p">])</span>
    <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>My answer:</p>

<p>position 1 =&gt; B<br />
  position 2 =&gt; D<br />
  position 3 =&gt; D</p>

<p>Correct as I created Elixir example.</p>

<h3 id="what-is-the-role-of-the-stride-parameter-in-the-gptdatasetv1-class-and-how-does-it-affect-the-generation-of-input-target-pairs">What is the role of the stride parameter in the GPTDatasetV1 class, and how does it affect the generation of input-target pairs?</h3>

<p>My answer:</p>

<p>The <code class="language-plaintext highlighter-rouge">stride</code> parameter defines the size of the sliding window. This is the number of input tokens we skip when defining the next input-target pair.</p>

<p>Workbook answer:</p>

<p><em>The stride parameter determines the step size of the sliding window. A
smaller stride results in more overlapping input chunks, while a larger
stride creates less overlap. The choice of stride influences the amount of
data generated and the potential for capturing long-range dependencies
in the text.</em></p>

<p>I did not describe how it affects input-target pairs.</p>

<h3 id="explain-the-purpose-of-the-max_length-parameter-in-the-gptdatasetv1-class-and-its-impact-on-the-input-target-pairs">Explain the purpose of the max_length parameter in the GPTDatasetV1 class and its impact on the input-target pairs.</h3>

<p>My answer:</p>

<p>The <code class="language-plaintext highlighter-rouge">max_length</code> parameter defines the number of tokens per input-target pair. Its purpose is to define how many chunks we want to split the list of input tokens into.</p>

<p>Workbook answer:</p>

<p><em>The max_length parameter defines the size of the input chunks extracted
from the text. It determines the number of tokens included in each input
sequence. A larger max_length allows the LLM to process longer
contexts, but it also increases the computational cost of training.</em></p>

<p>Again, I did not explain the effect. The effect is the context size, a famous LLM parameter!</p>

<h3 id="pieces-of-code-have-been-removed-from-two-places-in-this-listing-which-of-these-terms-have-been-removed-and-where-should-they-go">Pieces of code have been removed from two places in this listing. Which of these terms have been removed, and where should they go?</h3>

<p>A input_chunks<br />
B Tiktoken<br />
C dataset<br />
D target_chunks</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">def</span> <span class="n">create_dataloader_v1</span><span class="p">(</span><span class="n">opts</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">raw_text</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">fetch!</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:raw_text</span><span class="p">)</span>
    <span class="n">batch_size</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:batch_size</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
    <span class="n">max_length</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:max_length</span><span class="p">,</span> <span class="mi">256</span><span class="p">)</span>
    <span class="n">stride</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:stride</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span>
    <span class="n">shuffle</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:shuffle</span><span class="p">,</span> <span class="no">true</span><span class="p">)</span>
    <span class="n">drop_last</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:drop_last</span><span class="p">,</span> <span class="no">true</span><span class="p">)</span>
    <span class="n">num_workers</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:num_workers</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

    <span class="c1"># Create dataset</span>
    <span class="p">[</span><span class="ss">input_chunks:</span> <span class="n">input_chunks</span><span class="p">,</span> <span class="ss">target_chunks:</span> <span class="n">target_chunks</span><span class="p">]</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">GptDatasetV1</span><span class="o">.</span><span class="n">chunk_dataset</span><span class="p">(</span><span class="n">raw_text</span><span class="p">,</span> <span class="s2">"code-davinci-002"</span><span class="p">,</span> <span class="n">max_length</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>

    <span class="c1"># Zip input and target chunks together so each item is {input, target}</span>
    <span class="n">dataset</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">zip</span><span class="p">(</span><span class="n">input_chunks</span><span class="p">,</span> <span class="n">position_1</span><span class="p">)</span>

    <span class="no">LlmScratch</span><span class="o">.</span><span class="no">DataLoader</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">position_2</span><span class="p">,</span>
      <span class="ss">batch_size:</span> <span class="n">batch_size</span><span class="p">,</span>
      <span class="ss">shuffle:</span> <span class="n">shuffle</span><span class="p">,</span>
      <span class="ss">drop_last:</span> <span class="n">drop_last</span><span class="p">,</span>
      <span class="ss">num_workers:</span> <span class="n">num_workers</span>
    <span class="p">)</span>
  <span class="k">end</span>
</code></pre></div></div>
<p>My answer:</p>

<p>position_1 =&gt; target_chunks<br />
position_2 =&gt; dataset</p>

<p>Correct because I created Elixir example.</p>

<h3 id="what-is-the-significance-of-using-pytorchs-dataset-and-dataloader-classes-for-creating-a-data-loader-for-llm-training">What is the significance of using PyTorch’s Dataset and DataLoader classes for creating a data loader for LLM training?</h3>

<p>My answer:</p>

<p>Dataset and DataLoader help us organize input training data into sliding windows and prepare it for the embedding stage.</p>

<p>Workbook answer:</p>

<p><em>PyTorch’s Dataset and DataLoader classes provide a convenient and
efficient way to manage and iterate over large datasets. They allow for
batching, shuffling, and parallel data loading, which are crucial for
optimizing the training process of LLMs.</em></p>

<p>Did not mention shuffle, concurrency and batching feature.</p>

<h3 id="match-the-term-on-the-left-to-its-description-on-the-right-3">Match the term on the left to its description on the right:</h3>

<table>
  <thead>
    <tr>
      <th>Term</th>
      <th> </th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Byte Pair Encoding (BPE)</td>
      <td> </td>
      <td>words that are not present in the tokenizer’s predefined vocabulary.</td>
    </tr>
    <tr>
      <td>Subword Units</td>
      <td> </td>
      <td>total number of unique tokens that a tokenizer can recognize and represent.</td>
    </tr>
    <tr>
      <td>Out-of-Vocabulary (OOV) Words</td>
      <td> </td>
      <td>smaller units of text that a BPE tokenizer breaks down words into, which can be individual characters or combinations of characters.</td>
    </tr>
    <tr>
      <td>Vocabulary Size</td>
      <td> </td>
      <td>a tokenization scheme that breaks down words into smaller subword units or individual characters, allowing it to handle unknown words by representing them as sequences of subword tokens or characters.</td>
    </tr>
  </tbody>
</table>

<p>My answer:</p>

<table>
  <thead>
    <tr>
      <th>Left Hand Column</th>
      <th>1</th>
      <th>2</th>
      <th>3</th>
      <th>4</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Right Hand Column</td>
      <td>4</td>
      <td>3</td>
      <td>1</td>
      <td>2</td>
    </tr>
  </tbody>
</table>

<p>Correct.</p>

<h2 id="creating-token-embeddings">Creating token embeddings</h2>

<h3 id="why-are-embedding-vectors-necessary-for-training-gpt-like-llms">Why are embedding vectors necessary for training GPT-like LLMs?</h3>

<p>My answer:</p>

<p>They are necessary because, to encode how tokens relate to each other in an LLM, we need floating-point numbers instead of integer token IDs. With integers, we cannot encode these relationships because integers are discrete values.</p>

<p>Workbook answer:</p>

<p><em>Embedding vectors are essential for training GPT-like LLMs because these
models are deep neural networks that rely on the backpropagation
algorithm for learning. Backpropagation requires continuous vector
representations, which embedding vectors provide.</em></p>

<p>I missed mentioning the backpropagation algorithm.</p>

<h3 id="how-are-embedding-weights-initialized-in-the-beginning-of-llm-training">How are embedding weights initialized in the beginning of LLM training?</h3>

<p>My answer:</p>

<p>Embedding weights are initialized with random floating-point values at the beginning.</p>

<p>Workbook answer:</p>

<p><em>Embedding weights are initially assigned random values. These random
values serve as the starting point for the LLM’s learning process. During
training, the embedding weights are optimized through backpropagation
to improve the model’s performance.</em></p>

<p>This is correct, but the workbook also explains what happens to the initial weights during the training phase.</p>

<h3 id="what-is-the-missing-stage-from-this-diagram">What is the missing stage from this diagram?</h3>

<p><img src="/assets/gpt-tokenization-embeddings-diagram.png" alt="GPT pipeline with tokenization and embeddings" /></p>

<p>My answer:</p>

<p>Token IDs</p>

<p>Correct.</p>

<h3 id="describe-the-process-of-converting-a-token-id-into-an-embedding-vector-using-an-embedding-layer">Describe the process of converting a token ID into an embedding vector using an embedding layer.</h3>

<p>My answer:</p>

<p>We first create embedding weights with random values, for tensor with shape <code class="language-plaintext highlighter-rouge">{vocabular_size, embedding_dimensions}</code>. So each row in this tensor is for token ID, using zero based indexing. First row is for ID 0, second for ID 1, and so on. When we forward embeding weights on input IDs, we get new tensor of shape `{number_of_input_ids, embedding_dimensions}, and this is our embedding vector that we can use to train LLM.</p>

<p>Workbook answer:</p>

<p><em>The embedding layer acts as a lookup table. When given a token ID, it
retrieves the corresponding embedding vector from its weight matrix.
This embedding vector is a continuous representation of the token,
allowing the LLM to process it effectively.</em></p>

<p>Correct. But better explanation that this is lookup table for input IDs.</p>

<h3 id="how-does-the-embedding-layers-weight-matrix-relate-to-the-vocabulary-size-and-embedding-dimension">How does the embedding layer’s weight matrix relate to the vocabulary size and embedding dimension?</h3>

<p>My answer:</p>

<p>Shape of embedding layer’s weight matrix is defines as <code class="language-plaintext highlighter-rouge">{vocabulary_size, embedding_dimension}</code></p>

<p>Workbook answer:</p>

<p><em>The embedding layer’s weight matrix has a number of rows equal to the
vocabulary size, representing each unique token. The number of columns
corresponds to the embedding dimension, which determines the size of
the embedding vector for each token.</em></p>

<p>Correct.</p>

<h3 id="match-the-term-on-the-left-to-its-description-on-the-right-4">Match the term on the left to its description on the right:</h3>

<table>
  <thead>
    <tr>
      <th>Term</th>
      <th> </th>
      <th>Definition</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Context Size</td>
      <td> </td>
      <td>number of positions the input window is shifted when creating the next batch of input–target pairs.</td>
    </tr>
    <tr>
      <td>Input–Target Pairs</td>
      <td> </td>
      <td>a technique used to create input–target pairs from a text dataset by moving a window of tokens across the text.</td>
    </tr>
    <tr>
      <td>Sliding Window</td>
      <td> </td>
      <td>number of tokens that the LLM uses as input to predict the next word.</td>
    </tr>
    <tr>
      <td>Stride</td>
      <td> </td>
      <td>a set of data used to train an LLM, where the input is a sequence of tokens and the target is the next token in the sequence.</td>
    </tr>
  </tbody>
</table>

<p>My answers:</p>

<p>Context size =&gt; row 3<br />
Input-Target Pairs =&gt; row 4<br />
Sliding Window =&gt; row 2<br />
Stride =&gt; row 1</p>

<p>Correct.</p>

<h2 id="encoding-word-positions">Encoding word positions</h2>

<h3 id="what-is-the-main-shortcoming-of-llms-in-terms-of-token-order-and-how-is-it-addressed">What is the main shortcoming of LLMs in terms of token order and how is it addressed?</h3>

<p>My answer:</p>

<p>The main shortcoming of LLMs in terms of token order is that we always map token IDs to the same embedding vector weights because token IDs are indices of tokens in the vocabulary, and we do not have information about token position in the input text. Token position in the input text is important for next-token prediction. This problem is addressed by adding a vector with token absolute positions to the embedding vector.</p>

<p>Workbook answer:</p>

<p><em>LLMs’ self-attention mechanism lacks a notion of token order. To address
this, positional embeddings are introduced, which provide information
about the position of each token within a sequence.</em></p>

<p>Wrong. I missed that position indices are first forwarded to embedding weights and then added with token embeddings. I also did not mention that we are enhancing the self-attention mechanism.</p>

<h3 id="explain-the-difference-between-absolute-and-relative-positional-embeddings">Explain the difference between absolute and relative positional embeddings.</h3>

<p>My answer:</p>

<p>In absolute positioning, we use the token index position in the input text. In relative positional embeddings, we use a value that represents how far tokens are from each other.</p>

<p>Workbook answer:</p>

<p><em>Absolute positional embeddings assign a unique embedding to each
position in a sequence, indicating its exact location. Relative positional
embeddings focus on the relative distance between tokens, allowing the
model to generalize better to sequences of varying lengths.</em></p>

<p>Correct. But exact location and index position could use a better explanation, which I provided in a later answer.</p>

<h3 id="how-are-positional-embeddings-used-in-openais-gpt-models">How are positional embeddings used in OpenAI’s GPT models?</h3>

<p>My answer:</p>

<p>We add positional embeddings to token embeddings.</p>

<p>Workbook answer:</p>

<p><em>GPT models use absolute positional embeddings that are optimized during
training. These embeddings are not fixed or predefined but are learned
alongside the model’s other parameters.</em></p>

<p>Incorrect. I did not mention absolute positions or that those embeddings are not fixed but learned, because we also pass those positions together with token embeddings.</p>

<h3 id="describe-the-process-of-creating-input-embeddings-for-an-llm-using-token-embeddings-and-positional-embeddings">Describe the process of creating input embeddings for an LLM using token embeddings and positional embeddings.</h3>

<p>My answer:</p>

<p>In the case of absolute positions, we create a positional embedding vector that has, for each token, its index, which represents distance from the beginning of the input. Then we add this positional embedding vector to the token embedding vector.</p>

<p>Workbook answer:</p>

<p><em>Token embeddings are generated by mapping token IDs to vectors.
Positional embeddings are then added to these token embeddings,
resulting in input embeddings that incorporate both token identity and
positional information.</em></p>

<p>Incorrect. I forgot one step where positional IDs are also passed to obtain positional embeddings. Then we add token embeddings and positional embedding vectors.</p>

<h3 id="what-is-the-purpose-of-the-token_embedding_layer-and-pos_embedding_layer-in-the-code-provided">What is the purpose of the token_embedding_layer and pos_embedding_layer in the code provided?</h3>

<p>My answer:</p>

<p>These are token embedding and positional embedding layers.</p>

<p>Workbook answer:</p>

<p><em>The token_embedding_layer converts token IDs into embedding vectors,
while the pos_embedding_layer generates positional embeddings based
on the position of each token in the sequence.</em></p>

<p>Incorrect. I did not mention the conversion process.</p>

<h2 id="conclusion">Conclusion</h2>

<p>That is it for the workbook questions. Open-ended questions are still hard for me, but I am satisfied with how I generally answered them in this chapter.</p>

<h2 id="references">References</h2>

<ul>
  <li><a href="https://www.manning.com/books/build-a-large-language-model-from-scratch">Build a LLM from Scratch</a></li>
</ul>]]></content><author><name></name></author><category term="llm-from-scratch" /><category term="ai" /><category term="tutorials" /><category term="llm" /><category term="nlp" /><category term="workbook" /><category term="text-data" /><summary type="html"><![CDATA[Workbook answers and notes for Chapter 2 of Build a LLM from Scratch: Working with Text Data.]]></summary></entry><entry><title type="html">Build LLM from Scratch, Chapter 2 — Working with Text Data</title><link href="https://karlosmid.com/2026/01/build-llm-from-scratch-chapter-2-working-with-text-data/" rel="alternate" type="text/html" title="Build LLM from Scratch, Chapter 2 — Working with Text Data" /><published>2026-01-27T09:00:00+00:00</published><updated>2026-01-27T09:00:00+00:00</updated><id>https://karlosmid.com/2026/01/build-llm-from-scratch-chapter-2-working-with-text-data</id><content type="html" xml:base="https://karlosmid.com/2026/01/build-llm-from-scratch-chapter-2-working-with-text-data/"><![CDATA[<h3 id="tldr">TL;DR</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>raw text
  -&gt; tokens
  -&gt; token IDs
  -&gt; sliding windows (input/target pairs)
  -&gt; token embeddings
  -&gt; positional embeddings
  -&gt; model-ready tensors
</code></pre></div></div>

<p>Now that we have <a href="/2026/01/chapter-1-recap-build-llm-from-scratch/">basic knowledge about LLMs</a>, it’s time to get our hands dirty. Sebastian’s book uses Python, but I’m taking a different route: Elixir. The reason is simple—Elixir has solid machine-learning support, and I want to dig into those libraries. All elixir code is in my <a href="https://github.com/karlosmid/llm-from-scratch">git repo</a></p>

<h3 id="embeddings">Embeddings</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>human-readable text
  -&gt; tokens
  -&gt; numbers (IDs / vectors)
  -&gt; tensors
  -&gt; neural network computations
</code></pre></div></div>

<p>We use mathematical operations to train and implement neural networks. Those operations work on numbers, so human-readable text is not compatible input. We split input text (text that humans can read and understand) into tokens (for example, words and punctuation), and then represent tokens with continuous-valued vectors.
Sebastian directs us to Appendix A for a brief introduction to vectors and tensors. We need to understand the math if we want to understand LLMs under the hood. If you hate math, don’t despair: the application of this math is a chatbot, a very exciting technology.</p>

<h4 id="vectors-and-tensors-in-elixir">Vectors and Tensors in Elixir</h4>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>rank 0 -&gt; scalar
rank 1 -&gt; vector
rank 2 -&gt; matrix
rank 3+ -&gt; tensor

Nx.rank / Nx.shape -&gt; inspect dimensions
</code></pre></div></div>

<p>A tensor is a mathematical object that groups numbers. It has a rank that represents the number of dimensions. In Elixir we have the Nx library, which supports tensors. Here is a tensor of rank 0:</p>
<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">iex</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span>
  <span class="mi">1</span>
<span class="o">&gt;</span>
<span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">rank</span>
<span class="mi">0</span>
</code></pre></div></div>
<p>Yes, we all know that tensor: it is a regular number. Nx handles number types. s32 stands for a signed 32-bit integer (int32), which means it occupies 32 bits in memory. Remember that next time you want to run an LLM on your machine. Every model has its size in its name. For example, 7B stands for 7 billion numbers. If each number occupies 32 bits, we need around 28 GB of memory (about 26.7 GiB).</p>

<p>Moving forward, a one-dimensional tensor is a vector (list), and a two-dimensional tensor is a matrix (list of lists). A tensor of three dimensions is, well, just a tensor of three dimensions — or nested lists.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">iex</span><span class="p">(</span><span class="mi">12</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">])</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
  <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span>
<span class="o">&gt;</span>

<span class="n">iex</span><span class="p">(</span><span class="mi">13</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span><span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">5</span><span class="p">,</span><span class="mi">6</span><span class="p">]])</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span><span class="p">[</span><span class="mi">3</span><span class="p">][</span><span class="mi">2</span><span class="p">]</span>
  <span class="p">[</span>
    <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span>
    <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span>
    <span class="p">[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">]</span>
  <span class="p">]</span>
<span class="o">&gt;</span>

<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span><span class="p">[</span><span class="mi">3</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="mi">2</span><span class="p">]</span>
  <span class="p">[</span>
    <span class="p">[</span>
      <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span>
    <span class="p">],</span>
    <span class="p">[</span>
      <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span>
    <span class="p">],</span>
    <span class="p">[</span>
      <span class="p">[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">]</span>
    <span class="p">]</span>
  <span class="p">]</span>
<span class="o">&gt;</span>
</code></pre></div></div>
<p>Tensors have some typical operations. We have already seen Nx.rank. We also have Nx.shape:</p>
<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">]],</span> <span class="p">[[</span><span class="mi">3</span><span class="p">,</span><span class="mi">4</span><span class="p">]],</span> <span class="p">[[</span><span class="mi">5</span><span class="p">,</span><span class="mi">6</span><span class="p">]]])</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span>
<span class="p">{</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">}</span>
</code></pre></div></div>

<p>The first list has 3 elements, the next nested list has one element, and the last nested list has two elements.</p>

<p>Nx.reshape is another useful function that changes the shape of a tensor. This will be very useful later in this book.</p>
<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span><span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">5</span><span class="p">,</span><span class="mi">6</span><span class="p">]])</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">reshape</span><span class="p">({</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">})</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span><span class="p">[</span><span class="mi">2</span><span class="p">][</span><span class="mi">3</span><span class="p">]</span>
  <span class="p">[</span>
    <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span>
    <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">]</span>
  <span class="p">]</span>
<span class="o">&gt;</span>
</code></pre></div></div>

<p>PyTorch also has the <code class="language-plaintext highlighter-rouge">view</code> function, which is the same as <code class="language-plaintext highlighter-rouge">reshape</code>. Nx does not have a <code class="language-plaintext highlighter-rouge">view</code> function. The next useful function is transpose, which flips tensor values across the diagonal:</p>
<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span><span class="mi">6</span><span class="p">],</span> <span class="p">[</span><span class="mi">7</span><span class="p">,</span><span class="mi">8</span><span class="p">,</span><span class="mi">9</span><span class="p">]])</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span><span class="p">[</span><span class="mi">3</span><span class="p">][</span><span class="mi">3</span><span class="p">]</span>
  <span class="p">[</span>
    <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span>
    <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">],</span>
    <span class="p">[</span><span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">9</span><span class="p">]</span>
  <span class="p">]</span>
<span class="o">&gt;</span>
<span class="n">iex</span><span class="p">(</span><span class="mi">23</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span><span class="mi">6</span><span class="p">],</span> <span class="p">[</span><span class="mi">7</span><span class="p">,</span><span class="mi">8</span><span class="p">,</span><span class="mi">9</span><span class="p">]])</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">transpose</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span><span class="p">[</span><span class="mi">3</span><span class="p">][</span><span class="mi">3</span><span class="p">]</span>
  <span class="p">[</span>
    <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">7</span><span class="p">],</span>
    <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">8</span><span class="p">],</span>
    <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">9</span><span class="p">]</span>
  <span class="p">]</span>
<span class="o">&gt;</span>
</code></pre></div></div>

<p>And we want to multiply tensors:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># rank 0 </span>
<span class="n">iex</span><span class="p">(</span><span class="mi">33</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">)</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span>
  <span class="mi">2</span>
<span class="o">&gt;</span>
<span class="c1"># rank 1</span>
<span class="n">iex</span><span class="p">(</span><span class="mi">34</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">]),</span><span class="mi">2</span><span class="p">)</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
  <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">6</span><span class="p">]</span>
<span class="o">&gt;</span>
<span class="c1"># rank 2</span>
<span class="n">iex</span><span class="p">(</span><span class="mi">35</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">]),</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">]))</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">s32</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
  <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">9</span><span class="p">]</span>
<span class="o">&gt;</span>
</code></pre></div></div>

<p>This detour about tensors ends here, and we survived. Elixir has a tensor implementation in the Nx library!</p>

<p>Embedding is the concept where we transform input data (human-readable text) into vector form. We can also embed video, audio, and images, so embedding is not only for text. Each data type has its own embedding model, so a text embedding model cannot be used for videos. A unit in text processing could be a word, sentence, or paragraph. Now I know what RAG is: retrieval-augmented generation. The model combines text generation with retrieval of external knowledge sources (the internet or databases).</p>

<p>An LLM is not the only algorithm for creating word embeddings. One example is the Word2Vec algorithm. That algorithm predicts the context of a word based on a target word. The idea is that words with similar contexts are close to each other. What does “close” mean? For example, if a vector has 2 dimensions, the words are represented with points in an x-y coordinate system.</p>

<p>Below is a simple ASCII diagram of the Word2Vec workflow (CBOW and Skip-gram) and a tiny 2D coordinate sketch to show “closeness” in the embedding space:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>EMBEDDINGS: 2D COORDINATE SKETCH

y
^
|             (king)
|               *
|      (queen) *      *
|                    (prince)
|
|   (cat)  *           (dog) *
|
+----------------------------------&gt; x
</code></pre></div></div>
<p>Word2Vec has more than 2 dimensions, up to 1000. LLMs have more dimensions; for example, GPT-3 has 12,288 dimensions. The number of dimensions is a trade-off between performance and efficiency.</p>

<h3 id="give-me-some-tokens">Give me some tokens!</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>The Verdict (raw text)
  -&gt; split text
  -&gt; keep useful punctuation
  -&gt; token list
</code></pre></div></div>

<p>We will tokenize the short story The Verdict by Edith Wharton.</p>

<p><em>The narrator visits the Riviera home of Jack Gisburn, a once‑celebrated painter who abruptly stopped working after marrying a wealthy widow. Looking at Gisburn’s last portrait, the narrator recognizes it as empty showmanship and realizes the artist’s reputation rested more on social fashion than genuine talent. The story ends with the narrator quietly keeping this “verdict” to himself, leaving Gisburn’s public acclaim undisturbed.</em></p>

<p>Here is how you can manipulate files and URLs in Elixir. We use the proven Req library for URL processing, created by Wojtek Mach. The File module is included in Elixir.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test</span> <span class="s2">"the-verdict.txt character count and exact line 99 content"</span> <span class="k">do</span>
    <span class="n">url</span> <span class="o">=</span>
      <span class="s2">"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"</span>

    <span class="n">filename</span> <span class="o">=</span> <span class="s2">"the-verdict.txt"</span>
    <span class="p">%</span><span class="no">Req</span><span class="o">.</span><span class="no">Response</span><span class="p">{</span><span class="ss">status:</span> <span class="mi">200</span><span class="p">,</span> <span class="ss">body:</span> <span class="n">body</span><span class="p">}</span> <span class="o">=</span> <span class="no">Req</span><span class="o">.</span><span class="n">get!</span><span class="p">(</span><span class="n">url</span><span class="p">)</span>
    <span class="no">File</span><span class="o">.</span><span class="n">write!</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">body</span><span class="p">)</span>
    <span class="c1"># Read the file content as a single string</span>
    
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">file_content</span><span class="p">}</span> <span class="o">=</span> <span class="no">File</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>

    <span class="c1"># Assert on the number of characters in the file</span>
    <span class="n">char_count</span> <span class="o">=</span> <span class="no">String</span><span class="o">.</span><span class="n">length</span><span class="p">(</span><span class="n">file_content</span><span class="p">)</span>
    <span class="n">assert</span> <span class="n">char_count</span> <span class="o">==</span> <span class="mi">20_479</span>
    <span class="n">first_99_chars</span> <span class="o">=</span> <span class="no">String</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">file_content</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">99</span><span class="p">)</span>

    <span class="c1"># Assert on content of line 99</span>
    <span class="n">assert</span> <span class="n">first_99_chars</span> <span class="o">==</span>
             <span class="s2">"I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no "</span>
  <span class="k">end</span>
</code></pre></div></div>

<h4 id="regular-expressions">Regular expressions</h4>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Regex.split(...)
  -&gt; split on whitespace / punctuation
  -&gt; optionally keep separators
  -&gt; trim empties
  -&gt; tokenizer output pieces
</code></pre></div></div>

<p>For educational purposes, we will implement the first version of the tokenizer using regular expressions:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="n">test</span> <span class="s2">"split text on whitespace, keep the whitespace"</span> <span class="k">do</span>
    <span class="n">text</span> <span class="o">=</span> <span class="s2">"Hello, world. This, is a test."</span>
    <span class="n">result</span> <span class="o">=</span> <span class="no">Regex</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sr">~r{\s}</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="ss">include_captures:</span> <span class="no">true</span><span class="p">,</span> <span class="ss">trim:</span> <span class="no">false</span><span class="p">)</span>
    <span class="n">assert</span> <span class="n">result</span> <span class="o">==</span> <span class="p">[</span><span class="s2">"Hello,"</span><span class="p">,</span> <span class="s2">" "</span><span class="p">,</span> <span class="s2">"world."</span><span class="p">,</span> <span class="s2">" "</span><span class="p">,</span> <span class="s2">"This,"</span><span class="p">,</span> <span class="s2">" "</span><span class="p">,</span> <span class="s2">"is"</span><span class="p">,</span> <span class="s2">" "</span><span class="p">,</span> <span class="s2">"a"</span><span class="p">,</span> <span class="s2">" "</span><span class="p">,</span> <span class="s2">"test."</span><span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>The Regex module is included in Elixir. <code class="language-plaintext highlighter-rouge">\s</code> means we split the input on whitespace; <code class="language-plaintext highlighter-rouge">include_captures</code> determines whether we keep what we split on (in this example, whitespace); and <code class="language-plaintext highlighter-rouge">trim: false</code> means we do not trim whitespace before or after the string. The result is what we got.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="n">test</span> <span class="s2">"split text on whitespace, commas, and periods, keep them except whitespace"</span> <span class="k">do</span>
    <span class="n">text</span> <span class="o">=</span> <span class="s2">"Hello, world. This, is a test."</span>
    <span class="n">result</span> <span class="o">=</span> <span class="no">Regex</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sr">~r{[,.]|\s}</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="ss">include_captures:</span> <span class="no">true</span><span class="p">,</span> <span class="ss">trim:</span> <span class="no">true</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">result</span> <span class="o">==</span> <span class="p">[</span>
             <span class="s2">"Hello"</span><span class="p">,</span>
             <span class="s2">","</span><span class="p">,</span>
             <span class="s2">" "</span><span class="p">,</span>
             <span class="s2">"world"</span><span class="p">,</span>
             <span class="s2">"."</span><span class="p">,</span>
             <span class="s2">" "</span><span class="p">,</span>
             <span class="s2">"This"</span><span class="p">,</span>
             <span class="s2">","</span><span class="p">,</span>
             <span class="s2">" "</span><span class="p">,</span>
             <span class="s2">"is"</span><span class="p">,</span>
             <span class="s2">" "</span><span class="p">,</span>
             <span class="s2">"a"</span><span class="p">,</span>
             <span class="s2">" "</span><span class="p">,</span>
             <span class="s2">"test"</span><span class="p">,</span>
             <span class="s2">"."</span>
           <span class="p">]</span>

    <span class="c1"># Remove whitespaces from the result</span>
    <span class="n">result_no_whitespace</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reject</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="k">fn</span> <span class="n">s</span> <span class="o">-&gt;</span> <span class="n">s</span> <span class="o">==</span> <span class="s2">" "</span> <span class="k">end</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">result_no_whitespace</span> <span class="o">==</span> <span class="p">[</span>
             <span class="s2">"Hello"</span><span class="p">,</span>
             <span class="s2">","</span><span class="p">,</span>
             <span class="s2">"world"</span><span class="p">,</span>
             <span class="s2">"."</span><span class="p">,</span>
             <span class="s2">"This"</span><span class="p">,</span>
             <span class="s2">","</span><span class="p">,</span>
             <span class="s2">"is"</span><span class="p">,</span>
             <span class="s2">"a"</span><span class="p">,</span>
             <span class="s2">"test"</span><span class="p">,</span>
             <span class="s2">"."</span>
           <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>Here we add a dot and comma to the split characters, keep them, and trim the split items.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="n">test</span> <span class="s2">"split text on punctuation, keep them except whitespace"</span> <span class="k">do</span>
    <span class="n">text</span> <span class="o">=</span> <span class="s2">"Hello, world. Is this-- a test?"</span>
    <span class="n">result</span> <span class="o">=</span> <span class="no">Regex</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sr">~r{[,.:;?_!"()\']|--|\s}</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="ss">include_captures:</span> <span class="no">true</span><span class="p">,</span> <span class="ss">trim:</span> <span class="no">true</span><span class="p">)</span>
    <span class="c1"># Remove whitespaces from the result</span>
    <span class="n">result_no_whitespace</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reject</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="k">fn</span> <span class="n">s</span> <span class="o">-&gt;</span> <span class="n">s</span> <span class="o">==</span> <span class="s2">" "</span> <span class="k">end</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">result_no_whitespace</span> <span class="o">==</span> <span class="p">[</span>
             <span class="s2">"Hello"</span><span class="p">,</span>
             <span class="s2">","</span><span class="p">,</span>
             <span class="s2">"world"</span><span class="p">,</span>
             <span class="s2">"."</span><span class="p">,</span>
             <span class="s2">"Is"</span><span class="p">,</span>
             <span class="s2">"this"</span><span class="p">,</span>
             <span class="s2">"--"</span><span class="p">,</span>
             <span class="s2">"a"</span><span class="p">,</span>
             <span class="s2">"test"</span><span class="p">,</span>
             <span class="s2">"?"</span>
           <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>We extend the list of split characters to match what is used in The Verdict, and we remove whitespace from the result.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="n">test</span> <span class="s2">"predprocess the verdict.txt file"</span> <span class="k">do</span>
    <span class="n">filename</span> <span class="o">=</span> <span class="s2">"the-verdict.txt"</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">file_content</span><span class="p">}</span> <span class="o">=</span> <span class="no">File</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>

    <span class="n">result</span> <span class="o">=</span>
      <span class="no">Regex</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sr">~r/([,.:;?_!"()\']|--|\s)/</span><span class="p">,</span> <span class="n">file_content</span><span class="p">,</span> <span class="ss">include_captures:</span> <span class="no">true</span><span class="p">,</span> <span class="ss">trim:</span> <span class="no">true</span><span class="p">)</span>

    <span class="c1"># Remove whitespaces from the result - match Python's behavior: strip and filter empty</span>
    <span class="n">predprocessed_text</span> <span class="o">=</span>
      <span class="n">result</span>
      <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="o">&amp;</span><span class="no">String</span><span class="o">.</span><span class="n">trim</span><span class="o">/</span><span class="mi">1</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reject</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="nv">&amp;1</span> <span class="o">==</span> <span class="s2">""</span><span class="p">))</span>

    <span class="n">assert</span> <span class="no">Enum</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">predprocessed_text</span><span class="p">,</span> <span class="mi">0</span><span class="o">..</span><span class="mi">29</span><span class="p">)</span> <span class="o">==</span> <span class="p">[</span>
             <span class="s2">"I"</span><span class="p">,</span>
             <span class="s2">"HAD"</span><span class="p">,</span>
             <span class="s2">"always"</span><span class="p">,</span>
             <span class="s2">"thought"</span><span class="p">,</span>
             <span class="s2">"Jack"</span><span class="p">,</span>
             <span class="s2">"Gisburn"</span><span class="p">,</span>
             <span class="s2">"rather"</span><span class="p">,</span>
             <span class="s2">"a"</span><span class="p">,</span>
             <span class="s2">"cheap"</span><span class="p">,</span>
             <span class="s2">"genius"</span><span class="p">,</span>
             <span class="s2">"--"</span><span class="p">,</span>
             <span class="s2">"though"</span><span class="p">,</span>
             <span class="s2">"a"</span><span class="p">,</span>
             <span class="s2">"good"</span><span class="p">,</span>
             <span class="s2">"fellow"</span><span class="p">,</span>
             <span class="s2">"enough"</span><span class="p">,</span>
             <span class="s2">"--"</span><span class="p">,</span>
             <span class="s2">"so"</span><span class="p">,</span>
             <span class="s2">"it"</span><span class="p">,</span>
             <span class="s2">"was"</span><span class="p">,</span>
             <span class="s2">"no"</span><span class="p">,</span>
             <span class="s2">"great"</span><span class="p">,</span>
             <span class="s2">"surprise"</span><span class="p">,</span>
             <span class="s2">"to"</span><span class="p">,</span>
             <span class="s2">"me"</span><span class="p">,</span>
             <span class="s2">"to"</span><span class="p">,</span>
             <span class="s2">"hear"</span><span class="p">,</span>
             <span class="s2">"that"</span><span class="p">,</span>
             <span class="s2">","</span><span class="p">,</span>
             <span class="s2">"in"</span>
           <span class="p">]</span>

    <span class="n">assert</span> <span class="n">length</span><span class="p">(</span><span class="n">predprocessed_text</span><span class="p">)</span> <span class="o">==</span> <span class="mi">4690</span>

    <span class="c1"># Use Pythonx to read the file</span>
    <span class="p">{</span><span class="n">result_obj</span><span class="p">,</span> <span class="n">_globals</span><span class="p">}</span> <span class="o">=</span>
      <span class="no">Pythonx</span><span class="o">.</span><span class="n">eval</span><span class="p">(</span>
        <span class="sd">"""
        import re
        with open("the-verdict.txt", "r", encoding="utf-8") as f:
            raw_text = f.read()
        pattern = r'''([,.:;?_!"()']|--|\s)'''
        preprocessed = re.split(pattern, raw_text)
        preprocessed = [item.strip() for item in preprocessed if item.strip()]
        preprocessed
        """</span><span class="p">,</span>
        <span class="p">%{}</span>
      <span class="p">)</span>

    <span class="n">preprocessed_text_python</span> <span class="o">=</span> <span class="no">Pythonx</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">result_obj</span><span class="p">)</span>
    <span class="n">assert</span> <span class="n">preprocessed_text_python</span> <span class="o">==</span> <span class="n">predprocessed_text</span>
 
    <span class="c1"># Verify the file content starts correctly</span>
    <span class="n">assert</span> <span class="no">String</span><span class="o">.</span><span class="n">starts_with?</span><span class="p">(</span>
             <span class="n">file_content</span><span class="p">,</span>
             <span class="s2">"I HAD always thought Jack Gisburn rather a cheap genius"</span>
           <span class="p">)</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>We processed The Verdict file. We trim after the split and remove empty strings. To compare with Sebastian’s Python implementation, we use the Pythonx Elixir library.</p>

<h3 id="give-me-some-ids">Give me some IDs</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tokens
  -&gt; sort + deduplicate
  -&gt; vocabulary
  -&gt; token &lt;-&gt; integer ID mapping
  -&gt; encode / decode
</code></pre></div></div>

<p>Idea is very simple. From tokens, we need to create vocabulary. We sort tokens that we have, then we remove duplicates, this is vocabulary. Then each token get’s integer number.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test</span> <span class="s2">"token IDs"</span> <span class="k">do</span>
    <span class="n">filename</span> <span class="o">=</span> <span class="s2">"the-verdict.txt"</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">file_content</span><span class="p">}</span> <span class="o">=</span> <span class="no">File</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>

    <span class="n">result</span> <span class="o">=</span>
      <span class="no">Regex</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sr">~r/([,.:;?_!"()\']|--|\s)/</span><span class="p">,</span> <span class="n">file_content</span><span class="p">,</span> <span class="ss">include_captures:</span> <span class="no">true</span><span class="p">,</span> <span class="ss">trim:</span> <span class="no">true</span><span class="p">)</span>

    <span class="c1"># Remove whitespaces from the result - match Python's behavior: strip and filter empty</span>
    <span class="n">predprocessed_text</span> <span class="o">=</span>
      <span class="n">result</span>
      <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="o">&amp;</span><span class="no">String</span><span class="o">.</span><span class="n">trim</span><span class="o">/</span><span class="mi">1</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reject</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="nv">&amp;1</span> <span class="o">==</span> <span class="s2">""</span><span class="p">))</span>

    <span class="n">all_words</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="no">MapSet</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">predprocessed_text</span><span class="p">))</span>
    <span class="n">vocab_size</span> <span class="o">=</span> <span class="n">length</span><span class="p">(</span><span class="n">all_words</span><span class="p">)</span>
    <span class="n">assert</span> <span class="n">vocab_size</span> <span class="o">==</span> <span class="mi">1130</span>

    <span class="n">vocab</span> <span class="o">=</span>
      <span class="no">Enum</span><span class="o">.</span><span class="n">with_index</span><span class="p">(</span><span class="n">all_words</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Enum</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="mi">0</span><span class="o">..</span><span class="mi">50</span><span class="p">)</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">{</span><span class="s2">"!"</span><span class="p">,</span> <span class="mi">0</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"</span><span class="se">\"</span><span class="s2">"</span><span class="p">,</span> <span class="mi">1</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"'"</span><span class="p">,</span> <span class="mi">2</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"("</span><span class="p">,</span> <span class="mi">3</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">")"</span><span class="p">,</span> <span class="mi">4</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">","</span><span class="p">,</span> <span class="mi">5</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"--"</span><span class="p">,</span> <span class="mi">6</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"."</span><span class="p">,</span> <span class="mi">7</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">":"</span><span class="p">,</span> <span class="mi">8</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">";"</span><span class="p">,</span> <span class="mi">9</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"?"</span><span class="p">,</span> <span class="mi">10</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"A"</span><span class="p">,</span> <span class="mi">11</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Ah"</span><span class="p">,</span> <span class="mi">12</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Among"</span><span class="p">,</span> <span class="mi">13</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"And"</span><span class="p">,</span> <span class="mi">14</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Are"</span><span class="p">,</span> <span class="mi">15</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Arrt"</span><span class="p">,</span> <span class="mi">16</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"As"</span><span class="p">,</span> <span class="mi">17</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"At"</span><span class="p">,</span> <span class="mi">18</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Be"</span><span class="p">,</span> <span class="mi">19</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Begin"</span><span class="p">,</span> <span class="mi">20</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Burlington"</span><span class="p">,</span> <span class="mi">21</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"But"</span><span class="p">,</span> <span class="mi">22</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"By"</span><span class="p">,</span> <span class="mi">23</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Carlo"</span><span class="p">,</span> <span class="mi">24</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Chicago"</span><span class="p">,</span> <span class="mi">25</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Claude"</span><span class="p">,</span> <span class="mi">26</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Come"</span><span class="p">,</span> <span class="mi">27</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Croft"</span><span class="p">,</span> <span class="mi">28</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Destroyed"</span><span class="p">,</span> <span class="mi">29</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Devonshire"</span><span class="p">,</span> <span class="mi">30</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Don"</span><span class="p">,</span> <span class="mi">31</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Dubarry"</span><span class="p">,</span> <span class="mi">32</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Emperors"</span><span class="p">,</span> <span class="mi">33</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Florence"</span><span class="p">,</span> <span class="mi">34</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"For"</span><span class="p">,</span> <span class="mi">35</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Gallery"</span><span class="p">,</span> <span class="mi">36</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Gideon"</span><span class="p">,</span> <span class="mi">37</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Gisburn"</span><span class="p">,</span> <span class="mi">38</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Gisburns"</span><span class="p">,</span> <span class="mi">39</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Grafton"</span><span class="p">,</span> <span class="mi">40</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Greek"</span><span class="p">,</span> <span class="mi">41</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Grindle"</span><span class="p">,</span> <span class="mi">42</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Grindles"</span><span class="p">,</span> <span class="mi">43</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"HAD"</span><span class="p">,</span> <span class="mi">44</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Had"</span><span class="p">,</span> <span class="mi">45</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Hang"</span><span class="p">,</span> <span class="mi">46</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Has"</span><span class="p">,</span> <span class="mi">47</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"He"</span><span class="p">,</span> <span class="mi">48</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Her"</span><span class="p">,</span> <span class="mi">49</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"Hermia"</span><span class="p">,</span> <span class="mi">50</span><span class="p">}</span>
           <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>To remove duplicate tokens, we use MapSet datastruct.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="n">iex</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">MapSet</span><span class="o">.</span><span class="n">new</span><span class="p">([</span><span class="s2">"Karlo"</span><span class="p">,</span> <span class="s2">"karlo"</span><span class="p">,</span> <span class="s2">"Karlo"</span><span class="p">,</span> <span class="s2">"Smid"</span><span class="p">])</span>
<span class="no">MapSet</span><span class="o">.</span><span class="n">new</span><span class="p">([</span><span class="s2">"Karlo"</span><span class="p">,</span> <span class="s2">"Smid"</span><span class="p">,</span> <span class="s2">"karlo"</span><span class="p">])</span>
</code></pre></div></div>

<p>In the end we list first 51 vocabulary entries.</p>

<p>Last thing to do is simple tokenizer, Python has class, Elixir module. We have following functions:</p>

<table>
  <thead>
    <tr>
      <th>Function</th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>vocab_from_file</td>
      <td>Creates a sorted token vocabulary from text file (with special tokens)</td>
    </tr>
    <tr>
      <td>encode</td>
      <td>Converts a text string into a sequence of IDs using vocabulary</td>
    </tr>
    <tr>
      <td>decode</td>
      <td>Converts a sequence of IDs back into a text string using vocabulary</td>
    </tr>
    <tr>
      <td>tokenize</td>
      <td>Splits text into tokens (words and punctuation used in The Verdict)</td>
    </tr>
  </tbody>
</table>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span> <span class="k">do</span>
  <span class="nv">@moduledoc</span> <span class="sd">"""
  A simple tokenizer that tokenizes text into words.
  """</span>
  <span class="k">def</span> <span class="n">vocab_from_file</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">special_tokens</span> <span class="p">\\</span> <span class="p">[])</span> <span class="k">do</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">file_content</span><span class="p">}</span> <span class="o">=</span> <span class="no">File</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>

    <span class="n">tokens</span> <span class="o">=</span>
      <span class="n">tokenize</span><span class="p">(</span><span class="n">file_content</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">MapSet</span><span class="o">.</span><span class="n">new</span><span class="p">()</span>
      <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span>

    <span class="p">(</span><span class="n">tokens</span> <span class="o">++</span> <span class="n">special_tokens</span><span class="p">)</span>
    <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">with_index</span><span class="p">()</span>
  <span class="k">end</span>

  <span class="k">def</span> <span class="n">encode</span><span class="p">(</span><span class="n">text</span><span class="p">,</span> <span class="n">vocab</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">tokenize</span><span class="p">(</span><span class="n">text</span><span class="p">)</span>
    <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">fn</span> <span class="n">token</span> <span class="o">-&gt;</span>
      <span class="no">Enum</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="k">fn</span> <span class="p">{</span><span class="n">vocab_token</span><span class="p">,</span> <span class="n">_</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">token</span> <span class="o">==</span> <span class="n">vocab_token</span> <span class="k">end</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="k">case</span> <span class="k">do</span>
        <span class="p">{</span><span class="n">_</span><span class="p">,</span> <span class="n">id</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">id</span>
        <span class="no">nil</span> <span class="o">-&gt;</span> <span class="n">find_unknown_token</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="n">token</span><span class="p">)</span>
      <span class="k">end</span>
    <span class="k">end</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">find_unknown_token</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="n">token</span><span class="p">)</span> <span class="k">do</span>
    <span class="no">Enum</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="k">fn</span> <span class="p">{</span><span class="n">vocab_token</span><span class="p">,</span> <span class="n">_</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">vocab_token</span> <span class="o">==</span> <span class="s2">"&lt;|unk|&gt;"</span> <span class="k">end</span><span class="p">)</span>
    <span class="o">|&gt;</span> <span class="k">case</span> <span class="k">do</span>
      <span class="p">{</span><span class="n">_</span><span class="p">,</span> <span class="n">id</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">id</span>
      <span class="no">nil</span> <span class="o">-&gt;</span> <span class="k">raise</span> <span class="s2">"Token not found in vocab: </span><span class="si">#{</span><span class="n">token</span><span class="si">}</span><span class="s2">"</span>
    <span class="k">end</span>
  <span class="k">end</span>

  <span class="k">def</span> <span class="n">decode</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">vocab</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">text</span> <span class="o">=</span>
      <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="k">fn</span> <span class="n">id</span> <span class="o">-&gt;</span>
        <span class="no">Enum</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="k">fn</span> <span class="p">{</span><span class="n">_</span><span class="p">,</span> <span class="n">vocab_id</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">id</span> <span class="o">==</span> <span class="n">vocab_id</span> <span class="k">end</span><span class="p">)</span>
        <span class="o">|&gt;</span> <span class="k">case</span> <span class="k">do</span>
          <span class="p">{</span><span class="n">vocab_token</span><span class="p">,</span> <span class="n">_</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">vocab_token</span>
          <span class="no">nil</span> <span class="o">-&gt;</span> <span class="k">raise</span> <span class="s2">"ID not found in vocab: </span><span class="si">#{</span><span class="n">id</span><span class="si">}</span><span class="s2">"</span>
        <span class="k">end</span>
      <span class="k">end</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s2">" "</span><span class="p">)</span>

    <span class="no">Regex</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="sr">~r/\s+([,.?!"()\'])/</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="s2">"</span><span class="se">\\</span><span class="s2">1"</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">tokenize</span><span class="p">(</span><span class="n">text</span><span class="p">)</span> <span class="k">do</span>
    <span class="no">Regex</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sr">~r/([,.:;?_!"()\']|--|\s)/</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="ss">include_captures:</span> <span class="no">true</span><span class="p">,</span> <span class="ss">trim:</span> <span class="no">true</span><span class="p">)</span>
    <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="o">&amp;</span><span class="no">String</span><span class="o">.</span><span class="n">trim</span><span class="o">/</span><span class="mi">1</span><span class="p">)</span>
    <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reject</span><span class="p">(</span><span class="o">&amp;</span><span class="p">(</span><span class="nv">&amp;1</span> <span class="o">==</span> <span class="s2">""</span><span class="p">))</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<p>And now we can use simple tokenizer:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test</span> <span class="s2">"encode and decode text with simple tokenizer"</span> <span class="k">do</span>
    <span class="n">filename</span> <span class="o">=</span> <span class="s2">"the-verdict.txt"</span>
    <span class="n">vocab</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">vocab_from_file</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>

    <span class="n">text</span> <span class="o">=</span> <span class="sd">"""
    "It's the last he painted, you know," 
           Mrs. Gisburn said with pardonable pride.
    """</span>

    <span class="n">encoded_text</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">text</span><span class="p">,</span> <span class="n">vocab</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">encoded_text</span> <span class="o">==</span> <span class="p">[</span>
             <span class="mi">1</span><span class="p">,</span>
             <span class="mi">56</span><span class="p">,</span>
             <span class="mi">2</span><span class="p">,</span>
             <span class="mi">850</span><span class="p">,</span>
             <span class="mi">988</span><span class="p">,</span>
             <span class="mi">602</span><span class="p">,</span>
             <span class="mi">533</span><span class="p">,</span>
             <span class="mi">746</span><span class="p">,</span>
             <span class="mi">5</span><span class="p">,</span>
             <span class="mi">1126</span><span class="p">,</span>
             <span class="mi">596</span><span class="p">,</span>
             <span class="mi">5</span><span class="p">,</span>
             <span class="mi">1</span><span class="p">,</span>
             <span class="mi">67</span><span class="p">,</span>
             <span class="mi">7</span><span class="p">,</span>
             <span class="mi">38</span><span class="p">,</span>
             <span class="mi">851</span><span class="p">,</span>
             <span class="mi">1108</span><span class="p">,</span>
             <span class="mi">754</span><span class="p">,</span>
             <span class="mi">793</span><span class="p">,</span>
             <span class="mi">7</span>
           <span class="p">]</span>

    <span class="n">decoded_text</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">encoded_text</span><span class="p">,</span> <span class="n">vocab</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">decoded_text</span> <span class="o">==</span>
             <span class="s2">"</span><span class="se">\"</span><span class="s2"> It' s the last he painted, you know,</span><span class="se">\"</span><span class="s2"> Mrs. Gisburn said with pardonable pride."</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>What about unknown tokens?</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
  <span class="n">test</span> <span class="s2">"missing token in vocab raises an error"</span> <span class="k">do</span>
    <span class="n">vocab</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">vocab_from_file</span><span class="p">(</span><span class="s2">"the-verdict.txt"</span><span class="p">)</span>
    <span class="n">text</span> <span class="o">=</span> <span class="s2">"Hello, do you like tea. Is this-- a test?"</span>

    <span class="n">assert_raise</span> <span class="no">RuntimeError</span><span class="p">,</span> <span class="s2">"Token not found in vocab: Hello"</span><span class="p">,</span> <span class="k">fn</span> <span class="o">-&gt;</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">text</span><span class="p">,</span> <span class="n">vocab</span><span class="p">)</span>
    <span class="k">end</span>
  <span class="k">end</span>
</code></pre></div></div>

<h3 id="i-am-a-special-token">I am a special token</h3>

<p>In the token world we also have special tokens. These tokens have a special purpose: they set context in the training data. For example, when we encounter an unknown vocabulary token, it is not useful training practice to raise RuntimeError; instead, we mark that we hit an unknown token and move on. For that purpose, we use the token |unk|. We are training our LLM on one short story, The Verdict. But in the real world, we will use many books. To set context for a book, we use the token |endoftext| as the last token of the book; then the training algorithm knows to set the context of the next book. Let’s see that in our <code class="language-plaintext highlighter-rouge">SimpleTokenizerV1</code>:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
  <span class="n">test</span> <span class="s2">"encode and decode text with special token"</span> <span class="k">do</span>
    <span class="n">vocab</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">vocab_from_file</span><span class="p">(</span><span class="s2">"the-verdict.txt"</span><span class="p">,</span> <span class="p">[</span><span class="s2">"&lt;|endoftext|&gt;"</span><span class="p">,</span> <span class="s2">"&lt;|unk|&gt;"</span><span class="p">])</span>

    <span class="n">assert</span> <span class="n">length</span><span class="p">(</span><span class="n">vocab</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1132</span>
    <span class="n">last_five</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="o">-</span><span class="mi">5</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">last_five</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">{</span><span class="s2">"younger"</span><span class="p">,</span> <span class="mi">1127</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"your"</span><span class="p">,</span> <span class="mi">1128</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"yourself"</span><span class="p">,</span> <span class="mi">1129</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"&lt;|endoftext|&gt;"</span><span class="p">,</span> <span class="mi">1130</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">"&lt;|unk|&gt;"</span><span class="p">,</span> <span class="mi">1131</span><span class="p">}</span>
           <span class="p">]</span>

    <span class="n">text1</span> <span class="o">=</span> <span class="s2">"Hello, do you like tea?"</span>
    <span class="n">text2</span> <span class="o">=</span> <span class="s2">"In the sunlit terraces of the palace."</span>
    <span class="n">text</span> <span class="o">=</span> <span class="n">text1</span> <span class="o">&lt;&gt;</span> <span class="s2">" &lt;|endoftext|&gt; "</span> <span class="o">&lt;&gt;</span> <span class="n">text2</span>

    <span class="n">assert</span> <span class="n">text</span> <span class="o">==</span> <span class="s2">"Hello, do you like tea? &lt;|endoftext|&gt; In the sunlit terraces of the palace."</span>

    <span class="n">assert</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">text</span><span class="p">,</span> <span class="n">vocab</span><span class="p">)</span> <span class="o">==</span> <span class="p">[</span>
             <span class="mi">1131</span><span class="p">,</span>
             <span class="mi">5</span><span class="p">,</span>
             <span class="mi">355</span><span class="p">,</span>
             <span class="mi">1126</span><span class="p">,</span>
             <span class="mi">628</span><span class="p">,</span>
             <span class="mi">975</span><span class="p">,</span>
             <span class="mi">10</span><span class="p">,</span>
             <span class="mi">1130</span><span class="p">,</span>
             <span class="mi">55</span><span class="p">,</span>
             <span class="mi">988</span><span class="p">,</span>
             <span class="mi">956</span><span class="p">,</span>
             <span class="mi">984</span><span class="p">,</span>
             <span class="mi">722</span><span class="p">,</span>
             <span class="mi">988</span><span class="p">,</span>
             <span class="mi">1131</span><span class="p">,</span>
             <span class="mi">7</span>
           <span class="p">]</span>

    <span class="n">assert</span> <span class="n">text</span>
           <span class="o">|&gt;</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">vocab</span><span class="p">)</span>
           <span class="o">|&gt;</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">SimpleTokenizerV1</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">vocab</span><span class="p">)</span> <span class="o">==</span>
             <span class="s2">"&lt;|unk|&gt;, do you like tea? &lt;|endoftext|&gt; In the sunlit terraces of the &lt;|unk|&gt;."</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>We add |endoftext| and |unk| at the end of our vocabulary, and we concatenate two text examples with different contexts that contain unknown vocabulary tokens. In the end, you see what we get after the <code class="language-plaintext highlighter-rouge">encode() |&gt; decode()</code> sequence. You can use other tokens for the same purpose, like [BOS] and [EOS] for beginning and end of sequence, or [PAD] for padding input texts of different contexts to the same length.</p>

<h3 id="byte-pair-encoding">Byte pair encoding</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>text
  -&gt; bytes / subword pieces
  -&gt; frequent merges
  -&gt; BPE tokens
  -&gt; handles unknown words better
</code></pre></div></div>

<p>Byte pair encoding (BPE) is the real deal because it was used to tokenize the training data for LLMs like GPT-2, GPT-3, and the original models used in ChatGPT. Sebastian does not go into the details of the <a href="https://arxiv.org/abs/2306.16837">Byte pair encoding algorithm</a> because it is a complicated topic. In the byte pair encoding algorithm, text is broken into letters, like a, b, c, … Then, in a second pass, it merges letters that usually go together, like de in dentist or define. That is the basic idea; the rest you can find in the paper. Because of that, byte pair encoding can handle all words, not just words from the vocabulary. tiktoken is a popular library with a byte pair encoding implementation. There is a Rust implementation, and thanks to the rustler library, we have tiktoken for Elixir. Note how SomeunknownPlace is successfully decoded.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"byte pair encoding using gpt2 tiktoken"</span> <span class="k">do</span>
    <span class="n">model</span> <span class="o">=</span> <span class="s2">"code-davinci-002"</span>

    <span class="n">text</span> <span class="o">=</span> <span class="s2">"Hello, do you like tea? &lt;|endoftext|&gt; In the sunlit terraces\of someunknownPlace."</span>

    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">encoded_tokens</span><span class="p">}</span> <span class="o">=</span> <span class="no">Tiktoken</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="p">[</span><span class="s2">"&lt;|endoftext|&gt;"</span><span class="p">])</span>

    <span class="n">assert</span> <span class="n">encoded_tokens</span> <span class="o">==</span> <span class="p">[</span>
             <span class="mi">15496</span><span class="p">,</span>
             <span class="mi">11</span><span class="p">,</span>
             <span class="mi">466</span><span class="p">,</span>
             <span class="mi">345</span><span class="p">,</span>
             <span class="mi">588</span><span class="p">,</span>
             <span class="mi">8887</span><span class="p">,</span>
             <span class="mi">30</span><span class="p">,</span>
             <span class="mi">220</span><span class="p">,</span>
             <span class="mi">50256</span><span class="p">,</span>
             <span class="mi">554</span><span class="p">,</span>
             <span class="mi">262</span><span class="p">,</span>
             <span class="mi">4252</span><span class="p">,</span>
             <span class="mi">18250</span><span class="p">,</span>
             <span class="mi">8812</span><span class="p">,</span>
             <span class="mi">2114</span><span class="p">,</span>
             <span class="mi">1659</span><span class="p">,</span>
             <span class="mi">617</span><span class="p">,</span>
             <span class="mi">34680</span><span class="p">,</span>
             <span class="mi">27271</span><span class="p">,</span>
             <span class="mi">13</span>
           <span class="p">]</span>

    <span class="n">assert</span> <span class="no">Tiktoken</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">encoded_tokens</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">text</span><span class="p">}</span>
  <span class="k">end</span>
</code></pre></div></div>

<h3 id="sliding-windows">Sliding windows</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>token sequence
  -&gt; window(size=max_length, stride)
  -&gt; input chunk
  -&gt; target chunk (shifted by 1)
  -&gt; many training samples
</code></pre></div></div>

<p>Ok, this section was hard for me, I did several reading pases. For traning the llm, we do not use just embeded input text. We need to have embeded input/target pairs. Idea is rather simple, we are using sliding window.</p>

<p>Example with sliding window size <code class="language-plaintext highlighter-rouge">5</code> (text only, no embeddings):</p>

<p>Text:
<code class="language-plaintext highlighter-rouge">I love learning how to build small language models</code></p>

<p>Tokens:
<code class="language-plaintext highlighter-rouge">["I", "love", "learning", "how", "to", "build", "small", "language", "models"]</code></p>

<p>Training input/target pairs:</p>

<ol>
  <li>input: <code class="language-plaintext highlighter-rouge">["I", "love", "learning", "how", "to"]</code> -&gt; target: <code class="language-plaintext highlighter-rouge">"build"</code></li>
  <li>input: <code class="language-plaintext highlighter-rouge">["love", "learning", "how", "to", "build"]</code> -&gt; target: <code class="language-plaintext highlighter-rouge">"small"</code></li>
  <li>input: <code class="language-plaintext highlighter-rouge">["learning", "how", "to", "build", "small"]</code> -&gt; target: <code class="language-plaintext highlighter-rouge">"language"</code></li>
  <li>input: <code class="language-plaintext highlighter-rouge">["how", "to", "build", "small", "language"]</code> -&gt; target: <code class="language-plaintext highlighter-rouge">"models"</code></li>
</ol>

<p>Let’s do that in Elixir using The Verdict story. We are using last 50 tokens from The Verdict because ending is interesting part.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"data sampling with sliding window"</span> <span class="k">do</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">file_content</span><span class="p">}</span> <span class="o">=</span> <span class="no">File</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="s2">"the-verdict.txt"</span><span class="p">)</span>
    <span class="n">model</span> <span class="o">=</span> <span class="s2">"code-davinci-002"</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">encoded_tokens</span><span class="p">}</span> <span class="o">=</span> <span class="no">Tiktoken</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">file_content</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">length</span><span class="p">(</span><span class="n">encoded_tokens</span><span class="p">)</span> <span class="o">==</span> <span class="mi">5145</span>
    <span class="n">encoded_last_50</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">drop</span><span class="p">(</span><span class="n">encoded_tokens</span><span class="p">,</span> <span class="mi">50</span><span class="p">)</span>
    <span class="n">context_size</span> <span class="o">=</span> <span class="mi">4</span>

    <span class="n">context_desired_pairs</span> <span class="o">=</span>
      <span class="n">for</span> <span class="n">i</span> <span class="o">&lt;-</span> <span class="mi">1</span><span class="o">..</span><span class="n">context_size</span> <span class="k">do</span>
        <span class="n">context</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">encoded_last_50</span><span class="p">,</span> <span class="mi">0</span><span class="o">..</span><span class="p">(</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
        <span class="n">desired</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">at</span><span class="p">(</span><span class="n">encoded_last_50</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
        <span class="p">{</span><span class="n">context</span><span class="p">,</span> <span class="n">desired</span><span class="p">}</span>
      <span class="k">end</span>

    <span class="n">assert</span> <span class="n">context_desired_pairs</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">{[</span><span class="mi">290</span><span class="p">],</span> <span class="mi">4920</span><span class="p">},</span>
             <span class="p">{[</span><span class="mi">290</span><span class="p">,</span> <span class="mi">4920</span><span class="p">],</span> <span class="mi">2241</span><span class="p">},</span>
             <span class="p">{[</span><span class="mi">290</span><span class="p">,</span> <span class="mi">4920</span><span class="p">,</span> <span class="mi">2241</span><span class="p">],</span> <span class="mi">287</span><span class="p">},</span>
             <span class="p">{[</span><span class="mi">290</span><span class="p">,</span> <span class="mi">4920</span><span class="p">,</span> <span class="mi">2241</span><span class="p">,</span> <span class="mi">287</span><span class="p">],</span> <span class="mi">257</span><span class="p">}</span>
           <span class="p">]</span>

    <span class="n">decoded_context_desired_pairs</span> <span class="o">=</span>
      <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">context_desired_pairs</span><span class="p">,</span> <span class="k">fn</span> <span class="p">{</span><span class="n">current_context</span><span class="p">,</span> <span class="n">current_desired</span><span class="p">}</span> <span class="o">-&gt;</span>
        <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">decoded_current_context</span><span class="p">}</span> <span class="o">=</span> <span class="no">Tiktoken</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">current_context</span><span class="p">)</span>
        <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">decoded_current_desired</span><span class="p">}</span> <span class="o">=</span> <span class="no">Tiktoken</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="p">[</span><span class="n">current_desired</span><span class="p">])</span>
        <span class="p">{</span><span class="n">decoded_current_context</span><span class="p">,</span> <span class="n">decoded_current_desired</span><span class="p">}</span>
      <span class="k">end</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">decoded_context_desired_pairs</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">{</span><span class="s2">" and"</span><span class="p">,</span> <span class="s2">" established"</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">" and established"</span><span class="p">,</span> <span class="s2">" himself"</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">" and established himself"</span><span class="p">,</span> <span class="s2">" in"</span><span class="p">},</span>
             <span class="p">{</span><span class="s2">" and established himself in"</span><span class="p">,</span> <span class="s2">" a"</span><span class="p">}</span>
           <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<blockquote>
  <p><strong><em>Note:</em></strong> <code class="language-plaintext highlighter-rouge">DataSet</code> and <code class="language-plaintext highlighter-rouge">DataLoader</code> is the hardest part so far. I will try to do my best, but you will have to start up your imagination for easier understanding of this concepts.</p>
</blockquote>

<p>Next what we need are DataLoader and DataSet helpers. In Python world we use PyTorch classes, last time I checked in Elixir, there are no such modules in any library. If you know about them, please let me know (LinkedIn is the best option). DataLoader and DataSet I crafted with llm help:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">GptDatasetV1</span> <span class="k">do</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Tokenizes `txt` and builds next-token prediction windows.

  ## Parameters

  - `txt` - input text to tokenize (`String.t()`)
  - `model` - tokenizer model name passed to `Tiktoken.encode/3` (`String.t()`)
  - `max_length` - number of tokens per input/target chunk (`pos_integer()`)
  - `stride` - number of tokens to move between consecutive windows (`pos_integer()`)

  Returns keyword list with:
  - `:input_chunks` - list of input tensors shaped `{max_length}`
  - `:target_chunks` - list of target tensors shaped `{max_length}`

  Each target chunk is the corresponding input chunk shifted by one token.
  Windows are generated with step `stride`.
  """</span>
  <span class="k">def</span> <span class="n">chunk_dataset</span><span class="p">(</span><span class="n">txt</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">max_length</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span> <span class="k">do</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">token_ids</span><span class="p">}</span> <span class="o">=</span> <span class="no">Tiktoken</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">txt</span><span class="p">,</span> <span class="p">[</span><span class="s2">"&lt;|endoftext|&gt;"</span><span class="p">])</span>

    <span class="k">if</span> <span class="n">length</span><span class="p">(</span><span class="n">token_ids</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">max_length</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
      <span class="k">do</span><span class="p">:</span> <span class="k">raise</span><span class="p">(</span><span class="s2">"Number of tokenized inputs must be at least max_length + 1"</span><span class="p">)</span>

    <span class="n">num_chunks</span> <span class="o">=</span> <span class="n">length</span><span class="p">(</span><span class="n">token_ids</span><span class="p">)</span> <span class="o">-</span> <span class="n">max_length</span>

    <span class="n">chunks</span> <span class="o">=</span>
      <span class="mi">0</span><span class="o">..</span><span class="n">num_chunks</span><span class="o">//</span><span class="n">stride</span>
      <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reduce</span><span class="p">([</span><span class="ss">input_chunks:</span> <span class="p">[],</span> <span class="ss">target_chunks:</span> <span class="p">[]],</span> <span class="k">fn</span> <span class="n">i</span><span class="p">,</span> <span class="n">acc</span> <span class="o">-&gt;</span>
        <span class="n">input_chunk</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">token_ids</span><span class="p">,</span> <span class="n">i</span><span class="o">..</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="n">max_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
        <span class="n">target_chunk</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">token_ids</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">..</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="n">max_length</span><span class="p">))</span>

        <span class="p">[</span>
          <span class="ss">input_chunks:</span> <span class="p">[</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">input_chunk</span><span class="p">)</span> <span class="o">|</span> <span class="n">acc</span><span class="p">[</span><span class="ss">:input_chunks</span><span class="p">]],</span>
          <span class="ss">target_chunks:</span> <span class="p">[</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">target_chunk</span><span class="p">)</span> <span class="o">|</span> <span class="n">acc</span><span class="p">[</span><span class="ss">:target_chunks</span><span class="p">]]</span>
        <span class="p">]</span>
      <span class="k">end</span><span class="p">)</span>

    <span class="p">[</span>
      <span class="ss">input_chunks:</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reverse</span><span class="p">(</span><span class="n">chunks</span><span class="p">[</span><span class="ss">:input_chunks</span><span class="p">]),</span>
      <span class="ss">target_chunks:</span> <span class="no">Enum</span><span class="o">.</span><span class="n">reverse</span><span class="p">(</span><span class="n">chunks</span><span class="p">[</span><span class="ss">:target_chunks</span><span class="p">])</span>
    <span class="p">]</span>
  <span class="k">end</span>

  <span class="nv">@doc</span> <span class="sd">"""
  Creates a `LlmScratch.DataLoader` from raw text using GPT-style input/target pairs.

  ## Parameters

  - `opts` - keyword options controlling dataset chunking and dataloader behavior

  ## Options (`opts`)

  - `:raw_text` (required) - source text to tokenize
  - `:batch_size` (default: `4`) - number of `{input, target}` pairs per batch
  - `:max_length` (default: `256`) - token length of each sequence chunk
  - `:stride` (default: `128`) - step between consecutive windows
  - `:shuffle` (default: `true`) - whether to shuffle dataset before cycling
  - `:drop_last` (default: `true`) - whether to drop incomplete batches
  - `:num_workers` (default: `0`) - worker count for async iteration
  """</span>
  <span class="k">def</span> <span class="n">create_dataloader_v1</span><span class="p">(</span><span class="n">opts</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">raw_text</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">fetch!</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:raw_text</span><span class="p">)</span>
    <span class="n">batch_size</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:batch_size</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
    <span class="n">max_length</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:max_length</span><span class="p">,</span> <span class="mi">256</span><span class="p">)</span>
    <span class="n">stride</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:stride</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span>
    <span class="n">shuffle</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:shuffle</span><span class="p">,</span> <span class="no">true</span><span class="p">)</span>
    <span class="n">drop_last</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:drop_last</span><span class="p">,</span> <span class="no">true</span><span class="p">)</span>
    <span class="n">num_workers</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:num_workers</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

    <span class="c1"># Create dataset</span>
    <span class="p">[</span><span class="ss">input_chunks:</span> <span class="n">input_chunks</span><span class="p">,</span> <span class="ss">target_chunks:</span> <span class="n">target_chunks</span><span class="p">]</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">GptDatasetV1</span><span class="o">.</span><span class="n">chunk_dataset</span><span class="p">(</span><span class="n">raw_text</span><span class="p">,</span> <span class="s2">"code-davinci-002"</span><span class="p">,</span> <span class="n">max_length</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>

    <span class="c1"># Zip input and target chunks together so each item is {input, target}</span>
    <span class="n">dataset</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">zip</span><span class="p">(</span><span class="n">input_chunks</span><span class="p">,</span> <span class="n">target_chunks</span><span class="p">)</span>

    <span class="no">LlmScratch</span><span class="o">.</span><span class="no">DataLoader</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span>
      <span class="ss">batch_size:</span> <span class="n">batch_size</span><span class="p">,</span>
      <span class="ss">shuffle:</span> <span class="n">shuffle</span><span class="p">,</span>
      <span class="ss">drop_last:</span> <span class="n">drop_last</span><span class="p">,</span>
      <span class="ss">num_workers:</span> <span class="n">num_workers</span>
    <span class="p">)</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<p>Let’s start with <code class="language-plaintext highlighter-rouge">chunk_dataset</code> function and a test that is using this function:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test</span> <span class="s2">"chunk dataset"</span> <span class="k">do</span>
    <span class="n">txt</span> <span class="o">=</span> <span class="s2">"Hello, do you like tea? &lt;|endoftext|&gt; In the sunlit terraces of someunknownPlace."</span>
    <span class="n">model</span> <span class="o">=</span> <span class="s2">"code-davinci-002"</span>
    <span class="n">max_length</span> <span class="o">=</span> <span class="mi">10</span> <span class="o">-</span> <span class="mi">1</span>
    <span class="n">stride</span> <span class="o">=</span> <span class="mi">5</span>
    <span class="n">chunks</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">GptDatasetV1</span><span class="o">.</span><span class="n">chunk_dataset</span><span class="p">(</span><span class="n">txt</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">max_length</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
    <span class="n">assert</span> <span class="n">length</span><span class="p">(</span><span class="n">chunks</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>

    <span class="p">[</span><span class="ss">input_chunks:</span> <span class="n">input_chunks</span><span class="p">,</span> <span class="ss">target_chunks:</span> <span class="n">target_chunks</span><span class="p">]</span> <span class="o">=</span> <span class="n">chunks</span>

    <span class="n">assert</span> <span class="n">input_chunks</span> <span class="o">==</span> <span class="p">[</span>
             <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">15496</span><span class="p">,</span> <span class="mi">11</span><span class="p">,</span> <span class="mi">466</span><span class="p">,</span> <span class="mi">345</span><span class="p">,</span> <span class="mi">588</span><span class="p">,</span> <span class="mi">8887</span><span class="p">,</span> <span class="mi">30</span><span class="p">,</span> <span class="mi">220</span><span class="p">,</span> <span class="mi">50256</span><span class="p">]),</span>
             <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">8887</span><span class="p">,</span> <span class="mi">30</span><span class="p">,</span> <span class="mi">220</span><span class="p">,</span> <span class="mi">50256</span><span class="p">,</span> <span class="mi">554</span><span class="p">,</span> <span class="mi">262</span><span class="p">,</span> <span class="mi">4252</span><span class="p">,</span> <span class="mi">18250</span><span class="p">,</span> <span class="mi">8812</span><span class="p">]),</span>
             <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">262</span><span class="p">,</span> <span class="mi">4252</span><span class="p">,</span> <span class="mi">18250</span><span class="p">,</span> <span class="mi">8812</span><span class="p">,</span> <span class="mi">2114</span><span class="p">,</span> <span class="mi">286</span><span class="p">,</span> <span class="mi">617</span><span class="p">,</span> <span class="mi">34680</span><span class="p">,</span> <span class="mi">27271</span><span class="p">])</span>
           <span class="p">]</span>

    <span class="n">assert</span> <span class="n">target_chunks</span> <span class="o">==</span> <span class="p">[</span>
             <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">11</span><span class="p">,</span> <span class="mi">466</span><span class="p">,</span> <span class="mi">345</span><span class="p">,</span> <span class="mi">588</span><span class="p">,</span> <span class="mi">8887</span><span class="p">,</span> <span class="mi">30</span><span class="p">,</span> <span class="mi">220</span><span class="p">,</span> <span class="mi">50256</span><span class="p">,</span> <span class="mi">554</span><span class="p">]),</span>
             <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">30</span><span class="p">,</span> <span class="mi">220</span><span class="p">,</span> <span class="mi">50256</span><span class="p">,</span> <span class="mi">554</span><span class="p">,</span> <span class="mi">262</span><span class="p">,</span> <span class="mi">4252</span><span class="p">,</span> <span class="mi">18250</span><span class="p">,</span> <span class="mi">8812</span><span class="p">,</span> <span class="mi">2114</span><span class="p">]),</span>
             <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">4252</span><span class="p">,</span> <span class="mi">18250</span><span class="p">,</span> <span class="mi">8812</span><span class="p">,</span> <span class="mi">2114</span><span class="p">,</span> <span class="mi">286</span><span class="p">,</span> <span class="mi">617</span><span class="p">,</span> <span class="mi">34680</span><span class="p">,</span> <span class="mi">27271</span><span class="p">,</span> <span class="mi">13</span><span class="p">])</span>
           <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p><code class="language-plaintext highlighter-rouge">chunk_dataset</code> has four attributes, please check the function documentation. <code class="language-plaintext highlighter-rouge">txt</code> is input text that we want to tokenize. <code class="language-plaintext highlighter-rouge">model</code> is a code for tokenizer model. We aligned with model used in book, “code-davinci-002” (but it took me some trial and errors to figure out which code model to use, but little tester in me enjoyed that part!).<br />
  <code class="language-plaintext highlighter-rouge">max_length</code> defines number of Tensor elements, or using Tensor language, Nx.shape is {9} as Nx.rank is 1. So we have a list of 9 elements. You can see that shape in test assert. And that is actually window size.<br />
  <code class="language-plaintext highlighter-rouge">stride</code> is how many tokens we skip when we start new window. This is why in input_chunks, second tensor starts with token 8887, because this is first tensor element with index 5 (where we start counting from 0).<br />
  So we are cunking encoded input text in the shape that will be used to train our model.<br />
  <code class="language-plaintext highlighter-rouge">chunk_dataset</code> function is used in <code class="language-plaintext highlighter-rouge">create_dataloader_v1</code>, so let’s learn about that function also with a test example.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="n">test</span> <span class="s2">"gpt dataset v1"</span> <span class="k">do</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">file_content</span><span class="p">}</span> <span class="o">=</span> <span class="no">File</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="s2">"the-verdict.txt"</span><span class="p">)</span>
    <span class="n">model</span> <span class="o">=</span> <span class="s2">"code-davinci-002"</span>

    <span class="n">dataloader</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">GptDatasetV1</span><span class="o">.</span><span class="n">create_dataloader_v1</span><span class="p">(</span>
        <span class="ss">raw_text:</span> <span class="n">file_content</span><span class="p">,</span>
        <span class="ss">batch_size:</span> <span class="mi">1</span><span class="p">,</span>
        <span class="ss">max_length:</span> <span class="mi">4</span><span class="p">,</span>
        <span class="ss">stride:</span> <span class="mi">1</span><span class="p">,</span>
        <span class="ss">shuffle:</span> <span class="no">false</span><span class="p">,</span>
        <span class="ss">drop_last:</span> <span class="no">false</span><span class="p">,</span>
        <span class="ss">num_workers:</span> <span class="mi">0</span>
      <span class="p">)</span>

    <span class="n">batch_0</span> <span class="o">=</span> <span class="n">dataloader</span><span class="o">.</span><span class="n">stream</span> <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="p">[{</span><span class="n">input_chunk</span><span class="p">,</span> <span class="n">target_chunk</span><span class="p">}]</span> <span class="o">=</span> <span class="n">batch_0</span>
    <span class="n">input_decoded_chunk</span> <span class="o">=</span> <span class="n">decode_token_pieces</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_flat_list</span><span class="p">(</span><span class="n">input_chunk</span><span class="p">))</span> 
    <span class="n">target_decoded_chunk</span> <span class="o">=</span> <span class="n">decode_token_pieces</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_flat_list</span><span class="p">(</span><span class="n">target_chunk</span><span class="p">))</span>  
    <span class="n">assert</span> <span class="n">batch_0</span> <span class="o">==</span> <span class="p">[{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">40</span><span class="p">,</span> <span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">,</span> <span class="mi">1807</span><span class="p">])}]</span>
    <span class="n">assert</span> <span class="p">[{</span><span class="n">input_decoded_chunk</span><span class="p">,</span> <span class="n">target_decoded_chunk</span><span class="p">}]</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">{[</span><span class="s2">"I"</span><span class="p">,</span> <span class="s2">" H"</span><span class="p">,</span> <span class="s2">"AD"</span><span class="p">,</span> <span class="s2">" always"</span><span class="p">],</span> <span class="p">[</span><span class="s2">" H"</span><span class="p">,</span> <span class="s2">"AD"</span><span class="p">,</span> <span class="s2">" always"</span><span class="p">,</span> <span class="s2">" thought"</span><span class="p">]}</span>
           <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>Input parameters are well documented using @doc comment, please check that documentation first. <code class="language-plaintext highlighter-rouge">raw_text</code> is human readable text, in our case The Verdict story. <code class="language-plaintext highlighter-rouge">batch_size</code> determins number of pairs {input_tensor, output_tensor} in one batch. <code class="language-plaintext highlighter-rouge">max_length</code> and <code class="language-plaintext highlighter-rouge">stride</code> we already know from <code class="language-plaintext highlighter-rouge">chunk_data</code> function. Last three parameters <code class="language-plaintext highlighter-rouge">shuffle</code>, <code class="language-plaintext highlighter-rouge">drop_last</code> and <code class="language-plaintext highlighter-rouge">num_workers</code> we just pass to <code class="language-plaintext highlighter-rouge">DataLoader</code> and will be explained next as part of <code class="language-plaintext highlighter-rouge">DataLoader</code> explanation.<br />
  To see how data grows with <code class="language-plaintext highlighter-rouge">batch_size</code> 8, we have a test for that. In practice, <code class="language-plaintext highlighter-rouge">batch_size</code> is at least 256.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  <span class="n">test</span> <span class="s2">"gpt dataset v1, batch_size is 8"</span> <span class="k">do</span>
    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">file_content</span><span class="p">}</span> <span class="o">=</span> <span class="no">File</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="s2">"the-verdict.txt"</span><span class="p">)</span>
    <span class="n">model</span> <span class="o">=</span> <span class="s2">"code-davinci-002"</span>

    <span class="n">dataloader</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">GptDatasetV1</span><span class="o">.</span><span class="n">create_dataloader_v1</span><span class="p">(</span>
        <span class="ss">raw_text:</span> <span class="n">file_content</span><span class="p">,</span>
        <span class="ss">batch_size:</span> <span class="mi">8</span><span class="p">,</span>
        <span class="ss">max_length:</span> <span class="mi">4</span><span class="p">,</span>
        <span class="ss">stride:</span> <span class="mi">4</span><span class="p">,</span>
        <span class="ss">shuffle:</span> <span class="no">false</span><span class="p">,</span>
        <span class="ss">drop_last:</span> <span class="no">false</span><span class="p">,</span>
        <span class="ss">num_workers:</span> <span class="mi">0</span>
      <span class="p">)</span>

    <span class="n">batch_0</span> <span class="o">=</span> <span class="n">dataloader</span><span class="o">.</span><span class="n">stream</span> <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

    <span class="n">decoded_batch_0</span> <span class="o">=</span>
      <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">batch_0</span><span class="p">,</span> <span class="k">fn</span> <span class="p">{</span><span class="n">input_chunk</span><span class="p">,</span> <span class="n">target_chunk</span><span class="p">}</span> <span class="o">-&gt;</span>
        <span class="p">{</span>
          <span class="n">decode_token_pieces</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_flat_list</span><span class="p">(</span><span class="n">input_chunk</span><span class="p">)),</span>
          <span class="n">decode_token_pieces</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="no">Nx</span><span class="o">.</span><span class="n">to_flat_list</span><span class="p">(</span><span class="n">target_chunk</span><span class="p">))</span>
        <span class="p">}</span>
      <span class="k">end</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">batch_0</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">40</span><span class="p">,</span> <span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">,</span> <span class="mi">1807</span><span class="p">])},</span>
             <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1807</span><span class="p">,</span> <span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">,</span> <span class="mi">10899</span><span class="p">])},</span>
             <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">10899</span><span class="p">,</span> <span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">,</span> <span class="mi">15632</span><span class="p">])},</span>
             <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">15632</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">922</span><span class="p">])},</span>
             <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">922</span><span class="p">,</span> <span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">568</span><span class="p">])},</span>
             <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">568</span><span class="p">,</span> <span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">,</span> <span class="mi">1049</span><span class="p">])},</span>
             <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1049</span><span class="p">,</span> <span class="mi">5975</span><span class="p">,</span> <span class="mi">284</span><span class="p">,</span> <span class="mi">502</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">5975</span><span class="p">,</span> <span class="mi">284</span><span class="p">,</span> <span class="mi">502</span><span class="p">,</span> <span class="mi">284</span><span class="p">])},</span>
             <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">284</span><span class="p">,</span> <span class="mi">3285</span><span class="p">,</span> <span class="mi">326</span><span class="p">,</span> <span class="mi">11</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3285</span><span class="p">,</span> <span class="mi">326</span><span class="p">,</span> <span class="mi">11</span><span class="p">,</span> <span class="mi">287</span><span class="p">])}</span>
           <span class="p">]</span>

    <span class="n">assert</span> <span class="n">decoded_batch_0</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">{[</span><span class="s2">"I"</span><span class="p">,</span> <span class="s2">" H"</span><span class="p">,</span> <span class="s2">"AD"</span><span class="p">,</span> <span class="s2">" always"</span><span class="p">],</span> <span class="p">[</span><span class="s2">" H"</span><span class="p">,</span> <span class="s2">"AD"</span><span class="p">,</span> <span class="s2">" always"</span><span class="p">,</span> <span class="s2">" thought"</span><span class="p">]},</span>
             <span class="p">{[</span><span class="s2">" thought"</span><span class="p">,</span> <span class="s2">" Jack"</span><span class="p">,</span> <span class="s2">" G"</span><span class="p">,</span> <span class="s2">"is"</span><span class="p">],</span> <span class="p">[</span><span class="s2">" Jack"</span><span class="p">,</span> <span class="s2">" G"</span><span class="p">,</span> <span class="s2">"is"</span><span class="p">,</span> <span class="s2">"burn"</span><span class="p">]},</span>
             <span class="p">{[</span><span class="s2">"burn"</span><span class="p">,</span> <span class="s2">" rather"</span><span class="p">,</span> <span class="s2">" a"</span><span class="p">,</span> <span class="s2">" cheap"</span><span class="p">],</span> <span class="p">[</span><span class="s2">" rather"</span><span class="p">,</span> <span class="s2">" a"</span><span class="p">,</span> <span class="s2">" cheap"</span><span class="p">,</span> <span class="s2">" genius"</span><span class="p">]},</span>
             <span class="p">{[</span><span class="s2">" genius"</span><span class="p">,</span> <span class="s2">"--"</span><span class="p">,</span> <span class="s2">"though"</span><span class="p">,</span> <span class="s2">" a"</span><span class="p">],</span> <span class="p">[</span><span class="s2">"--"</span><span class="p">,</span> <span class="s2">"though"</span><span class="p">,</span> <span class="s2">" a"</span><span class="p">,</span> <span class="s2">" good"</span><span class="p">]},</span>
             <span class="p">{[</span><span class="s2">" good"</span><span class="p">,</span> <span class="s2">" fellow"</span><span class="p">,</span> <span class="s2">" enough"</span><span class="p">,</span> <span class="s2">"--"</span><span class="p">],</span> <span class="p">[</span><span class="s2">" fellow"</span><span class="p">,</span> <span class="s2">" enough"</span><span class="p">,</span> <span class="s2">"--"</span><span class="p">,</span> <span class="s2">"so"</span><span class="p">]},</span>
             <span class="p">{[</span><span class="s2">"so"</span><span class="p">,</span> <span class="s2">" it"</span><span class="p">,</span> <span class="s2">" was"</span><span class="p">,</span> <span class="s2">" no"</span><span class="p">],</span> <span class="p">[</span><span class="s2">" it"</span><span class="p">,</span> <span class="s2">" was"</span><span class="p">,</span> <span class="s2">" no"</span><span class="p">,</span> <span class="s2">" great"</span><span class="p">]},</span>
             <span class="p">{[</span><span class="s2">" great"</span><span class="p">,</span> <span class="s2">" surprise"</span><span class="p">,</span> <span class="s2">" to"</span><span class="p">,</span> <span class="s2">" me"</span><span class="p">],</span> <span class="p">[</span><span class="s2">" surprise"</span><span class="p">,</span> <span class="s2">" to"</span><span class="p">,</span> <span class="s2">" me"</span><span class="p">,</span> <span class="s2">" to"</span><span class="p">]},</span>
             <span class="p">{[</span><span class="s2">" to"</span><span class="p">,</span> <span class="s2">" hear"</span><span class="p">,</span> <span class="s2">" that"</span><span class="p">,</span> <span class="s2">","</span><span class="p">],</span> <span class="p">[</span><span class="s2">" hear"</span><span class="p">,</span> <span class="s2">" that"</span><span class="p">,</span> <span class="s2">","</span><span class="p">,</span> <span class="s2">" in"</span><span class="p">]}</span>
           <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>And now is time to tackle the <code class="language-plaintext highlighter-rouge">DataLoader</code> beast:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
  <span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">DataLoader</span> <span class="k">do</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Builds a data loader map from an in-memory dataset.

  Returns a map with:
  - `:stream` - an infinite stream of batches
  - `:batch_size` - configured batch size (default `32`)
  - `:drop_last` - whether incomplete batches are dropped
  - `:num_workers` - concurrency used by `iterate/2`

  ## Options
  - `:batch_size` - number of samples per batch (default: `32`)
  - `:shuffle` - shuffles dataset once before cycling (default: `true`)
  - `:drop_last` - drops batches smaller than `:batch_size` (default: `false`)
  - `:num_workers` - parallel workers for iteration (default: `0`)
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">opts</span> <span class="p">\\</span> <span class="p">[])</span> <span class="ow">when</span> <span class="n">is_list</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">batch_size</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:batch_size</span><span class="p">,</span> <span class="mi">32</span><span class="p">)</span>
    <span class="n">shuffle</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:shuffle</span><span class="p">,</span> <span class="no">true</span><span class="p">)</span>
    <span class="n">drop_last</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:drop_last</span><span class="p">,</span> <span class="no">false</span><span class="p">)</span>
    <span class="n">num_workers</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:num_workers</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

    <span class="n">stream</span> <span class="o">=</span>
      <span class="n">dataset</span>
      <span class="o">|&gt;</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="n">shuffle</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="no">Stream</span><span class="o">.</span><span class="n">chunk_every</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
      <span class="o">|&gt;</span> <span class="n">filter_incomplete_batches</span><span class="p">(</span><span class="n">drop_last</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>

    <span class="p">%{</span>
      <span class="ss">stream:</span> <span class="n">stream</span><span class="p">,</span>
      <span class="ss">batch_size:</span> <span class="n">batch_size</span><span class="p">,</span>
      <span class="ss">drop_last:</span> <span class="n">drop_last</span><span class="p">,</span>
      <span class="ss">num_workers:</span> <span class="n">num_workers</span>
    <span class="p">}</span>
  <span class="k">end</span>
  
  <span class="nv">@doc</span> <span class="sd">"""
  Iterates over loader batches and applies `fun` to each batch.

  When `num_workers` is `0`, batches are processed sequentially.
  When `num_workers` is greater than `0`, batches are processed concurrently
  using `Task.async_stream/3`.
  """</span>
  <span class="k">def</span> <span class="n">iterate</span><span class="p">(%{</span><span class="ss">stream:</span> <span class="n">stream</span><span class="p">,</span> <span class="ss">num_workers:</span> <span class="mi">0</span><span class="p">},</span> <span class="n">fun</span><span class="p">)</span> <span class="ow">when</span> <span class="n">is_function</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="k">do</span>
    <span class="no">Enum</span><span class="o">.</span><span class="n">each</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="n">fun</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">def</span> <span class="n">iterate</span><span class="p">(%{</span><span class="ss">stream:</span> <span class="n">stream</span><span class="p">,</span> <span class="ss">num_workers:</span> <span class="n">num_workers</span><span class="p">},</span> <span class="n">fun</span><span class="p">)</span>
      <span class="ow">when</span> <span class="n">is_function</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="ow">and</span> <span class="n">num_workers</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">do</span>
    <span class="n">stream</span>
    <span class="o">|&gt;</span> <span class="no">Task</span><span class="o">.</span><span class="n">async_stream</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="ss">max_concurrency:</span> <span class="n">num_workers</span><span class="p">,</span> <span class="ss">ordered:</span> <span class="no">false</span><span class="p">)</span>
    <span class="o">|&gt;</span> <span class="no">Stream</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="no">true</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">dataset</span>
    <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">shuffle</span><span class="p">()</span>
    <span class="o">|&gt;</span> <span class="no">Stream</span><span class="o">.</span><span class="n">cycle</span><span class="p">()</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="no">false</span><span class="p">)</span> <span class="k">do</span>
    <span class="no">Stream</span><span class="o">.</span><span class="n">cycle</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">filter_incomplete_batches</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="no">true</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span> <span class="k">do</span>
    <span class="no">Stream</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="o">&amp;</span><span class="p">(</span><span class="n">length</span><span class="p">(</span><span class="nv">&amp;1</span><span class="p">)</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">))</span>
  <span class="k">end</span>

  <span class="k">defp</span> <span class="n">filter_incomplete_batches</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="no">false</span><span class="p">,</span> <span class="n">_batch_size</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">stream</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<p>Documentation starts with function output struct. First key in that struct is <code class="language-plaintext highlighter-rouge">:datastream</code>, that is an infinitive of batches. So what does infinitive stream really means? If you check private function <code class="language-plaintext highlighter-rouge">prepare_dataset</code>, you will see that returns call to <code class="language-plaintext highlighter-rouge">Stream.cycle(dataset)</code>. We know that <code class="language-plaintext highlighter-rouge">dataset</code> are our chunks of input data, but our input data is not infinitive, it is The Verdict story. Here is example what <code class="language-plaintext highlighter-rouge">Stream.cycle</code> do:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">]</span> <span class="o">|&gt;</span> <span class="no">Stream</span><span class="o">.</span><span class="n">cycle</span><span class="p">()</span> <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span>
<span class="c1"># =&gt; [1,2,3,1,2,3,1,2]</span>
</code></pre></div></div>
<p>As the name states, it just cycles through input list, repeating it indefinetly. <code class="language-plaintext highlighter-rouge">Stream.cycle</code> does not hold this data in memory, it just know how to do the cycle via function calls. <code class="language-plaintext highlighter-rouge">Stream.cycle</code> always goes hand with hand with <code class="language-plaintext highlighter-rouge">Enum</code> function calls. `Enum.take(8) takes first stream elements.<br />
But this will run indefinetly:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">]</span> <span class="o">|&gt;</span> <span class="no">Stream</span><span class="o">.</span><span class="n">cycle</span><span class="p">()</span> <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">each</span><span class="p">(</span><span class="k">fn</span> <span class="n">x</span> <span class="o">-&gt;</span> <span class="n">x</span> <span class="k">end</span><span class="p">)</span>
</code></pre></div></div>

<p>Because <code class="language-plaintext highlighter-rouge">Stream.cycle</code> generates indefinet list.<br />
<code class="language-plaintext highlighter-rouge">batch_size</code> we know from <code class="language-plaintext highlighter-rouge">create_dataloader_v1</code>, it is value how we chunk <code class="language-plaintext highlighter-rouge">dataset</code><br />
<code class="language-plaintext highlighter-rouge">shuffle</code> randomize input <code class="language-plaintext highlighter-rouge">dataset</code> original order that we got from <code class="language-plaintext highlighter-rouge">chunk_dataset</code>, that helps us to get better llm model.<br />
<code class="language-plaintext highlighter-rouge">num_workers</code> is used in <code class="language-plaintext highlighter-rouge">iterate</code> function. In <code class="language-plaintext highlighter-rouge">iterate</code> function we run over batches some function of our choice. If <code class="language-plaintext highlighter-rouge">num_workers</code> if &gt; 1, then we do this concurrently.<br />
<code class="language-plaintext highlighter-rouge">drop_last</code> will be explained with test as example:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"data loader drop_last true drops trailing incomplete batch from tuple dataset"</span> <span class="k">do</span>
    <span class="n">dataset</span> <span class="o">=</span> <span class="p">[</span>
      <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">40</span><span class="p">,</span> <span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">,</span> <span class="mi">1807</span><span class="p">])},</span>
      <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1807</span><span class="p">,</span> <span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">,</span> <span class="mi">10899</span><span class="p">])},</span>
      <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">10899</span><span class="p">,</span> <span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">,</span> <span class="mi">15632</span><span class="p">])},</span>
      <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">15632</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">922</span><span class="p">])},</span>
      <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">922</span><span class="p">,</span> <span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">568</span><span class="p">])},</span>
      <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">568</span><span class="p">,</span> <span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">,</span> <span class="mi">1049</span><span class="p">])},</span>
      <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1049</span><span class="p">,</span> <span class="mi">5975</span><span class="p">,</span> <span class="mi">284</span><span class="p">,</span> <span class="mi">502</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">5975</span><span class="p">,</span> <span class="mi">284</span><span class="p">,</span> <span class="mi">502</span><span class="p">,</span> <span class="mi">284</span><span class="p">])},</span>
      <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">284</span><span class="p">,</span> <span class="mi">3285</span><span class="p">,</span> <span class="mi">326</span><span class="p">,</span> <span class="mi">11</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3285</span><span class="p">,</span> <span class="mi">326</span><span class="p">,</span> <span class="mi">11</span><span class="p">,</span> <span class="mi">287</span><span class="p">])}</span>
    <span class="p">]</span>

    <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">3</span>

    <span class="n">dataloader_keep_last</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">DataLoader</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="ss">batch_size:</span> <span class="n">batch_size</span><span class="p">,</span> <span class="ss">shuffle:</span> <span class="no">false</span><span class="p">,</span> <span class="ss">drop_last:</span> <span class="no">false</span><span class="p">)</span>

    <span class="n">dataloader_drop_last</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">DataLoader</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="ss">batch_size:</span> <span class="n">batch_size</span><span class="p">,</span> <span class="ss">shuffle:</span> <span class="no">false</span><span class="p">,</span> <span class="ss">drop_last:</span> <span class="no">true</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Enum</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">dataloader_keep_last</span><span class="o">.</span><span class="n">stream</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">[</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">40</span><span class="p">,</span> <span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">,</span> <span class="mi">1807</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1807</span><span class="p">,</span> <span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">,</span> <span class="mi">10899</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">10899</span><span class="p">,</span> <span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">,</span> <span class="mi">15632</span><span class="p">])}</span>
             <span class="p">],</span>
             <span class="p">[</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">15632</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">922</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">922</span><span class="p">,</span> <span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">568</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">568</span><span class="p">,</span> <span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">,</span> <span class="mi">1049</span><span class="p">])}</span>
             <span class="p">],</span>
             <span class="p">[</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1049</span><span class="p">,</span> <span class="mi">5975</span><span class="p">,</span> <span class="mi">284</span><span class="p">,</span> <span class="mi">502</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">5975</span><span class="p">,</span> <span class="mi">284</span><span class="p">,</span> <span class="mi">502</span><span class="p">,</span> <span class="mi">284</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">284</span><span class="p">,</span> <span class="mi">3285</span><span class="p">,</span> <span class="mi">326</span><span class="p">,</span> <span class="mi">11</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3285</span><span class="p">,</span> <span class="mi">326</span><span class="p">,</span> <span class="mi">11</span><span class="p">,</span> <span class="mi">287</span><span class="p">])}</span>
             <span class="p">]</span>
           <span class="p">]</span>

    <span class="n">assert</span> <span class="no">Enum</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">dataloader_drop_last</span><span class="o">.</span><span class="n">stream</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">==</span> <span class="p">[</span>
             <span class="p">[</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">40</span><span class="p">,</span> <span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">,</span> <span class="mi">1807</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1807</span><span class="p">,</span> <span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">,</span> <span class="mi">10899</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">10899</span><span class="p">,</span> <span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">,</span> <span class="mi">15632</span><span class="p">])}</span>
             <span class="p">],</span>
             <span class="p">[</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">15632</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">922</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">922</span><span class="p">,</span> <span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">568</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">568</span><span class="p">,</span> <span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">,</span> <span class="mi">1049</span><span class="p">])}</span>
             <span class="p">],</span>
             <span class="p">[</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">40</span><span class="p">,</span> <span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">,</span> <span class="mi">1807</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1807</span><span class="p">,</span> <span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">,</span> <span class="mi">10899</span><span class="p">])},</span>
               <span class="p">{</span><span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">10899</span><span class="p">,</span> <span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">]),</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">,</span> <span class="mi">15632</span><span class="p">])}</span>
             <span class="p">]</span>
           <span class="p">]</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>If we divide the number of dataset elements by <code class="language-plaintext highlighter-rouge">batch_size</code>, in this test we get 8/3 with a remainder of 2. With <code class="language-plaintext highlighter-rouge">drop_last</code>, we can drop the final batch if its size differs from the others.</p>

<p>That is it for the dataloader, which helps organize training data. Now we move to token embeddings.</p>

<h3 id="token-embeddings">Token embeddings</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>input token IDs {batch, seq_len}
  -&gt; embedding lookup table {vocab_size, emb_dim}
  -&gt; token vectors {batch, seq_len, emb_dim}
</code></pre></div></div>

<p>To use token IDs for LLM training, we must convert them (remember Bruce’s <code class="language-plaintext highlighter-rouge">construct |&gt; reduce |&gt; convert</code>?) into embedding vectors. But what is an embedding vector, how do we create it, and why do we need it?<br />
  We need embeddings because an LLM is a neural network, and token IDs like 7 and 8 are just integers with no built-in notion of similarity. Embeddings use floating-point vectors so the model can learn relationships between tokens. We also use matrices for the same reason: to represent how tokens relate to one another.
  Next is the “what” and “how.” Python has <code class="language-plaintext highlighter-rouge">torch.nn.Embedding</code>. In Elixir, we have Nx and Axon. Meet <code class="language-plaintext highlighter-rouge">LlmScratch.EmbeddingNative</code>.</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
  <span class="k">defmodule</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span> <span class="k">do</span>
  <span class="nv">@moduledoc</span> <span class="sd">"""
  Native embedding layer implemented with Nx and Axon.

  Provides a small API compatible with `LlmScratch.Embedding`:

    * `new/3`
    * `weight/1`
    * `forward/2`
    * `call/2`
  """</span>

  <span class="k">defstruct</span> <span class="p">[</span><span class="ss">:weight</span><span class="p">,</span> <span class="ss">:vocab_size</span><span class="p">,</span> <span class="ss">:embedding_dim</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">]</span>

  <span class="nv">@type</span> <span class="n">t</span> <span class="p">::</span> <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
          <span class="ss">weight:</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span>
          <span class="ss">vocab_size:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">embedding_dim:</span> <span class="n">pos_integer</span><span class="p">(),</span>
          <span class="ss">seed:</span> <span class="n">integer</span><span class="p">()</span> <span class="o">|</span> <span class="no">nil</span>
        <span class="p">}</span>

  <span class="nv">@spec</span> <span class="n">new</span><span class="p">(</span><span class="n">pos_integer</span><span class="p">(),</span> <span class="n">pos_integer</span><span class="p">(),</span> <span class="n">keyword</span><span class="p">())</span> <span class="p">::</span> <span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Creates a new embedding layer.

  ## Options

    * `:seed` - integer seed used to initialize weights deterministically.
    * `:weight` - optional pre-initialized weight tensor with shape
      `{vocab_size, embedding_dim}`.
  """</span>
  <span class="k">def</span> <span class="n">new</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">opts</span> <span class="p">\\</span> <span class="p">[])</span>
      <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">)</span> <span class="ow">and</span> <span class="n">vocab_size</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">embedding_dim</span><span class="p">)</span> <span class="ow">and</span>
             <span class="n">embedding_dim</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">do</span>
    <span class="n">seed</span> <span class="o">=</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:seed</span><span class="p">)</span>

    <span class="n">weight</span> <span class="o">=</span>
      <span class="k">case</span> <span class="no">Keyword</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">opts</span><span class="p">,</span> <span class="ss">:weight</span><span class="p">)</span> <span class="k">do</span>
        <span class="p">%</span><span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="p">{}</span> <span class="o">=</span> <span class="n">provided_weight</span> <span class="o">-&gt;</span>
          <span class="n">expected_shape</span> <span class="o">=</span> <span class="p">{</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">}</span>

          <span class="k">if</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">provided_weight</span><span class="p">)</span> <span class="o">!=</span> <span class="n">expected_shape</span> <span class="k">do</span>
            <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span>
                  <span class="s2">"expected :weight shape </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">expected_shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">provided_weight</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
          <span class="k">end</span>

          <span class="no">Nx</span><span class="o">.</span><span class="n">as_type</span><span class="p">(</span><span class="n">provided_weight</span><span class="p">,</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

        <span class="no">nil</span> <span class="o">-&gt;</span>
          <span class="n">init_weight</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">seed</span><span class="p">)</span>
      <span class="k">end</span>

    <span class="p">%</span><span class="bp">__MODULE__</span><span class="p">{</span>
      <span class="ss">weight:</span> <span class="n">weight</span><span class="p">,</span>
      <span class="ss">vocab_size:</span> <span class="n">vocab_size</span><span class="p">,</span>
      <span class="ss">embedding_dim:</span> <span class="n">embedding_dim</span><span class="p">,</span>
      <span class="ss">seed:</span> <span class="n">seed</span>
    <span class="p">}</span>
  <span class="k">end</span>

  <span class="nv">@spec</span> <span class="n">weight</span><span class="p">(</span><span class="n">t</span><span class="p">())</span> <span class="p">::</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Returns the embedding weight matrix.

  The returned tensor has shape `{vocab_size, embedding_dim}`.
  """</span>
  <span class="k">def</span> <span class="n">weight</span><span class="p">(%</span><span class="bp">__MODULE__</span><span class="p">{</span><span class="ss">weight:</span> <span class="n">weight</span><span class="p">}),</span> <span class="k">do</span><span class="p">:</span> <span class="n">weight</span>

  <span class="nv">@spec</span> <span class="n">forward</span><span class="p">(</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">())</span> <span class="p">::</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Maps token ids to embedding vectors.

  `token_ids` is cast to `{:s, 64}` and used as row indices into the
  embedding matrix.
  """</span>
  <span class="k">def</span> <span class="n">forward</span><span class="p">(%</span><span class="bp">__MODULE__</span><span class="p">{</span><span class="ss">weight:</span> <span class="n">weight</span><span class="p">},</span> <span class="n">token_ids</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">token_ids</span>
    <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">as_type</span><span class="p">({</span><span class="ss">:s</span><span class="p">,</span> <span class="mi">64</span><span class="p">})</span>
    <span class="o">|&gt;</span> <span class="n">then</span><span class="p">(</span><span class="o">&amp;</span><span class="no">Nx</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">weight</span><span class="p">,</span> <span class="nv">&amp;1</span><span class="p">,</span> <span class="ss">axis:</span> <span class="mi">0</span><span class="p">))</span>
  <span class="k">end</span>

  <span class="nv">@spec</span> <span class="n">call</span><span class="p">(</span><span class="n">t</span><span class="p">(),</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">())</span> <span class="p">::</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Tensor</span><span class="o">.</span><span class="n">t</span><span class="p">()</span>
  <span class="nv">@doc</span> <span class="sd">"""
  Alias for `forward/2`.
  """</span>
  <span class="k">def</span> <span class="n">call</span><span class="p">(</span><span class="n">embedding</span><span class="p">,</span> <span class="n">token_ids</span><span class="p">),</span> <span class="k">do</span><span class="p">:</span> <span class="n">forward</span><span class="p">(</span><span class="n">embedding</span><span class="p">,</span> <span class="n">token_ids</span><span class="p">)</span>

  <span class="k">defp</span> <span class="n">init_weight</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">seed</span><span class="p">)</span> <span class="k">do</span>
    <span class="n">key</span> <span class="o">=</span>
      <span class="k">case</span> <span class="n">seed</span> <span class="k">do</span>
        <span class="no">nil</span> <span class="o">-&gt;</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="no">System</span><span class="o">.</span><span class="n">unique_integer</span><span class="p">([</span><span class="ss">:positive</span><span class="p">]))</span>
        <span class="n">int</span> <span class="ow">when</span> <span class="n">is_integer</span><span class="p">(</span><span class="n">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">int</span><span class="p">)</span>
        <span class="n">other</span> <span class="o">-&gt;</span> <span class="k">raise</span> <span class="no">ArgumentError</span><span class="p">,</span> <span class="s2">"seed must be an integer or nil, got: </span><span class="si">#{</span><span class="n">inspect</span><span class="p">(</span><span class="n">other</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
      <span class="k">end</span>

    <span class="no">Axon</span><span class="o">.</span><span class="no">Initializers</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="ss">scale:</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="p">({</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">},</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">},</span> <span class="n">key</span><span class="p">)</span>
  <span class="k">end</span>
<span class="k">end</span>
</code></pre></div></div>

<p>This module returns a struct that includes vocabulary size, embedding dimension, seed, and a weight tensor.<br />
The <code class="language-plaintext highlighter-rouge">new</code> function requires vocabulary size and embedding dimension, while <code class="language-plaintext highlighter-rouge">weight</code> and <code class="language-plaintext highlighter-rouge">seed</code> are optional. If we provide <code class="language-plaintext highlighter-rouge">weight</code>, it must be an <code class="language-plaintext highlighter-rouge">Nx.Tensor</code> with shape <code class="language-plaintext highlighter-rouge">{vocab_size, embedding_dim}</code>. Then <code class="language-plaintext highlighter-rouge">Nx.as_type</code> casts the weight values to <code class="language-plaintext highlighter-rouge">float32</code>.<br />
If we do not provide <code class="language-plaintext highlighter-rouge">weight</code>, the private <code class="language-plaintext highlighter-rouge">init_weight</code> function is used. First, it calculates a random key from <code class="language-plaintext highlighter-rouge">seed</code>. If <code class="language-plaintext highlighter-rouge">seed</code> is not set, it uses a random positive integer. Why do we need a random key? Because we start with random <code class="language-plaintext highlighter-rouge">weight</code> values for token IDs. <code class="language-plaintext highlighter-rouge">LlmScratch.Random.manual_seed(int)</code> is based on <code class="language-plaintext highlighter-rouge">Nx.Random.key(int)</code>:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="no">Nx</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">123</span><span class="p">)</span>
<span class="c1">#Nx.Tensor&lt;</span>
  <span class="n">u32</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
  <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">123</span><span class="p">]</span>
<span class="o">&gt;</span>
<span class="n">iex</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">123</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span>
<span class="p">{</span><span class="mi">2</span><span class="p">}</span>
<span class="n">iex</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span><span class="o">&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="no">Random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">123</span><span class="p">)</span> <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">rank</span>
<span class="mi">1</span>
</code></pre></div></div>

<p>So the random key is an <code class="language-plaintext highlighter-rouge">Nx.Tensor</code> (because we do math operations with other <code class="language-plaintext highlighter-rouge">Nx.Tensor</code>s), with shape <code class="language-plaintext highlighter-rouge">{2}</code> and rank 1, where values are 0 and <code class="language-plaintext highlighter-rouge">seed</code>.</p>

<p>The final call is <code class="language-plaintext highlighter-rouge">Axon.Initializers.uniform(scale: 1.0)</code>, which generates an initializer function that fills weights with random values from a uniform distribution. The <code class="language-plaintext highlighter-rouge">weight</code> shape is <code class="language-plaintext highlighter-rouge">{vocab_size, embedding_dim}</code>, values are cast to <code class="language-plaintext highlighter-rouge">float32</code>, and the random key controls initialization.<br />
<code class="language-plaintext highlighter-rouge">forward</code> maps token IDs to weights. Let’s look at a test to see how this works:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"Elixir-style EmbeddingNative with manual_seed (torch.nn.Embedding equivalent)"</span> <span class="k">do</span>
    <span class="n">vocab_size</span> <span class="o">=</span> <span class="mi">6</span>
    <span class="n">embedding_dim</span> <span class="o">=</span> <span class="mi">3</span>

    <span class="n">embedding</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="n">assert</span> <span class="n">embedding</span><span class="o">.</span><span class="n">vocab_size</span> <span class="o">==</span> <span class="n">vocab_size</span>
    <span class="n">assert</span> <span class="n">embedding</span><span class="o">.</span><span class="n">embedding_dim</span> <span class="o">==</span> <span class="n">embedding_dim</span>

    <span class="n">weight</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span><span class="o">.</span><span class="n">weight</span><span class="p">(</span><span class="n">embedding</span><span class="p">)</span>

    <span class="n">expected_weight</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span><span class="mf">0.3373701572418213</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1777772158384323</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.16895616054534912</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">0.9177640080451965</span><span class="p">,</span> <span class="mf">1.5809690952301025</span><span class="p">,</span> <span class="mf">1.3010399341583252</span><span class="p">],</span>
          <span class="p">[</span><span class="mf">1.275301218032837</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.20095309615135193</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.16056379675865173</span><span class="p">],</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">0.40148791670799255</span><span class="p">,</span> <span class="mf">0.966571569442749</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.1481444835662842</span><span class="p">],</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">1.158868670463562</span><span class="p">,</span> <span class="mf">0.32547101378440857</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6315054297447205</span><span class="p">],</span>
          <span class="p">[</span><span class="o">-</span><span class="mf">2.839993953704834</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.7848533391952515</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.4095723628997803</span><span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">weight</span><span class="p">,</span> <span class="n">expected_weight</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">),</span>
           <span class="s2">"Embedding weights should match PyTorch's weights exactly with seed=123"</span>

    <span class="n">row_3</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span><span class="n">weight</span><span class="p">,</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">])</span>

    <span class="n">expected_row_3</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">0.40148791670799255</span><span class="p">,</span> <span class="mf">0.966571569442749</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.1481444835662842</span><span class="p">],</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">row_3</span><span class="p">,</span> <span class="n">expected_row_3</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">)</span>

    <span class="n">input_ids</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">]],</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:s</span><span class="p">,</span> <span class="mi">64</span><span class="p">})</span>
    <span class="n">embeddings</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">embedding</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">)</span>

    <span class="n">expected_forward</span> <span class="o">=</span>
      <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
        <span class="p">[</span>
          <span class="p">[</span>
            <span class="p">[</span><span class="mf">1.275301218032837</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.20095309615135193</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.16056379675865173</span><span class="p">],</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">0.40148791670799255</span><span class="p">,</span> <span class="mf">0.966571569442749</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.1481444835662842</span><span class="p">],</span>
            <span class="p">[</span><span class="o">-</span><span class="mf">2.839993953704834</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.7848533391952515</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.4095723628997803</span><span class="p">],</span>
            <span class="p">[</span><span class="mf">0.9177640080451965</span><span class="p">,</span> <span class="mf">1.5809690952301025</span><span class="p">,</span> <span class="mf">1.3010399341583252</span><span class="p">]</span>
          <span class="p">]</span>
        <span class="p">],</span>
        <span class="ss">type:</span> <span class="p">{</span><span class="ss">:f</span><span class="p">,</span> <span class="mi">32</span><span class="p">}</span>
      <span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">expected_forward</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">),</span>
           <span class="s2">"Embeddings from forward pass should match expected values exactly"</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>In this test, for simplicity, we have a vocabulary of six tokens, with possible values <code class="language-plaintext highlighter-rouge">[0, 1, 2, 3, 4, 5]</code>. We create an embedding with shape <code class="language-plaintext highlighter-rouge">{6, 3}</code>, and <code class="language-plaintext highlighter-rouge">expected_weight</code> is an <code class="language-plaintext highlighter-rouge">Nx.tensor</code> with random values based on seed 123. We do not assert with <code class="language-plaintext highlighter-rouge">==</code>; we use <code class="language-plaintext highlighter-rouge">Nx.all_close</code> with a threshold of <code class="language-plaintext highlighter-rouge">1.0e-6</code> (<code class="language-plaintext highlighter-rouge">0.000001</code>), so values are considered equal if they differ only within that tolerance. The reason is that weights are floating-point values.<br />
 <code class="language-plaintext highlighter-rouge">Nx.slice</code> is a handy way to extract data from a weight <code class="language-plaintext highlighter-rouge">Nx.Tensor</code>. We extract from the fourth row and first column, taking one row and three columns.<br />
 Here is an example of using <code class="language-plaintext highlighter-rouge">forward</code>. We first create token IDs as an <code class="language-plaintext highlighter-rouge">Nx.Tensor</code> of type <code class="language-plaintext highlighter-rouge">{:s, 64}</code>, and then call <code class="language-plaintext highlighter-rouge">forward</code> on the embedding. The result is an <code class="language-plaintext highlighter-rouge">Nx.Tensor</code> that contains the 3rd, 4th, 6th, and 2nd rows from the <code class="language-plaintext highlighter-rouge">weight</code> tensor, because those are the ID values (zero-based indexing). Under the hood, <code class="language-plaintext highlighter-rouge">forward</code> uses <code class="language-plaintext highlighter-rouge">Nx.take</code>.</p>

<h3 id="encoding-token-positions">Encoding token positions</h3>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>token embeddings (same token -&gt; same vector)
                +
positional embeddings (position 0..n-1)
                |
                v
order-aware embeddings for the transformer
</code></pre></div></div>

<p>We are almost done. We have token embedding vectors, but one important ingredient is still missing: token position (order) in the input text. This means that for our previous token IDs, <code class="language-plaintext highlighter-rouge">[2,3,5,1]</code>, token ID <code class="language-plaintext highlighter-rouge">3</code> will always be mapped to the same row in the weight tensor, regardless of where it appears in the sequence. Here is a ChatGPT 5.3 explanation:</p>
<blockquote>
  <p><strong>Note</strong>:</p>
</blockquote>

<blockquote>
  <p>It’s a problem because language meaning depends on word order, not just which words appear.
If the embedding layer maps a token ID to the same vector everywhere, then:</p>

  <ul>
    <li>“dog bites man” and “man bites dog” start with the same set of token vectors (just reordered)</li>
    <li>The model has no built-in way to know which token came first, second, etc.</li>
    <li>So it loses syntax/order information (who did what to whom)</li>
  </ul>

  <p>In other words, token embeddings alone encode identity (“this is the word dog”), but not position (“this dog is at position 3”).</p>

  <p>That’s why transformers add positional encodings/positional embeddings: to combine</p>
  <ul>
    <li>token meaning</li>
    <li>token position<br />
so the model can distinguish sequences with the same tokens in different orders.</li>
  </ul>
</blockquote>

<p>We have absolute and relative positional embeddings. With absolute positional embeddings, we encode the position of each token. With relative positional embeddings, we encode <code class="language-plaintext highlighter-rouge">how far apart</code> tokens are from each other. OpenAI ChatGPT uses absolute positional embeddings. Let’s add positional embeddings in a more realistic example:</p>

<div class="language-elixir highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">test</span> <span class="s2">"positional embedding"</span> <span class="k">do</span>
    <span class="n">previous_backend</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">default_backend</span><span class="p">()</span>
    <span class="no">Nx</span><span class="o">.</span><span class="n">default_backend</span><span class="p">(</span><span class="no">EXLA</span><span class="o">.</span><span class="no">Backend</span><span class="p">)</span>
    <span class="n">on_exit</span><span class="p">(</span><span class="k">fn</span> <span class="o">-&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">default_backend</span><span class="p">(</span><span class="n">previous_backend</span><span class="p">)</span> <span class="k">end</span><span class="p">)</span>

    <span class="n">vocab_size</span> <span class="o">=</span> <span class="mi">50257</span>
    <span class="n">embedding_dim</span> <span class="o">=</span> <span class="mi">256</span>

    <span class="n">token_embeding_layer</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>

    <span class="p">{</span><span class="ss">:ok</span><span class="p">,</span> <span class="n">file_content</span><span class="p">}</span> <span class="o">=</span> <span class="no">File</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="s2">"the-verdict.txt"</span><span class="p">)</span>

    <span class="n">dataloader</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">GptDatasetV1</span><span class="o">.</span><span class="n">create_dataloader_v1</span><span class="p">(</span>
        <span class="ss">raw_text:</span> <span class="n">file_content</span><span class="p">,</span>
        <span class="ss">batch_size:</span> <span class="mi">8</span><span class="p">,</span>
        <span class="ss">max_length:</span> <span class="mi">4</span><span class="p">,</span>
        <span class="ss">stride:</span> <span class="mi">4</span><span class="p">,</span>
        <span class="ss">shuffle:</span> <span class="no">false</span><span class="p">,</span>
        <span class="ss">drop_last:</span> <span class="no">false</span><span class="p">,</span>
        <span class="ss">num_workers:</span> <span class="mi">0</span>
      <span class="p">)</span>

    <span class="n">batch_0</span> <span class="o">=</span> <span class="n">dataloader</span><span class="o">.</span><span class="n">stream</span> <span class="o">|&gt;</span> <span class="no">Enum</span><span class="o">.</span><span class="n">at</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">inputs_list</span> <span class="o">=</span> <span class="no">Enum</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">batch_0</span><span class="p">,</span> <span class="k">fn</span> <span class="p">{</span><span class="n">input</span><span class="p">,</span> <span class="n">_target</span><span class="p">}</span> <span class="o">-&gt;</span> <span class="n">input</span> <span class="k">end</span><span class="p">)</span>

    <span class="c1"># Stack list of tensors into a single tensor: [tensor1, tensor2, ...] -&gt; tensor with shape [batch_size, seq_len]</span>
    <span class="n">inputs</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">inputs_list</span><span class="p">)</span>

    <span class="n">expected_inputs</span> <span class="o">=</span>
      <span class="p">[</span>
        <span class="p">[</span><span class="mi">40</span><span class="p">,</span> <span class="mi">367</span><span class="p">,</span> <span class="mi">2885</span><span class="p">,</span> <span class="mi">1464</span><span class="p">],</span>
        <span class="p">[</span><span class="mi">1807</span><span class="p">,</span> <span class="mi">3619</span><span class="p">,</span> <span class="mi">402</span><span class="p">,</span> <span class="mi">271</span><span class="p">],</span>
        <span class="p">[</span><span class="mi">10899</span><span class="p">,</span> <span class="mi">2138</span><span class="p">,</span> <span class="mi">257</span><span class="p">,</span> <span class="mi">7026</span><span class="p">],</span>
        <span class="p">[</span><span class="mi">15632</span><span class="p">,</span> <span class="mi">438</span><span class="p">,</span> <span class="mi">2016</span><span class="p">,</span> <span class="mi">257</span><span class="p">],</span>
        <span class="p">[</span><span class="mi">922</span><span class="p">,</span> <span class="mi">5891</span><span class="p">,</span> <span class="mi">1576</span><span class="p">,</span> <span class="mi">438</span><span class="p">],</span>
        <span class="p">[</span><span class="mi">568</span><span class="p">,</span> <span class="mi">340</span><span class="p">,</span> <span class="mi">373</span><span class="p">,</span> <span class="mi">645</span><span class="p">],</span>
        <span class="p">[</span><span class="mi">1049</span><span class="p">,</span> <span class="mi">5975</span><span class="p">,</span> <span class="mi">284</span><span class="p">,</span> <span class="mi">502</span><span class="p">],</span>
        <span class="p">[</span><span class="mi">284</span><span class="p">,</span> <span class="mi">3285</span><span class="p">,</span> <span class="mi">326</span><span class="p">,</span> <span class="mi">11</span><span class="p">]</span>
      <span class="p">]</span>
      <span class="o">|&gt;</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="ss">type:</span> <span class="p">{</span><span class="ss">:s</span><span class="p">,</span> <span class="mi">32</span><span class="p">})</span>

    <span class="c1"># Verify inputs match expected output exactly</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">all_close</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">expected_inputs</span><span class="p">,</span> <span class="ss">atol:</span> <span class="mf">1.0e-6</span><span class="p">),</span>
           <span class="s2">"Input embeddings should match expected values exactly"</span>

    <span class="c1"># Assert the size/shape of inputs is [8, 4]</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">}</span>

    <span class="c1"># Get token embeddings: shape [8, 4, 256]</span>
    <span class="n">token_embeddings</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">token_embeding_layer</span><span class="p">,</span> <span class="n">inputs</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">token_embeddings</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">256</span><span class="p">}</span>

    <span class="c1"># Create positional embedding layer: vocab_size=4 (positions 0,1,2,3), embedding_dim=256</span>
    <span class="n">positional_embedding_layer</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="ss">seed:</span> <span class="mi">123</span><span class="p">)</span>
    <span class="n">positional_embedding_weights</span> <span class="o">=</span> <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span><span class="o">.</span><span class="n">weight</span><span class="p">(</span><span class="n">positional_embedding_layer</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">positional_embedding_weights</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">4</span><span class="p">,</span> <span class="mi">256</span><span class="p">}</span>

    <span class="c1"># Create positional indices: [0, 1, 2, 3] for each position in the sequence</span>
    <span class="c1"># Shape: [4] -&gt; expand to [1, 4] -&gt; broadcast to [8, 4]</span>
    <span class="n">positional_indices</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="ss">type:</span> <span class="p">{</span><span class="ss">:s</span><span class="p">,</span> <span class="mi">64</span><span class="p">})</span>
    <span class="n">positional_indices_batch</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="no">Nx</span><span class="o">.</span><span class="n">new_axis</span><span class="p">(</span><span class="n">positional_indices</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">{</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">})</span>

    <span class="c1"># Get positional embeddings: shape [8, 4, 256]</span>
    <span class="n">positional_embeddings</span> <span class="o">=</span>
      <span class="no">LlmScratch</span><span class="o">.</span><span class="no">EmbeddingNative</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">positional_embedding_layer</span><span class="p">,</span> <span class="n">positional_indices_batch</span><span class="p">)</span>

    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">positional_embeddings</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">256</span><span class="p">}</span>

    <span class="c1"># Add token embeddings and positional embeddings: shape [8, 4, 256]</span>
    <span class="n">embeddings_sum</span> <span class="o">=</span> <span class="no">Nx</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">token_embeddings</span><span class="p">,</span> <span class="n">positional_embeddings</span><span class="p">)</span>
    <span class="n">assert</span> <span class="no">Nx</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">embeddings_sum</span><span class="p">)</span> <span class="o">==</span> <span class="p">{</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">256</span><span class="p">}</span>
  <span class="k">end</span>
</code></pre></div></div>

<p>We first set the EXLA backend because Elixir is slow for this type of computation without it. On my MacBook Pro M3, this test takes about 71 seconds without EXLA, but under second with EXLA. Then we use the full vocabulary size of the BPE tokenizer and an embedding dimension of 256. ChatGPT-3 has 12,288 dimensions. We load <em>The Verdict</em> story into 8 batches with a maximum length of 4 and a stride of 4. Then we extract the inputs from the first batch (<code class="language-plaintext highlighter-rouge">batch_0</code>). To get token embeddings, we combine the embedding layer and the inputs, which gives us a tensor with shape <code class="language-plaintext highlighter-rouge">{8,4,256}</code>.<br />
  Next, we create absolute positional embeddings. The vocabulary size is 4, and the number of dimensions is the same: 256. The positional inputs are zero-based indices, <code class="language-plaintext highlighter-rouge">[0,1,2,3]</code>. <code class="language-plaintext highlighter-rouge">Nx.broadcast</code> is a handy way to turn these positional indices into an <code class="language-plaintext highlighter-rouge">Nx.Tensor</code> with shape <code class="language-plaintext highlighter-rouge">{8, 4}</code>.<br />
  The last thing to do is add these two tensors with <code class="language-plaintext highlighter-rouge">Nx.add</code>. Now we are ready to build attention mechanism!</p>

<h3 id="conclusion">Conclusion</h3>

<p>We finally got our hands dirty with Elixir code, and we survived because we were able to do everything in Elixir instead of Python. EXLA is supported by Nx out of the box, which helps with heavy computations, because this is one area where Elixir does not shine. I reread this post, and now I understand the codebase and the theory better. Next, I want to reread Giles’s blog posts about Chapter 2 (<a href="https://www.gilesthomas.com/2024/12/llm-from-scratch-2">part 2</a>, <a href="https://www.gilesthomas.com/2024/12/llm-from-scratch-3">part 3</a>) to fill possible gaps in my knowledge and answer the Chapter 2 questions from the workbook.</p>]]></content><author><name></name></author><category term="llm-from-scratch" /><category term="ai" /><category term="tutorials" /><category term="llm" /><category term="nlp" /><category term="text-prep" /><summary type="html"><![CDATA[An Elixir/Nx walkthrough of preparing text for LLM training: tokenization, token IDs, BPE, sliding windows, token embeddings, and positional embeddings.]]></summary></entry><entry><title type="html">Foundation of Taking Testing Seriously</title><link href="https://karlosmid.com/2026/01/foundation-taking-testing-seriously/" rel="alternate" type="text/html" title="Foundation of Taking Testing Seriously" /><published>2026-01-23T00:00:00+00:00</published><updated>2026-01-23T00:00:00+00:00</updated><id>https://karlosmid.com/2026/01/foundation-taking-testing-seriously</id><content type="html" xml:base="https://karlosmid.com/2026/01/foundation-taking-testing-seriously/"><![CDATA[<p><img src="/assets/foundations.jpg" alt="Cover of Isaac Asimov’s Foundation (1951)" /></p>

<h3 id="tldr">TL;DR</h3>

<p>In the <a href="/2025/12/taking-testing-seriously-chapter-1/">previous post on Taking Testing Seriously Chapter 1</a>, the authors explained why we need another book about software testing. The reason is very simple: a book like that has never been written about software testing. It sets the tone and philosophy so you can decide if this approach is for you or not. Let’s move on to Chapter 2, The Foundation. Again, I will move through my <a href="https://botie.app">Botie.app</a> notes to deepen my understanding of the foundations.</p>

<h3 id="el3e">el3e</h3>

<p>The authors stress again that <em>Testing is a human, social, and technical activity</em>. To be honest, we all start with the technical activity, forget the other two, and because of that miss a lot of bugs. It is human because most of our colleagues are human (I also have two LLM colleagues), and social because humans are social beings.</p>

<p>As it should, the foundation states the definition of software testing. It is the el3e mnemonic:</p>

<p><em>Software testing is the process of evaluating the product by learning about it by experimenting, experiencing and exploring with it.</em></p>

<p>Isn’t this more fun to do than to write browser automation scripts!?</p>

<h3 id="the-evidence">The evidence</h3>

<p>We are gathering evidences about the product under test. To determine if this evidence matters to the client, we need to apply the CRC pattern. CRC is my addition, not James and Michael’s. CRC comes from Elixir; it was developed and popularized by Bruce Tate, a well-known Elixir instructor. If you’re curious to learn more about Bruce Tate’s teachings or Elixir in general, check out <a href="https://grox.io">Groxio</a>, his training site. Construct, Reduce, Convert. We first need to construct the evidence, telling a story about what we found in the product. Then we need to reduce it to the bone so it contains only the essential information we can sell to the client. Convert is the final rewrite, in the terms of the client’s business domain (this risk could be realized).</p>

<h3 id="learning-opportunity">Learning opportunity</h3>

<p>The authors state that each product failure is a learning opportunity for a software tester. I would add it is the same for the developer. If you do not have the urge to open champagne when you find a product issue, then maybe software testing is not the right career choice. Here is my example. I am developing a feature that integrates a client product with another product. During the connect phase, we get, via REST API, a number of objects from the other product. My integration handles hundreds of objects, while others handle tens. This integration happens in the background, and it takes time—I can see hundreds of requests in the logs. I kept that in the back of my mind. The application was working, but it strangely changed data in the UI after the logs were done. QA reported one issue that I could not replicate at first because I was waiting for the load to be done. QA did not know about those loads and used the UI immediately. Keeping in mind the loading time for the logs, I managed to replicate and fix the async issue. I changed the order of the data load so we first load the data that is immediately used in the UI. I have never dismissed those REST API requests in the logs since they finally fit the application model.</p>

<h3 id="checking-vs-testing">Checking vs testing</h3>

<p>I witnessed a lot of bad blood spill during the conversation about this topic. Here is what James and Michael say about it. In programming you need to understand both the business and machine sides of the story. Memory management, networking, and disk are parts of the machine story. In the end, you are programming a machine. Programmers perform checking, where checking is the mechanistic (automated) process of verifying propositions about the product. Those propositions can only be true or false; there is never a third option. And now the important thing about quality.</p>

<p><em>Quality is an opinion, not a fact; quality cannot be verified.</em></p>

<p>If quality is an opinion, how could we assure or assist quality (QA)? We cannot quantify quality, because it is <em>value to a person that matters (Weinberg)</em>. Testing cannot be automated, but checking can. Testing is an <em>encounter with a product, an activity, not an artifact</em>. Testing is not an idea. It is everything a tester does as part of their role to fulfill the mission of testing.</p>

<h3 id="testing">Testing</h3>

<p>Deep testing uncovers bugs that are challenging to detect and issues that truly matter. Preparing for deep testing requires more time and effort, making it a more demanding activity, while shallow testing typically exposes the obvious bugs. Oracles help us decide whether something is actually a bug or not, and in deep testing, we have to be especially thoughtful when designing these oracles.</p>

<p>You probably recognize this pattern: we often start with shallow testing and then proceed to deep testing. A bug that matters is one that ties directly to significant business risk.</p>

<p>For testers, the product is often a mystery—we are searching for problems, not simply confirming that things work. Our real aim is to uncover business risks.</p>

<p>Testers use heuristics, not rigid algorithms. Clear terminology is crucial because abbreviations can cause confusion. Testing is an event: testers work on behalf of a client and should do their best to simulate, anticipate, and inhabit the client’s perspective. In <a href="https://www.moderntesting.org/">Modern Testing</a>, Alan Page states that only the customer can truly evaluate the quality of a product. This isn’t a contradiction: our job is to prepare the product for real customers, ensuring that when it reaches them for the final test, it will meet or exceed their expectations. Some argue that it is impossible to test exactly like a client because clients are driven by business success and will use the application in ways testers might never imagine. So what counts as fake testing? Here are a few signals:</p>

<ul>
  <li>creating test cases before seeing the product</li>
  <li>blindly following a script and dismissing all side findings</li>
  <li>reporting only pass/fail rates</li>
  <li>constraining testing to whatever the test tool supports</li>
  <li>using a process we do not understand</li>
  <li>pushing reported bugs so aggressively that it looks like we are angling to take over as project lead</li>
</ul>

<p>We need to be humble about our testing expertise (I still consider myself an evolving software developer and tester who can learn new things about software testing from anybody).</p>

<p><em>I have expertise, instead of I am expert.</em></p>

<h3 id="the-model">The Model</h3>

<p>So how do we test? We create a model of a product. Two different testers would have two different models of the same product. We try to answer the question: how good is the product? A flowchart diagram or product requirements are examples of a product model. How do we create those models? We do black/white/gray box testing. We explore the product using one of those techniques, and based on the findings (evidence), we create a better product model. And there is a famous testing fallacy that we do not need to do integration testing because we have unit testing. Do not fall for that—black/white/gray box testing can uncover a lot of interesting product evidence. If you run out of testing ideas while you try to simulate the product user, page 35 has a list of sixteen types of product user.</p>

<h3 id="jerry-weinberg">Jerry Weinberg</h3>

<p><em>Quality is value to some person.</em></p>

<p>Bragging time. I had the pleasure of meeting Jerry in the flesh at the STP testing conference in 2010, where he signed my copy of <em>Perfect Software and Other Illusions About Testing</em>. Michael and James’s work is deeply inspired by Jerry’s work. When you ask them what book they recommend on software testing, the answer is “Anything from Jerry Weinberg.” Remember, quality is value, and as we have already stated, value cannot be quantified into numbers, and because of that it cannot be measured (verified). So we gather product evidence, report it, and then the user decides whether it is valuable.</p>

<p>Then we have the <em>lucid fallacy</em>. This is when product quality is defined as a set of predefined tests that must pass. If those pass, we claim the product has quality. Or there could be a set of requirements. And then we stop listening to the users of the product.</p>

<h3 id="risk">Risk</h3>

<p>Risk is the danger of the product doing something bad or failing to do something good that it is supposed to do. Risk potential directs our testing because we cannot test all product variations. Risk could be unknown, identified, or tested. Unknown risks are the hardest—we interact with the product and then must figure out which evidence could be a risk. For an identified risk we have direction on how to test. “Tested” is when we have evidence that the risk could not be fulfilled.</p>

<h3 id="context">Context</h3>

<p>It is common that for any question you ask a context-driven tester, the answer is, “It depends on the context.” There are three principles of context-driven testing:</p>

<ul>
  <li>the value of any practice depends on its context</li>
  <li>there are good practices in the context, but there are no best practices</li>
  <li>people who work together are the most important part of any project’s context</li>
</ul>

<p>Simple as that. And yes, project context changes over time—especially in startups when pivoting is a normal modus operandi. Every tester must have a testing mission. What is the purpose of our testing? That mission must be agreed upon with the client. Mixing other roles into testing is not good; the reason is simple: testing is a full-time job. But we can ask for testing help from other teammates. In that case, we must take care that testing is convergent, not divergent.</p>

<h3 id="the-knowledge">The knowledge</h3>

<p>We have tacit and explicit knowledge. Tacit knowledge is for solving problems, and explicit knowledge is when we can explain those problems. Exploratory testing is moving through a space and learning about it. We use the product, and we learn about it by making choices when we encounter the unknown. <em>If others make choices for us, then they are explorers</em>; we are just a vessel. Opponents of exploratory testing say that there is no formality in it, but they are wrong. Exploratory testing moves from freedom to formality as we discover more product patterns. We cannot get something right the first time; this is why we iterate. At the beginning of the project we have confusion, and this is expected. We must know how to frame our testing. We must explain how our testing is related to business needs. In general systems thinking (Jerry Weinberg), the trick is to know what and when to ignore findings about the product. How do test and strategy relate to each other? Testing produces a story of product status, and strategy is the story of how we are going to achieve that testing.</p>

<h3 id="the-boundary-bug">The boundary bug</h3>

<p>I found one instance of a boundary bug: the last footnote only has the number, but the footnote text is missing.</p>

<p><img src="/assets/boundary_bug.jpeg" alt="Boundary bug" /></p>

<h3 id="the-heuristic">The heuristic</h3>

<p><em>A heuristic is a fallible means of solving a problem.</em> For example, my sink has a 9 cm drain hole. In order to clean the sink trap, I need a big screwdriver. My heuristic is that my biggest screwdriver will be enough. That was not the truth; my heuristic failed.</p>

<p>Solving a problem can involve many things—for example, understanding the situation, discovering the problem, or making a decision. Our job is to craft heuristics that are less likely to fail. Now we can also explain best practices vs. good practices. A “best practice” is when some authority claims, without proof, that a particular practice should be used to solve a problem. Heuristics tell a different story: here is a practice, here is how it works, why it works, and how it may not work—so be careful. You need skill to apply heuristics, and you need practice to build that skill. Heuristics is when you develop judgement in your context.</p>

<h3 id="the-soul-of-rst">The soul of rst</h3>

<p>rst teaches us about “enoughness”—how to recognize that we can stop testing. Testers naturally have distance from the product (they did not create it), which often makes them better at spotting issues than developers. We should focus when we feel confused and defocus when we feel frustrated. During focused testing, change one variable at a time; during defocused testing, change many. Focus reduces chaos and confusion, while defocus helps us uncover new ideas.</p>

<p>For rst, experience is what happens to the tester during the testing activity. The challenge is capturing that experience, which is different from merely collecting screencasts or logs. Sensemaking is a social activity where we interpret data, but unconscious bias can cause us to miss bugs. Testing is an experiment—we seek surprises in the product. If we already know what will happen, it’s just a demonstration. We aim for product knowledge first, then surprises. Because any metric can be gamed, assessment relies on thoughtful, open-ended questions about quality. That’s how rst evaluates testing, not by counting bugs.</p>

<h3 id="closing">Closing</h3>

<p>Just as Asimov’s <em>Foundation</em> collected the Empire’s knowledge, this chapter gathers the essential rst building blocks. And guess what—we start testing in the next chapters.</p>]]></content><author><name></name></author><category term="rapid software testing" /><category term="book-notes" /><category term="testing" /><category term="rst" /><summary type="html"><![CDATA[My foundation plan for studying Taking Testing Seriously, the foundations, so the lessons from James Bach and Michael Bolton become habits instead of inspirational quotes.]]></summary></entry><entry><title type="html">Chapter 1 recap of Build LLM from Scratch</title><link href="https://karlosmid.com/2026/01/chapter-1-recap-build-llm-from-scratch/" rel="alternate" type="text/html" title="Chapter 1 recap of Build LLM from Scratch" /><published>2026-01-20T00:00:00+00:00</published><updated>2026-01-20T00:00:00+00:00</updated><id>https://karlosmid.com/2026/01/chapter-1-recap-build-llm-from-scratch</id><content type="html" xml:base="https://karlosmid.com/2026/01/chapter-1-recap-build-llm-from-scratch/"><![CDATA[<h2 id="tldr">TL;DR</h2>

<p>I reviewed every Chapter 1 study note I published and reread Giles Thomas’s first “LLM from Scratch” blog to cement the fundamentals, capturing several new takeaways from the chapter.</p>

<h2 id="integration-test-of-llm-on-my-blog-posts">Integration test of LLM on my blog posts</h2>

<p>Recently I switched from WordPress to Jekyll on GitHub Pages, which gave me a more efficient workflow. I now draft posts in Cursor, use Cursor or Codex agents, and no longer copy and paste to the ChatGPT UI. When I finished my recent LLM series, I asked an agent to check and fix the grammar. I knew that several answers to the Chapter 1 study questions were incomplete or wrong. Because I trusted the LLM, I never reviewed its edits. Since it predicts the next token based on patterns, it “fixed” my incorrect answers instead of just polishing the grammar. So I lost my history of learning about llms. My prompt never stated “fix grammar only,” so the model tried to improve the content. I attempted a review using git diff, but with Markdown every paragraph sits on a single line, which made tracking changes hard. From now on I will proofread the rendered page in my local Jekyll instance.</p>

<p>That experience reinforced Chapter 1’s explanation of how an LLM works: it predicts the next token based on the weights it has learned.</p>

<h2 id="what-i-learned-from-giles-thomass-blog-post">What I learned from Giles Thomas’s blog post</h2>

<p>The first Transformer architecture had both an encoder and a decoder. I had missed that the encoder is also a neural network with its own attention blocks: the encoder produces contextual token representations, and the decoder consumes them to generate output tokens. Today we mostly see decoder-only architectures, yet encoder-only models still exist. Google’s BERT is a well-known example that handles classification and masked-word prediction but not token generation. Giles also highlighted that traditional machine learning includes algorithms that auto-label data from their inputs—Naive Bayes is one such case, so “automatic labeling” is not unique to LLMs.</p>

<h2 id="conclusion">Conclusion</h2>

<p>An important part of learning about LLMs is to use them. Applying one to my own posts showed me how it behaves on real content and why I must review every change. Rereading Giles’s blog post was also the right move; it reinforced my Chapter 1 knowledge and clarified the points I previously glossed over.</p>]]></content><author><name></name></author><category term="llm" /><category term="llm" /><category term="study-notes" /><summary type="html"><![CDATA[Re-read my Chapter 1 study posts and Giles Thomas’s companion article to reinforce Sebastian Raschka’s Build LLMs from Scratch takeaways and capture the new insights I missed the first time.]]></summary></entry><entry><title type="html">Building a large language model</title><link href="https://karlosmid.com/2026/01/building-a-large-language-model/" rel="alternate" type="text/html" title="Building a large language model" /><published>2026-01-12T00:00:00+00:00</published><updated>2026-01-12T00:00:00+00:00</updated><id>https://karlosmid.com/2026/01/building-a-large-language-model</id><content type="html" xml:base="https://karlosmid.com/2026/01/building-a-large-language-model/"><![CDATA[<h2 id="tldr">TL;DR</h2>

<p>This is the sixth post in my LLM learning series. In the <a href="/2026/01/a-closer-look-at-the-gpt-architecture/">previous post</a>, we explored the GPT architecture in detail. Today, we dive into the process of building a large language model.</p>

<h2 id="what-are-the-three-main-stages-involved-in-building-a-large-language-model-from-scratch">What are the three main stages involved in building a large language model from scratch?</h2>

<p>The first stage is data preparation. Various data sources like books, web forums, and blog posts must be scraped into files and prepared for LLM training. The second stage is LLM pretraining, where we create a base model that understands the training language’s grammar, patterns, and nuances.
The final stage is fine-tuning the base model on a smaller dataset so the LLM can perform specific tasks with more accuracy.</p>

<p><em>The three stages are implementing the LLM architecture and data
preparation process, pretraining an LLM to create a foundation model,
and fine-tuning the foundation model for specific tasks.</em></p>

<p>I cheated a little bit on this question. I only knew two stages, so I reached for my botie.app notes and remembered that the first stage is data preparation.</p>

<h2 id="what-is-the-key-idea-behind-the-transformer-architecture-used-in-llms">What is the key idea behind the transformer architecture used in LLMs?</h2>

<p>The key idea is auto-labeling of input data during the model pretraining stage.</p>

<p><em>The transformer architecture utilizes an attention mechanism that allows
the LLM to selectively access the entire input sequence when generating
output, word by word.</em></p>

<p>This one was a total miss :)</p>

<h2 id="what-is-the-primary-task-used-for-pretraining-llms-like-gpt-3">What is the primary task used for pretraining LLMs like GPT-3?</h2>

<p>The primary task in the pretraining stage is to make the LLM understand language grammar, patterns, and nuances.</p>

<p><em>LLMs like GPT-3 are pretrained on a massive corpus of text by predicting
the next word in a sentence, using this prediction as a label.</em></p>

<p>Another failure.</p>

<h2 id="explain-the-concept-of-emergent-properties-in-llms">Explain the concept of emergent properties in LLMs.</h2>

<p>The emergent properties concept is when an LLM can solve tasks that we did not expect it to solve, like language translation.</p>

<p><em>While the primary pretraining task for GPT-like models is next-word
prediction, they exhibit emergent properties, meaning they can perform
tasks like classification, translation, and summarization without explicit
training for those tasks</em></p>

<p>I forgot to mention summarization and classification as examples of emergent tasks.</p>

<h2 id="why-is-fine-tuning-a-pretrained-llm-beneficial-for-specific-tasks">Why is fine-tuning a pretrained LLM beneficial for specific tasks?</h2>

<p>It is beneficial because the LLM can then better solve specific tasks in the domain of the fine-tuned dataset. An example is answering questions about specific medical problems.</p>

<p><em>Fine-tuning a pretrained LLM on a custom dataset allows it to specialize in
specific tasks and potentially outperform general LLMs on those tasks.</em></p>

<p>The key point I forgot is that it can outperform the general (base) LLM in specific tasks.</p>

<h2 id="put-these-stages-of-creating-a-pretrained-llm-base-model-in-order">Put these stages of creating a pretrained LLM (base model) in order:</h2>

<p>A. Evaluate the model’s performance on text generation
tasks.
B. Implement the transformer decoder architecture (GPT-
like).
C. Prepare the text data by cleaning and tokenizing it.
D. Train the model using a next-word prediction task on a
large text dataset.</p>

<table>
  <thead>
    <tr>
      <th>Step</th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>C</td>
      <td>Prepare the text data by cleaning and tokenizing it.</td>
    </tr>
    <tr>
      <td>B</td>
      <td>Implement the transformer decoder architecture (GPT-like).</td>
    </tr>
    <tr>
      <td>D</td>
      <td>Train the model using a next-word prediction task on a large text dataset.</td>
    </tr>
    <tr>
      <td>A</td>
      <td>Evaluate the model’s performance on text generation tasks.</td>
    </tr>
  </tbody>
</table>

<p>It seems that I am good at table questions; this one was correct.</p>

<h2 id="match-the-term-to-its-description-on-the-right">Match the term to its description on the right:</h2>

<table>
  <thead>
    <tr>
      <th>Term</th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Autoregressive Model</td>
      <td>The task of predicting the next word in a sequence, which is used to train GPT models.</td>
    </tr>
    <tr>
      <td>Self-Supervised Learning</td>
      <td>A type of machine learning where the model learns from the data itself, without requiring explicit labels.</td>
    </tr>
    <tr>
      <td>Next-Word Prediction</td>
      <td>A type of model that generates text by predicting the next word in a sequence based on the words that have already been generated.</td>
    </tr>
    <tr>
      <td>Decoder-Only Architecture</td>
      <td>The architecture of GPT models, which uses only the decoder portion of the transformer architecture, making it suitable for text generation.</td>
    </tr>
  </tbody>
</table>

<table>
  <thead>
    <tr>
      <th>Term</th>
      <th>Row</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Autoregressive Model</td>
      <td>Row 3</td>
    </tr>
    <tr>
      <td>Self-Supervised Learning</td>
      <td>Row 2</td>
    </tr>
    <tr>
      <td>Next-Word Prediction</td>
      <td>Row 1</td>
    </tr>
    <tr>
      <td>Decoder-Only Architecture</td>
      <td>Row 4</td>
    </tr>
  </tbody>
</table>

<h2 id="conclusion">Conclusion</h2>

<p>This section was also hard for me; mostly I could not remember the facts that I read one month ago. To recap, I will read chapter 1 again, read all my blog posts for chapter 1, and read <a href="https://www.gilesthomas.com/2024/12/llm-from-scratch-1">Giles’s blog part 1</a>.</p>]]></content><author><name></name></author><category term="llm" /><category term="llm" /><category term="ai" /><category term="machine-learning" /><summary type="html"><![CDATA[Exploring the three main stages of building an LLM — data preparation, pretraining, and fine-tuning — along with key concepts like transformer architecture, emergent properties, and self-supervised learning.]]></summary></entry><entry><title type="html">A Closer Look at the GPT Architecture</title><link href="https://karlosmid.com/2026/01/a-closer-look-at-the-gpt-architecture/" rel="alternate" type="text/html" title="A Closer Look at the GPT Architecture" /><published>2026-01-05T00:00:00+00:00</published><updated>2026-01-05T00:00:00+00:00</updated><id>https://karlosmid.com/2026/01/a-closer-look-at-the-gpt-architecture</id><content type="html" xml:base="https://karlosmid.com/2026/01/a-closer-look-at-the-gpt-architecture/"><![CDATA[<h2 id="tldr">TL;DR</h2>

<p>In the <a href="/2026/01/utilizing-large-datasets/">previous post</a>, I discussed utilizing large datasets. Today, I am continuing with the study questions from <em>Build LLM from Scratch</em> by Sebastian Raschka, and the topic is a closer look at the GPT architecture.</p>

<h2 id="what-is-the-primary-task-that-gpt-models-are-trained-on-and-how-does-this-relate-to-their-ability-to-perform-other-tasks-like-translation">What is the primary task that GPT models are trained on, and how does this relate to their ability to perform other tasks like translation?</h2>

<p>The primary task that GPT models are trained on is to understand language nuances, context, and patterns.</p>

<p><em>GPT models are primarily trained on a next-word prediction task, which
involves predicting the next word in a sequence. This seemingly simple
task allows the models to learn the relationships between words and
phrases, enabling them to perform other tasks like translation, even
though they were not explicitly trained for it.</em></p>

<p>I totally failed on this question. I did not know what the primary task of a GPT model is. Ouch!</p>

<h2 id="explain-the-concept-of-self-supervised-learning-in-the-context-of-gpt-models">Explain the concept of self-supervised learning in the context of GPT models.</h2>

<p>The concept of self-supervised learning is where we send input to the decoder module gradually, first sending the first token, then we send the first and predicted token back to the decoder module, and we repeat that until there is no predicted token.</p>

<p><em>GPT models utilize self-supervised learning, where the model learns from
the data itself without requiring explicit labels. In the case of GPT, the
next word in a sentence serves as the label for the model to predict,
allowing for training on massive unlabeled text datasets.</em></p>

<p>Again, a total miss. Ok, I knew that the next word is important in the concept of self-supervised learning. So the number of labels in one sentence is equal to the number of words (tokens) in the sentence. And what comes before that label also matters.</p>

<h2 id="how-does-the-gpt-architecture-differ-from-the-original-transformer-architecture-and-what-are-the-implications-of-this-difference">How does the GPT architecture differ from the original transformer architecture, and what are the implications of this difference?</h2>

<p>The original transformer architecture had an encoder and decoder, and the LLM transformer only has a decoder, where encoding/decoding happens.</p>

<p><em>The GPT architecture uses only the decoder portion of the transformer,
making it a decoder-only model. This design makes it suitable for text
generation and next-word prediction tasks, as it generates text one word
at a time in a unidirectional, left-to-right manner.</em></p>

<p>This is what I do not understand: who does the encoder job in an LLM where we only have a decoder? So I asked Cursor:</p>

<p><em>In transformer architecture for llm, there is only decoder. Who then does encoder work from old transformer architecture</em></p>

<p>The answer:</p>

<h2 id="original-transformer-2017--encoder-decoder">Original Transformer (2017) — Encoder-Decoder</h2>

<p><strong>Encoder’s job</strong>: Read the entire input and create a rich representation (understanding)</p>
<ul>
  <li>Example: Read “The cat sat on the mat” and understand its meaning</li>
</ul>

<p><strong>Decoder’s job</strong>: Generate output using the encoder’s representation</p>
<ul>
  <li>Example: Generate translation “Le chat s’est assis sur le tapis”</li>
</ul>

<p>The encoder could see ALL tokens at once (bidirectional attention).</p>

<h2 id="gpts-decoder-only-architecture">GPT’s Decoder-Only Architecture</h2>

<p>In GPT, the <strong>self-attention mechanism does both jobs</strong>.</p>

<h3 id="how-self-attention-encodes">How Self-Attention “Encodes”</h3>

<p>When processing “The cat sat on the”:</p>
<ol>
  <li>Each token attends to all <strong>previous</strong> tokens</li>
  <li>“on” can look at “The”, “cat”, “sat” to understand context</li>
  <li>“the” can look at “The”, “cat”, “sat”, “on” to understand context</li>
</ol>

<p>This builds up a contextual representation — similar to what an encoder does!</p>

<h3 id="key-difference-causal-masking">Key Difference: Causal Masking</h3>

<table>
  <thead>
    <tr>
      <th>Architecture</th>
      <th>What each token can see</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Encoder (BERT)</td>
      <td>All tokens (past AND future)</td>
    </tr>
    <tr>
      <td>Decoder (GPT)</td>
      <td>Only past tokens (causal mask)</td>
    </tr>
  </tbody>
</table>

<p>BERT (encoder): “sat” sees → The, cat, <strong>sat</strong>, on, the, mat</p>

<p>GPT (decoder): “sat” sees → The, cat, <strong>sat</strong> ← can’t see future</p>

<h3 id="why-it-works">Why It Works</h3>

<p>The decoder’s self-attention layers progressively build understanding through multiple layers:</p>
<ul>
  <li><strong>Layer 1</strong>: Basic word relationships</li>
  <li><strong>Layer 2-N</strong>: Deeper semantic understanding</li>
  <li><strong>Final layers</strong>: Rich contextual representation</li>
</ul>

<p>So the “encoding” happens <strong>implicitly</strong> within the decoder’s self-attention layers, just restricted to only seeing past context.</p>

<p>So we reached the layers concept, but more about that in future posts.</p>

<h2 id="what-is-happening-at-the-labels-1-and-2-in-the-diagram">What is happening at the labels 1 and 2 in the diagram?</h2>

<p><img src="/assets/decoder_stages.png" alt="Decoder Stages" /></p>

<table>
  <thead>
    <tr>
      <th>Label</th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>1</td>
      <td>“is” token is the predicted token for “This”</td>
    </tr>
    <tr>
      <td>2</td>
      <td>“This is” is the output of iteration 1 and is the input for iteration 2</td>
    </tr>
  </tbody>
</table>

<p>The answer:</p>

<table>
  <thead>
    <tr>
      <th>Label</th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>1</td>
      <td>The next word is created based on the input text</td>
    </tr>
    <tr>
      <td>2</td>
      <td>The output of the previous round becomes the input to the next round</td>
    </tr>
  </tbody>
</table>

<p>Finally, one correct answer from my side!</p>

<h2 id="what-is-the-significance-of-gpt-models-being-considered-autoregressive-models">What is the significance of GPT models being considered autoregressive models?</h2>

<p>The significance of autoregressive models is that as we move through iterations, the model gets better at predicting the next token.</p>

<p><em>Autoregressive models, like GPT, incorporate their previous outputs as
inputs for future predictions. This means that each new word generated
by GPT is based on the preceding sequence, ensuring coherence and
fluency in the generated text.</em></p>

<p>By getting better and better at predicting new tokens, I was thinking of predicting the correct token. But what is a correct token? It seems to be one that ensures coherence and fluency in the text. The predicted token should logically connect to previous tokens.</p>

<p>Example:</p>
<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Coherent (what GPT produces):
&gt; "The cat sat on the mat. It was warm from the sunlight streaming through the window."
"It" clearly refers to "the mat"
The second sentence logically connects to the first
The narrative flows naturally
Incoherent (if words were generated independently):
&gt; "The cat sat on the mat. Purple ideas sleep furiously yesterday."
No connection between sentences
No logical flow
</code></pre></div></div>

<h2 id="describe-the-relationship-between-the-size-and-complexity-of-gpt-models-and-their-capabilities">Describe the relationship between the size and complexity of GPT models and their capabilities.</h2>

<p>As an LLM model is trained on a large set of input tokens (size is in billions), the LLM consists of a large set of weight parameters that help it to consistently predict the next token. The LLM is more complex as it contains a larger set of weight parameters.</p>

<p><em>GPT models, particularly GPT-3, are significantly larger than the original
transformer model, with a greater number of layers and parameters. This
increased size and complexity contribute to their ability to perform a
wider range of tasks and achieve higher accuracy.</em></p>

<p>A GPT model is able to perform a wider range of tasks and achieve higher accuracy versus my answer: consistently predict next token, which is wrong. Tasks could be:</p>
<ul>
  <li>predict next token</li>
  <li>classify input</li>
  <li>summarize input</li>
  <li>generate essay</li>
</ul>

<h2 id="match-the-term-on-the-left-to-its-description-on-the-right">Match the term on the left to its description on the right:</h2>

<p><img src="/assets/models.png" alt="GPT Model Structure" />
| Term              | Description |
|——————-|————-|
| pretrained models | row 2       |
| fine-tuning       | row 3       |
| base model        | row 1       |</p>

<p>Correct answer</p>

<h2 id="conclusion">Conclusion</h2>

<p>For me, this part was tough. We are scratching the surface of LLM magic, and things are getting complicated.</p>]]></content><author><name></name></author><category term="LLM" /><category term="llm" /><category term="gpt" /><category term="transformer" /><category term="architecture" /><category term="deep-learning" /><summary type="html"><![CDATA[Exploring GPT architecture through study questions — understanding next-word prediction, self-supervised learning, decoder-only design, autoregressive generation, and how model size impacts capabilities.]]></summary></entry><entry><title type="html">Utilizing Large Datasets</title><link href="https://karlosmid.com/2026/01/utilizing-large-datasets/" rel="alternate" type="text/html" title="Utilizing Large Datasets" /><published>2026-01-04T11:00:00+00:00</published><updated>2026-01-04T11:00:00+00:00</updated><id>https://karlosmid.com/2026/01/utilizing-large-datasets</id><content type="html" xml:base="https://karlosmid.com/2026/01/utilizing-large-datasets/"><![CDATA[<h2 id="tldr">TL;DR</h2>

<p>In the <a href="/2026/01/introduction-to-transformer-architecture/">previous post</a>, I discussed the Transformer architecture. Today, I am continuing with the study questions from <em>Build LLM from Scratch</em> by Sebastian Raschka, and the topic is utilizing large datasets.</p>

<h2 id="what-are-the-key-characteristics-of-the-training-datasets-used-for-large-language-models-like-gpt-3-and-bert">What are the key characteristics of the training datasets used for large language models like GPT-3 and BERT?</h2>

<p>The key characteristics of the training dataset are size and diversity. For GPT-3, the training corpus had 410 billion tokens. For diversity, the dataset consisted of books, Wikipedia, blog posts and from sites like Reddit.</p>

<p><em>These datasets are vast, encompassing billions of words and covering a
wide range of topics and languages. They are designed to expose the
models to diverse text, enabling them to learn language syntax,
semantics, and context.</em></p>

<p>I forgot to mention that various languages are used.</p>

<h2 id="explain-the-significance-of-the-size-and-diversity-of-the-training-dataset-for-the-performance-of-large-language-models">Explain the significance of the size and diversity of the training dataset for the performance of large language models.</h2>

<p>As the training dataset becomes more diverse and larger, the LLM can predict the next word more precisely because it develops a better language model.</p>

<p><em>The scale and diversity of the training data allow these models to perform
well on various tasks, including those requiring general knowledge. The
models learn to understand and generate text that reflects the real-world
complexities of language</em></p>

<p>I forgot to mention that the LLM can predict the next word for various tasks.</p>

<h2 id="what-is-the-concept-of-tokenization-in-the-context-of-large-language-models">What is the concept of ‘tokenization’ in the context of large language models?</h2>

<p>To tokenize input text means to break it into tokens, where each token is, for example, a word or special characters used in language like “?”. A token can be smaller than a word. Then we create an attention model on those tokens to determine how significantly they correlate to each other.</p>

<p><em>Tokenization is the process of converting text into individual units called
tokens, which are the basic building blocks that the model reads and
processes. These tokens can be words, punctuation marks, or other
meaningful units of text.</em></p>

<p>This I answered correctly.</p>

<h2 id="describe-the-concept-of-pretraining-in-the-context-of-large-language-models-and-its-significance">Describe the concept of ‘pretraining’ in the context of large language models and its significance.</h2>

<p>In the pretraining phase, we create a large language model based on a general dataset, where the goal is that the LLM generally understands the structure of language, concepts, and patterns.</p>

<p><em>Pretraining involves training a large language model on a massive dataset
to learn general language patterns and knowledge. This pre-trained
model serves as a foundation, making it adaptable for various
downstream tasks through fine-tuning, which involves further training on
specific datasets for specific applications.</em></p>

<p>I forgot to mention that the result of the pretraining phase is a base for the fine-tuning process, where we train the LLM for specific tasks.</p>

<h2 id="explain-the-concept-of-fine-tuning-in-the-context-of-large-language-models-and-its-advantages">Explain the concept of ‘fine-tuning’ in the context of large language models and its advantages.</h2>

<p>The fine-tuning process of an LLM is a phase where we use a smaller dataset than in the pretraining phase. The goal is to teach the LLM how to resolve a specific task. An example is detecting whether an email is spam or not. Fine-tuning is done on a pretrained LLM model that has general knowledge about the language.</p>

<p><em>Fine-tuning involves further training a pre-trained large language model
on a smaller, task-specific dataset. This process adapts the model to
perform well on specific tasks, such as text summarization or question
answering, while leveraging the general knowledge learned during
pretraining.</em></p>

<p>This is a correct answer.</p>

<h2 id="match-the-term-on-the-left-to-its-description-on-the-right">Match the term on the left to its description on the right:</h2>

<p><img src="/assets/large_dataset.png" alt="Large Dataset Matching Exercise" /></p>

<table>
  <thead>
    <tr>
      <th>Term</th>
      <th>Description</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Encoder</td>
      <td>row 3</td>
    </tr>
    <tr>
      <td>Decoder</td>
      <td>row 2</td>
    </tr>
    <tr>
      <td>Self-attention mechanism</td>
      <td>row 4</td>
    </tr>
    <tr>
      <td>Zero-shot learning</td>
      <td>row 1</td>
    </tr>
  </tbody>
</table>

<p>This is correct!</p>

<h2 id="conclusion">Conclusion</h2>

<p>An LLM without a dataset cannot exist. We learned the most important characteristics of LLM datasets.</p>]]></content><author><name></name></author><category term="llm" /><summary type="html"><![CDATA[Exploring the role of large datasets in LLMs — tokenization, pretraining, and fine-tuning. Study questions from Build LLM from Scratch by Sebastian Raschka.]]></summary></entry><entry><title type="html">Introduction to Transformer Architecture</title><link href="https://karlosmid.com/2026/01/introduction-to-transformer-architecture/" rel="alternate" type="text/html" title="Introduction to Transformer Architecture" /><published>2026-01-03T10:00:00+00:00</published><updated>2026-01-03T10:00:00+00:00</updated><id>https://karlosmid.com/2026/01/introduction-to-transformer-architecture</id><content type="html" xml:base="https://karlosmid.com/2026/01/introduction-to-transformer-architecture/"><![CDATA[<h2 id="tldr">TL;DR</h2>

<p>In the <a href="/2026/01/stages-of-building-and-using-llms/">previous post</a>, I discussed the stages of building and using LLMs. Today, I am continuing with the study questions from <em>Build LLM from Scratch</em> by Sebastian Raschka, and the topic is the Transformer architecture.</p>

<h2 id="what-is-the-transformer-architecture-and-what-is-its-significance-in-the-development-of-llms">What is the transformer architecture and what is its significance in the development of LLMs?</h2>

<p>The Transformer architecture consists of encoder and decoder parts. In the encoder, input text is broken into tokens, and each token is transformed into an embedding vector. The embedding vector encodes the token’s relation to other tokens. Token embeddings are input to the decoder, and the decoder predicts the most probable next word. LLMs only use the decoder part, where input text is broken into tokens, tokens are encoded to embeddings, and decoded to predict the next word.</p>

<p><em>The transformer architecture is a deep neural network architecture that
revolutionized natural language processing. It’s the foundation for most
modern LLMs, enabling them to process and understand language
effectively.</em></p>

<p>I forgot to state that this is a deep neural network. My understanding was that LLMs only use the decoder part, and this is true. Old Transformers had encoder and decoder, and new ones only have the decoder.</p>

<h2 id="which-label-in-the-diagram-does-the-output-of-the-embeddings-stage-go-to">Which label in the diagram does the output of the Embeddings stage go to?</h2>

<p><img src="/assets/transformer.png" alt="Transformer Architecture" /></p>

<p>Label 2. This picture shows the old Transformer with encoder/decoder blocks. Embeddings are input to the decoder blocks. Old Transformers were translators, and this picture shows translation from English to German.</p>

<h2 id="describe-the-two-main-components-of-the-transformer-architecture-and-their-roles-in-language-processing">Describe the two main components of the transformer architecture and their roles in language processing.</h2>

<p>The two main components in the old Transformer are the Encoder and Decoder. The Encoder breaks text into tokens and then embeds each token into a vector (numerical value). The Decoder uses these embeddings as input and, based on that, predicts the next word.</p>

<p><em>The transformer architecture consists of an encoder and a decoder. The
encoder processes the input text and converts it into numerical
representations, while the decoder uses these representations to
generate the output text.</em></p>

<p>On the spot! It seems that I correctly grasped the Transformer concept!</p>

<h2 id="what-is-the-self-attention-mechanism-and-how-does-it-contribute-to-the-transformers-effectiveness">What is the self-attention mechanism and how does it contribute to the transformer’s effectiveness?</h2>

<p>The self-attention mechanism is encoded in embeddings and helps the decoder pay attention to selected parts of the input text to generate the next token (word).</p>

<p><em>The self-attention mechanism allows the transformer to weigh the
importance of different words in a sequence relative to each other. This
helps the model capture long-range dependencies and contextual
relationships, leading to more coherent and relevant output.</em></p>

<p>I missed the key concept: words pay attention to each other, and the result is an “understanding” of conceptual word relationships.</p>

<h2 id="explain-the-key-differences-between-bert-and-gpt-models-in-terms-of-their-training-approaches-and-primary-applications">Explain the key differences between BERT and GPT models in terms of their training approaches and primary applications.</h2>

<p>BERT is a text classifier, and GPT is a chatbot.</p>

<p><em>BERT focuses on masked word prediction and excels in tasks like text classification, while GPT is designed for generative tasks like text completion, translation, and summarization.</em></p>

<p>I learned that the three main features of GPT, or a chatbot, are text completion, translation, and summarization.</p>

<h2 id="what-are-zero-shot-and-few-shot-learning-and-how-do-they-relate-to-gpt-models">What are zero-shot and few-shot learning, and how do they relate to GPT models?</h2>

<p>GPT completes sentences. It can also perform tasks. In zero-shot, we provide no examples and let the model figure it out. In few-shot, we provide several examples to guide the output.</p>

<p><em>Zero-shot learning allows GPT models to perform tasks without prior
training on specific examples, while few-shot learning enables them to
learn from a minimal number of examples. These capabilities
demonstrate GPT’s versatility and adaptability</em></p>

<p>I got zero-shot wrong; as the name says, we use zero examples.</p>

<p>Zero-shot example:</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Classify the sentiment of this review as positive or negative:

"The food was amazing and the service was excellent!"

Sentiment:

Answer: positive
</code></pre></div></div>

<p>Few-shot example:</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Classify the sentiment of reviews:

Review: "Terrible experience, never coming back"
Sentiment: negative

Review: "Best pizza I've ever had!"
Sentiment: positive

Review: "The movie was boring and too long"
Sentiment: negative

Review: "I love this product, works perfectly!"
Sentiment:

Answer: positive
</code></pre></div></div>

<h2 id="conclusion">Conclusion</h2>

<p>The Transformer architecture is the key concept of LLMs, and it is not a simple one. Answering these questions gave us an easy introduction to it.</p>]]></content><author><name></name></author><category term="llm" /><summary type="html"><![CDATA[Diving into the Transformer architecture — encoder vs decoder, self-attention, BERT vs GPT, and zero-shot/few-shot learning. Study questions from Build LLM from Scratch.]]></summary></entry></feed>