"""Rabi experiment that sweeps amplitude and frequency (with probability)."""
from dataclasses import dataclass, field
from typing import Optional, Union
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from qibolab import AcquisitionType, AveragingMode, Parameter, Sweeper
from qibocal import update
from qibocal.auto.operation import Data, Parameters, QubitId, Routine
from qibocal.calibration import CalibrationPlatform
from qibocal.config import log
from qibocal.protocols.utils import (
HZ_TO_GHZ,
fallback_period,
guess_period,
readout_frequency,
table_dict,
table_html,
)
from ...result import magnitude, phase
from .amplitude_signal import RabiAmplitudeSignalResults
from .utils import fit_amplitude_function, sequence_amplitude
__all__ = [
"RabiAmplitudeSignalResults",
"RabiAmplitudeFrequencySignalParameters",
"RabiAmplitudeFreqSignalData",
"_update",
"rabi_amplitude_frequency_signal",
]
[docs]
@dataclass
class RabiAmplitudeFrequencySignalParameters(Parameters):
"""RabiAmplitudeFrequency runcard inputs."""
min_amp: float
"""Minimum amplitude."""
max_amp: float
"""Maximum amplitude."""
step_amp: float
"""Step amplitude."""
min_freq: int
"""Minimum frequency as an offset."""
max_freq: int
"""Maximum frequency as an offset."""
step_freq: int
"""Frequency to use as step for the scan."""
rx90: bool = False
"""Calibration of native pi pulse, if true calibrates pi/2 pulse"""
pulse_length: Optional[float] = None
"""RX pulse duration [ns]."""
@dataclass
class RabiAmplitudeFrequencySignalResults(RabiAmplitudeSignalResults):
"""RabiAmplitudeFrequency outputs."""
frequency: dict[QubitId, Union[float, list[float]]]
"""Drive frequency for each qubit."""
rx90: bool
"""Pi or Pi_half calibration"""
RabiAmpFreqSignalType = np.dtype(
[
("amp", np.float64),
("freq", np.float64),
("signal", np.float64),
("phase", np.float64),
]
)
"""Custom dtype for rabi amplitude."""
[docs]
@dataclass
class RabiAmplitudeFreqSignalData(Data):
"""RabiAmplitudeFreqSignal data acquisition."""
rx90: bool
"""Pi or Pi_half calibration"""
durations: dict[QubitId, float] = field(default_factory=dict)
"""Pulse durations provided by the user."""
data: dict[QubitId, npt.NDArray[RabiAmpFreqSignalType]] = field(
default_factory=dict
)
"""Raw data acquired."""
[docs]
def register_qubit(self, qubit, freq, amp, signal, phase):
"""Store output for single qubit."""
size = len(freq) * len(amp)
frequency, amplitude = np.meshgrid(freq, amp)
data = np.empty(size, dtype=RabiAmpFreqSignalType)
data["freq"] = frequency.ravel()
data["amp"] = amplitude.ravel()
data["signal"] = signal.ravel()
data["phase"] = phase.ravel()
self.data[qubit] = np.rec.array(data)
[docs]
def amplitudes(self, qubit):
"""Unique qubit amplitudes."""
return np.unique(self[qubit].amp)
[docs]
def frequencies(self, qubit):
"""Unique qubit frequency."""
return np.unique(self[qubit].freq)
def _acquisition(
params: RabiAmplitudeFrequencySignalParameters,
platform: CalibrationPlatform,
targets: list[QubitId],
) -> RabiAmplitudeFreqSignalData:
"""Data acquisition for Rabi experiment sweeping amplitude."""
sequence, qd_pulses, ro_pulses, durations = sequence_amplitude(
targets, params, platform, params.rx90
)
frequency_range = np.arange(
params.min_freq,
params.max_freq,
params.step_freq,
)
freq_sweepers = {}
for qubit in targets:
channel = platform.qubits[qubit].drive
freq_sweepers[qubit] = Sweeper(
parameter=Parameter.frequency,
values=platform.config(channel).frequency + frequency_range,
channels=[channel],
)
amp_sweeper = Sweeper(
parameter=Parameter.amplitude,
range=(params.min_amp, params.max_amp, params.step_amp),
pulses=[qd_pulses[qubit] for qubit in targets],
)
data = RabiAmplitudeFreqSignalData(durations=durations, rx90=params.rx90)
results = platform.execute(
[sequence],
[[amp_sweeper], [freq_sweepers[q] for q in targets]],
updates=[
{platform.qubits[q].probe: {"frequency": readout_frequency(q, platform)}}
for q in targets
],
nshots=params.nshots,
relaxation_time=params.relaxation_time,
acquisition_type=AcquisitionType.INTEGRATION,
averaging_mode=AveragingMode.CYCLIC,
)
for qubit in targets:
result = results[ro_pulses[qubit].id]
data.register_qubit(
qubit=qubit,
freq=freq_sweepers[qubit].values,
amp=amp_sweeper.values,
signal=magnitude(result),
phase=phase(result),
)
return data
def _fit(data: RabiAmplitudeFreqSignalData) -> RabiAmplitudeFrequencySignalResults:
"""Do not perform any fitting procedure."""
fitted_frequencies = {}
fitted_amplitudes = {}
fitted_parameters = {}
for qubit in data.data:
amps = data.amplitudes(qubit)
freqs = data.frequencies(qubit)
signal = data[qubit].signal
signal_matrix = signal.reshape(len(amps), len(freqs)).T
# guess optimal frequency maximizing oscillation amplitude
index = np.argmax([max(x) - min(x) for x in signal_matrix])
frequency = freqs[index]
# Guessing period using fourier transform
y = signal_matrix[index]
y_min = np.min(y)
y_max = np.max(y)
x_min = np.min(amps)
x_max = np.max(amps)
x = (amps - x_min) / (x_max - x_min)
y = (y - y_min) / (y_max - y_min)
period = fallback_period(guess_period(x, y))
pguess = [0.5, 0.5, period, 0]
try:
popt, _, pi_pulse_parameter = fit_amplitude_function(
x,
y,
pguess,
signal=True,
x_limits=(x_min, x_max),
y_limits=(y_min, y_max),
)
fitted_frequencies[qubit] = frequency
fitted_amplitudes[qubit] = pi_pulse_parameter
fitted_parameters[qubit] = popt
except Exception as e:
log.warning(f"Rabi fit failed for qubit {qubit} due to {e}.")
return RabiAmplitudeFrequencySignalResults(
amplitude=fitted_amplitudes,
length=data.durations,
fitted_parameters=fitted_parameters,
frequency=fitted_frequencies,
rx90=data.rx90,
)
def _plot(
data: RabiAmplitudeFreqSignalData,
target: QubitId,
fit: RabiAmplitudeFrequencySignalResults = None,
):
"""Plotting function for RabiAmplitudeFrequency."""
figures = []
fitting_report = ""
fig = make_subplots(
rows=1,
cols=2,
horizontal_spacing=0.1,
vertical_spacing=0.2,
subplot_titles=(
"Signal [a.u.]",
"Phase [rad]",
),
)
qubit_data = data[target]
frequencies = qubit_data.freq * HZ_TO_GHZ
amplitudes = qubit_data.amp
fig.add_trace(
go.Heatmap(
x=amplitudes,
y=frequencies,
z=qubit_data.signal,
colorbar_x=0.46,
),
row=1,
col=1,
)
fig.add_trace(
go.Heatmap(
x=amplitudes,
y=frequencies,
z=qubit_data.phase,
colorbar_x=1.01,
),
row=1,
col=2,
)
fig.update_xaxes(title_text="Amplitude [a.u.]", row=1, col=1)
fig.update_xaxes(title_text="Amplitude [a.u.]", row=1, col=2)
fig.update_yaxes(title_text="Frequency [GHz]", row=1, col=1)
figures.append(fig)
if fit is not None:
fig.add_trace(
go.Scatter(
x=[min(amplitudes), max(amplitudes)],
y=[fit.frequency[target] * HZ_TO_GHZ] * 2,
mode="lines",
line={"color": "white", "width": 4, "dash": "dash"},
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=[min(amplitudes), max(amplitudes)],
y=[fit.frequency[target] * HZ_TO_GHZ] * 2,
mode="lines",
line={"color": "white", "width": 4, "dash": "dash"},
),
row=1,
col=2,
)
pulse_name = "Pi-half pulse" if data.rx90 else "Pi pulse"
fitting_report = table_html(
table_dict(
target,
["Optimal rabi frequency", f"{pulse_name} amplitude"],
[
fit.frequency[target],
f"{fit.amplitude[target]:.6f} [a.u]",
],
)
)
fig.update_layout(
showlegend=False,
legend={"orientation": "h"},
)
return figures, fitting_report
[docs]
def _update(
results: RabiAmplitudeFrequencySignalResults,
platform: CalibrationPlatform,
target: QubitId,
):
update.drive_duration(results.length[target], results.rx90, platform, target)
update.drive_amplitude(results.amplitude[target], results.rx90, platform, target)
update.drive_frequency(results.frequency[target], platform, target)
rabi_amplitude_frequency_signal = Routine(_acquisition, _fit, _plot, _update)
"""Rabi amplitude with frequency tuning."""