Model and Related Functions
First up is a callback for tracking loss history. This is helpful because we can use it to view how training is progressing. The specifics aren’t too important, but it can be altered to trigger at different frequencies and on different events (like an epoch end vs. a batch end).
class LossHistory(tf.keras.callbacks.Callback):
def __init__(self, frequency=30):
self.frequency = frequency
self.batch_counter = 0
self.i = 0
self.x = []
self.mae = []
self.mse = []
self.accuracy = []
self.logs = []
def on_train_begin(self, logs={}):
self.fig = plt.figure()
def on_batch_end(self, batch, logs={}):
# if batch % self.frequency == 0:
self.batch_counter += self.frequency
self.x.append(self.i)
self.mae.append(logs.get('mae'))
self.mse.append(logs.get('mse'))
self.accuracy.append(logs.get('accuracy'))
self.i += 1
if batch % self.frequency == 0:
self.logs.append(logs)
clear_output(wait=True)
print(f"MAE: {self.mae[-1]} \t\t MSE: {self.mse[-1]} \t\t Accuracy: {self.accuracy[-1]}")
plt.figure(figsize=(18,5))
plt.subplot(131)
plt.plot(self.x, self.mae, color='#ff6347',label="mae")
plt.plot(self.x[-1], self.mae[-1],marker = 'o', markersize=10, color='#ff6347')
plt.legend()
plt.xlabel(r'batch');
plt.ylabel('Mean Absolute Error');
plt.ylim([0.,100.])
plt.subplot(132)
plt.plot(self.x, self.mse, color='#6495ed')
plt.plot(self.x[-1], self.mse[-1],marker = 'o', markersize=10, color='#6495ed')
plt.xlabel('batch')
plt.ylabel(r'Mean Squared Error [$cm^2/s^2$]')
plt.ylim([0.,1000.])
plt.subplot(133)
plt.plot(self.x, self.accuracy, color='#3cb371')
plt.plot(self.x[-1], self.accuracy[-1],marker = 'o', markersize=10, color='#3cb371')
plt.xlabel('batch')
plt.ylabel('Model Accuracy')
plt.ylim([0.,1.])
plt.show()
We define a custom loss function to take the MAE of multidimensional objects. (This should be able to filter NaNs, thereby solving the coastline problem, but this doesn’t work!). Not sure if this gives different answers from a standard MAE in the case that no NaNs are present. When NaNs are present, this should give different results from using .fillna(0)
on the training data, because it won’t entrain zeros when the model takes a convolution. But does it actually work? ¯\_(ツ)_/¯
# Corner case: what happens when everything is NaN?
class Grid_MAE(tf.keras.losses.Loss):
def call(self, y_true, y_pred):
avg = tf.math.abs(y_true - y_pred)
masked = tf.where(tf.math.is_finite(avg), avg, tf.zeros_like(avg))
return tf.math.reduce_sum(masked)
The get_model
function generates a neural network based on Sinha and Abernathey (2021), but offers some parameters to enable a broader class of neural networks of similar form.
def get_model(halo_size, ds, sc, conv_dims, nfilters, conv_kernels, dense_layers):
conv_init = tf.keras.Input(shape=tuple(conv_dims) + (len(sc.conv_var),))
last_layer = conv_init
for kernel in conv_kernels:
this_layer = tf.keras.layers.Conv2D(nfilters, kernel)(last_layer)
last_layer = this_layer
nfilters = nfilters / 2.
halo_dims = [x - 2*halo_size for x in conv_dims]
input_init = tf.keras.Input(shape=tuple(halo_dims) + (len(sc.input_var),))
last_layer = tf.keras.layers.concatenate([last_layer, input_init])
last_layer = tf.keras.layers.LeakyReLU(alpha=0.3)(last_layer)
for layer in range(dense_layers):
this_layer = tf.keras.layers.Dense(nfilters, activation='relu')(last_layer)
last_layer = this_layer
nfilters = nfilters / 2.
output_layer = tf.keras.layers.Dense(len(sc.target))(last_layer)
model = tf.keras.Model(inputs=[conv_init, input_init], outputs=output_layer)
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(loss=Grid_MAE(), optimizer=opt, metrics=['mae', 'mse', 'accuracy'])
model.summary()
return model
def train(ds, sc, conv_dims=[3,3], nfilters=80, conv_kernels=[3], dense_layers=3):
pars = locals()
halo_size = int((np.sum(conv_kernels) - len(conv_kernels))/2)
nlons, nlats = conv_dims
# bgen = xb.BatchGenerator(
# ds,
# {'nlon':nlons, 'nlat':nlats},
# {'nlon':2*halo_size, 'nlat':2*halo_size},
# concat_input_dims=True
# )
latlen = len(ds['nlat'])
lonlen = len(ds['nlon'])
nlon_range = range(nlons,lonlen,nlons - 2*halo_size)
nlat_range = range(nlats,latlen,nlats - 2*halo_size)
batch = (
ds
.rolling({"nlat": nlats, "nlon": nlons})
.construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}]
.stack({"input_batch": ("nlat", "nlon")}, create_index=False)
.rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'})
.transpose('input_batch',...)
# .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons})
.dropna('input_batch')
)
rnds = list(range(len(batch['input_batch'])))
np.random.shuffle(rnds)
batch = batch[{'input_batch':(rnds)}]
def batch_generator(batch_set, batch_size):
n = 0
while n < len(batch_set['input_batch']) - batch_size:
yield batch_set.isel({'input_batch':range(n,(n+batch_size))})
n += batch_size
# We need this subsetting stencil to compensate for the fact that a halo is
# removed by each convolution layer. This means that the input_var variables
# will be the wrong size at the concat layer unless we strip a halo from them
sub = {'nlon_input':range(halo_size,nlons-halo_size),
'nlat_input':range(halo_size,nlats-halo_size)}
model = get_model(halo_size, **pars)
history = LossHistory()
bgen = batch_generator(batch, 4096)
for batch in bgen:
batch_conv = [batch[x] for x in sc.conv_var]
batch_input = [batch[x][sub] for x in sc.input_var]
batch_target = [batch[x][sub] for x in sc.target]
batch_conv = xr.merge(batch_conv).to_array('var').transpose(...,'var')
batch_input = xr.merge(batch_input).to_array('var').transpose(...,'var')
batch_target = xr.merge(batch_target).to_array('var').transpose(...,'var')
clear_output(wait=True)
model.fit([batch_conv, batch_input],
batch_target,
batch_size=32, verbose=0, #epochs=4,
callbacks=[history])
model.save('models/'+ sc.name)
np.savez('models/history_'+sc.name, losses=history.mae, mse=history.mse, accuracy=history.accuracy)
return model, history
def test(ds, sc, conv_dims=[3,3], conv_kernels=[3]):
halo_size = int((np.sum(conv_kernels) - len(conv_kernels))/2)
nlons, nlats = conv_dims
latlen = len(ds['nlat'])
lonlen = len(ds['nlon'])
nlon_range = range(nlons,lonlen,nlons - 2*halo_size)
nlat_range = range(nlats,latlen,nlats - 2*halo_size)
batch = (
ds
.rolling({"nlat": nlats, "nlon": nlons})
.construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}]
.stack({"input_batch": ("nlat", "nlon")}, create_index=False)
.rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'})
.transpose('input_batch',...)
# .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons})
.dropna('input_batch')
)
model = tf.keras.models.load_model('models/'+ sc.name, custom_objects={'Grid_MAE':Grid_MAE})
batch_conv = [batch[x] for x in sc.conv_var]
batch_input = [batch[x][sub] for x in sc.input_var]
batch_target = [batch[x][sub] for x in sc.target]
batch_conv = xr.merge(batch_conv ).to_array('var').transpose(...,'var')
batch_input = xr.merge(batch_input ).to_array('var').transpose(...,'var')
batch_target = xr.merge(batch_target).to_array('var').transpose(...,'var')
model.evaluate([batch_conv, batch_input], batch_target)
pass
def predict(ds, sc, conv_dims=[3,3], conv_kernels=[3]):
halo_size = int((np.sum(conv_kernels) - len(conv_kernels))/2)
nlons, nlats = conv_dims
latlen = len(ds['nlat'])
lonlen = len(ds['nlon'])
nlon_range = range(nlons,lonlen,nlons - 2*halo_size)
nlat_range = range(nlats,latlen,nlats - 2*halo_size)
batch = (
ds
.rolling({"nlat": nlats, "nlon": nlons})
.construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}]
.stack({"input_batch": ("nlat", "nlon")}, create_index=False)
.rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'})
.transpose('input_batch',...)
# .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons})
.dropna('input_batch')
)
model = tf.keras.models.load_model('models/'+ sc.name, custom_objects={'Grid_MAE':Grid_MAE})
batch_conv = [batch[x] for x in sc.conv_var]
batch_input = [batch[x][sub] for x in sc.input_var]
batch_conv = xr.merge(batch_conv ).to_array('var').transpose(...,'var')
batch_input = xr.merge(batch_input ).to_array('var').transpose(...,'var')
target = model.predict([batch_conv, batch_input])
return target