Source code for qiboml.differentiations.psr

import math
from dataclasses import dataclass
from typing import Tuple
from numpy.typing import ArrayLike

from qibo import Circuit
from qibo.config import raise_error
from qiboml.differentiations.abstract import Differentiation


[docs]@dataclass class PSR(Differentiation): """ The Parameter Shift Rule differentiator. Especially useful for non analytical derivative calculation which, thus, makes it hardware compatible. """
[docs] def _on_build(self): if math.prod(self.decoding.output_shape) != 1: raise_error( RuntimeError, "PSR differentiation works only for decoders with scalar outpus, i.e. expectation values.", )
[docs] def evaluate(self, parameters: ArrayLike, wrt_inputs: bool = False): """ Evaluate the gradient of the quantum circuit w.r.t its parameters, i.e. its rotation angles. Args: parameters (List[ArrayLike]): the parameters at which to evaluate the model, and thus the derivative. wrt_inputs (bool): whether to calculate the derivative with respect to, also, inputs (i.e. encoding angles) or not, by default ``False``. Returns: (ArrayLike): the calculated jacobian. """ assert self._is_built, "Call .build(circuit, decoding) before evaluate()." circuits = [] eigvals = [] for i in range(self.nparams(wrt_inputs)): forward, backward, eigval = self.one_parameter_shift( parameters=parameters, parameter_index=i, wrt_inputs=wrt_inputs ) circuits.extend([forward, backward]) eigvals.append(eigval) # TODO: parallelize when decoding will support # the parallel execution of multiple circuits expvals = self.backend.cast( [self.decoding(circ) for circ in circuits], dtype=parameters.dtype ) forwards = expvals[::2] backwards = expvals[1::2] eigvals = self.backend.reshape( self.backend.cast(eigvals, dtype=parameters.dtype), forwards.shape ) return (forwards - backwards) * eigvals
[docs] def one_parameter_shift( self, parameters: ArrayLike, parameter_index: int, wrt_inputs: bool = False ) -> Tuple[Circuit, Circuit, float]: """Compute one derivative of the decoding strategy w.r.t. a target parameter.""" target_gates = ( self.circuit.parametrized_gates if wrt_inputs else self.circuit.trainable_gates ) gate = target_gates[parameter_index] generator_eigenval = gate.generator_eigenvalue() s = math.pi / (4 * generator_eigenval) tmp_params = self.backend.cast(parameters, copy=True, dtype=parameters[0].dtype) tmp_params = self.shift_parameter(tmp_params, parameter_index, s, self.backend) forward = self.circuit.copy(True) target_gates = ( forward.parametrized_gates if wrt_inputs else forward.trainable_gates ) # forward.set_parameters(tmp_params) for g, p in zip(target_gates, tmp_params): g.parameters = p forward._final_state = None tmp_params = self.backend.cast(parameters, copy=True, dtype=parameters[0].dtype) tmp_params = self.shift_parameter(tmp_params, parameter_index, -s, self.backend) backward = self.circuit.copy(True) target_gates = ( backward.parametrized_gates if wrt_inputs else backward.trainable_gates ) # backward.set_parameters(tmp_params) for g, p in zip(target_gates, tmp_params): g.parameters = p backward._final_state = None return forward, backward, generator_eigenval
[docs] @staticmethod def shift_parameter(parameters, i, epsilon, backend): if backend.platform == "tensorflow": return backend.engine.stack( [parameters[j] + int(i == j) * epsilon for j in range(len(parameters))] ) if backend.platform == "jax": parameters = parameters.at[i].set(parameters[i] + epsilon) else: parameters[i] = parameters[i] + epsilon return parameters