Bundled Models API Docs#

Note - at some point, “bundled models” will become simply “models”, and each model will be added to that namespace.

fourcastnext#

class fourcastnext.lightning_model.FourCastNextLM(model_params={}, *, base_lr=0.001, grad_accum_schedule=None, precision=32, loss_function='L1Loss', loss_kwargs={})#

FourCastNeXt model

Expects data in (B,T,C,H,W, B,T_1,C,H,W)

With the first element being the input and the second the target T_1 can be any length thus indicating training up to that rollout.

Parameters:
  • model_params (dict) – Model params to pass to AFNONet

  • base_lr – Base learning rate.

  • grad_accum_schedule – tbd.

  • precision – Float precision. Defaults to 32.

  • loss_function (str) – Loss function to use. Defaults to “L1Loss”.

  • loss_kwargs (dict) – Kwargs to pass to the loss function.

configure_optimizers()#

tbd

forward(x, net)#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

predict_step(batch, batch_idx)#

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

Predicted output (optional).

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
training_step(batch, batch_idx)#

B T C H W

validation_step(batch, batch_idx)#

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

class fourcastnext.registered_model.FourCastNextRM(*, pipeline_name=None, pipeline=None, output, lead_time, ckpt_path=None, interval=6, lightning_model_params={}, **kwargs)#

FourCastNeXt was originally developed by FourCastNeXt ([Guo et al. 2024](https://doi.org/10.48550/arXiv.2401.05584))

This class provides the underlying architecture as a registered model within the framework, so that it can be trained according to whatever data and resolution may be of interest.

Users need to train their own model weights.

Parameters:
  • lead_time (int | str | pyearthtools.data.TimeDelta) – Lead time to predict to. If int will be given as hours. Separate delta notation by -.

  • interval (int) – Data interval in hours. Defaults to 6.

  • ckpt_path (str, optional) – Override for weights path

  • pipeline_name (str)

  • output (str | Path)

Create FourCastNeXt Model

Parameters:
  • pipeline_name (str) – Pipeline name to use

  • output (str | Path) – Output location

  • lead_time (int | str) – Lead time of forecast (hours).

  • interval (int) – Data interval in hours. Defaults to 6.

  • ckpt_path (str | None) – Override for weights path

load(**kwargs)#

Load model

Returns:

Predictor, index kwargs

Return type:

(tuple[Any, dict[str, Any]])

class fourcastnext.CropToRectangle(warn=True)#

Cut with Bounding box

Default ERA5 is 721x1440. FourCastNeXt needs to be able to use 2x2 kernels, so needs an even number grid dimension. For now, this class just disposes of the surplus pixels. In future the cropping strategy from the paper could be implemented, or a complex regrid could be performed to resample to an even grid

class fourcastnext.CropToRectangleSmall(warn=True)#

Cut with Bounding box

Default ERA5 is 721x1440. FourCastNeXt needs to be able to use 2x2 kernels, so needs an even number grid dimension. For now, this class just disposes of the surplus pixels. In future the cropping strategy from the paper could be implemented, or a complex regrid could be performed to resample to an even grid

class fourcastnext.architecture.afnonet.Mlp(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fourcastnext.architecture.afnonet.AFNO2D(hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fourcastnext.architecture.afnonet.Block(dim, mlp_ratio=4.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, layer_scale=1.0, double_skip=True, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1.0, is_last_block=False)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fourcastnext.architecture.afnonet.AFNONet(img_size=(128, 128), in_channels=10, out_channels=10, patch_size=(4, 4), embed_dim=768, depth=12, mlp_ratio=4.0, drop_path_rate=0.0, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1.0)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fourcastnext.architecture.afnonet.PatchEmbed(img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.