Robust optimal control on a single spin¶
In this tutorial, we’ll show you, step-by-step, how to generate a robust Gaussian pulse on a single spin. We’ll begin by simulating the system, then perform optimal control without fluctuations, and finally implement robust optimal control that accounts for pulse amplitude and detuning fluctuations.
The tutorial consists of the following:
- Simulating a spin flip
- Performing optimal control
- Performing robust optimal control
Tip: Make sure that your environment is correctly set to handle float64
precision by setting JAX_ENABLE_X64=True
or add the following codeblock to your scripts' preamble:
# Enable float64 precision
import jax
jax.config.update("jax_enable_x64", True)
1. Simulating a spin flip¶
Since we're dealing with a single spin, our Hamiltonian, $H$, is given by
$$ H = \Omega(t) \sigma_x , $$
where $\Omega(t)$ is the time-dependent Rabi frequency, and $\sigma_x$ is the Pauli-$x$ matrix.
Although the Hamiltonian only contains a time-dependent term, the qruise-toolset
requires that we first define a stationary Hamiltonian. To satisfy this requirement, we'll use a zero matrix.
import qruise.toolset as qr
import qutip as qt
# stationary Hamiltonian, H0
H = qr.Hamiltonian(qt.sigmaz() * 0)
For the drive, we'll use a Gaussian pulse given by:
$$ \Omega(t; \{a,\sigma,\mu\}) = \frac{a}{\sigma\sqrt{2\pi}} \mathrm{exp}\left[-\frac{(t - \mu)^2}{2\sigma^2}\right], $$
which is characterised by three parameters:
$a$, a scalar that adjusts the amplitude of the pulse
$\sigma$, the pulse variance
$\mu$, the time at which it reaches its maximum amplitude
We can use this equation to define our drive pulse, which takes the time and pulse parameters as inputs.
import jax.numpy as jnp
def drive(t, params):
a = params["a"] # Gaussian pulse amplitude
sigma = params["sigma"] # pulse variance
factor = a / jnp.sqrt(2.0 * jnp.pi) / sigma
return factor * jnp.exp(-((t - tfinal / 2) ** 2) / (2.0 * sigma**2))
# Note: mu is set to tfinal/2
We can then add the time-dependent term to our Hamiltonian.
H.add_term(qt.sigmax(), drive)
Let's now define the time and pulse parameters and plot the drive to see how it looks.
t0 = 0.0 # start time of simulation
tfinal = 1.0 # end time of simulation
grid_size = 200 # number of time points
ts = jnp.linspace(t0, tfinal, grid_size) # time array
# Gaussian pulse parameters
params = {"a": 2.14, "sigma": 0.1 * tfinal}
canvas = qr.PlotUtil(
x_axis_label="t [a.u.]", y_axis_label="Amplitude [a.u.]", notebook=True
)
canvas.plot(ts, drive(ts, params), labels=["Pulse"])
canvas.show_canvas()
Great, our drive looks as expected.
Now to start simulating our gate operation. We'll initialise the spin in the ground state, with the aim of driving it to the excited state.
from qruise.toolset.utils import computational_basis
# define initial and target states
basis, labels = computational_basis([2])
y0 = basis[0] # initial state
yt = basis[1] # target state
To simulate the spin flip, we need to specify the equations and parameters that govern the system. In the qruise-toolset
, we combine these using a Problem
, which we instantiate with:
- the Hamiltonian (
H
) - the initial qubit state (
y0
, here the ground state) - the pulse parameters (
params
) - the time interval of the simulation (
t0
totfinal
).
The user can then define the type of solver they want to use and the equation they want to solve (for example, the Schrödinger equation, the master equation, or the Lindblad, equation).
In this case, we’ll choose to use the piecewise constant solver (PWCSolver
) to solve the Schrödinger equation (sepwc
). We can then calculate the wavefunction of the system at each timestamp (n=grid_size
).
from qruise.toolset.solvers.solvers import PWCSolver
# define Problem
prob = qr.Problem(H, y0, params, (t0, tfinal))
# use PWCSolver to solve Schrödinger equation
solver = PWCSolver(n=grid_size, store=True)
solver.set_system(prob.sepwc())
_, res = solver.evolve(*prob.problem())
We can then calculate and plot the state populations to see how they look.
from qruise.toolset.utils import get_population
# calculate expectation value of each qubit state
population = get_population(res)
# plot populations against time
canvas.init_canvas(x_axis_label="t [ns]", y_axis_label="Population")
canvas.plot(ts / 1e-9, population, labels=[str(label) for label in labels])
canvas.show_canvas()
We can see that while the populations do invert, they partially revert before settling, suggesting our drive might be too strong. We can perform quantum optimal control to try and enhance the performance.
2. Performing optimal control¶
We start by defining a loss function, $\mathcal{L}$, to quantify how far we are from our desired target. We can use the infidelity, which is given by $1-\mathcal{F}$, with $\mathcal{F}$ being the fidelity. The loss function is then given by
$$ \mathcal{L}=1-\mathcal{F}=1− |\langle \psi(t=t_\text{final})|\psi_t \rangle|^2, \tag{7} $$
where $\mathcal{F}$ is the gate fidelity, $|\psi(t=t_\text{final})\rangle$ is the wavefunction of the system at the end time of the pulse, and $|\psi_t\rangle$ is the target (desired) wavefunction. The objective is to optimise the parameters of the drive pulse by minimising $\mathcal{L}$.
We can define the loss function as follows:
def loss(x, y):
"""
Returns the infidelity (1 - |<x|y>|^2)
of two wavefunctions x and y.
"""
o = jnp.matmul(
x.conj().T, y
) # Calculates the inner product (overlap) of the two wavefunctions
return jnp.real(1.0 - o.conj() * o) # Returns the real part of the infidelity
We then define a quantum optimal control problem (QOCProblem
). Similar to the definition of a Problem
for simulating the dynamics of a quantum system, QOCProblem
further defines the optimisation protocol. As inputs, it takes:
- the Hamiltonian,
H
- the initial qubit state,
y0
- the pulse parameters,
params
- the time interval of interest,
t0
totfinal
- the desired qubit state,
yt
, i.e., 1st excited state - the loss function,
loss
from qruise.toolset import QOCProblem
opt_prob = QOCProblem(H, y0, params, (t0, tfinal), yt, loss)
The whole workflow now reduces to solving the optimal control problem using our initial guess for the pulse parameters. At each step, the system is simulated with the current parameters to obtain the final wavefunction and compute the infidelity with respect to the target state. Based on the gradient values, the parameter values are updated iteratively until they converge and minimise the loss function.
The Optimiser
class is provided for this specific task. As a user, you need to specify both the solver used to simulate the dynamics and the optimisation algorithm itself. In this case, we again use a piecewise constant solver (PWCSolver
), together with the Broyden–Fletcher–Goldfarb–Shanno (BFGS) algorithm provided by the OptimistixMinimiser
class from qruise-toolset
.
import optimistix as optx
from qruise.toolset import OptimistixMinimiser, Optimiser
# set up optimiser
minimiser = OptimistixMinimiser(optx.BFGS, rtol=1e-2, atol=1e-3)
opt = Optimiser(minimiser, PWCSolver(n=grid_size, eq=prob.sepwc()))
# Set the loss function for the optimiser to minimise (infidelity)
opt.set_optimisation(opt_prob.loss)
# run optimisation on problem
result, opt_summary = opt.optimise(*opt_prob.problem())
We can then plot the optimised Gaussian alongside the original to see how it's changed.
from pprint import pprint
canvas = qr.PlotUtil(x_axis_label="t [ns]", y_axis_label="Amplitude", notebook=True)
# Print initial (unoptimised) and optimised parameters
pprint(f"Unoptimised {params}")
pprint(f"Optimised {result}")
# Plot initial (unoptimised) and optimised parameters
canvas.plot(ts / 1e-9, drive(ts, params), labels=["Unoptimised"])
canvas.plot(ts / 1e-9, drive(ts, result), labels=["Optimised"])
canvas.show_canvas()
"Unoptimised {'a': 2.14, 'sigma': 0.1}" ("Optimised {'a': Array(1.57079877, dtype=float64), 'sigma': Array(0.1002266, " 'dtype=float64)}')
As we expected, the optimised pulse has a reduced amplitude (and very slightly larger variance) than the unoptimised. Let’s now simulate and plot the state populations to see how the new pulse parameters affect them.
# check the simulation with optimised parameters
_, res_opt = solver.evolve(*prob.problem(params=result))
# get the state populations after optimisation
opt_pop = get_population(res_opt)
# plot unoptimised and optimised populations on same graph
canvas.init_canvas(x_axis_label="t [ns]", y_axis_label="Population")
canvas.plot(
ts / 1e-9, population, labels=["Unopt. ground state", "Unopt. 1st excited state"]
)
canvas.plot(ts / 1e-9, opt_pop, labels=["Opt. ground state", "Opt. 1st excited state"])
canvas.show_canvas()
We can see that the population exchange between the ground and first excited states is now perfect - our optimisation was successful!
We can quantify this success by printing the fidelity before and after optimisation. (Remember: we defined our loss function as the infidelity = 1 - fidelity).
fidelity = [1 - loss(res[-1, :], yt), 1 - loss(res_opt[-1, :], yt)]
pprint(f"Unoptimised: {fidelity[0]*100:.10f} %")
pprint(f"Optimised: {fidelity[1]*100:.10f} %")
'Unoptimised: 70.9521676090 %' 'Optimised: 99.9999999998 %'
3. Performing robust optimal control¶
Okay, now that we know how to perform optimal control, we can start to investigate how robust our pulse is to fluctuations. Specifically, we'll model variations in detuning offset and drive pulse amplitude fluctuations. By considering a range of these parameters simultaneously, we can design pulses that maintain high fidelity under realistic experimental imperfections.
We can modify our Hamiltonian accordingly:
$$ H = \Delta \sigma_z + \Omega(t) \sigma_x , $$
where $\Delta$ models the detuning offset and $\Omega(t)$ includes a scaling parameter for pulse amplitude variations.
We can then define the new drive and the detuning offset as so:
def drive(t, params):
a = params["a"]
scale = params["scale"] # scaling factor for amplitude variations
sigma = params["sigma"]
factor = a / jnp.sqrt(2.0 * jnp.pi) / sigma
return scale * factor * jnp.exp(-((t - tfinal / 2) ** 2) / (2.0 * sigma**2))
def offset(t, params):
return params["offset"] # detuning offset
Then we can define our new Hamiltonian.
H = qr.Hamiltonian(qt.sigmaz() * 0)
H.add_term(qt.sigmax(), drive)
H.add_term(qt.sigmaz(), offset)
Now we can set up our parameters to optimise over. We'll use the optimised parameters for $a$ and $\sigma$ that we obtained in the last section, and we'll define a range for the amplitude scaling factor ($\Omega(t)$) and detuning offset ($\Delta$).
ensemble_params = {
"a": result["a"],
"sigma": result["sigma"],
"offset": list(jnp.linspace(-5, 5, 10)),
"scale": list(jnp.linspace(0.8, 1.2, 11)),
}
This time, we’ll use EnsembleProblem
instead of Problem
because we’re dealing with an ensemble of values for $\Delta$ and $\Omega(t)$, though the overall process remains largely the same.
from qruise.toolset import EnsembleProblem
# define EnsembleProblem
ens = EnsembleProblem(H, y0, ensemble_params, (t0, tfinal))
# solve EnsembleProblem using PWCSolver
solver = PWCSolver(n=grid_size, store=False)
solver.set_system(ens.sepwc())
_, res = solver.ensemble_evolve(*ens.problem(), cartesian=True)
# calculate expectation value of each qubit state for each value of Delta and Omega
population = get_population(res)
We can visualise the population for our initial parameters by plotting a heatmap. We install and use the matplotlib
library for this purpose.
%pip install -q matplotlib
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. pid, fd = os.forkpty()
Note: you may need to restart the kernel to use updated packages.
import matplotlib.pyplot as plt
f, ax = plt.subplots()
c = ax.imshow(
population[:, :, 1],
cmap="viridis",
aspect="auto",
extent=[
ensemble_params["scale"][0],
ensemble_params["scale"][-1],
ensemble_params["offset"][0],
ensemble_params["offset"][-1],
],
vmin=0,
vmax=1,
) # [scale_min, scale_max, offset_min, offset_max]
ax.set_xlabel("Amplitude scaling factor")
ax.set_ylabel("Detuning offset")
ax.set_title("Population")
plt.colorbar(c)
plt.show()
We can see that the pulse is already fairly robust to fluctuations in the drive amplitude and the detuning offset, but can we do better?
We can write a custom ensemble loss function by combining the losses of individual ensemble members. If needed, this approach could be parallelised to improve performance.
def ensemble_loss(x):
ensemble_params["a"] = jnp.array([x[0]])
ensemble_params["sigma"] = jnp.array([x[1]])
ens = EnsembleProblem(H, y0, ensemble_params, (t0, tfinal))
solver = PWCSolver(n=grid_size, store=False)
solver.set_system(ens.sepwc())
_, res = solver.ensemble_evolve(*ens.problem(), cartesian=True)
err = 0.0
for i in range(res.shape[0]):
for j in range(res.shape[1]):
err += loss(yt, res[i, j])
return err
We can then use SciPy to do the optimisation.
import numpy as np
from scipy.optimize import minimize
value_and_grad = jax.value_and_grad(ensemble_loss)
def scipy_loss_grad(x):
value, grad = value_and_grad(x)
return float(value), np.array(grad)
x0 = jnp.array([2.14, 0.1 * tfinal])
result = minimize(
fun=scipy_loss_grad,
x0=x0,
method="L-BFGS-B",
jac=True,
options={"disp": True, "maxiter": 15, "gtol": 1e-3, "ftol": 1e-3},
)
optimised_params = result.x
/tmp/ipykernel_3255/2385915702.py:14: DeprecationWarning: scipy.optimize: The `disp` and `iprint` options of the L-BFGS-B solver are deprecated and will be removed in SciPy 1.18.0. result = minimize(
Now let's again solve the EnsembleProblem
to see if our pulse is now more robust.
opt_ens_params = {
"a": optimised_params[0],
"sigma": optimised_params[1],
"offset": list(jnp.linspace(-5, 5, 10)),
"scale": list(jnp.linspace(0.8, 1.2, 11)),
}
# define EnsembleProblem
ens = EnsembleProblem(H, y0, opt_ens_params, (t0, tfinal))
# solve EnsembleProblem using PWCSolver
solver = PWCSolver(n=grid_size, store=False)
solver.set_system(ens.sepwc())
_, res = solver.ensemble_evolve(*ens.problem(), cartesian=True)
# calculate expectation value of each qubit state for each value of Delta and Omega
population = get_population(res)
We can visualise the effect of our optimisation by plotting another heatmap.
f, ax = plt.subplots()
c = ax.imshow(
population[:, :, 1],
cmap="viridis",
aspect="auto",
extent=[
opt_ens_params["scale"][0],
opt_ens_params["scale"][-1],
opt_ens_params["offset"][0],
opt_ens_params["offset"][-1],
],
vmin=0,
vmax=1,
) # [scale_min, scale_max, offset_min, offset_max]
ax.set_xlabel("Amplitude scaling factor")
ax.set_ylabel("Detuning offset")
ax.set_title("Population")
plt.colorbar(c)
plt.show()
Excellent! You can see that the excited state population stays close to 1 over a much wider range of detuning offsets and pulse amplitudes, demonstrating that our pulse is significantly more robust than before.