Nowcasting Autoencoder Tutorial

Nowcasting Autoencoder Tutorial#

Creating an effective autoencoder is a great first step in developing a predictive model. Learning how to create a high-quality latent state is critical, and can later be used to train a new encoder for a pre-existing predictive model, or a new decoder. An example for this could be training an encoder which uses different or fewer variables than the original model. Another might be training a higher-resolution decoder which is capable of producing higher-resolution predictions from the same latent state.

This tutorial present a simple architecture which is suitable for a first step in learning how to build an autoencoder, and can be extended to produce higher-quality images in future exercises.

Autoencoders are also a great concept to learn when understanding neural network architectures.

In this example, we first blend radar and satellite data onto the same grid, then train an autoencoder to perform dimensionality reduction and produce a useful latent state which can be used to resonstruct the original inputs.

[1]:
import pyearthtools.data as petdata
import pyearthtools.pipeline as petpipe
import site_archive_nci

from pyearthtools.data.time import Petdt
from pyearthtools.pipeline.operations.xarray.join import GeospatialTimeSeriesMerge

import xarray as xr

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt


# Set random seed for reproducibility
torch.manual_seed(42)

# Autodetect GPU and use if possible
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
[2]:
rf3proj = petdata.transforms.projection.Rainfields3ProjAus()
radar_projector = petdata.transforms.projection.XYtoLonLatRectilinear(rf3proj)
[3]:
# We specify the date, hour, and minute for querying data
doi = '2021-06-09T02'
[4]:
himawari = petdata.archive.Himawari('surface_global_irradiance')
[5]:
# TODO: It would be nice if this normalised the data nicely
satpipe = petpipe.Pipeline(
    himawari
)
[6]:
himawari_sample = satpipe[doi]
himawari_sample
/g/data/kd24/tjl/src/PyEarthTools/packages/data/src/pyearthtools/data/operations/index_routines.py:326: FutureWarning: In a future version of xarray the default value for data_vars will change from data_vars='all' to data_vars=None. This is likely to lead to different results when multiple datasets have matching variables with overlapping values. To opt in to new defaults and get rid of these warnings now use `set_options(use_new_combine_kwarg_defaults=True) or set data_vars explicitly.
  full_ds = xr.open_mfdataset(
[6]:
<xarray.Dataset> Size: 183MB
Dimensions:                    (time: 6, latitude: 1726, longitude: 2214)
Coordinates:
  * time                       (time) datetime64[ns] 48B 2021-06-09T02:00:00 ...
  * latitude                   (latitude) float32 7kB -44.5 -44.48 ... -10.0
  * longitude                  (longitude) float32 9kB 112.0 112.0 ... 156.3
Data variables:
    surface_global_irradiance  (time, latitude, longitude) float64 183MB dask.array<chunksize=(1, 1726, 2214), meta=np.ndarray>
Attributes: (12/50)
    Conventions:                      CF-1.7
    Metadata_Conventions:             Unidata Dataset Discovery v1.0
    acknowledgment:                   The following acknowledgement is requir...
    cdm_data_type:                    grid
    comment:                          Solar radiation data derived from satel...
    contributor_name:                 Mines ParisTech; Commonwealth of Austra...
    ...                               ...
    geospatial_lon_resolution:        0.02
    bias_correction_applied_meaning:  0: not applied; 1:applied
    quality_meaning:                  0: no_known_issues 1: known_issue
    project:                          Gridded Solar Observations
    references:                       Poulsen C., Majewski L. J. (2022) Gridd...
    NCO:                              netCDF Operators version 4.7.7 (Homepag...
[7]:
radar = petdata.archive.Rainfields3(variables='prcp-crate')
[8]:
radarpipe = petpipe.Pipeline(
    radar,
    radar_projector,
    petpipe.operations.xarray.metadata.Rename({'valid_time':'time'}),
)
[9]:
prepare = petpipe.Pipeline(
    (satpipe, radarpipe),
    GeospatialTimeSeriesMerge(reference_dataset=himawari_sample), # These are pretty similar grids, so just pick one
    iterator=petpipe.iterators.DateRange(2021, 2023, interval='20 minutes')
)
prepare

ipipe = iter(prepare)  # Make an iterator to walk the time period
[10]:
%%time

# Takes around 15 seconds per sample to retrieve, largely due to the zip compression used on-disk
merged_sample = next(ipipe)
merged_sample
CPU times: user 12.7 s, sys: 1.63 s, total: 14.4 s
Wall time: 14.9 s
[10]:
<xarray.Dataset> Size: 245MB
Dimensions:                    (time: 1, latitude: 1726, longitude: 2214, n2: 2)
Coordinates:
  * time                       (time) datetime64[ns] 8B 2021-01-01
  * latitude                   (latitude) float32 7kB -44.5 -44.48 ... -10.0
  * longitude                  (longitude) float32 9kB 112.0 112.0 ... 156.3
    x                          (longitude, latitude) float64 31MB -1.651e+03 ...
    y                          (longitude, latitude) float64 31MB -4.99e+03 ....
Dimensions without coordinates: n2
Data variables:
    surface_global_irradiance  (time, latitude, longitude) float64 31MB dask.array<chunksize=(1, 1726, 2214), meta=np.ndarray>
    proj                       (time) int8 1B 0
    y_bounds                   (time, longitude, latitude, n2) float64 61MB -...
    x_bounds                   (time, longitude, latitude, n2) float64 61MB -...
    rain_rate                  (time, longitude, latitude) float64 31MB nan ....
Attributes: (12/58)
    Conventions:                      CF-1.7
    Metadata_Conventions:             Unidata Dataset Discovery v1.0
    acknowledgment:                   The following acknowledgement is requir...
    cdm_data_type:                    grid
    comment:                          Solar radiation data derived from satel...
    contributor_name:                 Mines ParisTech; Commonwealth of Austra...
    ...                               ...
    quality:                          0
    quality_meaning:                  0: no_known_issues 1: known_issue
    project:                          Gridded Solar Observations
    history:                          Mon Mar  4 01:55:23 2024: ncatted -a re...
    references:                       Poulsen C., Majewski L. J. (2022) Gridd...
    NCO:                              netCDF Operators version 4.7.7 (Homepag...
[11]:
full = petpipe.Pipeline(
    (satpipe, radarpipe),
    GeospatialTimeSeriesMerge(reference_dataset=himawari_sample), # These are pretty similar grids, so just pick one
    petdata.transforms.variables.Drop(['x_bounds', 'y_bounds', 'proj', 'x', 'y']),
    petpipe.operations.xarray.Sort(order=['time', 'latitude', 'longitude']),  #
    petpipe.operations.xarray.AlignDataVariableDimensionsToDatasetCoords(),  # Align data variables coordinate ordering to dataset coordinate ordering
    petdata.transform.region.Bounding(-40, -25, 135, 152),  # cut down on region for example
    petpipe.operations.xarray.conversion.ToNumpy(),
    petpipe.operations.numpy.reshape.Rearrange('c t h w -> t c h w'), # channel time height width -> time channel height width
    iterator=petpipe.iterators.DateRange('20200101T00', '20210101T00', interval='20 minutes')
)
full

ipipe = iter(full)  # Make an iterator to walk the time period
[12]:
fullsat = petpipe.Pipeline(
    satpipe,
    # GeospatialTimeSeriesMerge(reference_dataset=himawari_sample), # These are pretty similar grids, so just pick one
    # petdata.transforms.variables.Drop(['x_bounds', 'y_bounds', 'proj', 'x', 'y']),
    petpipe.operations.xarray.Sort(order=['time', 'latitude', 'longitude']),  #
    # Align the data variable's coordinate order to the dataset coordinate order so all arrays are the same shape
    petpipe.operations.xarray.AlignDataVariableDimensionsToDatasetCoords(),
    petdata.transform.region.Bounding(-35, -25, 138, 150),  # cut down on region for example
    petpipe.operations.xarray.normalisation.SingleValueDivision(1200),
    petpipe.operations.xarray.conversion.ToNumpy(),
    petpipe.operations.numpy.reshape.Rearrange('c t h w -> t c h w'), # channel time height width -> time channel height width
    iterator=petpipe.iterators.DateRange('20200101T00', '20210101T00', interval='10 minutes'),
    exceptions_to_ignore=petdata.exceptions.DataNotFoundError
)
fullsat

ipipe = iter(fullsat)  # Make an iterator to walk the time period
[13]:
fullsat.exceptions_to_ignore
[13]:
(pyearthtools.data.exceptions.DataNotFoundError,)
[14]:
n = next(ipipe)
# n
[15]:
plt.imshow(n[0][0])
[15]:
<matplotlib.image.AxesImage at 0x1517e2077490>
../../_images/notebooks_tutorial_AutoEncoder_Example_15_1.png
[16]:
# Here we define an "AutoEncoder". This is a model which reproduces its inputs,
# through a bottleneck layer. It is one of the primary concepts behind many
# neural network architectures that you will work with in future, and is key to
# conceptual understanding as well as being sometimes useful in

# Reminder, the image size is latitude: 1726, longitude: 2214

class AutoEncoder(nn.Module):
    def __init__(self,
                 input_height = 501,
                 input_width = 601,
                 kernel_size = 4,
                 stride=2,
                 input_channel_count = 2,
                 output_channel_count = 2,
                 latent_dim=300):
        super(AutoEncoder, self).__init__()

        self.input_width = input_width
        self.input_height = input_height
        self.input_channel_count = input_channel_count
        self.output_channel_count = output_channel_count

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=self.input_channel_count, out_channels=16, kernel_size=kernel_size, stride = stride, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride =2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=7),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=16, out_channels=self.output_channel_count, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1),
            nn.Sigmoid()
        )


    def forward(self, x):

        # Get latent representation
        latent = self.encoder(x)

        # Reconstruct input
        reconstructed = self.decoder(latent)

        return reconstructed
[17]:
# Initialize model and move to device
model = AutoEncoder(input_channel_count=1, output_channel_count=1).to(device)

# Loss function and optimizer
criterion = nn.L1Loss()
# criterion = nn.KLDivLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
[18]:
x = torch.from_numpy(n).float().to(device)
[19]:
# This should make us a messy prediction from an untrained model in the dimension of the input
prediction = model.forward(x)
[20]:
image_numpy_for_display = prediction.to('cpu').detach().numpy()
# image_numpy_for_display
[21]:
plt.imshow(image_numpy_for_display[0][0])
[21]:
<matplotlib.image.AxesImage at 0x1517e1f29610>
../../_images/notebooks_tutorial_AutoEncoder_Example_21_1.png
[22]:
%%time

# 1000 samples is taking about 2 minutes
# It should be able to go much much much faster, but still we can test it and make progress
# Let's to 30 minutes of training, i.e. 15k samples

def train(debug=True, num_epochs=1, max_samples=10, print_per=20):
    # Training loop

    sample_ix = 0


    for epoch in range(num_epochs):
        total_loss = 0
        epoch_samples = 0
        ipipe = iter(fullsat)  # Make an iterator to walk the time period

        while True:
            try:
                sample = next(ipipe)
            except StopIteration:
                break  # advance the epoch loop
            except:
                pass # some samples are just missing

            sample_ix += 1
            epoch_samples += 1
            if epoch_samples % print_per == 0:
                print(epoch_samples)

            if sample_ix > max_samples:
                break

            if debug:
                print(sample_ix)

            x = torch.from_numpy(sample).float().to(device)

            if torch.any(torch.isnan(x)):
                # Skip nan inputs, they break the training
                continue

            if debug:
                print("Input")
                print(x)

            optimizer.zero_grad()

            # Forward pass
            y = model.forward(x)

            if debug:
                print("prediction")
                print(y)

            loss = criterion(y, x)
            if debug:
                print("loss")
                print(loss)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Print epoch statistics
        avg_loss = total_loss / epoch_samples
        epoch_samples = 0  # Reset for next epoch
        print(f'Epoch [{epoch+1}/{epoch_samples}], Average Loss: {avg_loss:.4f}')

train(debug=False, num_epochs=1, max_samples=15 * 1000, print_per = 500)
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
7000
7500
8000
8500
9000
9500
10000
10500
11000
11500
12000
12500
13000
13500
14000
14500
15000
Epoch [1/0], Average Loss: 0.0376
CPU times: user 23min 14s, sys: 5min 43s, total: 28min 58s
Wall time: 46min 15s
[23]:
x.min()
[23]:
tensor(0.2183, device='cuda:0')
[24]:
prediction = model.forward(x)
image_numpy_for_display = prediction.to('cpu').detach().numpy()
plt.imshow(image_numpy_for_display[0][0])
[24]:
<matplotlib.image.AxesImage at 0x1517c4c67b10>
../../_images/notebooks_tutorial_AutoEncoder_Example_24_1.png
[ ]:

[25]:


z = torch.from_numpy(fullsat['20210303T0400']).float().to(device) image_numpy_for_display = z.to('cpu').detach().numpy() plt.imshow(image_numpy_for_display[0][0])
[25]:
<matplotlib.image.AxesImage at 0x1517c4b88e10>
../../_images/notebooks_tutorial_AutoEncoder_Example_26_1.png
[26]:
latent = model.encoder(z)
[27]:
latent.shape
[27]:
torch.Size([1, 64, 119, 144])
[28]:
reconstruction = model.forward(z)
image_numpy_for_display = reconstruction.to('cpu').detach().numpy()
plt.imshow(image_numpy_for_display[0][0])
[28]:
<matplotlib.image.AxesImage at 0x1517c4a09d90>
../../_images/notebooks_tutorial_AutoEncoder_Example_29_1.png
[ ]: