Training single and multimodal AstroPT models
This guide describes how to train AstroPT-style Large Observation Models (LOMs) using the provided training scripts.
Overview
We provide two example training scripts in scripts/:
train.py- For training on single-modality data (here galaxy imagery)train_multimodal.py- For training on multiple modalities (here both galaxy imagery and spectral data)
Of course they can be modified for any modality after altering the tokeniser and position embedding code. Both scripts support single-GPU and distributed multi-GPU training with Distributed Data Parallel (DDP).
Command-line usage
Single GPU
# For single-modality training
python train.py
# For multi-modality training
python train_multimodal.py
Distributed Data Parallel (4 GPUs on 1 node)
# For single-modality training
torchrun --standalone --nproc_per_node=4 train.py
# For multi-modality training
torchrun --standalone --nproc_per_node=4 train_multimodal.py
Distributed Data Parallel (8 GPUs across 2 nodes)
On the first (master) node:
# For either training script
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py # or train_multimodal.py
On the worker node:
# For either training script
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py # or train_multimodal.py
Key differences between scripts
While train.py and train_multimodal.py share most of their code base, there are several important differences:
Modality configuration
train.pyconfigures a single modality for galaxy imagery:modalities = [ ModalityConfig( name="images", input_size=16 * 16 * n_chan, patch_size=16, loss_weight=1.0, embed_pos=True, pos_input_size=1, ), ]
train_multimodal.pyconfigures multiple modalities:modalities = [ ModalityConfig( name="images", input_size=16 * 16 * n_chan, patch_size=16, loss_weight=1.0, embed_pos=True, pos_input_size=1, ), ModalityConfig( name="spectra", input_size=256, patch_size=256, pos_input_size=256, loss_weight=0.5, embed_pos=False, ), ]
Dataset handling
train.pycan stream Hugging Face datasets (here we use Smith42/galaxies):# When use_hf=True tds_hf = load_dataset("Smith42/galaxies", split="train", streaming=True)
train_multimodal.pyuses local file paths for both modalities:tds = GalaxyImageDataset( paths={"images": "./hsc_matched.txt", "spectra": "./spectra_matched.txt"}, spiral=spiral, transform=transforms, modality_registry=modality_registry, )
where
hsc_matched.txtandspectra_matched.txtare crossmatched text files of the same length with one line perFITSorJPGimage/spectra.
Configuration options
Both scripts support numerous configuration parameters that can be set via command line or configuration files:
Model architecture
n_layer: Number of transformer layersn_head: Number of attention headsn_embd: Embedding dimensionn_chan: Number of image channelsblock_size: Maximum sequence lengthdropout: Dropout rate (0.0 recommended for pretraining)bias: Whether to use bias in LayerNorm and Linear layersattn_type: Attention type (“causal” is standard)
Data parameters
gradient_accumulation_steps: Number of steps to accumulate gradientsbatch_size: Batch size per GPUspiral: Process galaxy patches in spiral order (as described in our paper)image_size: Size of input imagesuse_hf: Use Hugging Face dataset version (train.pyonly)stream_hf_dataset: Stream the galaxies from Hugging Face (train.pyonly)
Optimiser settings
learning_rate: Maximum learning rateweight_decay: Weight decay valuebeta1: Adam beta1beta2: Adam beta2grad_clip: Gradient clipping valuedecay_lr: Whether to decay learning ratewarmup_iters: Warmup iterationslr_decay_iters: Total iterations for LR decaymin_lr: Minimum learning rate (learning_rate/10)
Training loop settings
max_iters: Total number of training iterationseval_interval: Interval for evaluationlog_interval: Interval for loggingeval_iters: Number of batches for evaluationcheckpoint_interval: Interval for saving checkpointseval_only: Only perform evaluation, no trainingalways_save_checkpoint: Always save checkpoints regardless of lossinit_from: Initialize from scratch or resume training
System settings
device: Device to use (default: “cuda”)dtype: Data type to use (default: “bfloat16”)compile: Use PyTorch 2.0 to compile the modelbackend: DDP backend (default: “nccl”)out_dir: Output directory for logs and checkpointslog_via_wandb: Use WandB for logginglog_emissions: Use CodeCarbon to track emissions
Configuration files
Instead of specifying all parameters via command line, you can create a configuration file:
# config/astropt.py
out_dir = "logs/astropt"
batch_size = 32
n_layer = 24
n_head = 16
n_embd = 1024
And then pass it to the script:
python train.py config/astropt.py # or train_multimodal.py
You can also override specific parameters from the config file:
python train.py config/astropt.py --batch_size=64
We have example config files in config.
Creating custom dataloaders
The AstroPT training pipeline uses the GalaxyImageDataset class from local_datasets.py to handle data loading and preprocessing. You can create your own custom dataloader for different modalities by extending this class or creating a similar class that follows the same interface.
Here’s a guide to creating a custom dataloader for AstroPT:
Basic structure
Your dataloader should inherit from
torch.utils.data.Datasetand implement the following methods:__init__: Initialize dataset with paths, transforms, and modality registry__len__: Return the dataset length__getitem__: Return a dictionary of data for each indexprocess_modes: Process data into X and Y tensors for the model
Example skeleton
class CustomDataset(Dataset): def __init__(self, paths, transform=None, modality_registry=None): """ Args: paths (dict): Dictionary of paths for each modality transform (dict, optional): Dictionary of transforms for each modality modality_registry: ModalityRegistry object containing modality configurations """ self.paths = paths self.transform = transform self.modality_registry = modality_registry def __len__(self): """Return the total number of samples in the dataset""" return len(self.paths[list(self.paths.keys())[0]]) def __getitem__(self, idx): """Get a single sample from the dataset""" # Process each modality and return a dictionary result = {} # Example for image modality if "images" in self.paths: # Load and process image data image_data = self.load_image(self.paths["images"][idx]) processed_image = self.process_image(image_data) result["images"] = processed_image result["images_positions"] = torch.arange(0, len(processed_image), dtype=torch.long) # Example for another modality if "spectra" in self.paths: # Load and process spectral data spectral_data = self.load_spectrum(self.paths["spectra"][idx]) processed_spectrum, wavelengths = self.process_spectrum(spectral_data) result["spectra"] = processed_spectrum result["spectra_positions"] = wavelengths result["idx"] = idx return result def load_image(self, path): """Load image data from path""" # Implement loading logic for your image format pass def process_image(self, image_data): """Process loaded image data into model-ready format""" # Implement processing logic (patching, standardization, etc.) pass def load_spectrum(self, path): """Load spectral data from path""" # Implement loading logic for your spectral format pass def process_spectrum(self, spectral_data): """Process loaded spectral data into model-ready format""" # Implement processing logic pass @staticmethod def process_modes(x, modality_registry, device, shuf=False): """ Process data dictionary into X and Y tensors for model input/output Args: x (dict): Data dictionary from __getitem__ modality_registry: ModalityRegistry object device: torch device to move tensors to shuf (bool): Whether to shuffle modality order Returns: dict: Dictionary containing 'X' and 'Y' keys with model-ready tensors """ modes = modality_registry.generate_sequence(shuf=shuf) # Move all tensors to device x_on_device = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in x.items() } X = {} Y = {} # Process each modality for ii, mode in enumerate(modes): X[mode] = x_on_device[mode] X[f"{mode}_positions"] = x_on_device[f"{mode}_positions"] Y[mode] = x_on_device[mode] # Handle autoregressive prediction (shift by one token) if ii == 0: Y[mode] = Y[mode][:, 1:] if len(modes) == 1: X[mode] = X[mode][:, :-1] X[f"{mode}_positions"] = X[f"{mode}_positions"][:, :-1] return {"X": X, "Y": Y}
Example: custom spectra dataset
class StellarSpectraDataset(Dataset): def __init__(self, paths, transform=None, modality_registry=None): self.paths = paths self.transform = transform self.modality_registry = modality_registry # Read file paths if "spectra" in paths and paths["spectra"] is not None: self.spectra_paths = np.genfromtxt(paths["spectra"], delimiter=",", dtype=str) else: self.spectra_paths = None # Set length to the first non-None dataset self.dataset_len = len(self.spectra_paths) if self.spectra_paths is not None else 0 def __len__(self): return self.dataset_len def process_spectrum(self, raw_spectrum, wavelength): patch_size = self.modality_registry.get_config("spectra").patch_size # Apply padding to the spectrum to make it divisible by patch_size w = raw_spectrum.shape[0] pad_w = (patch_size - w % patch_size) % patch_size padded_spectrum = F.pad(raw_spectrum, (0, pad_w)) padded_wl = F.pad(wavelength, (0, pad_w)) # Rearrange into patches patch_spectrum = einops.rearrange( padded_spectrum, "(w p) -> (w) (p)", p=patch_size, ) patch_wl = einops.rearrange( padded_wl, "(w p) -> (w) (p)", p=patch_size, ) # Apply transforms if specified if "spectra" in self.transform: patch_spectrum = self.transform["spectra"](patch_spectrum) return patch_spectrum, patch_wl def __getitem__(self, idx): try: # Load spectral data from FITS file with fits.open(self.spectra_paths[idx]) as hdul: raw_spectrum = hdul[1].data["Flux"].astype(np.float32) wavelength = hdul[1].data["Wave"].astype(np.float32) # Convert to tensor and normalize raw_spectrum = torch.tensor(raw_spectrum).to(torch.bfloat16) wavelength = (torch.tensor(wavelength).to(torch.bfloat16) - 3000) / (10000 - 3000) # Process the spectrum patch_spectrum, patch_wl = self.process_spectrum(raw_spectrum, wavelength) # Check for NaN values if torch.isnan(patch_spectrum).any() or torch.isnan(patch_wl).any(): raise ValueError("Found NaNs in spectra, skipping file") return { "spectra": patch_spectrum, "spectra_positions": patch_wl, "idx": idx, } except Exception as err: print(f"Error processing file {self.spectra_paths[idx]}: {err}") raise @staticmethod def process_modes(x, modality_registry, device, shuf=False): modes = modality_registry.generate_sequence(shuf=shuf) x_on_device = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in x.items() } X = {} Y = {} for ii, mode in enumerate(modes): X[mode] = x_on_device[mode] X[f"{mode}_positions"] = x_on_device[f"{mode}_positions"] Y[mode] = x_on_device[mode] if ii == 0: Y[mode] = Y[mode][:, 1:] if len(modes) == 1: X[mode] = X[mode][:, :-1] X[f"{mode}_positions"] = X[f"{mode}_positions"][:, :-1] return {"X": X, "Y": Y}
Integration with training script
To use your custom dataloader in the training script:
Import your custom dataset class
Replace the GalaxyImageDataset instantiation with your custom dataset
Use the same dataloader configuration as in the original script
from custom_dataloader import CustomDataset # ... # Initialize custom dataset tds = CustomDataset( paths={"modality1": "path1.txt", "modality2": "path2.txt"}, transform=transforms, modality_registry=modality_registry, ) # Create DataLoader with the custom dataset tdl = iter( DataLoader( tds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, ) )
Output and monitoring
During training, both scripts provide:
Loss values for training and validation sets
Visual comparisons of original data and model predictions
Checkpoint saving based on validation performance
Optional WandB integration for experiment tracking
MFU (Model FLOP Utilization) estimates
Optional carbon emissions tracking
The training progress, model visualizations, and metrics are saved to the specified output directory.