{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Simulation & optimal control of a single spin qubit\n",
    "\n",
    "In this tutorial, we’ll show you, step-by-step, how to simulate the dynamics of a single spin qubit and perform optimal control using the `qruise-toolset`. \n",
    "\n",
    "The tutorial consists of the following:\n",
    "\n",
    "1. Defining the Hamiltonian\n",
    "2. Defining the time-dependent Hamiltonian with external drive\n",
    "3. Solving the Schrödinger equation\n",
    "4. Plotting the qubit dynamics\n",
    "5. Optimal control\n",
    "\n",
    "## 1. Defining the Hamiltonian\n",
    "\n",
    "The system we are interested in here is a single electron in a single quantum dot. With spin qubits,  spin-up corresponds to the ground state, i.e. $|0\\rangle = |\\uparrow \\rangle$, and spin-down corresponds to the excited state, i.e., $|1 \\rangle = |\\downarrow \\rangle$.\n",
    "\n",
    "As we only have one spin in this system, the only term we have in our stationary (time-independent) Hamiltonian, $H_0$ (also known as the drift term), corresponds to the Zeeman effect:\n",
    "\n",
    "$$\n",
    "H_0 =\\frac{\\gamma}{2} \\vec{B} \\cdot \\vec{\\sigma} ,\n",
    "$$\n",
    "\n",
    "where **$\\vec{B}$** is the external magnetic field vector, and **$\\vec{\\sigma} = (\\sigma_x,\\sigma_y,\\sigma_z)$** is the Pauli matrix vector. The gyromagnetic ratio, $\\gamma$, quantifies the response of an electron to a magnetic field and can be expressed by $\\gamma = \\frac{-g\\mu_B}{\\hbar}$, where $g$ is the $g$-factor ($\\sim$2 for a free electron), $\\mu_B$ is the Bohr magneton quantifying the magnetic moment of an electron, and $\\hbar$ is the reduced Planck’s constant. \n",
    "\n",
    "---\n",
    "\n",
    "**Note:** This Hamiltonian does not yet include an external drive. \n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "Let’s assume we only want to apply our external magnetic field in the $z$-direction, i.e., $B_x=B_y=0$. Our Hamiltonian then simplifies to\n",
    "\n",
    "$$\n",
    "H_0 = \\frac{\\gamma}{2} B_z \\sigma_z.\n",
    "$$\n",
    "\n",
    "Substituting in the eigenvalues of $\\sigma_z$ (±1), we see that the two spins states have an energy (Zeeman) splitting, $\\Delta E$, given by\n",
    "\n",
    "$$\n",
    "\\Delta E = \\frac{\\gamma}{2} B_z (1) - \\frac{\\gamma}{2} B_z (-1) = \\gamma B_z .\n",
    "$$\n",
    "\n",
    "The qubit frequency, $\\omega_q$, can then be extracted as $\\omega_q=\\gamma B_z$. This is also know as the Larmor frequency, which is the frequency at which the spin precesses around the magnetic field.\n",
    "\n",
    "Okay, now we understand our stationary Hamiltonian, we can start coding it. First, we need to define some initial qubit parameters: for an electron, $\\gamma = 2.8 \\, \\text{Hz} \\, \\text{T}^{-1}$, and we’ll set $B_z = 0.1 \\, \\text{T}$.\n",
    "\n",
    "---\n",
    "\n",
    "**Tip:** Make sure that your environment is correctly set to handle `float64` precision by setting `JAX_ENABLE_X64=True` or add the following codeblock to your scripts' preamble: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "jax.config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Spin qubit constants\n",
    "gamma = -2.8024951386169e10  # gyromagnetic ratio of an electron (Hz/T)\n",
    "Bz = 0.1  # magnetic field strength (T)\n",
    "freq = gamma * Bz  # Larmor or qubit frequency (Hz)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "Then we can define the stationary part of the Hamiltonian."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import Hamiltonian\n",
    "from qutip import sigmaz\n",
    "\n",
    "# define stationary Hamiltonian\n",
    "H = Hamiltonian(freq * sigmaz() / 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## 2. Defining the time-dependent Hamiltonian with external drive\n",
    "\n",
    "To construct our time-dependent Hamiltonian, $H$, we simply take our stationary Hamiltonian and add a term describing the effect of the external drive, $c(t)$. The time-dependent Hamiltonian is then given by\n",
    "\n",
    "$$\n",
    "H = \\frac{\\gamma}{2} B_z \\sigma_z + c(t)\\sigma_x .\n",
    "$$\n",
    "\n",
    "For this system, we’ll use a Gaussian pulse modulated by a cosine:\n",
    "\n",
    "$$\n",
    "c(t; \\{a,\\sigma,\\mu\\}) = \\frac{a}{\\sigma\\sqrt{2\\pi}} \\mathrm{exp}\\left[-\\frac{(t - \\mu)^2}{2\\sigma^2}\\right] \\cos(\\omega_{d}t),\n",
    "$$\n",
    "\n",
    "which is characterised by four parameters:\n",
    "\n",
    "(i) $a$, a scalar that adjusts the amplitude of the pulse\n",
    "\n",
    "(ii) $\\sigma$, the pulse variance\n",
    "\n",
    "(iii) $\\mu$, the total pulse duration\n",
    "\n",
    "(iv) $\\omega_{d}$, the local oscillator frequency, which is often set to $\\omega_q=\\gamma B_z$.\n",
    "\n",
    "We can use this equation to define our `drive` pulse."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "\n",
    "def drive(t, params):\n",
    "    \"\"\"Envelope: Gaussian pulse modulated by a cosine\"\"\"\n",
    "\n",
    "    amp = params[\"a\"]  # Gaussian pulse amplitude\n",
    "    sigma = params[\"sigma\"]  # Std dev of Gaussian pulse (controls pulse width)\n",
    "    lo_freq = params[\"omega\"]  # Carrier (local oscillator) frequency\n",
    "\n",
    "    factor = amp / jnp.sqrt(2 * jnp.pi) / sigma\n",
    "    gaussian = factor * jnp.exp(-((t - tfinal / 2) ** 2) / (2 * sigma**2))\n",
    "    # Note: mu is set to tfinal/2\n",
    "    return gaussian * jnp.cos(lo_freq * t)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "Now we need to define the time and pulse parameters for the simulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# time parameters for simulation\n",
    "t0 = 0.0  # initial time (s)\n",
    "tfinal = 20e-9  # final time (s)\n",
    "grid_size = 1000  # simulation grid size (number of time points)\n",
    "ts = jnp.linspace(t0, tfinal, grid_size)\n",
    "\n",
    "# drive pulse parameters\n",
    "params = {\"a\": 2.0, \"sigma\": 0.2 * tfinal, \"omega\": freq}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "It’s a good idea to check your drive behaves as desired before proceeding, so let’s quickly plot it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import PlotUtil\n",
    "\n",
    "canvas = PlotUtil(x_axis_label=\"t [ns]\", y_axis_label=\"Amplitude\", notebook=True)\n",
    "canvas.plot(ts / 1e-9, drive(ts, params), labels=[\"Pulse\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "Great, our drive looks as expected!\n",
    "\n",
    "Now, we simply add the drive term to our Hamiltonian."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qutip import sigmax\n",
    "\n",
    "# Add the drive term to the Hamiltonian\n",
    "H.add_term(sigmax(), drive)  # H(t) = H0 + sigma_x * drive"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14",
   "metadata": {},
   "source": [
    "## 3. Solving the Schrödinger equation\n",
    "\n",
    "Now we can solve the Schrödinger equation using a piecewise constant solver.\n",
    "\n",
    "We’ll start by specifying our initial qubit state, which we choose as the ground state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define initial qubit state\n",
    "y0 = jnp.array([1.0, 0.0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "We then define our `Problem`, which takes the Hamiltonian, the initial qubit state, and the time range as inputs. We then use Qruise’s `PWCSolver` to solve the `Problem` and calculate the wavefunction of the system at each timestamp (`n=grid_size`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import Problem, PWCSolver\n",
    "\n",
    "# define Problem\n",
    "prob = Problem(H, y0, params, (0.0, tfinal))\n",
    "\n",
    "# uses PWCSolver to solve Schrödinger equation\n",
    "solver = PWCSolver(n=grid_size, store=True)\n",
    "solver.set_system(prob.sepwc())\n",
    "_, res = solver.evolve(*prob.problem())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**Note:** Setting `store=True` in the `PWCSolver` ensures that the results are stored at each timestamp during the simulation.\n",
    "\n",
    "---\n",
    "\n",
    "## 4. Plotting the qubit dynamics\n",
    "\n",
    "Finally, we can view the results of our simulation by calculating and plotting the expectation value of the transmon state populations against time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import get_population\n",
    "\n",
    "# calculates the expectation value of each qubit state\n",
    "pop = get_population(res)\n",
    "\n",
    "# plots populations against time\n",
    "canvas.init_canvas(x_axis_label=\"t [ns]\", y_axis_label=\"Population\")\n",
    "canvas.plot(ts / 1e-9, pop, labels=[\"Ground state\", \"1st excited state\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "Congratulations! You’ve just simulated your first spin qubit. You’ll notice, however, that the population exchange between the ground and first excited states is not complete. We can improve this by performing quantum optimal control.\n",
    "\n",
    "## 5. Optimal control\n",
    "\n",
    "As we saw in the previous plot, the drive pulse we defined caused a certain degree of population exchange between the ground and excited states. However, if we want to perform an X gate on our qubit, we need to optimise the drive pulse parameters such that, at the end of the simulation, the population of the ground and first excited states is fully inverted.\n",
    "\n",
    "To perform the optimisation, we first need to define a loss function, $\\mathcal{L}$, to quantify how far we are from our desired target. We can use the infidelity, which is given by $1-\\mathcal{F}$, with $\\mathcal{F}$ being the fidelity. The loss function is then given by\n",
    "\n",
    "$$\n",
    "\\mathcal{L}=1-\\mathcal{F}=1− |\\langle \\psi(t=t_\\text{final})|\\psi_t \\rangle|^2,\n",
    "$$\n",
    "\n",
    "where $\\mathcal{F}$ is the gate fidelity, $|\\psi(t=t_\\text{final})\\rangle$ is the wavefunction of the system at the end time of the pulse, and $|\\psi_t\\rangle$ is the target (desired) wavefunction. The objective is to optimise the parameters of the drive pulse by minimising $\\mathcal{L}$.\n",
    "\n",
    "We can define the loss function as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(x, y):\n",
    "    \"\"\"\n",
    "    Returns the infidelity (1 - |<x|y>|^2)\n",
    "    of two wavefunctions x and y.\n",
    "    \"\"\"\n",
    "    o = jnp.matmul(\n",
    "        x.conj().T, y\n",
    "    )  # Calculates the inner product (overlap) of the two wavefunctions\n",
    "    return jnp.real(1.0 - o.conj() * o)  # Returns the real part of the infidelity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "We are now ready to define a *quantum optimal control problem* (`QOCProblem`). Similar to the definition of a `Problem` for simulating the dynamics of a quantum system, `QOCProblem` further defines the optimisation protocol. As inputs, it takes \n",
    "\n",
    "- the Hamiltonian, `H`\n",
    "- the initial qubit state, `y0`\n",
    "- the pulse parameters, `params`\n",
    "- the time interval of interest, `t0` to `tfinal`\n",
    "- the desired qubit state, `yt`, i.e., 1st excited state\n",
    "- the loss function, `loss`\n",
    "\n",
    "---\n",
    "\n",
    "**Note:** We defined the first four inputs earlier in this tutorial.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import QOCProblem\n",
    "\n",
    "yt = jnp.array([0.0, 1.0])  # define desired state\n",
    "\n",
    "opt_prob = QOCProblem(H, y0, params, (t0, tfinal), yt, loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24",
   "metadata": {},
   "source": [
    "The whole workflow now reduces to solving the optimal control problem using our initial guess for the pulse parameters. At each step, the system is simulated with the current parameters to obtain the final wavefunction and compute the infidelity with respect to the target state. Based on the gradient values, the parameter values are updated iteratively until they converge and minimise the loss function.\n",
    "\n",
    "The `Optimiser` class is provided for this specific task. As a user, you need to specify both the solver used to simulate the dynamics and the optimisation algorithm itself. In this case, we again use a piecewise constant solver (`PWCSolver`), together with the Broyden–Fletcher–Goldfarb–Shanno (BFGS) method from the `optimistix` library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25",
   "metadata": {},
   "outputs": [],
   "source": [
    "import optimistix as optx\n",
    "from qruise.toolset import OptimistixMinimiser, Optimiser\n",
    "\n",
    "minimiser = OptimistixMinimiser(optx.BFGS, rtol=1e-2, atol=1e-3)\n",
    "\n",
    "opt = Optimiser(minimiser, PWCSolver(n=grid_size, eq=prob.sepwc()))\n",
    "\n",
    "# Set the loss function for the optimiser to minimise (infidelity)\n",
    "opt.set_optimisation(opt_prob.loss)\n",
    "\n",
    "# Run optimisation on problem\n",
    "opt_params, opt_summary = opt.optimise(*opt_prob.problem())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**Note:** Since `qruise-simple` is fully written in JAX, we use other libraries within the same ecosystem, such as `diffrax` for solving differential equations and `optimistix` for minimisation problems.\n",
    "\n",
    "---\n",
    "\n",
    "We can now compare the initial values of the parameters with those the optimisation yielded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "from qruise.toolset import PlotUtil\n",
    "\n",
    "canvas = PlotUtil(x_axis_label=\"t [ns]\", y_axis_label=\"Amplitude\", notebook=True)\n",
    "\n",
    "# Print initial (unoptimised) and optimised parameters\n",
    "pprint(f\"Unoptimised {params}\")\n",
    "pprint(f\"Optimised {opt_params}\")\n",
    "\n",
    "# Plot initial (unoptimised) and optimised parameters\n",
    "canvas.plot(ts / 1e-9, drive(ts, params), labels=[\"Unoptimised\"])\n",
    "canvas.plot(ts / 1e-9, drive(ts, opt_params), labels=[\"Optimised\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28",
   "metadata": {},
   "source": [
    "We can see that the optimised pulse has increased amplitude and a wider envelope (greater duration) compared to the unoptimised pulse.\n",
    "\n",
    "Let’s now simulate and plot the state populations to see if our optimisation worked."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the simulation with optimised parameters\n",
    "_, opt_res = solver.evolve(*opt_prob.problem(params=opt_params)[:-1])\n",
    "\n",
    "# Get the population of the ground and excited states before optimisation\n",
    "pop = get_population(res)\n",
    "# Get the population of the ground and excited states after optimisation\n",
    "opt_pop = get_population(opt_res)\n",
    "\n",
    "# Plot unoptimised and optimised populations on same graph\n",
    "canvas.init_canvas(x_axis_label=\"t [ns]\", y_axis_label=\"Population\")\n",
    "canvas.plot(ts / 1e-9, pop, labels=[\"Unopt. ground state\", \"Unopt. 1st excited state\"])\n",
    "canvas.plot(ts / 1e-9, opt_pop, labels=[\"Opt. ground state\", \"Opt. 1st excited state\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30",
   "metadata": {},
   "source": [
    "We can see here that the population exchange between the ground and first excited states is almost perfect. The final populations are around 0.03% and 99.7%, respectively, compared to 30% and 70% before optimisation. Our optimisation was successful! \n",
    "\n",
    "We can quantify this success by printing the fidelity before and after optimisation. (Remember: we defined our loss function as the infidelity = 1 - fidelity)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31",
   "metadata": {},
   "outputs": [],
   "source": [
    "fidelity = [1 - loss(res[-1, :], yt), 1 - loss(opt_res[-1, :], yt)]\n",
    "pprint(f\"Unoptimised: {fidelity[0]*100:.1f} %\")\n",
    "pprint(f\"Optimised: {fidelity[1]*100:.1f} %\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32",
   "metadata": {},
   "source": [
    "We can clearly see that our optimisation has significantly enhanced the fidelity!\n",
    "\n",
    "Congratulations, you’ve just performed optimal control on a spin qubit! You can now use these methods to start optimising your qubit performance."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qruise-simple",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
