
In this section, you’ll learn:¶
- Utilizing UXarry to compute analysis increments, visualize increments in horizontal and vertical cross sections
Prerequisites¶
Concepts | Importance | Notes |
---|---|---|
Atmospheric Data Assimilation | Helpful |
Time to learn: 10 minutes
Import packages¶
%%time
# autoload external python modules if they changed
%load_ext autoreload
%autoreload 2
# add ../funcs to the current path
import sys, os
sys.path.append(os.path.join(os.getcwd(), ".."))
# import modules
import warnings
import math
import cartopy.crs as ccrs
import geoviews as gv
import geoviews.feature as gf
import holoviews as hv
import hvplot.xarray
from holoviews import opts
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import s3fs
import geopandas as gp
import numpy as np
import uxarray as ux
import xarray as xr
Configure visualization tools¶
hv.extension("bokeh")
# hv.extension("matplotlib")
# common border lines
coast_lines = gf.coastline(projection=ccrs.PlateCarree(), line_width=1, scale="50m")
state_lines = gf.states(projection=ccrs.PlateCarree(), line_width=1, line_color='gray', scale="50m")
Helper functions¶
The following functions are used for visualizing the data. The horizontal_contour
function generates the contour map for a given slice of data.
# Generates a contour plot for a horizontal slice
def horizontal_contour(ux_hslice, title, cmin=None, cmax=None, width=800, height=500, clevs=20, cmap="coolwarm", symmetric_cmap=False):
# Get min and max
amin = ux_hslice.min().item()
amax = ux_hslice.max().item()
cmin = math.floor(amin) if(cmin is None) else cmin
cmax = math.ceil(amax) if(cmax is None) else cmax
if symmetric_cmap: # get a symmetric color map when cmin < 0, cmax >0
cmax = max(abs(cmin), cmax)
cmin = -cmax
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
# Generate contour plot
title = f" min={amin:.1f} max={amax:.1f}"
contour_plot = hv.operation.contours(
ux_hslice.plot(),
levels=np.linspace(cmin, cmax, num=clevs), # levels=np.arange(cmin, cmax, 0.5)
filled=True
).opts(
line_color=None, # line_width=0.001
width=width, height=height,
cmap=cmap, clim=(cmin, cmax),
colorbar=True, show_legend=False,
tools=['hover'], title=title
)
return contour_plot
Retrieve/load MPAS/JEDI data¶
The example MPAS/JEDI data are stored at jetstream2. We need to retreive those data first.
There are two ways to retrieve MPAS data:
- Download all example data from JetStream2 to local and them load them locally. This approach allows downloading the data once per machine and reuse it in notebooks.
- Stream the JetStream2 S3 objects on demand. In this case, each notebook (including restarting a notebook) will retrieve the required data separately as needed.
# choose the data_load_method, check the above cell for details. Default to method 1, i.e. download once and reuse it in multiple notebooks
data_load_method = 2 # 1 or 2
Method 1: Download all example data once and reuse it in mulptile notebooks¶
%%time
local_dir="/tmp"
if data_load_method == 1 and not os.path.exists(local_dir + "/conus12km/bkg/mpasout.2024-05-06_01.00.00.nc"):
jetstream_url = 'https://js2.jetstream-cloud.org:8001/'
fs = s3fs.S3FileSystem(anon=True, asynchronous=False,client_kwargs=dict(endpoint_url=jetstream_url))
conus12_path = 's3://pythia/mpas/conus12km'
fs.get(conus12_path, local_dir, recursive=True)
print("Data downloading completed")
else:
print("Skip..., either data is available in local or data_load_method is NOT 1")
Skip..., either data is available in local or data_load_method is NOT 1
CPU times: user 442 μs, sys: 0 ns, total: 442 μs
Wall time: 447 μs
# Set file path
if data_load_method == 1:
grid_file = local_dir + "/conus12km/conus12km.invariant.nc_L60_GFS"
ana_file = local_dir + "/conus12km/bkg/mpasout.2024-05-06_01.00.00.nc"
bkg_file = local_dir + "/conus12km/ana/mpasout.2024-05-06_01.00.00.nc"
Method 2: Stream the JetStream2 S3 objects on demand¶
%%time
if data_load_method == 2:
jetstream_url = 'https://js2.jetstream-cloud.org:8001/'
fs = s3fs.S3FileSystem(anon=True, asynchronous=False,client_kwargs=dict(endpoint_url=jetstream_url))
conus12_path = 's3://pythia/mpas/conus12km'
grid_url=f"{conus12_path}/conus12km.invariant.nc_L60_GFS"
bkg_url=f"{conus12_path}/bkg/mpasout.2024-05-06_01.00.00.nc"
ana_url=f"{conus12_path}/ana/mpasout.2024-05-06_01.00.00.nc"
grid_file = fs.open(grid_url)
ana_file = fs.open(ana_url)
bkg_file = fs.open(bkg_url)
else:
print("Skip..., data_load_method is NOT 2")
CPU times: user 56.9 ms, sys: 16 ms, total: 72.8 ms
Wall time: 122 ms
Loading the data into UXarray datasets¶
We use the UXarray data structures for working with the data. This package supports data defined over unstructured grid and provides utilities for modifying and visualizing it. The available fucntionality are discussed in UxDataset
documentation.
uxds_a = ux.open_dataset(grid_file, ana_file)
uxds_b = ux.open_dataset(grid_file, bkg_file)
compute the analysis increments from the JEDI data assimilation¶
JEDI updates the background atmospheric state (uxds_b
) with observation innovations and gets a new atmospheric state called analysis (uxds_a
).
The difference of uxds_a
- uxds_b
is called “analysis increments”
var_name = "theta"
uxdiff0 = uxds_a[var_name] - uxds_b[var_name]
uxvar = uxdiff0
plot temperature analysis increments at different levels¶
## Utility to create ncl style color bars.
# def make_discrete_cmap(n_colors, base="coolwarm", *, center_white=False):
# if center_white and n_colors % 2 == 0:
# raise ValueError("center_white=True requires an odd n_colors")
# base_cmap = plt.get_cmap(base)
# positions = np.linspace(0, 1, n_colors)
# colours = base_cmap(positions)
# if center_white:
# colours[n_colors // 2] = (1.0, 1.0, 1.0, 1.0)
# suffix = "_cw" if center_white else ""
# return ListedColormap(colours, name=f"{base}_{n_colors}{suffix}")
# # levels = np.arange(-4.0, 4.0 + 0.5, 0.5)
# cmap = make_discrete_cmap(base="coolwarm", n_colors=16, center_white=False)
from matplotlib.colors import ListedColormap, BoundaryNorm, to_rgba
def make_interval_cmap(
edges,
colors=None,
*,
base_cmap="viridis",
bad="none",
under=None,
over=None
):
edges = np.asarray(edges, dtype=float)
if edges.ndim != 1 or len(edges) < 2:
raise ValueError("`edges` must be 1D with at least two values.")
if not np.all(np.diff(edges) > 0):
raise ValueError("`edges` must be strictly increasing.")
n_bins = len(edges) - 1
if colors is None:
# sample from a base cmap at bin midpoints
cm = plt.get_cmap(base_cmap)
mids = 0.5 * (edges[:-1] + edges[1:])
# normalize mids to 0..1 based on full range of edges
t = (mids - edges[0]) / (edges[-1] - edges[0])
cols = cm(t)
else:
if len(colors) != n_bins:
raise ValueError(f"`colors` must have length {n_bins}.")
cols = [to_rgba(c) for c in colors]
cmap = ListedColormap(cols, name="interval_cmap")
if bad != "none":
cmap.set_bad(bad)
if under is not None:
cmap.set_under(under)
if over is not None:
cmap.set_over(over)
norm = BoundaryNorm(edges, cmap.N, extend=("neither" if (under is None and over is None) else "both"))
return cmap, norm
edges = [-4, -3.5, -3, -2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1.0, 1.5, 2, 2.5 ,3, 3.5, 4]
# edges = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
colors = [
"green",
"blue",
"#313695", # deep navy
"#4575b4", # medium blue
"#74add1", # light blue
"#abd9e9", # pale blue
"#d0f0c0", # light green
"#ffffbf", # pale yellow
"white", # pure white (center)
"#fee090", # pale yellow (mirror)
"#f46d43", # orange-red
"#d73027", # red
"#a50026", # deep red
"#800026", # darker maroon
"#4d0013", # deep maroon
"black" # extreme end
]
cmap, norm = make_interval_cmap(edges, colors, bad="lightgray", under=None, over=None)
%%time
nt=0 # time dimension
plot_levels = [0, 19, 29, 39, 42, 49, 58] # [0, 29, 42] # [0, 19, 29, 39, 49, 58]
zero_shift = 0.0
plots = []
for lev in plot_levels:
tmp = horizontal_contour(
uxvar.isel(Time=nt, nVertLevels=lev),
title=f'lev={lev}',
symmetric_cmap=True,
cmap =cmap,
cmax=4,
cmin=4
#clevs=20,
) # for the whole domain
plots.append(tmp * coast_lines * state_lines)
for p in plots:
display(p)
# dat.where((dat > 0.1) | (dat < -0.1)),
Zoomed into Colorado using the subset capability¶
%%time
lon_center = -105.03
lat_center = 39.0
lon_incr = 5 # degree
lat_incr = 3 # degree
lon_bounds = (lon_center - lon_incr, lon_center + lon_incr)
lat_bounds = (lat_center - lat_incr, lat_center + lat_incr)
### subset to a small domain
uxdiff1 = uxdiff0.subset.bounding_box(lon_bounds, lat_bounds,)
uxvar = uxdiff1
nt=0 # time dimension
plot_levels = [0, 29, 42] # [0, 19, 29, 39, 49, 58]
plots = []
for lev in plot_levels:
tmp = horizontal_contour(uxvar.isel(Time=nt, nVertLevels=lev), title=f'lev={lev}', width=700, height=500) # for the subdomain
plots.append(tmp * coast_lines * state_lines .opts(xlim=(lon_bounds[0], lon_bounds[1]), ylim=(lat_bounds[0], lat_bounds[1])))
# plots share one toolbar, which facilitates doing sync'ed zoom-in/out
# hv.Layout(plots).cols(1)
# each plot has its own toolbar, which facilitates controlling each plot individually
for p in plots:
display(p)
Random Great Circle Arc (GCA)¶
# lat=43.3
# step_between_points = 100
start_point = (-110, 20)
end_point = (-70, 50)
var_name = "theta"
uxdiff0 = uxds_a[var_name].isel(Time=0) - uxds_b[var_name].isel(Time=0)
uxvar = uxdiff0
cross_section_gca = uxvar.cross_section(start=start_point, end=end_point, steps=100)
hlabelticks = [
f"{abs(lat):.1f}°{'N' if lat >= 0 else 'S'}\n{abs(lon):.1f}°{'E' if lon >= 0 else 'W'}"
for lat, lon in zip(cross_section_gca['lat'], cross_section_gca['lon'])
]
# cross_section_gca.isel(Time=0).transpose().plot.contourf()
%matplotlib inline
fig= plt.figure(figsize=(8,3))
gs= fig.add_gridspec(1,1)
ax = fig.add_subplot(gs[0,0])
cf=ax.contourf(cross_section_gca.transpose(),cmap='Reds',extend='both')
tick_stride = 10
ax.set_xticks(cross_section_gca['steps'][::tick_stride])
ax.set_xticklabels(hlabelticks[::tick_stride])
[Text(0, 0, '20.0°N\n110.0°W'),
Text(10, 0, '23.5°N\n107.1°W'),
Text(20, 0, '27.0°N\n104.1°W'),
Text(30, 0, '30.3°N\n100.9°W'),
Text(40, 0, '33.6°N\n97.4°W'),
Text(50, 0, '36.8°N\n93.7°W'),
Text(60, 0, '39.9°N\n89.7°W'),
Text(70, 0, '42.8°N\n85.3°W'),
Text(80, 0, '45.5°N\n80.5°W'),
Text(90, 0, '48.0°N\n75.2°W')]

lon=-83.3
cross_section_lon = uxvar.cross_section(lon=lon, steps=100)
hlabelticks = [
f"{abs(lat):.1f}°{'N' if lat >= 0 else 'S'}" for lat in cross_section_lon['lat']
]
%matplotlib inline
fig= plt.figure(figsize=(8,3))
gs= fig.add_gridspec(1,1)
ax = fig.add_subplot(gs[0,0])
cf=ax.contourf(cross_section_lon.transpose(),cmap='Reds',extend='both')
ax.set_xticks(cross_section_lon['steps'][::tick_stride])
ax.set_xticklabels(hlabelticks[::tick_stride])
[Text(0, 0, '90.0°S'),
Text(10, 0, '71.8°S'),
Text(20, 0, '53.6°S'),
Text(30, 0, '35.5°S'),
Text(40, 0, '17.3°S'),
Text(50, 0, '0.9°N'),
Text(60, 0, '19.1°N'),
Text(70, 0, '37.3°N'),
Text(80, 0, '55.5°N'),
Text(90, 0, '73.6°N')]

cross_section_lon = uxvar.cross_section(lon=-100., steps=100)
cross_section_lon
save plots to files¶
# hv.save(tmp, 'vslice.png')