Full-Waveform Inversion of a portion of MarmousiΒΆ

Full-Waveform Inversion provides the potential to invert for a model that matches the whole wavefield, including refracted arrivals. It performs inversion using the regular propagator rather than the Born propagator.

We continue with using the Marmousi model, but due to the computational cost of running many iterations of FWI, we will work on only a portion of it:

import torch
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
import deepwave
from deepwave import scalar

device = torch.device('cuda' if torch.cuda.is_available()
                      else 'cpu')
ny = 2301
nx = 751
dx = 4.0
v_true = torch.from_file('marmousi_vp.bin',
                         size=ny*nx).reshape(ny, nx)

# Select portion of model for inversion
ny = 600
nx = 250
v_true = v_true[:ny, :nx]

We smooth the true model to create our initial guess of the wavespeed, which we will attempt to improve with inversion. We also load the data that we generated in the forward modelling example to serve as our target observed data:

v_init = (torch.tensor(1/gaussian_filter(1/v_true.numpy(), 40))
v = v_init.clone()

n_shots = 115

n_sources_per_shot = 1
d_source = 20  # 20 * 4m = 80m
first_source = 10  # 10 * 4m = 40m
source_depth = 2  # 2 * 4m = 8m

n_receivers_per_shot = 384
d_receiver = 6  # 6 * 4m = 24m
first_receiver = 0  # 0 * 4m = 0m
receiver_depth = 2  # 2 * 4m = 8m

freq = 25
nt = 750
dt = 0.004
peak_time = 1.5 / freq

observed_data = (
    .reshape(n_shots, n_receivers_per_shot, nt)

As our model is now smaller, we also need to extract only the portion of the observed data that covers this section of the model:

n_shots = 20
n_receivers_per_shot = 100
nt = 300
observed_data = (
    observed_data[:n_shots, :n_receivers_per_shot, :nt].to(device)

We set-up the sources and receivers as before:

# source_locations
source_locations = torch.zeros(n_shots, n_sources_per_shot, 2,
                               dtype=torch.long, device=device)
source_locations[..., 1] = source_depth
source_locations[:, 0, 0] = (torch.arange(n_shots) * d_source +

# receiver_locations
receiver_locations = torch.zeros(n_shots, n_receivers_per_shot, 2,
                                 dtype=torch.long, device=device)
receiver_locations[..., 1] = receiver_depth
receiver_locations[:, :, 0] = (
    (torch.arange(n_receivers_per_shot) * d_receiver +
    .repeat(n_shots, 1)

# source_amplitudes
source_amplitudes = (
    (deepwave.wavelets.ricker(freq, nt, dt, peak_time))
    .repeat(n_shots, n_sources_per_shot, 1).to(device)

We are now ready to run the optimiser to perform iterative inversion of the wavespeed model. We apply a scaling (1e10) to boost the gradient values to a range that will help us to make good progress with each iteration, but also apply a clipping to the gradients (to the 98th percentile of their magnitude) to avoid making very large changes at a small number of points (such as around the sources):

# Setup optimiser to perform inversion
optimiser = torch.optim.SGD([v], lr=0.1, momentum=0.9)
loss_fn = torch.nn.MSELoss()

# Run optimisation/inversion
n_epochs = 250
v_true = v_true.to(device)

for epoch in range(n_epochs):
    def closure():
        out = scalar(
            v, dx, dt,
        loss = 1e10 * loss_fn(out[-1], observed_data)
            torch.quantile(v.grad.detach().abs(), 0.98)
        return loss


The result is quite a good improvement in the accuracy of our estimate of the wavespeed model.


This is a simple implementation of FWI. Faster convergence and greater robustness in more realistic situations can be achieved with modifications such as a more sophisticated loss function. You can find some slightly more sophisticated setups in some of the other examples, and Deepwave makes it easy for you to come-up with your own. As PyTorch will automatically backpropagate through any differentiable operations that you apply to the output of Deepwave, you only have to specify the forward action of such loss functions and can then let PyTorch automatically handle the backpropagation.

Full example code