{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimising with frozen signal-chain parameters\n",
    "\n",
    "In this tutorial, we demonstrate the `freeze_keys` feature of `Optimiser.optimise()`.\n",
    "When a signal chain component (such as a `TransferFunc`) is used directly as a Hamiltonian\n",
    "coefficient, it may contain parameters that represent **fixed hardware properties** —\n",
    "filter coefficients, timing boundaries, etc. — that should never be modified by the\n",
    "optimiser. \n",
    "\n",
    "**This tutorial covers:**\n",
    "\n",
    "1. System setup: `PWCPulse` → `TransferFunc` → `Hamiltonian` (direct, no wrapper)\n",
    "2. Inspecting frozen parameters via `opt_prob.inactive_params`\n",
    "3. Optimisation **without** `freeze_keys` — demonstrating the problem\n",
    "4. Optimisation **with** `freeze_keys` — the correct approach\n",
    "5. Summary and best practices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To ensure numerical precision throughout, we enable 64-bit floating point in JAX:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "jax.config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. System setup\n",
    "\n",
    "We model a two-level qubit with a stationary detuning term ($\\sigma_z$) and an external\n",
    "drive ($\\sigma_x$) shaped by a piecewise-constant (PWC) pulse filtered through a\n",
    "Gaussian rise-time transfer function.\n",
    "\n",
    "The signal chain is:\n",
    "\n",
    "```\n",
    "PWCPulse(\"drive\")  →  TransferFunc(\"tf\")  →  Hamiltonian coefficient\n",
    "```\n",
    "\n",
    "- `PWCPulse` provides the optimisable envelope (`env`, active) plus timing parameters\n",
    "  (`t0`, `dt`) that are frozen — they encode the hardware sampling grid.\n",
    "- `TransferFunc` applies a Gaussian filter whose coefficients (`b`, `a`) and time\n",
    "  window (`t0`, `t1`) are frozen — they represent fixed hardware characteristics.\n",
    "- The `TransferFunc` instance is passed **directly** to the `Hamiltonian` constructor\n",
    "  via its `Ht` parameter, allowing `Hamiltonian.inactive_params` to discover all frozen\n",
    "  keys automatically."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We begin by setting up the simulation grid and pulse parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from qruise.toolset.utils import compute_rise_time_gaussian_coeffs\n",
    "\n",
    "t0 = 0.0  # start time (s)\n",
    "tfinal = 20e-9  # end time (s)\n",
    "grid_size = 50  # simulation grid points\n",
    "pulse_size = 10  # number of PWC steps\n",
    "rise_time = 3e-9  # hardware rise time (s)\n",
    "\n",
    "ts_pulse = jnp.linspace(t0, tfinal, pulse_size)\n",
    "dt = float(ts_pulse[1] - ts_pulse[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The optimisable input signal is a piecewise-constant (PWC) pulse. We mark its envelope samples (`env`) as active and its timing grid (`t0`, `dt`) as frozen — the timing encodes a fixed hardware sampling grid that the optimiser must not modify."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset.control_stack import PWCPulse\n",
    "\n",
    "pwc = PWCPulse(\n",
    "    \"drive\",\n",
    "    {\n",
    "        \"env\": (jnp.ones(pulse_size) * 1e7, True),  # active — will be optimised\n",
    "        \"t0\": (t0, False),  # frozen — hardware timing\n",
    "        \"dt\": (dt, False),  # frozen — hardware timing\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `TransferFunc` applies a Gaussian low-pass filter that models the hardware rise time. We compute its IIR filter coefficients using `compute_rise_time_gaussian_coeffs`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "abs_coeffs = compute_rise_time_gaussian_coeffs(rise_time, n_order=4)\n",
    "b = jnp.array([1.0])\n",
    "a = jnp.array(abs_coeffs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`TransferFunc` needs a time derivative of its source signal (`dsource`), which is computed with `jax.grad`. Since `jax.grad` requires a scalar-to-scalar function and `PWCPulse.__call__` operates on a batched time array, we wrap it in a thin scalar helper first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# scalar wrapper for jax.grad (scalar t → scalar output)\n",
    "def pwc_signal(t, params):\n",
    "    return pwc(t, params)\n",
    "\n",
    "\n",
    "def scalar_pwc_signal(t, params):\n",
    "    return pwc_signal(jnp.array([t]), params)[0]\n",
    "\n",
    "\n",
    "# time derivative of the signal\n",
    "dsource = jax.grad(scalar_pwc_signal, argnums=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now construct the `TransferFunc`, marking all filter properties as frozen:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import TransferFunc\n",
    "\n",
    "tf = TransferFunc(\n",
    "    \"tf\",\n",
    "    {\n",
    "        \"b\": (b, False),  # frozen — filter numerator\n",
    "        \"a\": (a, False),  # frozen — filter denominator\n",
    "        \"t0\": (t0, False),  # frozen — filter start time\n",
    "        \"t1\": (tfinal, False),  # frozen — filter end time\n",
    "    },\n",
    "    pwc_signal,\n",
    "    dsource=dsource,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**Key point:** `tf` (the `TransferFunc` instance) is passed directly to the `Hamiltonian`\n",
    "constructor via its `Ht` parameter — not wrapped in a plain Python function. This allows\n",
    "`Hamiltonian.inactive_params` to introspect its frozen keys automatically.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now merge the `PWCPulse` and `TransferFunc` parameter dicts into a single flat collection using `ParameterCollection`, and define the Hamiltonian. Note that `tf` is passed **directly** as the Hamiltonian coefficient — not wrapped in a plain Python function. This is what enables `Hamiltonian.inactive_params` to introspect frozen keys automatically, as highlighted in the note above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import Hamiltonian, ParameterCollection\n",
    "from qutip import sigmaz, sigmax\n",
    "\n",
    "# merge PWCPulse and TransferFunc parameters into a single flat dict\n",
    "pc = ParameterCollection()\n",
    "pc.add_dict(pwc.params | tf.params)\n",
    "params = pc.get_collection()\n",
    "\n",
    "delta = 1e8  # qubit detuning (Hz)\n",
    "\n",
    "# tf passed directly as Hamiltonian coefficient\n",
    "H = Hamiltonian(delta * sigmaz() / 2, [(sigmax() / 2, tf)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the Hamiltonian in place, we define the `QOCProblem` as usual, specifying the initial state, target state, and state infidelity loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import QOCProblem\n",
    "\n",
    "y0 = jnp.array([1.0, 0.0], dtype=jnp.complex128)  # initial state |0>\n",
    "y_t = jnp.array([0.0, 1.0], dtype=jnp.complex128)  # target state |1>\n",
    "\n",
    "\n",
    "def loss(x, y):\n",
    "    \"\"\"State infidelity: 1 - |<x|y>|^2\"\"\"\n",
    "    return jnp.real(1.0 - jnp.abs(jnp.vdot(x, y)) ** 2)\n",
    "\n",
    "\n",
    "opt_prob = QOCProblem(H, y0, params, (t0, tfinal), y_t, loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Inspecting frozen parameters\n",
    "\n",
    "`QOCProblem.inactive_params` aggregates frozen keys from all signal-chain components\n",
    "that are used directly as Hamiltonian coefficients.\n",
    "\n",
    "The expected frozen keys are those from `TransferFunc`: `tf/a`, `tf/b`, `tf/t0`,\n",
    "and `tf/t1`. Note that `drive/t0` and `drive/dt` from `PWCPulse` are **not** surfaced\n",
    "here because `PWCPulse` is accessed via the `pwc_signal` closure, not directly as a\n",
    "Hamiltonian coefficient — see the Summary section for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Frozen parameter keys:\")\n",
    "for key in sorted(opt_prob.inactive_params):\n",
    "    print(f\"  {key}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Optimisation without `freeze_keys` — why the search space matters\n",
    "\n",
    "Passing all params to the minimiser unnecessarily inflates the search space and risks\n",
    "drifting hardware parameters over longer runs. The scipy parameter vector includes the\n",
    "array-valued filter coefficients `tf/b` (shape `(1, 5)`) and `tf/a` (shape `(5,)`),\n",
    "alongside the scalar timing parameters `tf/t0` and `tf/t1`. The optimiser operates in a\n",
    "much larger search space than necessary — `tf/b`, `tf/a`, `tf/t0`, `tf/t1` contribute\n",
    "extra dimensions that should be held fixed.\n",
    "\n",
    "We show the number of parameters in the scipy vector with and without `freeze_keys`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we record the initial values of the frozen filter parameters so we can check whether they are modified after optimisation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qruise.toolset import Optimiser, ScipyMinimiser, PWCSolver\n",
    "\n",
    "# record initial filter param values for later comparison\n",
    "b_initial = params[\"tf/b\"].copy()\n",
    "a_initial = params[\"tf/a\"].copy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To see the difference concretely, we count how many scalar values are packed into the scipy minimiser's parameter vector in each case. Array-valued parameters such as `tf/b` (shape `(1, 5)`) and `tf/a` (shape `(5,)`) each contribute multiple elements, inflating the search space when `freeze_keys` is not used:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# total scalar elements across all params\n",
    "n_all = sum(v.flatten().size if hasattr(v, \"flatten\") else 1 for v in params.values())\n",
    "\n",
    "# active params only (frozen keys excluded)\n",
    "active_keys = set(params.keys()) - opt_prob.inactive_params\n",
    "active_params = {k: v for k, v in params.items() if k in active_keys}\n",
    "n_active = sum(\n",
    "    v.flatten().size if hasattr(v, \"flatten\") else 1 for v in active_params.values()\n",
    ")\n",
    "\n",
    "print(f\"Parameters in scipy vector WITHOUT freeze_keys: {n_all}\")\n",
    "print(f\"Parameters in scipy vector WITH    freeze_keys: {n_active}\")\n",
    "print(f\"Frozen params excluded: {sorted(opt_prob.inactive_params)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's run the optimisation without `freeze_keys` to observe the behaviour:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "minimiser_bad = ScipyMinimiser(\"L-BFGS-B\", maxiter=5)\n",
    "solver_bad = PWCSolver(n=grid_size, eq=opt_prob.sepwc())\n",
    "opt_bad = Optimiser(minimiser_bad, solver_bad)\n",
    "opt_bad.set_optimisation(opt_prob.loss)\n",
    "\n",
    "# no freeze_keys — all params enter the scipy vector\n",
    "result_bad, _ = opt_bad.optimise(*opt_prob.problem())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Optimisation with `freeze_keys` — the correct approach\n",
    "\n",
    "Passing `freeze_keys=opt_prob.inactive_params` tells the optimiser to exclude frozen\n",
    "parameters from the search space. They are held constant, and the minimiser only\n",
    "updates the active envelope samples (`drive/env`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "minimiser_good = ScipyMinimiser(\"L-BFGS-B\", maxiter=20)\n",
    "solver_good = PWCSolver(n=grid_size, eq=opt_prob.sepwc())\n",
    "opt_good = Optimiser(minimiser_good, solver_good)\n",
    "opt_good.set_optimisation(opt_prob.loss)\n",
    "\n",
    "result_good, summary = opt_good.optimise(\n",
    "    *opt_prob.problem(),\n",
    "    freeze_keys=opt_prob.inactive_params,  # pass frozen keys here\n",
    ")\n",
    "\n",
    "print(\"tf/b initial:          \", b_initial)\n",
    "print(\"tf/b after (unchanged):\", result_good[\"tf/b\"])\n",
    "print(\"tf/a unchanged:\", jnp.allclose(result_good[\"tf/a\"], a_initial))\n",
    "print(f\"Optimiser converged: {summary.success}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can verify that the frozen parameters are bitwise identical to their initial values, while the active envelope has been updated by the optimiser:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert jnp.allclose(result_good[\"tf/b\"], params[\"tf/b\"]), \"tf/b drifted!\"\n",
    "assert jnp.allclose(result_good[\"tf/a\"], params[\"tf/a\"]), \"tf/a drifted!\"\n",
    "assert result_good[\"tf/t0\"] == params[\"tf/t0\"], \"tf/t0 drifted!\"\n",
    "assert result_good[\"tf/t1\"] == params[\"tf/t1\"], \"tf/t1 drifted!\"\n",
    "assert \"drive/env\" in result_good, \"drive/env missing from result!\"\n",
    "\n",
    "print(\"All assertions passed — frozen params are unchanged, active params optimised.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Summary and best practices\n",
    "\n",
    "The `freeze_keys` workflow for signal-chain optimisation:\n",
    "\n",
    "```python\n",
    "opt_prob = QOCProblem(H, y0, params, (t0, tfinal), y_t, loss)\n",
    "\n",
    "result, summary = opt.optimise(\n",
    "    *opt_prob.problem(),\n",
    "    freeze_keys=opt_prob.inactive_params,  # automatically excludes frozen params\n",
    ")\n",
    "```\n",
    "\n",
    "**Key points:**\n",
    "\n",
    "- Declare frozen parameters with `False` and active ones with `True` when constructing\n",
    "  signal-chain components (`PWCPulse`, `TransferFunc`, etc.).\n",
    "- Pass signal-chain components **directly** to the `Hamiltonian` constructor via its\n",
    "  `Ht` parameter (not wrapped in a plain function) so that `Hamiltonian.inactive_params`\n",
    "  can discover frozen keys.\n",
    "- Use `opt_prob.inactive_params` as the `freeze_keys` argument — it aggregates all frozen\n",
    "  keys from the entire signal chain automatically.\n",
    "\n",
    "**Known limitation:** If a signal-chain component is wrapped in a plain Python function\n",
    "before being passed to the Hamiltonian, `inactive_params` will return an empty set for\n",
    "that component. In that case, pass `freeze_keys` manually:\n",
    "\n",
    "```python\n",
    "# When using a plain-function wrapper, specify frozen keys explicitly:\n",
    "result, _ = opt.optimise(*opt_prob.problem(), freeze_keys={\"drive/t0\", \"drive/dt\"})\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qruise-simple",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
