FourCastNeXt Training Demo#

This notebook shows how to train the high-resolution global weather model “FourCastNeXt” by Guo et al. (2024) (https://doi.org/10.48550/arXiv.2401.05584). This model runs at 0.25 degree resolution and produces ten-day forecasts. This tutorial takes 48 hours of training on a single V100 GPU to converge. It may be faster or slower on other GPU configurations.

Note This model is not identical to the model from Guo et al. (2024). This tutorial uses a reduced set of variables and does not precisely match the number of blocks used for data augmentation.

Training can be initiated a few different ways, and there are good reasons for each use case. It can be done:

  1. Inside a Jupyter notebook, as Python code, with or without experiment tracking

  2. From the command-line or in a Jupyter notebook using the command-line execution magic, leveraging “Hydra” for experiment tracking

  3. Using a supercomputer job scheduler to submit training jobs as part of a queuing system

This notebook will start with the first approach to illuminate the process, but for those involved in research into new model archictures, (2) and (3) offer more flexibility for training multiple models at once, and for operating across multiple HPC nodes or cloud instances to accelerate training and discovery.

Performing the training from the command-line is basically a matter of making a .py version of this notebook. Job schedule submission is a matter of calling that .py file via a job script wrapper.

Training from ERA5 - Limited Variable, Early Stopping Demonstration#

[1]:
import hydra
from omegaconf import OmegaConf
import site_archive_nci
import fourcastnext
import pathlib
[2]:
doi = '20220222T00'
[3]:
import pyearthtools.training
import pyearthtools.pipeline
[4]:
variables = ['msl', '10u', '10v', '2t']
train_start = 2000
train_end = 2015
valid_start = 2018
valid_end = 2020
[5]:
training_pipeline = pyearthtools.pipeline.Pipeline(
    pyearthtools.data.archive.ERA5(variables),
    pyearthtools.data.transforms.coordinates.StandardLongitude('-180-180'),
    fourcastnext.CropToRectangle(),  # Shave off a pixel for convenience in the demo
    pyearthtools.pipeline.modifications.idx_modification.TemporalRetrieval(((-6,1), (6,2,6))),
    pyearthtools.pipeline.operations.xarray.conversion.ToNumpy(),
    pyearthtools.pipeline.operations.numpy.reshape.Rearrange('c t h w -> t c h w'),
    sampler=pyearthtools.pipeline.samplers.Default(),
    iterator=pyearthtools.pipeline.iterators.DateRange(train_start, train_end, interval='6 hours')

)
# training_pipeline
[6]:
valid_pipeline = pyearthtools.pipeline.Pipeline(
    pyearthtools.data.archive.ERA5(variables),
    pyearthtools.data.transforms.coordinates.StandardLongitude('-180-180'),
    fourcastnext.CropToRectangle(),  # Shave off a pixel for convenience in the demo
    pyearthtools.pipeline.modifications.idx_modification.TemporalRetrieval(((-6,1), (6,2,6))),
    pyearthtools.pipeline.operations.xarray.conversion.ToNumpy(),
    pyearthtools.pipeline.operations.numpy.reshape.Rearrange('c t h w -> t c h w'),
    sampler=pyearthtools.pipeline.samplers.Default(),
    iterator=pyearthtools.pipeline.iterators.DateRange(valid_start, valid_end, interval='6 hours')

)
# valid_pipeline
[10]:
datamodule = pyearthtools.training.data.lightning.PipelineLightningDataModule(
    training_pipeline,
    train_split = pyearthtools.pipeline.iterators.DateRange(1980, 2016, interval='6 hours'),
    valid_split = pyearthtools.pipeline.iterators.DateRange(2018, 2020, interval = '6 hours'),
    num_workers = 4,
    batch_size = 1
)
[11]:
# Uncomment this to see the output of the pipeline at a given hour
# datamodule.pipelines['2001-01-01T00']
[12]:
model = fourcastnext.registered_model.FourCastNextRM(
    pipeline=datamodule,
    lightning_model_params = {'img_size': (720, 1440),
                              'in_channels': 4, # Increase this if using additional data
                              'out_channels': 4, # Increase this if using additional data
                              'embed_dim': 768,
                              'num_blocks': 1,
                              'patch_size': (4,4),  # Change this to (4,4) if the GPU memory is exceeded
                              'depth': 12,
                             },
    output='.',
    lead_time=6  # Time delta for autoregressive step
)
Setting up PyTorch Lightning Model
{'img_size': (720, 1440), 'in_channels': 4, 'out_channels': 4, 'embed_dim': 768, 'num_blocks': 1, 'patch_size': (4, 4), 'depth': 12}
[13]:
trainer_configuration = {'precision': '16',
                         'checkpointing': [
                             {'monitor': 'train_loss', 'mode': 'min',
                              'dirpath': '{path}/Checkpoints/Train',
                              'filename': 'model-{epoch:02d}-{step:02d}',
                              'every_n_train_steps': 1000, 'save_top_k': 10},
                             {'monitor': 'valid_loss', 'mode': 'min',
                              'dirpath': '{path}/Checkpoints/Valid',
                              'filename': 'model-{epoch:02d}-{step:02d}-{valid_loss}',
                              'every_n_train_steps': 5000, 'save_top_k': 10},
                             {'monitor': 'epoch', 'mode': 'max',
                              'dirpath': '{path}/Checkpoints/Epoch',
                              'filename': 'model-{epoch:02d}',
                              'save_on_train_epoch_end': True,
                              'save_top_k': 50}]}

[14]:
MAX_EPOCHS = 2  # Suggest training 1 or 2 epochs at first.
           # It is worth experimenting with longer training

checkpoint_path = '/scratch/kd24/ML/full_res'

trainer = pyearthtools.training.lightning.Train(
    model.lightning_model,  # We train the lightning model not the registered model?
    datamodule,
    path=checkpoint_path,
    trainer_kwargs={'num_sanity_val_steps': 0, 'max_epochs': MAX_EPOCHS, },
    **trainer_configuration
)
[15]:
# Should probably checkpoint the weights everyon 1k steps or so and allow resumption of training
[ ]:
trainer.fit()
[ ]:

[ ]:

[ ]: