Few-Shot Prompting: Does It Really Degrade Model Performance?
A spotlight on one of the claims in the Deep Seek Paper.
DeepSeek, a rising name in AI innovation, recently released DeepSeek-R1, a model trained entirely using reinforcement learning (RL). In their paper, they make a bold claim:
"Few-shot prompting consistently degrades performance."
This challenges conventional wisdom that few-shot prompting—providing a model with a few examples before querying—improves accuracy.
In this post, we summarize an experiment using toy models to explore how reinforcement learning models and deep learning models handle few-shot and zero-shot tasks. Is DeepSeek's claim universal, or does it apply only to RL-trained models?
Experiment Overview
We created a toy test framework with two models:
Reinforcement learning (RL) model, trained via trial and error using the REINFORCE algorithm.
Deep learning (DL) model, trained with supervised learning on labeled arithmetic problems.
Task
The models solved simple arithmetic problems (e.g., "3 + 5"
), evaluated under:
Zero-shot prompting: No examples provided.
Few-shot prompting: A few examples included before the query, e.g.,:
"Example 1: 1 + 2 = 3. Example 2: 4 + 5 = 9. Now, what’s 3 + 5?"
Results
DeepSeek's claim that "few-shot prompting consistently degrades performance" was tested using a toy framework comparing reinforcement learning (RL) and deep learning (DL) models on arithmetic tasks. The RL model's accuracy dropped from 85% (zero-shot) to 70% (few-shot), while the DL model improved from 78% to 82%, suggesting that few-shot prompting may hinder RL models but benefit DL models. This highlights how the effectiveness of prompting depends on the model's training approach.
The results highlight that RL-trained models underperformed when given few-shot prompts, consistent with DeepSeek's claim. However, DL models benefited from few-shot examples, aligning with traditional expectations.
Experiment Code
Arithmetic Environment
The environment generates arithmetic problems and evaluates responses:
import random
class ArithmeticEnvironment:
def __init__(self, max_num=50):
self.max_num = max_num
def reset(self):
self.a = random.randint(1, self.max_num)
self.b = random.randint(1, self.max_num)
self.question = f"{self.a} + {self.b}"
self.answer = self.a + self.b
return self.question
def step(self, model_output):
correct = int(model_output == self.answer)
reward = 1.0 if correct else -1.0
return reward, True
Tokenizer
This simple tokenizer handles text encoding and decoding for arithmetic problems:
class Tokenizer:
def __init__(self):
self.token_to_id = {"<PAD>": 0, "<UNK>": 1}
self.id_to_token = {0: "<PAD>", 1: "<UNK>"}
def encode(self, text):
tokens = text.split()
return [self.token_to_id.get(token, self.token_to_id["<UNK>"]) for token in tokens]
def decode(self, ids):
return " ".join(self.id_to_token.get(id, "<UNK>") for id in ids)
def build_vocab(self, dataset):
for question, _ in dataset:
for token in question.split():
if token not in self.token_to_id:
new_id = len(self.token_to_id)
self.token_to_id[token] = new_id
self.id_to_token[new_id] = token
Data Generation
We generated datasets for arithmetic training and evaluation:
def generate_dataset(num_samples=1000, max_num=50):
dataset = []
for _ in range(num_samples):
a = random.randint(1, max_num)
b = random.randint(1, max_num)
question = f"{a} + {b}"
answer = str(a + b)
dataset.append((question, answer))
return dataset
RL and DL Models
Both models use transformer-like architectures for learning arithmetic tasks.
RL Transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
class RLTransformer(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super(RLTransformer, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
x = self.embedding(x)
x = F.relu(self.fc1(x.mean(dim=1)))
x = self.fc2(x)
return F.softmax(x, dim=-1)
DL Transformer
class DLTransformer(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super(DLTransformer, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
x = self.embedding(x)
x = F.relu(self.fc1(x.mean(dim=1)))
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
Training
RL Model Training
import torch.optim as optim
def train_rl_model(model, env, tokenizer, device, episodes=5000, lr=1e-3):
optimizer = optim.Adam(model.parameters(), lr=lr)
for episode in range(episodes):
question = env.reset()
question_ids = torch.tensor(tokenizer.encode(question)).unsqueeze(0).to(device)
probs = model(question_ids)
action = torch.multinomial(probs, num_samples=1).item()
predicted_answer = tokenizer.decode([action])
try:
predicted_answer = int(predicted_answer)
except ValueError:
predicted_answer = -1
reward, _ = env.step(predicted_answer)
loss = -torch.log(probs[0, action]) * reward
optimizer.zero_grad()
loss.backward()
optimizer.step()
DL Model Training
def train_dl_model(model, dataloader, tokenizer, device, epochs=10, lr=1e-3):
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
total_loss = 0
for question, answer in dataloader:
question_ids = torch.tensor(tokenizer.encode(question)).unsqueeze(0).to(device)
answer_ids = torch.tensor(tokenizer.encode(answer)).unsqueeze(0).to(device)
optimizer.zero_grad()
output = model(question_ids)
loss = criterion(output.view(-1, output.size(-1)), answer_ids.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")
Conclusion
This small-scale experiment highlights how model architecture and training methods impact the effectiveness of prompting strategies. While RL models struggled with few-shot prompting, DL models thrived, showcasing that DeepSeek's claim appears for the deepseek model architecture but should not be generalized to include *all* model archtectures.