Artwork

Content provided by The Nonlinear Fund. All podcast content including episodes, graphics, and podcast descriptions are uploaded and provided directly by The Nonlinear Fund or their podcast platform partner. If you believe someone is using your copyrighted work without your permission, you can follow the process outlined here https://player.fm/legal.
Player FM - Podcast App
Go offline with the Player FM app!

AF - BatchTopK: A Simple Improvement for TopK-SAEs by Bart Bussmann

7:17
 
Share
 

Manage episode 429760407 series 3314709
Content provided by The Nonlinear Fund. All podcast content including episodes, graphics, and podcast descriptions are uploaded and provided directly by The Nonlinear Fund or their podcast platform partner. If you believe someone is using your copyrighted work without your permission, you can follow the process outlined here https://player.fm/legal.
Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: BatchTopK: A Simple Improvement for TopK-SAEs, published by Bart Bussmann on July 20, 2024 on The AI Alignment Forum. Work done in Neel Nanda's stream of MATS 6.0. Epistemic status: Tried this on a single sweep and seems to work well, but it might definitely be a fluke of something particular to our implementation or experimental set-up. As there are also some theoretical reasons to expect this technique to work (adaptive sparsity), it seems probable that for many TopK SAE set-ups it could be a good idea to also try BatchTopK. As we're not planning to investigate this much further and it might be useful to others, we're just sharing what we've found so far. TL;DR: Instead of taking the TopK feature activations per token during training, taking the Top(K*batch_size) for every batch seems to improve SAE performance. During inference, this activation can be replaced with a single global threshold for all features. Introduction Sparse autoencoders (SAEs) have emerged as a promising tool for interpreting the internal representations of large language models. By learning to reconstruct activations using only a small number of features, SAEs can extract monosemantic concepts from the representations inside transformer models. Recently, OpenAI published a paper exploring the use of TopK activation functions in SAEs. This approach directly enforces sparsity by only keeping the K largest activations per sample. While effective, TopK forces every token to use exactly k features, which is likely suboptimal. We came up with a simple modification that solves this and seems to improve its performance. BatchTopK Standard TopK SAEs apply the TopK operation independently to each sample in a batch. For a target sparsity of K, this means exactly K features are activated for every sample. BatchTopK instead applies the TopK operation across the entire flattened batch: 1. Flatten all feature activations across the batch 2. Take the top (K * batch_size) activations 3. Reshape back to the original batch shape This allows more flexibility in how many features activate per sample, while still maintaining an average of K active features across the batch. Experimental Set-Up For both the TopK and the BatchTopK SAEs we train a sweep with the following hyperparameters: Model: gpt2-small Site: layer 8 resid_pre Batch size: 4096 Optimizer: Adam (lr=3e-4, beta1 = 0.9, beta2=0.99) Number of tokens: 1e9 Expansion factor: [4, 8, 16, 32] Target L0 (k): [16, 32, 64] As in the OpenAI paper, the input gets normalized before feeding it into the SAE and calculating the reconstruction loss. We also use the same auxiliary loss function for dead features (features that didn't activate for 5 batches) that calculates the loss on the residual using the top 512 dead features per sample and gets multiplied by a factor 1/32. Results For a fixed number of active features (L0=32) the BatchTopK SAE has a lower normalized MSE than the TopK SAE and less downstream loss degradation across different dictionary sizes. Similarly, for fixed dictionary size (12288) BatchTopK outperforms TopK for different values of k. Our main hypothesis for the improved performance is thanks to adaptive sparsity: some samples contain more highly activating features than others. Let's have look at the distribution of number of active samples for the BatchTopK model. The BatchTopK model indeed makes use of its possibility to use different sparsities for different inputs. We suspect that the weird peak on the left side are the feature activations on BOS-tokens, given that its frequency is very close to 1 in 128, which is the sequence length. This serves as a great example of why BatchTopK might outperform TopK. At the BOS-token, a sequence has very little information yet, but the TopK SAE still activates 32 features. The BatchTopK model "saves" th...
  continue reading

2434 episodes

Artwork
iconShare
 
Manage episode 429760407 series 3314709
Content provided by The Nonlinear Fund. All podcast content including episodes, graphics, and podcast descriptions are uploaded and provided directly by The Nonlinear Fund or their podcast platform partner. If you believe someone is using your copyrighted work without your permission, you can follow the process outlined here https://player.fm/legal.
Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: BatchTopK: A Simple Improvement for TopK-SAEs, published by Bart Bussmann on July 20, 2024 on The AI Alignment Forum. Work done in Neel Nanda's stream of MATS 6.0. Epistemic status: Tried this on a single sweep and seems to work well, but it might definitely be a fluke of something particular to our implementation or experimental set-up. As there are also some theoretical reasons to expect this technique to work (adaptive sparsity), it seems probable that for many TopK SAE set-ups it could be a good idea to also try BatchTopK. As we're not planning to investigate this much further and it might be useful to others, we're just sharing what we've found so far. TL;DR: Instead of taking the TopK feature activations per token during training, taking the Top(K*batch_size) for every batch seems to improve SAE performance. During inference, this activation can be replaced with a single global threshold for all features. Introduction Sparse autoencoders (SAEs) have emerged as a promising tool for interpreting the internal representations of large language models. By learning to reconstruct activations using only a small number of features, SAEs can extract monosemantic concepts from the representations inside transformer models. Recently, OpenAI published a paper exploring the use of TopK activation functions in SAEs. This approach directly enforces sparsity by only keeping the K largest activations per sample. While effective, TopK forces every token to use exactly k features, which is likely suboptimal. We came up with a simple modification that solves this and seems to improve its performance. BatchTopK Standard TopK SAEs apply the TopK operation independently to each sample in a batch. For a target sparsity of K, this means exactly K features are activated for every sample. BatchTopK instead applies the TopK operation across the entire flattened batch: 1. Flatten all feature activations across the batch 2. Take the top (K * batch_size) activations 3. Reshape back to the original batch shape This allows more flexibility in how many features activate per sample, while still maintaining an average of K active features across the batch. Experimental Set-Up For both the TopK and the BatchTopK SAEs we train a sweep with the following hyperparameters: Model: gpt2-small Site: layer 8 resid_pre Batch size: 4096 Optimizer: Adam (lr=3e-4, beta1 = 0.9, beta2=0.99) Number of tokens: 1e9 Expansion factor: [4, 8, 16, 32] Target L0 (k): [16, 32, 64] As in the OpenAI paper, the input gets normalized before feeding it into the SAE and calculating the reconstruction loss. We also use the same auxiliary loss function for dead features (features that didn't activate for 5 batches) that calculates the loss on the residual using the top 512 dead features per sample and gets multiplied by a factor 1/32. Results For a fixed number of active features (L0=32) the BatchTopK SAE has a lower normalized MSE than the TopK SAE and less downstream loss degradation across different dictionary sizes. Similarly, for fixed dictionary size (12288) BatchTopK outperforms TopK for different values of k. Our main hypothesis for the improved performance is thanks to adaptive sparsity: some samples contain more highly activating features than others. Let's have look at the distribution of number of active samples for the BatchTopK model. The BatchTopK model indeed makes use of its possibility to use different sparsities for different inputs. We suspect that the weird peak on the left side are the feature activations on BOS-tokens, given that its frequency is very close to 1 in 128, which is the sequence length. This serves as a great example of why BatchTopK might outperform TopK. At the BOS-token, a sequence has very little information yet, but the TopK SAE still activates 32 features. The BatchTopK model "saves" th...
  continue reading

2434 episodes

All episodes

×
 
Loading …

Welcome to Player FM!

Player FM is scanning the web for high-quality podcasts for you to enjoy right now. It's the best podcast app and works on Android, iPhone, and the web. Signup to sync subscriptions across devices.

 

Quick Reference Guide