Training with AION tokenisation
This guide describes how to train AstroPT models using learned AION tokenisation instead of patch-based tokenisation.
Overview
AstroPT supports two main tokenisation approaches for image data:
Patch-based tokenisation (
train.py) - Splits images into fixed spatial patchesAION tokenisation (
train_aion.py) - Uses a learned VQ-VAE tokenizer from the AION project
AION tokenisation offers a more compressed representation by learning a discrete codebook of visual tokens, similar to how DALL-E or VQGAN work. Instead of treating each image patch as a continuous vector, AION maps image regions to discrete token IDs from a learned vocabulary.
Key differences from patch-based tokenisation
- Tokenisation location:
Patch-based: Happens in the dataset class, produces continuous vectors
AION: Happens in the dataloader’s collate function, produces discrete token IDs
- Model architecture:
Patch-based: Uses
Encodermodule to project continuous patches to embedding spaceAION: Uses
Embeddermodule (embedding table) to look up discrete tokens
- Modality configuration:
Patch-based:
vocab_size=0indicates continuous inputAION:
vocab_size=10000indicates discrete tokens from a learned codebook
- Memory footprint:
Patch-based: Full image data flows through the model
AION: Compressed token IDs flow through the model
Installation
To use AION tokenisation, you’ll need to install the AION library:
uv sync --all-extras
Command-line usage
Single GPU
uv run scripts/train_aion.py
Modality configuration for AION
The key difference in configuration is the vocab_size parameter:
modalities = [
ModalityConfig(
name="images_aion",
input_size=10000, # Must == vocab_size
patch_size=0, # Placeholder, not used for discrete tokens
loss_weight=1.0,
embed_pos=True,
pos_input_size=1,
vocab_size=10000, # Size of AION's learned codebook
),
]
When vocab_size > 0, AstroPT knows to use an embedding table instead of a continuous encoder.
Dataset and dataloader setup
AION tokenisation requires a different dataloader setup and dataset than patch-based tokenisation as the AION tokeniser has only been trained on raw astronomical imagery. Here we will use Legacy Survey imagery.
Dataset: Load raw image data from Hugging Face
from datasets import load_dataset tds = load_dataset( "Smith42/legacysurvey_hsc_crossmatched", split="train", streaming=True, ).select_columns("legacysurvey_image").rename_column("legacysurvey_image", "image")
Codec initialisation: Load the pre-trained AION tokenizer
from aion.codecs.image import ImageCodec from aion.modalities import LegacySurveyImage image_codec = ImageCodec.from_pretrained( "polymathic-ai/aion-base", modality=LegacySurveyImage ) image_codec = image_codec.eval()
Collate function: Tokenise in batches during loading
from functools import partial def collate_and_tokenise(batch, image_codec): """Batch tokenisation in main process""" # Stack all fluxes into a single batch tensor flux_batch = torch.stack([ torch.tensor(item["image"]["flux"]) for item in batch ]) # Create single batched LegacySurveyImage batched_img = LegacySurveyImage( flux=flux_batch, bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] ) # Single encode call for entire batch tokens = image_codec.encode(batched_img) batch_size = len(batch) return { "images_aion": tokens.long(), "images_aion_positions": torch.stack([ torch.arange(tokens.shape[1]) for _ in range(batch_size) ]) } collate_fn = partial(collate_and_tokenise, image_codec=image_codec) train_loader = DataLoader( tds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, persistent_workers=True if num_workers > 0 else False, collate_fn=collate_fn, )
Loss function differences
AION tokens are discrete, so the loss function differs from patch-based training:
Patch-based: Uses Huber loss (robust L1/L2 hybrid) for continuous predictions
AION: Uses cross-entropy loss for discrete token classification
# In the model's forward pass
if "aion" in mod_name:
loss += F.cross_entropy(
pred.reshape(-1, pred.size(-1)),
target.reshape(-1),
) * mod_config.loss_weight
else:
loss += F.huber_loss(pred, target) * mod_config.loss_weight
Validation and visualisation
During validation, AION tokens must be decoded back to images for visualisation:
@torch.no_grad()
def validate(iter_num, out_dir):
model.eval()
# Get predictions
B = gid.process_modes(next(vdl), modality_registry, device)
with ctx:
P, loss = model(B["X"], B["Y"])
# Decode target tokens to images
Yim = torch.cat((
torch.zeros(B["Y"]["images_aion"].shape[0], 1),
B["Y"]["images_aion"].cpu(),
), dim=1)
Yim = image_codec.decode(
Yim,
bands=["DES-G", "DES-R", "DES-Z"],
)
# Decode predicted tokens to images
Pim = torch.cat((
torch.zeros(B["Y"]["images_aion"].shape[0], 1),
torch.argmax(P["images_aion"], dim=-1).cpu(),
), dim=1)
Pim = image_codec.decode(
Pim,
bands=["DES-G", "DES-R", "DES-Z"],
)
# Visualize
clip_and_norm = lambda x: (torch.clamp(x, x.min(), x.quantile(0.99)) - x.min()) / (x.quantile(0.99) - x.min())
for ax, p, y in zip(axs, Pim.flux, Yim.flux):
ax[0].imshow(clip_and_norm(y.swapaxes(0, -1)))
ax[1].imshow(clip_and_norm(p.swapaxes(0, -1)))
References
AION project: https://github.com/PolymathicAI/aion
AION paper: https://arxiv.org/abs/2510.17960
VQ-VAE paper: https://arxiv.org/abs/1711.00937