{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simulating a Mølmer-Sørensen gate\n",
    "\n",
    "This tutorial shows you step by step how to simulate a Mølmer-Sørensen (MS) gate in a trapped-ion system using the `qruise-toolset`. The MS gate creates maximally entangled states by coupling two qubits through a shared motional mode. We'll first simulate an ideal gate using a rectangular pulse, and then include a finite rise time to model a more realistic, imperfect gate.\n",
    "\n",
    "The tutorial consists of the following:\n",
    "\n",
    "1. Defining the Hamiltonian (theory)\n",
    "2. Defining the external drive\n",
    "3. Defining the Hamiltonian\n",
    "4. Solving the Schrödinger equation\n",
    "5. Adding Gaussian filter (imperfect gate)\n",
    "\n",
    "## 1. Defining the Hamiltonian (theory)\n",
    "The system we're interested in consists of two trapped ion qubits coupled through a shared motional mode. Its Hamiltonian, $\\hat{H}$, in the interaction frame, within the Lamb-Dicke limit, is given by\n",
    "\n",
    "$$\n",
    "\\hat{H}=-2 \\hbar \\eta \\Omega \\hat{J}_y\\left(\\hat{a} e^{i \\delta t} + \\hat{a}^\\dagger e^{-i \\delta t}\\right) .\n",
    "$$\n",
    "\n",
    "Here, $\\eta$ is the Lamb-Dicke parameter, $\\Omega$ is the Rabi frequency associated with the laser-ion interaction, $\\delta$ is the detuning from the sideband transition frequency, and $\\hat{a}$ and $\\hat{a}^{\\dagger}$ are the annihilation and creation operators of the shared motional mode, respectively. $\\hat{J}_y$ is the global spin-$y$ operator in the two-qubit subspace and is given by\n",
    "\n",
    "$$ \\hat{J}_y=\\frac{1}{2}(\\hat{I} \\otimes \\hat{\\sigma}_y+\\hat{\\sigma}_y \\otimes \\hat{I}) ,$$\n",
    "\n",
    "where $\\hat{I}$ is the identity operator and $\\hat{\\sigma}_y$ is the Pauli $y$-matrix acting on the qubit. You can read more about where this Hamiltonian comes from <a href=\"https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.121.180502\" target=\"_blank\">here</a>.\n",
    "\n",
    "---\n",
    "\n",
    "**Note:** Within the Lamb-Dicke limit, the spatial extent of the ion's motion is much smaller than the wavelength of the driving laser field and $\\eta \\ll 1$.\n",
    "\n",
    "---\n",
    "\n",
    "**Tip:** Make sure that your environment is correctly set to handle `float64` precision by setting `JAX_ENABLE_X64=True` or add the following codeblock to your scripts' preamble:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "jax.config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-----"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Defining the external drive\n",
    "\n",
    "For our external drive, we'll use a piecewise constant (PWC) pulse. This type of pulse is a good choice for this exercise as it allows us to independently control the amplitude at each time step, rather than being constrained by a predefined function. To start with, we'll use a rectangular pulse with a finite number of time steps, $N$, of duration $dt = \\frac{t_{\\text{final}}}{N}$, where $t_\\text{final}$ is the total duration of our pulse.\n",
    "\n",
    "Let's begin by defining the relevant system parameters. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "# Hamiltonian parameters\n",
    "Omega = 2 * jnp.pi * 8e5  # Rabi frequency (Hz)\n",
    "eta = 0.1  # Lamb-Dicke parameter (Note: eta<<1 as required for the Lamb-Dicke limit)\n",
    "delta = 4 * eta * Omega  # detuning from sideband transition frequency (Hz)\n",
    "\n",
    "# pulse parameters\n",
    "amp = 1.0  # pulse amplitude, envelope for the drive\n",
    "t0 = 0.0  # start time of simulation (s)\n",
    "tfinal = jnp.pi / (2 * eta * Omega)  # pulse length (s)\n",
    "N = 1000  # number of steps\n",
    "ts = jnp.linspace(0.0, tfinal, N)  # time values\n",
    "dt = float(ts[1] - ts[0])  # time step, dt (s)\n",
    "\n",
    "# define dictionary of pulse parameters\n",
    "parameters_rect = {\n",
    "    \"amp\": amp,  # pulse amplitude\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to define a rectangular shape for the pulse envelope."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define rectangular pulse\n",
    "def rect_pulse(t, params):\n",
    "    \"\"\"Rectangular pulse. Returns amplitude at every time step.\"\"\"\n",
    "    amp = params[\"amp\"]\n",
    "    return jnp.ones_like(t) * amp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll use Qruise's `PWCPulse` class to create the drive pulse. For this, we need to define an amplitude array for the envelope using `rect_pulse()`. Then we can define a parameter dictionary for the drive pulse, consisting of the start time, the time step, and the envelope, and use it to create our pulse."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import qruise.toolset as qr\n",
    "\n",
    "# define pulse amplitude array using rect_pulse\n",
    "amplitude_array = jnp.asarray(rect_pulse(ts, parameters_rect))\n",
    "\n",
    "# define parameters for piecewise constant pulse\n",
    "pwc_params = {\n",
    "    \"t0\": (t0, False),  # start time (s)\n",
    "    \"dt\": (dt, False),  # time step (s)\n",
    "    \"env\": (amplitude_array, True),  # pulse envelope (rectangle)\n",
    "}\n",
    "\n",
    "# create piecewise constant pulse using pwc_params\n",
    "rabi_drive_pwc = qr.PWCPulse(\"pwc\", pwc_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can merge the parameter dictionaries to make them easier to work with. In this way, instead of calling `parameters_rect` for the envelope and `rabi_drive_pwc.params` for the PWC pulse, we can simply call `params` and the correct parameters are extracted automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge parameter dictionaries\n",
    "params = parameters_rect | rabi_drive_pwc.params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's quickly plot the pulse to see if it looks how we'd expect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canvas = qr.PlotUtil(\n",
    "    x_axis_label=\"t [µs]\", y_axis_label=\"Amplitude [a.u.]\", notebook=True\n",
    ")\n",
    "canvas.plot(ts / 1e-6, rabi_drive_pwc(ts, params), labels=[\"Drive (PWC) pulse\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great, we have a pulse envelope with amplitude equal to $1.0$ starting at $t=0.0$ ns and ending at $t=t_\\text{final} = 2.5~\\mu \\text{s}$.\n",
    "\n",
    "## 3. Defining the Hamiltonian\n",
    "\n",
    "We can now start coding our Hamiltonian. First we'll define the time-dependent envelopes that appear in it using our PWC pulse. Each envelope sets the time dependence for one of the terms in the Hamiltonian, and includes the prefactor $-\\eta \\Omega$ and the oscillating term $e^{\\pm i \\delta t}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define envelopes\n",
    "def env1(t, params):\n",
    "    return rabi_drive_pwc(t, params) * (-eta * Omega * jnp.exp(1j * delta * t))\n",
    "\n",
    "\n",
    "def env2(t, params):\n",
    "    return rabi_drive_pwc(t, params) * (-eta * Omega * jnp.exp(-1j * delta * t))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's quickly plot these to see how they look."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canvas = qr.PlotUtil(\n",
    "    x_axis_label=\"t [µs]\", y_axis_label=\"Amplitude [a.u.]\", notebook=True\n",
    ")\n",
    "canvas.plot(ts / 1e-6, jnp.imag(env1(ts, params)), labels=[\"Envelope 1\"])\n",
    "canvas.plot(ts / 1e-6, jnp.imag(env2(ts, params)), labels=[\"Envelope 2\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! We can now define the operators needed to construct the Hamiltonian. To do this, we'll first specify the Hilbert space dimensions of the system. For the qubits, this is 2 since they're two-level systems. The motional mode, being a harmonic oscillator, has an infinite-dimensional space, but for computational purposes we'll truncate it to 8 Fock states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import qutip as qt\n",
    "\n",
    "# Hilbert dimensions\n",
    "q_dim = 2  # 2-level system\n",
    "mot_dim = 8  # truncate to 8 Fock states\n",
    "\n",
    "# qubit operators\n",
    "sigma_y = qt.sigmay()  # sigma_y\n",
    "Id = qt.qeye(q_dim)  # identity\n",
    "Jy = qt.tensor(Id, sigma_y) + qt.tensor(sigma_y, Id)  # J_y, global spin-y operator\n",
    "\n",
    "# motional mode operators\n",
    "a = qt.destroy(mot_dim)  # a\n",
    "a_dag = qt.create(mot_dim)  # a-dagger\n",
    "\n",
    "# spin-motion interaction operators\n",
    "Ja = qt.tensor(Jy, a)\n",
    "Ja_dag = qt.tensor(Jy, a_dag)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can combine our drive envelopes and operators to create our time-dependent Hamiltonian."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create time-dependent Hamiltonian\n",
    "H = qr.Hamiltonian(None, [(Ja, env1), (Ja_dag, env2)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Solving the Schrödinger equation\n",
    "\n",
    "To solve the Schrödinger equation, we need to specify the equations and parameters that govern the system. In the `qruise-toolset`, we combine these using a `Problem`. To instantiate the `Problem`, we need:\n",
    "\n",
    "- the Hamiltonian (`H`)\n",
    "- the initial qubit state (`y0`)\n",
    "- the pulse parameters (`params`)\n",
    "- the time interval of the simulation (`t0` to `tfinal`).\n",
    "\n",
    "Before we can define our initial qubit state, we need to define and label our basis states. Since we have 2 two-level systems (qubits) and a motional mode truncated to 8 levels, the total system has $ 2^2 \\times 8 = 32 $ basis states:\n",
    "\n",
    "- Qubits: $ |00\\rangle, |01\\rangle, |10\\rangle, |11\\rangle $\n",
    "- Motional mode: $ |0\\rangle, |1\\rangle, |2\\rangle, |3\\rangle, |4\\rangle, |5\\rangle, |6\\rangle, |7\\rangle $\n",
    "\n",
    "Thus, the full basis states of the system are:\n",
    "\n",
    "$$\n",
    "|000\\rangle, |001\\rangle, |002\\rangle, |003\\rangle, |004\\rangle, |005\\rangle, |006\\rangle, |007\\rangle,  \\\\\n",
    "|010\\rangle, |011\\rangle, |012\\rangle, |013\\rangle, |014\\rangle, |015\\rangle, |016\\rangle, |017\\rangle,  \\\\\n",
    "|100\\rangle, |101\\rangle, |102\\rangle, |103\\rangle, |104\\rangle, |105\\rangle, |106\\rangle, |107\\rangle,  \\\\\n",
    "|110\\rangle, |111\\rangle, |112\\rangle, |113\\rangle, |114\\rangle, |115\\rangle, |116\\rangle, |117\\rangle .\n",
    "$$\n",
    "\n",
    "We'll start with both qubits and the motional mode in the ground states, so our initial state, `y0`, is $ |000\\rangle $. We can then define our `Problem`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define basis states and label\n",
    "basis, labels = qr.utils.computational_basis(\n",
    "    [q_dim, q_dim, mot_dim]\n",
    ")  # Q1, Q2, motional mode\n",
    "\n",
    "# define states\n",
    "state_000 = basis[labels.index((0, 0, 0))]  # state |000>\n",
    "state_110 = basis[labels.index((1, 1, 0))]  # state |110>\n",
    "state_010 = basis[labels.index((0, 1, 0))]  # state |010>\n",
    "state_100 = basis[labels.index((1, 0, 0))]  # state |100>\n",
    "\n",
    "# define initial qubit and motional mode state (ground state)\n",
    "y0 = state_000  # initial state\n",
    "\n",
    "# define Problem\n",
    "problem = qr.Problem(H, y0, params, (t0, tfinal))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the `qruise-toolset`, you can define the type of solver you want to use and the equation you want to be solved &mdash; for example, the Schrödinger equation, the master equation, or the Lindblad equation. \n",
    "\n",
    "In this case, we’ll choose to use the piecewise constant solver (`PWCSolver`) to solve the Schrödinger equation (`sepwc`). We can then calculate the wavefunction of the system at each timestamp."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use PWCSolver to solve Schrödinger equation\n",
    "solver = qr.PWCSolver(n=N, store=True)\n",
    "solver.set_system(problem.sepwc())\n",
    "time, res = solver.evolve(*problem.problem())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**Note:** Setting `store=True` in the `PWCSolver` ensures that the results are stored at each timestamp during the simulation.\n",
    "\n",
    "---\n",
    "\n",
    "We can then calculate the expectation value of the basis state populations at each timestamp."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset.utils import get_population\n",
    "\n",
    "# calculate expectation value of each state\n",
    "pop = get_population(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To keep the plot clear and readable, we’ll exclude basis states whose population remains zero for the entire gate duration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Find states that are zero across all time points\n",
    "zero_mask = np.all(pop == 0, axis=0)\n",
    "\n",
    "# keep only non-zero states\n",
    "nonzero_mask = ~zero_mask\n",
    "filtered_pop = pop[:, nonzero_mask]\n",
    "filtered_labels = [str(label) for i, label in enumerate(labels) if not zero_mask[i]]\n",
    "\n",
    "# plot non-zero states\n",
    "canvas.init_canvas(x_axis_label=\"t [µs]\", y_axis_label=\"Population\")\n",
    "canvas.plot(ts / 1e-6, filtered_pop, labels=filtered_labels)\n",
    "canvas.canvas.legend.spacing = -4\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, the gate transforms the initial state into an equal superposition of $|000\\rangle$ and $|110\\rangle$, each with 50% population, while all other states remain unpopulated.\n",
    "\n",
    "To check our gate is perfect, we can calculate the state fidelity, $\\mathcal{F}$, which is given by\n",
    "\n",
    "$$\n",
    "\\mathcal{F}=|\\langle \\psi(t=t_\\text{final})|\\psi_t \\rangle|^2,\n",
    "$$\n",
    "\n",
    "where $|\\psi(t=t_\\text{final})\\rangle$ is the wavefunction of the system at the end time of the pulse, and $|\\psi_t\\rangle$ is the target (desired) wavefunction."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We just need to define the desired final state, and then use it in the `fidelity()` function. To do so, we will consider all initial states and their respective final state and average the result, based on the MS gate truth table: \n",
    "\n",
    "$$\n",
    "\\ket{gg} \\rightarrow (\\ket{gg} + i \\ket{ee})/ \\sqrt{2} \\\\\n",
    "\\ket{eg} \\rightarrow (\\ket{eg} - i \\ket{ge})/ \\sqrt{2} \\\\\n",
    "\\ket{ge} \\rightarrow (\\ket{ge} - i \\ket{eg})/ \\sqrt{2} \\\\\n",
    "\\ket{ee} \\rightarrow (\\ket{ee} + i \\ket{gg})/ \\sqrt{2} \\\\\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "\n",
    "initial_states = [state_000, state_100, state_010, state_110]  # initial states\n",
    "\n",
    "# final final states\n",
    "yf_000 = (state_000 + 1j * state_110) / jnp.sqrt(2)\n",
    "yf_100 = (state_100 - 1j * state_010) / jnp.sqrt(2)\n",
    "yf_010 = (state_010 - 1j * state_100) / jnp.sqrt(2)\n",
    "yf_110 = (state_110 + 1j * state_000) / jnp.sqrt(2)\n",
    "\n",
    "final_states = [yf_000, yf_100, yf_010, yf_110]  # final states\n",
    "\n",
    "\n",
    "def get_fidelity_over_sates(H):\n",
    "    \"\"\"Calculates the average fidelity of the final states over the initial states.\"\"\"\n",
    "    states_fidelity = []  # list to store fidelity results\n",
    "\n",
    "    # define fidelity function\n",
    "    def fidelity(x, y):\n",
    "        \"\"\"Returns the fidelity (|<x|y>|^2) of two wavefunctions x and y.\"\"\"\n",
    "        o = jnp.matmul(\n",
    "            x.conj().T, y\n",
    "        )  # Calculates the inner product (overlap) of the two wavefunctions\n",
    "        return jnp.real(o.conj() * o)  # Returns the real part of the fidelity\n",
    "\n",
    "    for i, state in enumerate(initial_states):\n",
    "        # define the problem for the different initial states and solve them\n",
    "        problem = qr.Problem(H, state, params, (t0, tfinal))\n",
    "        solver = qr.PWCSolver(n=N, store=True)\n",
    "        solver.set_system(problem.sepwc())\n",
    "        time, res = solver.evolve(*problem.problem())\n",
    "\n",
    "        state_fidelity = fidelity(res[-1, :], final_states[i])\n",
    "        states_fidelity.append(state_fidelity)  # append fidelity result\n",
    "\n",
    "    return sum(states_fidelity) / len(\n",
    "        states_fidelity\n",
    "    )  # return average fidelity of all states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "states_fidelity_average = get_fidelity_over_sates(H)\n",
    "pprint(f\"Ideal gate: {states_fidelity_average*100:.3f} %\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can simulate four consecutive identical pulses to illustrate the full coherent evolution of the system, showing how the population dynamics evolve through entanglement and eventually return to the initial state.\n",
    "\n",
    "We can do this simply by updating the pulse parameters and the `Problem` and again solving it using the `PWCSolver`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "sequence_time = 4 * tfinal  # 4 pulses\n",
    "\n",
    "# define time and amplitude arrays (4N-3 time steps ensures dt remains the same)\n",
    "N_seq = 4 * N - 3  # number of time steps for full sequence\n",
    "sequence_ts = jnp.linspace(0.0, sequence_time, N_seq)  # time array\n",
    "amplitude_array_seq = jnp.asarray(\n",
    "    rect_pulse(sequence_ts, parameters_rect)\n",
    ")  # amplitude array\n",
    "\n",
    "# update pulse parameters for 4-gate sequence (use deepcopy to avoid modifying original params)\n",
    "params_full_seq = deepcopy(params)\n",
    "params_full_seq[\"pwc/env\"] = amplitude_array_seq\n",
    "\n",
    "# define and solve Problem for 4-gate sequence\n",
    "problem_seq = qr.Problem(H, y0, params_full_seq, (0.0, sequence_time))\n",
    "solver_seq = qr.PWCSolver(n=N_seq, store=True)\n",
    "solver_seq.set_system(problem_seq.sepwc())\n",
    "time_seq, res_seq = solver_seq.evolve(*problem_seq.problem())\n",
    "\n",
    "# calculate expectation value of each state for 4-gate sequence\n",
    "pop_seq = get_population(res_seq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then plot the state populations to see how they look for the four-pulse cycle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# keep only non-zero states\n",
    "filtered_pop_seq = pop_seq[:, nonzero_mask]\n",
    "\n",
    "# plot populations for MS gate (4 identical PWC pulses)\n",
    "canvas.init_canvas(x_axis_label=\"t [µs]\", y_axis_label=\"Population\")\n",
    "canvas.plot(sequence_ts / 1e-6, filtered_pop_seq, labels=filtered_labels)\n",
    "canvas.canvas.legend.spacing = -4\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see the populations evolve coherently, with the system returning to the $|000\\rangle$ state after four pulses, confirming the cyclic behaviour of the gate."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Adding Gaussian filter (imperfect gate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In reality, inherent limitations of the control electronics prevent us from generating a perfect rectangular pulse. To account for this, we can apply a Gaussian filter to simulate a finite rise time.\n",
    "\n",
    "Here, we'll use a $4^{\\text{th}}$ order filter, with a 150 $n \\text{s}$ rise time. We can then compute the filter coefficients based on these parameters.\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "**Note:** You can read more about simulating a Gaussian rise time in our [rise time tutorial](https://docs.qruise.com/latest/notebooks/control_stack/gaussian_rise_time/). \n",
    "\n",
    "___"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import TransferFunc\n",
    "from qruise.toolset.utils import compute_rise_time_gaussian_coeffs\n",
    "\n",
    "rise_time = 150e-9  # rise time (s)\n",
    "\n",
    "# compute Gaussian filter coefficients to simulate finite rise time\n",
    "abs_coeffs = compute_rise_time_gaussian_coeffs(rise_time, n_order=4)\n",
    "b = jnp.array([1.0])\n",
    "a = jnp.array(abs_coeffs)\n",
    "\n",
    "# define scalar version of PWC pulse for time differentiation\n",
    "scalar_pwc = lambda t, p: rabi_drive_pwc(jnp.array([t]), p)[0]\n",
    "dsource = jax.grad(scalar_pwc, argnums=0)\n",
    "\n",
    "# define transfer function that adds Gaussian rise time filter to original PWC pulse\n",
    "pwc_with_rise = TransferFunc(\n",
    "    \"tf_pwc\",\n",
    "    {\n",
    "        \"b\": (b, False),\n",
    "        \"a\": (a, False),\n",
    "        \"t0\": (0.0, False),\n",
    "        \"t1\": (tfinal, False),\n",
    "    },\n",
    "    rabi_drive_pwc,\n",
    "    dsource=dsource,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we did earlier, we can merge the parameter dictionaries to make them easier to work with. To do this, we just need to add `pwc_with_rise.params` to `params`.\n",
    "\n",
    "We can then evaluate the PWC pulse with the Gaussian rise time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge new parameters with existing ones\n",
    "params = params | pwc_with_rise.params\n",
    "\n",
    "# evaluate PWC pulse with rise time\n",
    "pwc_with_rise_vec = jax.vmap(pwc_with_rise, (0, None))\n",
    "pwc_with_rise_result = pwc_with_rise_vec(ts, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's quickly plot the pulses without and with rise time side-by-side to see the difference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canvas = qr.PlotUtil(\n",
    "    x_axis_label=\"t [µs]\",\n",
    "    y_axis_label=\"Amplitude [a.u.]\",\n",
    "    notebook=True,\n",
    "    x_range=[0, tfinal / 1e-6],\n",
    ")\n",
    "canvas.plot(ts / 1e-6, rabi_drive_pwc(ts, params), labels=[\"Drive pulse\"])\n",
    "canvas.plot(\n",
    "    ts / 1e-6,\n",
    "    pwc_with_rise_result,\n",
    "    labels=[\"Drive pulse with rise time\"],\n",
    "    line_dash=\"dashed\",\n",
    ")\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great, we see that without the rise time, the pulse envelope is flat with amplitude $1.0$, starting at $t=0.0$ ns and ending at $t = t_\\text{final}$. Once the rise time is included, its effect becomes clearly visible.\n",
    "\n",
    "We can then re-define our Hamiltonian and `Problem` using the drive pulse with rise time, and simulate the resulting state population dynamics as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define envelopes\n",
    "def env1_rise(t, params):\n",
    "    return pwc_with_rise(t, params) * (-eta * Omega * jnp.exp(1j * delta * t))\n",
    "\n",
    "\n",
    "def env2_rise(t, params):\n",
    "    return pwc_with_rise(t, params) * (-eta * Omega * jnp.exp(-1j * delta * t))\n",
    "\n",
    "\n",
    "# create time-dependent Hamiltonian\n",
    "H = qr.Hamiltonian(None, [(Ja, env1_rise), (Ja_dag, env2_rise)])\n",
    "\n",
    "# define Problem\n",
    "problem = qr.Problem(H, y0, params, (0.0, tfinal))\n",
    "\n",
    "# use PWCSolver to solve Schrödinger equation\n",
    "solver = qr.PWCSolver(n=N, store=True)\n",
    "solver.set_system(problem.sepwc())\n",
    "time, res = solver.evolve(*problem.problem())\n",
    "\n",
    "# calculate expectation value of each state\n",
    "pop = get_population(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll again plot only the states with non-zero population for readability, and calculate the state fidelity at the end of the gate duration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find states that are zero across all time points\n",
    "zero_mask = np.all(pop == 0, axis=0)\n",
    "\n",
    "# keep only non-zero states\n",
    "nonzero_mask = ~zero_mask\n",
    "filtered_pop = pop[:, nonzero_mask]\n",
    "filtered_labels = [str(label) for i, label in enumerate(labels) if not zero_mask[i]]\n",
    "\n",
    "# plot non-zero states\n",
    "canvas.init_canvas(x_axis_label=\"t [µs]\", y_axis_label=\"Population\")\n",
    "canvas.plot(ts / 1e-6, filtered_pop, labels=filtered_labels)\n",
    "canvas.canvas.legend.spacing = -4\n",
    "canvas.show_canvas()\n",
    "\n",
    "# calculate state fidelity\n",
    "states_fidelity_average = get_fidelity_over_sates(H)\n",
    "pprint(f\"Non-ideal gate: {states_fidelity_average*100:.3f} %\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The difference in the plot is somewhat subtle, but if you look carefully (try zooming in!), you'll notice that the $|101\\rangle$ state doesn't return to zero and the populations of the $|000\\rangle$ and $|110\\rangle$ states are less than 50%. This is clearer when we see that the fidelity has dropped to just 97.947%.\n",
    "\n",
    "We can again simulate the dynamics for four MS gates, to see more clearly how the rise time affects the system dynamics. This follows the same process as for the ideal gate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "# update pulse parameters for 4-gate sequence\n",
    "params_full_seq = deepcopy(params)\n",
    "params_full_seq[\"pwc/env\"] = amplitude_array_seq\n",
    "\n",
    "# define and solve Problem for 4-gate sequence\n",
    "problem_seq = qr.Problem(H, y0, params_full_seq, (0.0, sequence_time))\n",
    "solver_seq = qr.PWCSolver(\n",
    "    n=N_seq,\n",
    "    store=True,\n",
    ")\n",
    "solver_seq.set_system(problem_seq.sepwc())\n",
    "time_seq, res_seq = solver_seq.evolve(*problem_seq.problem())\n",
    "\n",
    "# calculate expectation value of each state for 4-gate sequence\n",
    "pop_seq = get_population(res_seq)\n",
    "\n",
    "# keep only non-zero states\n",
    "filtered_pop_seq = pop_seq[:, nonzero_mask]\n",
    "\n",
    "# plot populations for MS gate (4 identical PWC pulses)\n",
    "canvas.init_canvas(x_axis_label=\"t [µs]\", y_axis_label=\"Population\")\n",
    "canvas.plot(sequence_ts / 1e-6, filtered_pop_seq, labels=filtered_labels)\n",
    "canvas.canvas.legend.spacing = -4\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the population of the $|000\\rangle$ state does not return to 1, with residual population in the $|101\\rangle$ state, demonstrating how the rise time negatively impacts the gate operation.\n",
    "\n",
    "Great, you just simulated a Mølmer-Sørensen gate and explored how rise time affects the state fidelity."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qruise-simple",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
