"""Rabi experiment that sweeps length and frequency."""
from dataclasses import dataclass, field
from typing import Optional, Union
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from qibolab import AcquisitionType, AveragingMode, Parameter, Sweeper
from qibocal import update
from qibocal.auto.operation import Data, Parameters, QubitId, Routine
from qibocal.calibration import CalibrationPlatform
from qibocal.config import log
from qibocal.protocols.utils import table_dict, table_html
from ...result import magnitude, phase
from ..utils import HZ_TO_GHZ, fallback_period, guess_period, readout_frequency
from .length_signal import RabiLengthSignalResults
from .utils import fit_length_function, sequence_length
__all__ = [
"rabi_length_frequency_signal",
"RabiLengthFrequencySignalParameters",
"RabiLengthFreqSignalData",
"_update",
"RabiLengthFrequencySignalResults",
]
[docs]
@dataclass
class RabiLengthFrequencySignalParameters(Parameters):
"""RabiLengthFrequency runcard inputs."""
pulse_duration_start: float
"""Initial pi pulse duration [ns]."""
pulse_duration_end: float
"""Final pi pulse duration [ns]."""
pulse_duration_step: float
"""Step pi pulse duration [ns]."""
min_freq: int
"""Minimum frequency as an offset."""
max_freq: int
"""Maximum frequency as an offset."""
step_freq: int
"""Frequency to use as step for the scan."""
pulse_amplitude: Optional[float] = None
"""Pi pulse amplitude. Same for all qubits."""
rx90: bool = False
"""Calibration of native pi pulse, if true calibrates pi/2 pulse"""
interpolated_sweeper: bool = False
"""Use real-time interpolation if supported by instruments."""
[docs]
@dataclass
class RabiLengthFrequencySignalResults(RabiLengthSignalResults):
"""RabiLengthFrequency outputs."""
rx90: bool
"""Pi or Pi_half calibration"""
frequency: dict[QubitId, Union[float, list[float]]]
"""Drive frequency for each qubit."""
RabiLenFreqSignalType = np.dtype(
[
("len", np.float64),
("freq", np.float64),
("signal", np.float64),
("phase", np.float64),
]
)
"""Custom dtype for rabi length."""
[docs]
@dataclass
class RabiLengthFreqSignalData(Data):
"""RabiLengthFreqSignal data acquisition."""
rx90: bool
"""Pi or Pi_half calibration"""
amplitudes: dict[QubitId, float] = field(default_factory=dict)
"""Pulse amplitudes provided by the user."""
data: dict[QubitId, npt.NDArray[RabiLenFreqSignalType]] = field(
default_factory=dict
)
"""Raw data acquired."""
[docs]
def register_qubit(self, qubit, freq, lens, signal, phase):
"""Store output for single qubit."""
size = len(freq) * len(lens)
frequency, length = np.meshgrid(freq, lens)
data = np.empty(size, dtype=RabiLenFreqSignalType)
data["freq"] = frequency.ravel()
data["len"] = length.ravel()
data["signal"] = signal.ravel()
data["phase"] = phase.ravel()
self.data[qubit] = np.rec.array(data)
[docs]
def durations(self, qubit):
"""Unique qubit lengths."""
return np.unique(self[qubit].len)
[docs]
def frequencies(self, qubit):
"""Unique qubit frequency."""
return np.unique(self[qubit].freq)
def _acquisition(
params: RabiLengthFrequencySignalParameters,
platform: CalibrationPlatform,
targets: list[QubitId],
) -> RabiLengthFreqSignalData:
"""Data acquisition for Rabi experiment sweeping length."""
sequence, qd_pulses, delays, ro_pulses, amplitudes = sequence_length(
targets, params, platform, params.rx90
)
sweep_range = (
params.pulse_duration_start,
params.pulse_duration_end,
params.pulse_duration_step,
)
if params.interpolated_sweeper:
len_sweeper = Sweeper(
parameter=Parameter.duration_interpolated,
range=sweep_range,
pulses=[qd_pulses[q] for q in targets],
)
else:
len_sweeper = Sweeper(
parameter=Parameter.duration,
range=sweep_range,
pulses=[qd_pulses[q] for q in targets] + [delays[q] for q in targets],
)
frequency_range = np.arange(
params.min_freq,
params.max_freq,
params.step_freq,
)
freq_sweepers = {}
for qubit in targets:
channel = platform.qubits[qubit].drive
freq_sweepers[qubit] = Sweeper(
parameter=Parameter.frequency,
values=platform.config(channel).frequency + frequency_range,
channels=[channel],
)
data = RabiLengthFreqSignalData(amplitudes=amplitudes, rx90=params.rx90)
results = platform.execute(
[sequence],
[[len_sweeper], [freq_sweepers[q] for q in targets]],
updates=[
{platform.qubits[q].probe: {"frequency": readout_frequency(q, platform)}}
for q in targets
],
nshots=params.nshots,
relaxation_time=params.relaxation_time,
acquisition_type=AcquisitionType.INTEGRATION,
averaging_mode=AveragingMode.CYCLIC,
)
for qubit in targets:
result = results[ro_pulses[qubit].id]
data.register_qubit(
qubit=qubit,
freq=freq_sweepers[qubit].values,
lens=len_sweeper.values,
signal=magnitude(result),
phase=phase(result),
)
return data
def _fit(data: RabiLengthFreqSignalData) -> RabiLengthFrequencySignalResults:
"""Do not perform any fitting procedure."""
fitted_frequencies = {}
fitted_durations = {}
fitted_parameters = {}
for qubit in data.data:
durations = data.durations(qubit)
freqs = data.frequencies(qubit)
signal = data[qubit].signal
signal_matrix = signal.reshape(len(durations), len(freqs)).T
# guess optimal frequency maximizing oscillatio amplitude
index = np.argmax([max(x) - min(x) for x in signal_matrix])
frequency = freqs[index]
y = signal_matrix[index]
y_min = np.min(y)
y_max = np.max(y)
x_min = np.min(durations)
x_max = np.max(durations)
x = (durations - x_min) / (x_max - x_min)
y = (y - y_min) / (y_max - y_min)
period = fallback_period(guess_period(x, y))
pguess = [0, np.sign(y[0]) * 0.5, period, 0, 0]
try:
popt, _, pi_pulse_parameter = fit_length_function(
x,
y,
pguess,
signal=True,
x_limits=(x_min, x_max),
y_limits=(y_min, y_max),
)
fitted_frequencies[qubit] = frequency
fitted_durations[qubit] = pi_pulse_parameter
fitted_parameters[qubit] = popt
except Exception as e:
log.warning(f"Rabi fit failed for qubit {qubit} due to {e}.")
return RabiLengthFrequencySignalResults(
length=fitted_durations,
amplitude=data.amplitudes,
fitted_parameters=fitted_parameters,
frequency=fitted_frequencies,
rx90=data.rx90,
)
def _plot(
data: RabiLengthFreqSignalData,
target: QubitId,
fit: RabiLengthFrequencySignalResults = None,
):
"""Plotting function for RabiLengthFrequency."""
figures = []
fitting_report = ""
fig = make_subplots(
rows=1,
cols=2,
horizontal_spacing=0.1,
vertical_spacing=0.2,
subplot_titles=(
"Signal [a.u.]",
"Phase [rad]",
),
)
qubit_data = data[target]
frequencies = qubit_data.freq * HZ_TO_GHZ
durations = qubit_data.len
fig.add_trace(
go.Heatmap(
x=durations,
y=frequencies,
z=qubit_data.signal,
colorbar_x=0.46,
),
row=1,
col=1,
)
fig.add_trace(
go.Heatmap(
x=durations,
y=frequencies,
z=qubit_data.phase,
colorbar_x=1.01,
),
row=1,
col=2,
)
fig.update_xaxes(title_text="Durations [ns]", row=1, col=1)
fig.update_xaxes(title_text="Durations [ns]", row=1, col=2)
fig.update_yaxes(title_text="Frequency [GHz]", row=1, col=1)
figures.append(fig)
if fit is not None:
fig.add_trace(
go.Scatter(
x=[min(durations), max(durations)],
y=[fit.frequency[target] * HZ_TO_GHZ] * 2,
mode="lines",
line={"color": "white", "width": 4, "dash": "dash"},
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=[min(durations), max(durations)],
y=[fit.frequency[target] * HZ_TO_GHZ] * 2,
mode="lines",
line={"color": "white", "width": 4, "dash": "dash"},
),
row=1,
col=2,
)
pulse_name = "Pi-half pulse" if data.rx90 else "Pi pulse"
fitting_report = table_html(
table_dict(
target,
["Optimal rabi frequency", f"{pulse_name} duration"],
[
fit.frequency[target],
f"{fit.length[target]:.2f} ns",
],
)
)
fig.update_layout(
showlegend=False,
legend={"orientation": "h"},
)
return figures, fitting_report
[docs]
def _update(
results: RabiLengthFrequencySignalResults,
platform: CalibrationPlatform,
target: QubitId,
):
update.drive_amplitude(results.amplitude[target], results.rx90, platform, target)
update.drive_duration(results.length[target], results.rx90, platform, target)
update.drive_frequency(results.frequency[target], platform, target)
rabi_length_frequency_signal = Routine(_acquisition, _fit, _plot, _update)
"""Rabi length with frequency tuning."""