Bundled Models API Docs#
Note - at some point, “bundled models” will become simply “models”, and each model will be added to that namespace.
fourcastnext#
- class fourcastnext.lightning_model.FourCastNextLM(model_params={}, *, base_lr=0.001, grad_accum_schedule=None, precision=32, loss_function='L1Loss', loss_kwargs={})#
FourCastNeXt model
Expects data in (B,T,C,H,W, B,T_1,C,H,W)
With the first element being the input and the second the target
T_1can be any length thus indicating training up to that rollout.- Parameters:
model_params (dict) – Model params to pass to
AFNONetbase_lr – Base learning rate.
grad_accum_schedule – tbd.
precision – Float precision. Defaults to 32.
loss_function (str) – Loss function to use. Defaults to “L1Loss”.
loss_kwargs (dict) – Kwargs to pass to the loss function.
- configure_optimizers()#
tbd
- forward(x, net)#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- predict_step(batch, batch_idx)#
Step function called during
predict(). By default, it callsforward(). Override to add any processing logic.The
predict_step()is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWritershould be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)as predictions won’t be returned.- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Predicted output (optional).
Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- training_step(batch, batch_idx)#
B T C H W
- validation_step(batch, batch_idx)#
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- class fourcastnext.registered_model.FourCastNextRM(*, pipeline_name=None, pipeline=None, output, lead_time, ckpt_path=None, interval=6, lightning_model_params={}, **kwargs)#
FourCastNeXt was originally developed by FourCastNeXt ([Guo et al. 2024](https://doi.org/10.48550/arXiv.2401.05584))
This class provides the underlying architecture as a registered model within the framework, so that it can be trained according to whatever data and resolution may be of interest.
Users need to train their own model weights.
- Parameters:
lead_time (int | str | pyearthtools.data.TimeDelta) – Lead time to predict to. If int will be given as hours. Separate delta notation by -.
interval (int) – Data interval in hours. Defaults to 6.
ckpt_path (str, optional) – Override for weights path
pipeline_name (str)
output (str | Path)
Create FourCastNeXt Model
- Parameters:
pipeline_name (str) – Pipeline name to use
output (str | Path) – Output location
lead_time (int | str) – Lead time of forecast (hours).
interval (int) – Data interval in hours. Defaults to 6.
ckpt_path (str | None) – Override for weights path
- load(**kwargs)#
Load model
- Returns:
Predictor, index kwargs
- Return type:
(tuple[Any, dict[str, Any]])
- class fourcastnext.CropToRectangle(warn=True)#
Cut with Bounding box
Default ERA5 is 721x1440. FourCastNeXt needs to be able to use 2x2 kernels, so needs an even number grid dimension. For now, this class just disposes of the surplus pixels. In future the cropping strategy from the paper could be implemented, or a complex regrid could be performed to resample to an even grid
- class fourcastnext.CropToRectangleSmall(warn=True)#
Cut with Bounding box
Default ERA5 is 721x1440. FourCastNeXt needs to be able to use 2x2 kernels, so needs an even number grid dimension. For now, this class just disposes of the surplus pixels. In future the cropping strategy from the paper could be implemented, or a complex regrid could be performed to resample to an even grid
- class fourcastnext.architecture.afnonet.Mlp(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class fourcastnext.architecture.afnonet.AFNO2D(hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class fourcastnext.architecture.afnonet.Block(dim, mlp_ratio=4.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, layer_scale=1.0, double_skip=True, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1.0, is_last_block=False)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class fourcastnext.architecture.afnonet.AFNONet(img_size=(128, 128), in_channels=10, out_channels=10, patch_size=(4, 4), embed_dim=768, depth=12, mlp_ratio=4.0, drop_path_rate=0.0, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1.0)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class fourcastnext.architecture.afnonet.PatchEmbed(img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.