Simulation and robust optimal control of a single NV centre¶
In this tutorial, we’ll show you, step-by-step, how to simulate the dynamics of a single NV centre in the presence of inhomogeneities in detuning and microwave field power. We'll then show you how to perform robust optimal control using the qruise-toolset
.
The tutorial consists of the following:
- Defining the Hamiltonian and the drive
- Simulating an NV spin with a distribution of inhomogeneities
- Performing robust optimal control
Note: We encourage the reader to first refer to the tutorial on the Simulation of a single NV centre qubit before reading this tutorial.
1. Defining the Hamiltonian and the drive¶
The system we are interested in here is an NV centre. Its Hamiltonian, $H$, is given by:
$$ H = \underbrace{D S_z^2}_{\text{zero-field splitting}} + \underbrace{\gamma_e B_0 S_z}_{\text{Zeeman interaction}} + \underbrace{\gamma_e B_1(t) \left[\cos(\omega t + \phi) S_x + \sin(\omega t + \phi) S_y\right]}_{\text{driving field interaction}}, \tag{1} $$
where $S_{x,y,z}$ are the $S=1$ spin operators.
Please refer to sections 1 and 2 in our tutorial on Simulation of a single NV centre qubit for a detailed treatment of the theory and implementation of a NV qubit Hamiltonian.
After some simplification and taking into account the effect of noise, our final Hamiltonian is given by
$$ H = \delta \frac{\sigma_z}{2} + (1+ \beta)\Omega(t)\frac{\sigma_x}{2}, \tag{5} $$
where $\delta$ accounts for inhomogeneities in the static magnetic field ($B_0$) and hyperfine interactions between the NV centre and nearby nuclear spins, and $\beta$ represents the variation in Rabi rate due to inhomogeneities in the applied microwave field ($B_1$).
We can now start coding this Hamiltonian using qruise-toolset
.
Tip: Make sure that your environment is correctly set to handle float64
precision by setting JAX_ENABLE_X64=True
or add the following codeblock to your script's preamble:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
# define Gaussian drive
def drive(t, params):
a = params["a"]
beta = params["beta"]
sigma = params["sigma"]
factor = a / jnp.sqrt(2.0 * jnp.pi) / sigma
return beta * factor * jnp.exp(-((t - tfinal / 2) ** 2) / (2.0 * sigma**2))
# Note: mu is set to tfinal/2
# define detuning (Hz)
def detuning(t, params):
return params["delta"] * 1e6
# time parameters for simulation
t0 = 0.0 # initial time (s)
tfinal = 100e-9 # final time (s)
grid_size = 1000 # number of time points
ts = jnp.linspace(t0, tfinal, grid_size)
ensemble_params = {
"a": 3.229, # (Hz)
"sigma": 2.82e-08, # (s)
"delta": list(jnp.linspace(-40, 40, 21)), # (MHz)
"beta": list(jnp.linspace(0.8, 1.2, 21)),
}
Okay, now let's plot our drive (choosing an intermediate value for $\beta$) to see if it looks how we'd expect.
from qruise.toolset.plots import PlotUtil
canvas = PlotUtil(x_axis_label="t [s]", y_axis_label="Amplitude", notebook=True)
canvas.plot(
ts, drive(ts, {"a": 3.229, "sigma": 2.82e-8, "beta": 1.0}), labels=["Pulse"]
)
canvas.show_canvas()
2. Simulating an NV spin with a distribution of inhomogeneities¶
We’ll define the Hamiltonian using Qruise’s Hamiltonian
function, setting the first input (the stationary term) to None
and passing the drive term as the second input. We'll then use H.add_term
to include the effect of detuning variations.
from qruise.toolset import Hamiltonian
from qutip import sigmax, sigmaz
H = Hamiltonian(None, [(sigmax() / 2, drive)])
H.add_term(sigmaz() / 2, detuning)
To solve the Schrödinger equation, we need to specify the equations and parameters that govern the system. In the qruise-toolset
, these are usually combined using a Problem
object. Since we're working with an ensemble of values for $\delta$ and $\beta$, we’ll use EnsembleProblem
instead. To instantiate it, we need:
- the Hamiltonian (
H
) - the initial qubit state (
y0
, here the ground state) - the pulse parameters (
ensemble_params
) - the time interval of the simulation (
t0
totfinal
).
from qruise.toolset import EnsembleProblem
# define initial qubit state
y0 = jnp.array([1.0, 0.0], dtype=jnp.complex128)
# define ensemble problem
ens_prob = EnsembleProblem(H, y0, ensemble_params, (t0, tfinal))
The user can then define the type of solver they want to use and the equation they want to solve (for example, the Schrödinger, master, or Lindblad equation).
In this case, we’ll choose to use the piecewise constant solver (PWCSolver
) to solve the Schrödinger equation (sepwc
). We can then calculate the wavefunction of the system at each timestamp.
from qruise.toolset import PWCSolver
solver = PWCSolver(n=grid_size, store=True)
solver.set_system(ens_prob.sepwc())
_, unopt_res = solver.ensemble_evolve(*ens_prob.problem(), cartesian=True)
We can view the results of our simulation by calculating and plotting the expectation values of the NV centre state populations over time, averaged over $\delta$ and $\beta$.
from qruise.toolset.utils import get_population
from qruise.toolset.plots import PlotUtil
# calculate the expectation value of each qubit state for each value of delta and beta
unopt_pop = get_population(unopt_res)
# calculate the average population of each qubit state
average_unopt_pop = jnp.mean(unopt_pop, axis=(0, 1))
# plot populations against time
canvas = PlotUtil(
x_axis_label="t [s]", y_axis_label="Population [Unoptimised]", notebook=True
)
canvas.plot(ts, average_unopt_pop, labels=["Ground state", "Excited state"])
canvas.show_canvas()
Let's also plot a heatmap of the final population of the excited state ($|1\rangle$) as function of $\delta$ and $\beta$. We install and use the matplotlib
library for this purpose.
%pip install -q matplotlib
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. pid, fd = os.forkpty()
Note: you may need to restart the kernel to use updated packages.
import numpy as np
import matplotlib.pyplot as plt
# select data to plot
beta_vals = ensemble_params["beta"]
delta_vals = ensemble_params["delta"]
Z = unopt_pop[:, :, -1, 1] # first excited state at tfinal
# create figure and imshow plot
fig, ax = plt.subplots()
im = ax.imshow(
Z,
extent=(min(beta_vals), max(beta_vals), min(delta_vals), max(delta_vals)),
origin="lower",
cmap="viridis",
vmin=0,
vmax=1,
aspect="auto",
)
# add colour bar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Population")
# add labels and title
ax.set_xlabel(r"$\beta$")
ax.set_ylabel(r"$\delta$ [MHz]")
ax.set_title("Excited state population")
# Show plot
plt.show()
As you can see from the heatmap, the current pulse is only effective over a restricted range of $\delta$ and $\beta$ values. This might be acceptable in the absence of fluctuations, but real systems typically exhibit some degree of variation.
Now that we know how to simulate a spin in an NV centre with inhomogeneities, we can perform optimal control to enhance the fidelity of the operation across a broader range of $\delta$ and $\beta$ values.
3. Performing robust optimal control¶
To implement an X gate on our qubit, we need to optimise the drive pulse parameters such that, at the end of the simulation, the population of the ground and first excited states is fully inverted for all the values of $\delta$ and $\beta$, i.e. the entire population ends in the excited state.
We start by defining a loss function, $\mathcal{L}$, to quantify how far we are from our desired target. We can use the infidelity, which is given by $1-\mathcal{F}$, with $\mathcal{F}$ being the fidelity. The loss function is then given by
$$ \mathcal{L}=1-\mathcal{F}=1− |\langle \psi(t=t_\text{final})|\psi_t \rangle|^2, \tag{7} $$
where $|\psi(t=t_\text{final})\rangle$ is the wavefunction of the system at the end of the pulse duration, and $|\psi_t\rangle$ is the target (desired) wavefunction. The objective is to optimise the parameters of the drive pulse by minimising $\mathcal{L}$.
We can define the loss function as follows:
def loss(x, y):
"""
Returns the infidelity (1 - |<x|y>|^2)
of two wavefunctions x and y.
"""
o = jnp.matmul(
x.conj().T, y
) # Calculates the inner product (overlap) of the two wavefunctions
return jnp.real(1.0 - o.conj() * o) # Returns the real part of the infidelity
The whole workflow now reduces to simulating the dynamics using the initial guess values for the parameters we defined earlier. At the end of the simulation, we get the new wavefunction and calculate the infidelity with respect to the target wavefunction. We then calculate the gradients of the loss function for the initial value of the parameters. Based on the gradient values, we update the parameter values iteratively until they converge and minimise the loss function.
import numpy as np
# define ensemble loss function
def ensemble_loss(x):
# define parameters for optimisation
ensemble_params["a"] = jnp.array([x[0]])
ensemble_params["sigma"] = jnp.array([x[1]])
# define ensemble problem
ens = EnsembleProblem(H, y0, ensemble_params, (t0, tfinal))
solver = PWCSolver(n=grid_size, store=False)
solver.set_system(ens.sepwc())
_, res = solver.ensemble_evolve(*ens.problem(), cartesian=True)
# population = get_population(res)
err = 0.0
for i in range(res.shape[0]):
for j in range(res.shape[1]):
err += loss(yt, res[i, j])
return err
yt = jnp.array([0.0, 1.0], dtype=jnp.complex128) # define desired state
value_and_grad = jax.value_and_grad(ensemble_loss)
def scipy_loss_grad(x):
value, grad = value_and_grad(x)
return float(value), np.array(grad)
Now let's carry out the optimisation using scipy.optimize
from scipy.optimize import minimize
x0 = np.array([ensemble_params["a"], ensemble_params["sigma"]]) # initial guess
result = minimize(
fun=scipy_loss_grad,
x0=x0,
method="L-BFGS-B",
jac=True,
bounds=[(2.0, 4), (10e-9, tfinal)],
options={"disp": True, "maxiter": 1000, "gtol": 1e-4, "ftol": 1e-4},
)
optimised_params = result.x
/tmp/ipykernel_2856/1612681502.py:5: DeprecationWarning: scipy.optimize: The `disp` and `iprint` options of the L-BFGS-B solver are deprecated and will be removed in SciPy 1.18.0. result = minimize(
Let's plot the unoptimised and optimised pulses side-by-side to compare.
from qruise.toolset.plots import PlotUtil
canvas = PlotUtil(x_axis_label="t [s]", y_axis_label="Amplitude", notebook=True)
canvas.plot(
ts,
drive(ts, {"a": 3.229, "sigma": 2.82e-8, "beta": 1.0}),
labels=["Unopt. pulse"],
)
canvas.plot(
ts,
drive(ts, {"a": optimised_params[0], "sigma": optimised_params[1], "beta": 1.0}),
labels=["Opt. pulse"],
)
canvas.show_canvas()
We can see that the optimised pulse has a larger amplitude and smaller variance than the unoptimised pulse. Let’s now simulate and plot the state populations to see how the new pulse parameters affect them.
opt_ens_params = {
"a": optimised_params[0],
"sigma": optimised_params[1],
"delta": ensemble_params["delta"],
"beta": ensemble_params["beta"],
}
ens_prob = EnsembleProblem(H, y0, opt_ens_params, (t0, tfinal))
solver = PWCSolver(n=grid_size, store=True)
solver.set_system(ens_prob.sepwc())
_, opt_res = solver.ensemble_evolve(*ens_prob.problem(), cartesian=True)
# calculate the expectation value of each qubit state for each value of delta and beta
opt_pop = get_population(opt_res)
# calculate the average population of each qubit state
average_opt_pop = jnp.mean(opt_pop, axis=(0, 1))
# plots populations against time
canvas.init_canvas(x_axis_label="t [s]", y_axis_label="Population")
canvas.plot(
ts, average_unopt_pop, labels=["Unopt. ground state", "Unopt. 1st excited state"]
)
canvas.plot(ts, average_opt_pop, labels=["Opt. ground state", "Opt. 1st excited state"])
canvas.show_canvas()
We can clearly see that the state populations (averaged over all values of $\delta$ and $\beta$) is significantly better than before optimal control. We can quantify this success by printing the average fidelity before and after optimisation. (Remember: we defined our loss function as the infidelity = 1 - fidelity).
from pprint import pprint
fidelity = [
1 - loss(jnp.mean(unopt_res, axis=(0, 1))[-1, :], yt),
1 - loss(jnp.mean(opt_res, axis=(0, 1))[-1, :], yt),
]
pprint(f"Unoptimised: {fidelity[0]*100:.5f} %")
pprint(f"Optimised: {fidelity[1]*100:.5f} %")
'Unoptimised: 67.24266 %' 'Optimised: 90.47101 %'
We can see that our optimisation has significantly enhanced the fidelity!
We can also plot a heatmap of the final populations of the first excited state as function of $\delta$ and $\beta$.
Z1 = unopt_pop[:, :, -1, 1] # unoptimised population data
Z2 = opt_pop[:, :, -1, 1] # optimised population data
# create figure and axes
fig, axes = plt.subplots(1, 2, figsize=(12, 5), constrained_layout=True)
# plot unoptimised population using imshow
im1 = axes[0].imshow(
Z1,
extent=(min(beta_vals), max(beta_vals), min(delta_vals), max(delta_vals)),
origin="lower",
cmap="viridis",
vmin=0,
vmax=1,
aspect="auto",
)
cbar1 = fig.colorbar(im1, ax=axes[0])
axes[0].set_title("Unoptimised Population")
axes[0].set_xlabel(r"$\beta$")
axes[0].set_ylabel(r"$\delta$ [MHz]")
# plot optimised population using imshow
im2 = axes[1].imshow(
Z2,
extent=(min(beta_vals), max(beta_vals), min(delta_vals), max(delta_vals)),
origin="lower",
cmap="viridis",
vmin=0,
vmax=1,
aspect="auto",
)
cbar2 = fig.colorbar(im2, ax=axes[1])
axes[1].set_title("Optimised Population")
axes[1].set_xlabel(r"$\beta$")
axes[1].set_ylabel(r"$\delta$ [MHz]")
# Show plot
plt.show()
As you can clearly see from the heatmaps, the population of the first excited state is significantly higher and more uniform across the variation in $\delta$ and $\beta$, demonstrating that our control pulse is now much more robust.