{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rise time\n",
    "\n",
    "This notebook demonstrates how to simulate the effect of rise time using the `qruise-toolset`. To achieve this, we define a Heaviside (step) function as the source of the signal chain and convolve with a Gaussian filter to model the smoothing effect caused by the finite rise time.\n",
    "\n",
    "The tutorial consists of the following:\n",
    "\n",
    "1. Motivation\n",
    "2. Defining parameters\n",
    "3. Defining the Gaussian filter\n",
    "4. Constructing the signal chain\n",
    "5. Visualising & calculating the rise time\n",
    "\n",
    "## 1. Motivation\n",
    "\n",
    "An inherent trait of electronic devices is that they do not respond instantaneously to changes in the input signal, but rather a finite time passes before the effect takes place. This delay is called the rise time. In high-speed electronics, and particularly in quantum control, minimising the rise time is crucial for generating pulses with minimal distortion. \n",
    "\n",
    "In control electronics specification sheets, the rise time is defined as the time it takes a signal to transition from a minimum to a maximum percentage of its total amplitude. This is typically 20% to 80%, corresponding to a fraction of $0.6$, though the range can vary, so it's always a good idea to verify. In the `qruise-toolset`, rise time is modelled by convolving the pulse with a Gaussian filter (`GaussianRiseTime`), where the fraction determines the filter width.\n",
    "\n",
    "---\n",
    "\n",
    "**Tip:** Make sure that your environment is correctly set to handle `float64` precision by setting `JAX_ENABLE_X64=True` or add"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "jax.config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "to your scripts' preamble.\n",
    "\n",
    "---\n",
    "\n",
    "## 2. Defining parameters\n",
    "\n",
    "We start by defining the time parameters for our pulse. These include the start and end times of the pulse ($t_0$ and $t_\\text{final}$, respectively), and the grid size, which specifies the number of evenly spaced points used to discretise the time interval. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "t0 = 0.0  # start time of pulse (s)\n",
    "tfinal = 1e-9  # end (final) time of pulse (s)\n",
    "grid_size = int(1e3)  # simulation grid size (number of time points)\n",
    "t_span = jnp.linspace(t0, tfinal, grid_size)  # defines time span array"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we need to define the parameters of the Heaviside (step) function, $H(t)$, which will act as the source in the signal chain. Within the `qruise-toolset`, this is defined as:\n",
    "\n",
    "$$\n",
    "H(t) =\n",
    "\\begin{cases}\n",
    "0, & t < t_{\\text{offset}}, \\\\\n",
    "1, & t \\geq t_{\\text{offset}},\n",
    "\\end{cases}\n",
    "\n",
    "$$\n",
    "\n",
    "where $t_\\text{offset}$ is the time at which the step occurs. In this tutorial, we'll set it to occur in the middle of the simulation window, i.e., $t_\\text{offset}=\\frac{t_0+t_1}{2}$ :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "heaviside_offset = (t0 + tfinal) / 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Defining the Gaussian filter\n",
    "\n",
    "In this example, we'll convolve the Heaviside pulse with a Gaussian filter, $G(t)$, to simulate the effect of rise time on the pulse. The resultant pulse, $y(t)$, is given by \n",
    "\n",
    "$$\n",
    "y(t) = G(t;\\mu,\\sigma)*H(t),\n",
    "$$\n",
    "\n",
    "where the Gaussian filter is described by\n",
    "\n",
    "$$ G(t;a,\\mu,\\sigma) = \\frac{a}{\\sigma\\sqrt{2\\pi}} \\mathrm{exp}\\left[-\\frac{(t - \\mu)^2}{2\\sigma^2}\\right].$$\n",
    "\n",
    "Here, $a$ is a scalar that adjusts the amplitude of the pulse, $\\sigma$ is the pulse variance, and $\\mu$ is the time at which it reaches its maximum amplitude.\n",
    " \n",
    "To simulate the rise time, we need to set the order of the Gaussian filter, which determines how many derivatives are included. In this example, we'll include terms up to the 4<sup>th</sup> order. We also need to define the rise time and the fraction that determines the width of the Gaussian filter, i.e., the value used to calibrate $\\sigma$. Here, we'll use `rise_time = 0.1` ns and `fraction = 0.6`, corresponding to the most commonly used minimum and maximum values of 20% and 80%, as described earlier. We can then compute the filter coefficients based on these parameters. Thankfully, the `compute_rise_time_gaussian_coeffs` utility function does all the heavy-lifting for us."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset.utils import compute_rise_time_gaussian_coeffs\n",
    "\n",
    "# compute abs_coeffs using the function\n",
    "rise_time = 0.1e-9\n",
    "abs_coeffs = compute_rise_time_gaussian_coeffs(rise_time)\n",
    "b = jnp.array([1.0])\n",
    "a = jnp.array(abs_coeffs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Constructing the signal chain\n",
    "\n",
    "Now that we've defined all the relevant parameters, we need to construct the signal chain, which consists of two components: the source signal (our Heaviside function) and the rise time pseudo-device, for which we'll use a transfer function. We start by defining instances of the `Heaviside()` and `TransferFunc()` component classes with their respective parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import Heaviside, TransferFunc\n",
    "\n",
    "# define instance of Heaviside named \"awg_heaviside\" with heaviside_offset\n",
    "heaviside = Heaviside(\"awg_heaviside\", {\"tau\": (heaviside_offset, False)})\n",
    "\n",
    "# define instance of TransferFunc (rise time) named \"tf\" with Gaussian coefficients a and b and time parameters t0 and t1\n",
    "trans_func = TransferFunc(\n",
    "    \"tf\",\n",
    "    {\"b\": (b, False), \"a\": (a, False), \"t0\": (t0, False), \"t1\": (tfinal, False)},\n",
    "    heaviside,\n",
    "    dsource=jax.grad(heaviside, argnums=0),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we need to evaluate the signal chain over the relevant time span we defined earlier (`t_span`). A useful feature of the signal chain in the `qruise-toolset` is the ability to probe the signal at specific points, allowing us to observe the effects of individual components. For instance, we can probe the chain up to the source component (the Heaviside function) and then include the rise time component in order to observe its impact more clearly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import ParameterCollection\n",
    "from jax import vmap\n",
    "\n",
    "# combine Heaviside and transfer function parameters into single parameter collection\n",
    "pc = ParameterCollection()\n",
    "pc.add_dict(heaviside.params | trans_func.params)\n",
    "params = pc.get_collection()\n",
    "\n",
    "# evaluate signal over defined time span\n",
    "awg_result = heaviside(t_span, params)  # only Heaviside\n",
    "trans_func_result = vmap(trans_func, (0, None))(\n",
    "    t_span, params\n",
    ")  # Heaviside and rise time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Visualising & calculating the rise time\n",
    "\n",
    "We can now plot the signal chain both without and with the rise time component to see how it affects the output. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import PlotUtil\n",
    "\n",
    "canvas = PlotUtil(x_axis_label=\"t [sec]\", y_axis_label=\"Amplitude\", notebook=True)\n",
    "canvas.plot(t_span, awg_result, labels=[\"Heaviside (step) only\"])\n",
    "canvas.plot(t_span, trans_func_result, labels=[\"Heaviside + rise time\"])\n",
    "canvas.show_canvas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, without the rise time component, we see a standard step function that transitions from $0$ to $1$ at $t=\\frac{t_0+t_1}{2}$. When we include the rise time component, we observe a smooth, finite transition determined by the specified rise time.\n",
    "\n",
    "Let's now verify if this matches the theoretical rise time (up to some error due to the resolution of the `t_span`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate observed rise time\n",
    "ini_time = 0.5 * (\n",
    "    1.0 - 0.6\n",
    ")  # use same fraction as in compute_rise_time_gaussian_coeffs\n",
    "fin_time = 1.0 - ini_time\n",
    "ind_ini = jnp.where(trans_func_result < ini_time)[0].max()\n",
    "ind_fin = jnp.where(trans_func_result > fin_time)[0].min()\n",
    "risetime_meas = t_span[ind_fin] - t_span[ind_ini]\n",
    "print(f\"t[{ini_time:.1f}] = {t_span[ind_ini]*1e9:.3f} ns\")\n",
    "print(f\"t[{fin_time:.1f}] = {t_span[ind_fin]*1e9:.3f} ns\")\n",
    "print(f\"Rise time = {risetime_meas*1e9:.3f} ns\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great, our rise time is as expected! You can now integrate this into your control stack model. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qruise-simple",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
