{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "# Simulating a closed two-level quantum system\n",
    "\n",
    "This tutorial provides a step-by-step guide to simulating the dynamics of a simple two-level quantum system using the `qruise-toolset`. \n",
    "\n",
    "We'll begin by defining the system Hamiltonian, then demonstrate how to simulate pure states with the Schrödinger equation and mixed states using the von Neumann equation.\n",
    "\n",
    "## 1. Defining the Hamiltonian\n",
    "\n",
    "The two-level system (TLS) we are considering is described by the following Hamiltonian:\n",
    "\n",
    "$$\n",
    "H(t) = \\frac{\\omega}{2}\\sigma_z + c(t)\\sigma_x, \\tag{1}\n",
    "$$\n",
    "\n",
    "where $\\omega$ is the TLS resonance frequency, $\\sigma_z$ and $\\sigma_x$ are Pauli matrices, and $c(t)$ is the drive function. For this we'll use a Gaussian pulse given by: \n",
    "\n",
    "$$\n",
    "c(t; \\{a, \\sigma, \\mu \\}) =\n",
    "  \\frac{a}{\\sigma\\sqrt{2\\pi}} \\exp\\Big[-\\frac{(t - \\mu)^2}{2\\sigma^2}\\Big], \\tag{2}\n",
    "$$\n",
    "\n",
    "which is characterised by three parameters:\n",
    "\n",
    "- $a$, a scalar that adjusts the amplitude of the pulse\n",
    "\n",
    "- $\\sigma$, the pulse variance\n",
    "\n",
    "- $\\mu$, the total pulse duration.\n",
    "\n",
    "We can start by defining the stationary part of our Hamiltonian, $H_0$, i.e., the first term. For this, we can use `sigmaz` from the `QuTiP` library.\n",
    "\n",
    "---\n",
    "\n",
    "**Tip:** Make sure that your environment is correctly set to handle `float64` precision by setting `JAX_ENABLE_X64=True` or add the following codeblock to your scripts' preamble: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "jax.config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import qruise.toolset as qr\n",
    "import qutip as qt\n",
    "\n",
    "omega = 2.0  # TLS frequency\n",
    "\n",
    "# stationary Hamiltonian, H0\n",
    "H = qr.Hamiltonian(omega / 2 * qt.sigmaz())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we add the time-dependent term in our Hamiltonian, we need to define our Gaussian `drive` pulse."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "\n",
    "def drive(t, params):\n",
    "    \"\"\"Gaussian pulse\"\"\"\n",
    "\n",
    "    a = params[\"a\"]  # pulse amplitude\n",
    "    sigma = params[\"sigma\"]  # variance (controls pulse width)\n",
    "\n",
    "    factor = a / jnp.sqrt(2.0 * jnp.pi) / sigma\n",
    "    return factor * jnp.exp(-((t - tfinal / 2) ** 2) / (2.0 * sigma**2))\n",
    "    # Note: mu is set to tfinal/2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**Note:** The signature of the pulse (envelope), in our case the `drive` function, must strictly satisfy the `(t, params)` pattern, where `t` is `float` and `params` is `Dict` (dictionary). It's important to ensure that your functions respect this constraint.\n",
    "\n",
    "---\n",
    "\n",
    "We now need to define the time and pulse parameters for the simulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = 0.0  # start time of simulation\n",
    "tfinal = 1.0  # end time of simulation\n",
    "N = 200  # number of time points\n",
    "ts = jnp.linspace(t0, tfinal, N)  # time array\n",
    "\n",
    "# Parameters of the Gaussian pulse\n",
    "params = {\"a\": 2.14, \"sigma\": 0.8 * tfinal}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It’s a good idea to check your drive behaves as desired before proceeding, so let’s quickly plot it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canvas = qr.PlotUtil(\n",
    "    x_axis_label=\"t [a.u.]\", y_axis_label=\"Amplitude [a.u.]\", notebook=True\n",
    ")\n",
    "canvas.plot(ts, drive(ts, params), labels=[\"Pulse\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great, our drive looks as expected!\n",
    "\n",
    "Now, we simply add the drive term to our Hamiltonian."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "H.add_term(qt.sigmax(), drive)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Simulating pure states using the Schrödinger equation\n",
    "\n",
    "Pure states in a two-level quantum system can be simulated by solving the Schrödinger equation, which describes the unitary evolution of the system's wavefunction over time:\n",
    "\n",
    "$$\n",
    "i \\frac{\\partial}{\\partial t} \\psi(t) = H(t) \\psi(t). \\tag{3}\n",
    "$$\n",
    "\n",
    "To solve the Schrödinger equation, we need to specify the equations and parameters that govern the system. In the `qruise-toolset`, we combine these using a `Problem`, which is instantiated with:\n",
    "\n",
    "- the Hamiltonian (`H`)\n",
    "- the initial system state (`y0`, here the ground state)\n",
    "- the pulse parameters (`params`)\n",
    "- the time range of the simulation (`t0` to `tfinal`).\n",
    "\n",
    "--- \n",
    "**Note:** Generally speaking, a `Problem` is a set of differential or algebraic equations that describe the dynamics of a system. There are well-known equations that describe the dynamics of quantum systems, e.g. the Schrödinger equation for pure states, the von Neumann master equation for mixed states, or the Lindblad equation for open quantum systems with dissipation. The cornerstone of all these equations is the Hamiltonian.\n",
    "\n",
    "--- "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qutip import basis\n",
    "\n",
    "# define initial system state\n",
    "y0 = basis(2, 0)\n",
    "\n",
    "# define Problem\n",
    "prob = qr.Problem(H, y0, params, (t0, tfinal))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use Qruise's `ODESolver` to solve the Schrödinger equation (which is essentially an ordinary differential equation (ODE)), and calculate the wavefunction of the system at each timestamp. For this, we employ the adaptive Runge-Kutta 4(5) method, implemented in JAX through the `diffrax` library's `Tsit5` solver."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffrax import Tsit5\n",
    "\n",
    "# use ODEsolver to solve Schrödinger equation\n",
    "solver = qr.ODESolver(\n",
    "    Tsit5(),  # Runge-Kutta method\n",
    "    dt0=None,  # differential of time, set to None as we solve adaptively\n",
    "    saveat={\"ts\": ts},  # timestamps at which we want to save the solution\n",
    "    pid_args={\n",
    "        \"atol\": 1e-8,  # Absolute tolerance for the adaptive solver\n",
    "        \"rtol\": 1e-6,  # Relative tolerance for the adaptive solver\n",
    "    },\n",
    ")\n",
    "solver.set_system(prob.schroedinger())  # specify we want to solve Schrödinger equation\n",
    "_, res = solver.evolve(*prob.problem())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can view the results of our simulation by calculating and plotting the expectation value of the state populations against time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate the expectation value of each state\n",
    "pop = qr.get_population(res)\n",
    "\n",
    "# plot populations against time\n",
    "canvas.init_canvas(x_axis_label=\"t [a.u.]\", y_axis_label=\"Population\")\n",
    "canvas.plot(ts, pop, labels=[\"ground state\", \"1st excited state\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nice! You can see we get a (partial) population exchange between the ground and excited states of our TLS. \n",
    "\n",
    "## 3. Simulating mixed states using the von Neumann equation\n",
    "\n",
    "Mixed states in a two-level quantum system cannot be simulated by solving the Schrödinger equation. For this we need the von Neumann equation:\n",
    "\n",
    "$$\n",
    "\\frac{\\partial \\rho(t)}{\\partial t} = -i \\left[ H(t), \\rho(t) \\right] , \\tag{4}\n",
    "$$\n",
    "\n",
    "which describes the time evolution of the system's density matrix, $\\rho(t)$.\n",
    "\n",
    "We start by defining the initial system state, `rho0`, this time as a density matrix. We then need to redefine our `Problem` using `rho0` instead of `y0`. (We can leave all other variables as they are.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define initial system state (ground state)\n",
    "rho0 = jnp.asarray(qt.basis(2, 0).full())\n",
    "rho0 = rho0 @ rho0.T  # density matrix\n",
    "\n",
    "# define a new Problem from the older one\n",
    "# with different initial condition\n",
    "new_prob = prob.remake(y0=rho0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the `solver` we defined earlier, which employs the Qruise `ODESolver`. Then all we need to do is tell the `solver` to use the von Neumann equation instead of the Schrödinger equation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "solver.set_system(\n",
    "    new_prob.von_neumann()\n",
    ")  # specify we want to solve von Neumann equation\n",
    "_, rhos = solver.evolve(*new_prob.problem())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plotting the state populations will only give us information about the diagonal elements of the density matrix. This means we lose information on the coherence of the system, which is contained in the off-diagonal elements. For mixed states, it's therefore better to plot the expectation values of the Pauli matrices, $\\langle \\sigma_z \\rangle$ and $\\langle \\sigma_y \\rangle$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate expectation values of sigmaz and sigma y\n",
    "sigmaz_exp = qr.op_expectation(qt.sigmaz(), rhos)\n",
    "sigmay_exp = qr.op_expectation(qt.sigmay(), rhos)\n",
    "\n",
    "# plot expectation values against time\n",
    "canvas = qr.PlotUtil(\n",
    "    x_axis_label=\"Time [a.u.]\", y_axis_label=\"Operator expectation value\", notebook=True\n",
    ")\n",
    "canvas.plot(ts, sigmaz_exp, labels=[\"<sigma_z>\"])\n",
    "canvas.plot(ts, sigmay_exp, labels=[\"<sigma_y>\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! Now you know how to simulate a closed two-level system using both the Schrödinger equation and the von Neumann equation. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qruise-simple",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
