Training with AION tokenisation

This guide describes how to train AstroPT models using learned AION tokenisation instead of patch-based tokenisation.

Overview

AstroPT supports two main tokenisation approaches for image data:

  1. Patch-based tokenisation (train.py) - Splits images into fixed spatial patches

  2. AION tokenisation (train_aion.py) - Uses a learned VQ-VAE tokenizer from the AION project

AION tokenisation offers a more compressed representation by learning a discrete codebook of visual tokens, similar to how DALL-E or VQGAN work. Instead of treating each image patch as a continuous vector, AION maps image regions to discrete token IDs from a learned vocabulary.

Key differences from patch-based tokenisation

Tokenisation location:
  • Patch-based: Happens in the dataset class, produces continuous vectors

  • AION: Happens in the dataloader’s collate function, produces discrete token IDs

Model architecture:
  • Patch-based: Uses Encoder module to project continuous patches to embedding space

  • AION: Uses Embedder module (embedding table) to look up discrete tokens

Modality configuration:
  • Patch-based: vocab_size=0 indicates continuous input

  • AION: vocab_size=10000 indicates discrete tokens from a learned codebook

Memory footprint:
  • Patch-based: Full image data flows through the model

  • AION: Compressed token IDs flow through the model

Installation

To use AION tokenisation, you’ll need to install the AION library:

uv sync --all-extras

Command-line usage

Single GPU

uv run scripts/train_aion.py

Modality configuration for AION

The key difference in configuration is the vocab_size parameter:

modalities = [
    ModalityConfig(
        name="images_aion",
        input_size=10000,  # Must == vocab_size
        patch_size=0,      # Placeholder, not used for discrete tokens
        loss_weight=1.0,
        embed_pos=True,
        pos_input_size=1,
        vocab_size=10000,  # Size of AION's learned codebook
    ),
]

When vocab_size > 0, AstroPT knows to use an embedding table instead of a continuous encoder.

Dataset and dataloader setup

AION tokenisation requires a different dataloader setup and dataset than patch-based tokenisation as the AION tokeniser has only been trained on raw astronomical imagery. Here we will use Legacy Survey imagery.

  1. Dataset: Load raw image data from Hugging Face

    from datasets import load_dataset
    
    tds = load_dataset(
        "Smith42/legacysurvey_hsc_crossmatched",
        split="train",
        streaming=True,
    ).select_columns("legacysurvey_image").rename_column("legacysurvey_image", "image")
    
  2. Codec initialisation: Load the pre-trained AION tokenizer

    from aion.codecs.image import ImageCodec
    from aion.modalities import LegacySurveyImage
    
    image_codec = ImageCodec.from_pretrained(
        "polymathic-ai/aion-base",
        modality=LegacySurveyImage
    )
    image_codec = image_codec.eval()
    
  3. Collate function: Tokenise in batches during loading

    from functools import partial
    
    def collate_and_tokenise(batch, image_codec):
        """Batch tokenisation in main process"""
        # Stack all fluxes into a single batch tensor
        flux_batch = torch.stack([
            torch.tensor(item["image"]["flux"])
            for item in batch
        ])
    
        # Create single batched LegacySurveyImage
        batched_img = LegacySurveyImage(
            flux=flux_batch,
            bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z']
        )
    
        # Single encode call for entire batch
        tokens = image_codec.encode(batched_img)
    
        batch_size = len(batch)
    
        return {
            "images_aion": tokens.long(),
            "images_aion_positions": torch.stack([
                torch.arange(tokens.shape[1]) for _ in range(batch_size)
            ])
        }
    
    collate_fn = partial(collate_and_tokenise, image_codec=image_codec)
    
    train_loader = DataLoader(
        tds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True if num_workers > 0 else False,
        collate_fn=collate_fn,
    )
    

Loss function differences

AION tokens are discrete, so the loss function differs from patch-based training:

  • Patch-based: Uses Huber loss (robust L1/L2 hybrid) for continuous predictions

  • AION: Uses cross-entropy loss for discrete token classification

# In the model's forward pass
if "aion" in mod_name:
    loss += F.cross_entropy(
        pred.reshape(-1, pred.size(-1)),
        target.reshape(-1),
    ) * mod_config.loss_weight
else:
    loss += F.huber_loss(pred, target) * mod_config.loss_weight

Validation and visualisation

During validation, AION tokens must be decoded back to images for visualisation:

@torch.no_grad()
def validate(iter_num, out_dir):
    model.eval()

    # Get predictions
    B = gid.process_modes(next(vdl), modality_registry, device)
    with ctx:
        P, loss = model(B["X"], B["Y"])

        # Decode target tokens to images
        Yim = torch.cat((
            torch.zeros(B["Y"]["images_aion"].shape[0], 1),
            B["Y"]["images_aion"].cpu(),
        ), dim=1)
        Yim = image_codec.decode(
            Yim,
            bands=["DES-G", "DES-R", "DES-Z"],
        )

        # Decode predicted tokens to images
        Pim = torch.cat((
            torch.zeros(B["Y"]["images_aion"].shape[0], 1),
            torch.argmax(P["images_aion"], dim=-1).cpu(),
        ), dim=1)
        Pim = image_codec.decode(
            Pim,
            bands=["DES-G", "DES-R", "DES-Z"],
        )

        # Visualize
        clip_and_norm = lambda x: (torch.clamp(x, x.min(), x.quantile(0.99)) - x.min()) / (x.quantile(0.99) - x.min())

        for ax, p, y in zip(axs, Pim.flux, Yim.flux):
            ax[0].imshow(clip_and_norm(y.swapaxes(0, -1)))
            ax[1].imshow(clip_and_norm(p.swapaxes(0, -1)))

References