smol-llama 🦙
A 360M parameter LLaMA-style language model trained from scratch on 6B tokens. Built with GQA, RoPE, and SwiGLU on a single H100 GPU in 22 hours for $53.
Overview
smol-llama is a minimal, from-scratch implementation of a LLaMA-style language model for pre-training on custom data. The project demonstrates that you can train a capable small language model on a reasonable budget—the entire 360M parameter model was trained on 6B tokens using a single NVIDIA H100 GPU in approximately 22 hours at a total cost of around $53.
Model Architecture
Key Features
- Grouped Query Attention (GQA): Efficient inference with 15 query heads and 5 key-value heads
- RoPE: Rotary Position Embeddings for better position encoding
- RMSNorm: Root Mean Square normalization instead of LayerNorm
- SwiGLU: Gated Linear Unit activation in the feed-forward network
- Flash Attention 2: Fast and memory-efficient attention with SDPA fallback
- Gradient Checkpointing: Memory-efficient training for larger models
- torch.compile: Optimized training speed with PyTorch 2.0 compilation
Training Details
Dataset
The model is trained on fineweb-6b, a curated 6B token dataset pre-tokenized with a custom 49K BPE vocabulary. The dataset includes 11.3 GB of training tokens and 57 MB of validation tokens, all pre-processed for immediate use.
Quick Start
# Install dependencies
uv sync
# Run training
uv run ./pretrain.pyThe training script automatically downloads the pre-tokenized dataset, initializes the model, trains with gradient accumulation and mixed precision, and saves checkpoints every 200 steps.
Using the Pre-trained Model
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"ifkash/smol-llama",
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("ifkash/smol-llama")
# Generate text
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
do_sample=True,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))Training Configuration
Project Structure
pretrain.py: Main training script with gradient accumulation and checkpointingutils/model.py: Complete LLaMA architecture implementationutils/rotary.py: Rotary position embeddings (RoPE)utils/data.py: Efficient data loading from pre-tokenized binariesutils/checkpoint.py: Checkpoint saving/loading and HuggingFace uploadsutils/lr_schedule.py: Cosine learning rate schedule with warmuputils/logging.py: Weights & Biases integration for experiment trackingnotebooks/1-train-tokenizer.ipynb: Custom BPE tokenizer training