Source code for pennylane.templates.subroutines.qrom
# Copyright 2018-2025 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule contains the template for QROM.
"""
from collections import Counter
from collections.abc import Sequence
from functools import reduce
import numpy as np
import pennylane.math as pl_math
from pennylane import ops as qp_ops
from pennylane.decomposition import (
add_decomps,
register_resources,
resource_rep,
)
from pennylane.math import ceil_log2
from pennylane.operation import Operation
from pennylane.queuing import QueuingManager, apply
from pennylane.templates.embeddings import BasisEmbedding
from pennylane.typing import TensorLike
from pennylane.wires import Wires, WiresLike
from .select import Select
def _multi_swap(wires1, wires2):
"""Apply a series of SWAP gates between two sets of wires."""
for wire1, wire2 in zip(wires1, wires2):
qp_ops.SWAP(wires=[wire1, wire2])
def _new_ops(depth, target_wires, control_wires, swap_wires, data):
with QueuingManager.stop_recording():
ops_new = [BasisEmbedding(bits, wires=target_wires) for bits in data]
ops_identity_new = ops_new + [qp_ops.I(target_wires)] * int(
2 ** len(control_wires) - len(ops_new)
)
n_columns = data.shape[0] // depth if data.shape[0] % depth == 0 else data.shape[0] // depth + 1
new_ops = []
for i in range(n_columns):
column_ops = []
for j in range(depth):
dic_map = {
ops_identity_new[i * depth + j].wires[l]: swap_wires[j * len(target_wires) + l]
for l in range(len(target_wires))
}
column_ops.append(ops_identity_new[i * depth + j].map_wires(dic_map))
new_ops.append(qp_ops.prod(*column_ops))
return new_ops
def _select_ops(
control_wires, depth, target_wires, swap_wires, data, select_work_wires
): # pylint:disable=too-many-arguments
n_control_select_wires = ceil_log2(2 ** len(control_wires) / depth)
control_select_wires = control_wires[:n_control_select_wires]
if control_select_wires:
Select(
_new_ops(depth, target_wires, control_wires, swap_wires, data),
control=control_select_wires,
work_wires=select_work_wires,
)
else:
_new_ops(depth, target_wires, control_wires, swap_wires, data)
def _swap_ops(control_wires, depth, swap_wires, target_wires):
n_control_select_wires = ceil_log2(2 ** len(control_wires) / depth)
control_swap_wires = control_wires[n_control_select_wires:]
for i in range(len(control_swap_wires) - 1, -1, -1):
for j in range(2**i - 1, -1, -1):
qp_ops.ctrl(_multi_swap, control=control_swap_wires[-i - 1])(
swap_wires[(j) * len(target_wires) : (j + 1) * len(target_wires)],
swap_wires[(j + 2**i) * len(target_wires) : (j + 2 ** (i + 1)) * len(target_wires)],
)
[docs]
class QROM(Operation):
r"""Applies the QROM operator.
This operator encodes bitstrings associated with indexes:
.. math::
\text{QROM}|i\rangle|0\rangle = |i\rangle |b_i\rangle,
where :math:`b_i` is the bitstring associated with index :math:`i`.
Args:
data (TensorLike): the data to be encoded
control_wires (WiresLike):
The register that stores the index for the entry of the classical data we want to
read.
target_wires (Sequence[int]): the wires where the bitstring is loaded
work_wires (Sequence[int]): the auxiliary wires used for the computation
clean (bool): if True, the work wires are not altered by operator, default is ``True``
.. seealso:: :class:`~.BBQRAM`, :class:`~.QROMStatePreparation`
.. note::
QRAM and QROM, though similar, have different applications and purposes. QRAM is intended
for read-and-write capabilities, where the stored data can be loaded and changed. QROM is
designed to only load stored data into a quantum register.
**Example**
In this example, the QROM operator is applied to encode the third bitstring, associated with index 2, in the target wires.
.. code-block:: python
# a list of bitstrings is defined
data = [[0, 1, 0], [1, 1, 1], [1, 1, 0], [0, 0, 0]]
dev = qp.device("default.qubit")
@qp.qnode(dev, shots=1)
def circuit():
# the third index is encoded in the control wires [0, 1]
qp.BasisEmbedding(2, wires = [0,1])
qp.QROM(data = data,
control_wires = [0,1],
target_wires = [2,3,4],
work_wires = [5,6,7])
return qp.sample(wires = [2,3,4])
>>> print(circuit())
[[1 1 0]]
.. details::
:title: Usage Details
This template takes as input three different sets of wires. The first one is ``control_wires`` which is used
to encode the desired index. Therefore, if we have :math:`m` bitstrings, we need
at least :math:`\lceil \log_2(m)\rceil` control wires.
The second set of wires is ``target_wires`` which stores the bitstrings.
For instance, if the data is ``[0, 1, 1, 0]``, we will need four target wires. Internally,
the bitstrings are encoded using the :class:`~.BasisEmbedding` template.
The ``work_wires`` are auxiliary qubits used to reduce the gate complexity of the
operator. These wires are dynamically partitioned into two sets: one for the
:class:`~.Select` block and another to facilitate parallel data loading via a
`SWAP network <https://pennylane.ai/compilation/swap-network>`__.
The template determines the depth, :math:`\lambda` (a power of 2),
based on the available ``work_wires``. Let :math:`b` be the length of the bitstrings.
The number of wires allocated to the SWAP network is :math:`k_{swap} = b \cdot (\lambda - 1)`.
The remaining wires, :math:`k_{select}`, are assigned to the :class:`~.Select` block.
To ensure the decomposition is valid, the template guarantees that
:math:`k_{select} \geq c - \log_2(\lambda) - 1`, where :math:`c` is the number of
control wires, updating the depth if needed.
The QROM template has two variants. The first one (``clean = False``) is based on [`arXiv:1812.00954 <https://arxiv.org/abs/1812.00954>`__] that alternates the state in the ``work_wires``.
The second one (``clean = True``), based on [`arXiv:1902.02134 <https://arxiv.org/abs/1902.02134>`__], solves that issue by
returning ``work_wires`` to their initial state. This technique can be applied when the ``work_wires`` are not
initialized to zero.
"""
resource_keys = {
"num_bitstrings",
"num_control_wires",
"num_target_wires",
"num_work_wires",
"clean",
}
def __init__(
self,
data: TensorLike | Sequence[str],
control_wires: WiresLike,
target_wires: WiresLike,
work_wires: WiresLike,
clean=True,
id=None,
): # pylint: disable=too-many-arguments,disable=too-many-positional-arguments
control_wires = Wires(control_wires)
target_wires = Wires(target_wires)
if isinstance(data[0], str):
data = np.array(list(map(lambda bitstring: [int(bit) for bit in bitstring], data)))
if isinstance(data, (list, tuple)):
data = pl_math.array(data)
work_wires = Wires(() if work_wires is None else work_wires)
self.hyperparameters["control_wires"] = control_wires
self.hyperparameters["target_wires"] = target_wires
self.hyperparameters["work_wires"] = work_wires
self.hyperparameters["clean"] = clean
_wires_are_traced = any(
pl_math.is_abstract(w) for ws in (control_wires, target_wires, work_wires) for w in ws
)
# Wire overlap validation must be skipped when wires are JAX tracers,
# as their concrete values are not available during tracing.
if not _wires_are_traced:
if len(work_wires) != 0:
if any(wire in work_wires for wire in control_wires):
raise ValueError("Control wires should be different from work wires.")
if any(wire in work_wires for wire in target_wires):
raise ValueError("Target wires should be different from work wires.")
if any(wire in control_wires for wire in target_wires):
raise ValueError("Target wires should be different from control wires.")
if 2 ** len(control_wires) < data.shape[0]:
raise ValueError(
f"Not enough control wires ({len(control_wires)}) for the desired number of "
+ f"data ({data.shape[0]}). At least {ceil_log2(data.shape[0])} control "
+ "wires are required."
)
if data[0].shape[0] != len(target_wires):
raise ValueError("Bitstring length must match the number of target wires.")
all_wires = target_wires + control_wires + work_wires
super().__init__(data, wires=all_wires, id=id)
def _flatten(self):
metadata = tuple((key, value) for key, value in self.hyperparameters.items())
return tuple(self.data), metadata
@property
def resource_params(self) -> dict:
return {
"num_bitstrings": self.data[0].shape[0],
"num_control_wires": len(self.hyperparameters["control_wires"]),
"num_target_wires": len(self.hyperparameters["target_wires"]),
"num_work_wires": len(self.hyperparameters["work_wires"]),
"clean": self.hyperparameters["clean"],
}
@classmethod
def _unflatten(cls, data, metadata):
hyperparams_dict = dict(metadata)
return cls(*data, **hyperparams_dict)
def __repr__(self):
return f"QROM(control_wires={self.control_wires}, target_wires={self.target_wires}, work_wires={self.work_wires}, clean={self.clean})"
[docs]
def map_wires(self, wire_map: dict):
new_dict = {
key: [wire_map.get(w, w) for w in self.hyperparameters[key]]
for key in ["target_wires", "control_wires", "work_wires"]
}
return QROM(
self.data[0],
new_dict["control_wires"],
new_dict["target_wires"],
new_dict["work_wires"],
self.clean,
)
def __copy__(self):
"""Copy this op"""
cls = self.__class__
copied_op = cls.__new__(cls)
for attr, value in vars(self).items():
setattr(copied_op, attr, value)
return copied_op
[docs]
def decomposition(self):
return self.compute_decomposition(
self.data[0],
control_wires=self.control_wires,
target_wires=self.target_wires,
work_wires=self.work_wires,
clean=self.clean,
)
[docs]
@staticmethod
def compute_decomposition(
data, control_wires, target_wires, work_wires, clean
): # pylint: disable=arguments-differ
if len(control_wires) == 0:
return [BasisEmbedding(bits, wires=target_wires) for bits in data]
with QueuingManager.stop_recording():
swap_wires = target_wires + work_wires
# number of operators we store per column (power of 2)
depth = len(swap_wires) // len(target_wires)
depth = int(2 ** np.floor(np.log2(depth)))
depth = min(depth, data.shape[0])
ops = [BasisEmbedding(bits, wires=target_wires) for bits in data]
ops_identity = ops + [qp_ops.I(target_wires)] * int(2 ** len(control_wires) - len(ops))
n_columns = len(ops) // depth + int(bool(len(ops) % depth))
new_ops = []
for i in range(n_columns):
column_ops = []
for j in range(depth):
dic_map = {
ops_identity[i * depth + j].wires[l]: swap_wires[j * len(target_wires) + l]
for l in range(len(target_wires))
}
column_ops.append(ops_identity[i * depth + j].map_wires(dic_map))
new_ops.append(qp_ops.prod(*column_ops))
# Select block
n_control_select_wires = ceil_log2(2 ** len(control_wires) / depth)
control_select_wires = control_wires[:n_control_select_wires]
select_ops = []
if control_select_wires:
select_ops += [Select(new_ops, control=control_select_wires)]
else:
select_ops = new_ops
# Swap block
control_swap_wires = control_wires[n_control_select_wires:]
swap_ops = []
for ind in range(len(control_swap_wires)):
for j in range(2**ind):
new_op = qp_ops.prod(_multi_swap)(
swap_wires[(j) * len(target_wires) : (j + 1) * len(target_wires)],
swap_wires[
(j + 2**ind)
* len(target_wires) : (j + 2 ** (ind + 1))
* len(target_wires)
],
)
swap_ops.insert(0, qp_ops.ctrl(new_op, control=control_swap_wires[-ind - 1]))
if not clean or depth == 1:
# Based on this paper (Fig 1.c): https://arxiv.org/abs/1812.00954
decomp_ops = select_ops + swap_ops
else:
# Based on this paper (Fig 4): https://arxiv.org/abs/1902.02134
adjoint_swap_ops = swap_ops[::-1]
hadamard_ops = [qp_ops.Hadamard(wires=w) for w in target_wires]
decomp_ops = 2 * (hadamard_ops + adjoint_swap_ops + select_ops + swap_ops)
if QueuingManager.recording():
for op in decomp_ops:
apply(op)
return decomp_ops
@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)
@property
def control_wires(self):
"""The control wires."""
return self.hyperparameters["control_wires"]
@property
def target_wires(self):
"""The wires where the bitstring is loaded."""
return self.hyperparameters["target_wires"]
@property
def work_wires(self):
"""The wires where the index is specified."""
return self.hyperparameters["work_wires"]
@property
def wires(self):
"""All wires involved in the operation."""
return (
self.hyperparameters["control_wires"]
+ self.hyperparameters["target_wires"]
+ self.hyperparameters["work_wires"]
)
@property
def clean(self):
"""Boolean to select the version of QROM."""
return self.hyperparameters["clean"]
def _calculate_n_select_work_wires(terms, num_control_wires, num_target_wires, num_work_wires, **_):
"""Calculates the number of work wires passes to the select block.
This utility function determines how many auxiliary wires from the total pool
should be allocated to the Select operation versus the SWAP network.
Args:
terms (int): number of bitstrings/entries in the data
num_control_wires (int): number of control wires
num_target_wires (int): number of target wires (bitstring length)
num_work_wires (int): total number of available work wires
Returns:
int: The number of work wires assigned to the Select component.
"""
if num_work_wires < num_control_wires - 1:
return num_work_wires
# Initialize available swap space using total work wires
n_swap_work_wires = num_work_wires
n_swap_wires = num_target_wires + n_swap_work_wires
# Calculate depth: how many bitstrings we can load in parallel (power of 2)
depth = n_swap_wires // num_target_wires
depth = int(2 ** np.floor(np.log2(min(depth, terms))))
# Recalculate actual wires used by SWAP and the remaining for Select
n_swap_work_wires = num_target_wires * depth - num_target_wires
n_select_work_wires = num_work_wires - n_swap_work_wires
# Adjust depth if Select doesn't have enough work wires for the required control logic
n_select_control_wires = num_control_wires - np.floor(np.log2(depth))
while n_select_work_wires < n_select_control_wires - 1:
depth = depth // 2
n_swap_work_wires = num_target_wires * depth - num_target_wires
n_select_work_wires = num_work_wires - n_swap_work_wires
n_select_control_wires = num_control_wires - np.floor(np.log2(depth))
return n_select_work_wires
def _qrom_decomposition_resources(
num_bitstrings, num_control_wires, num_target_wires, num_work_wires, clean
): # pylint: disable=too-many-branches
num_work_wires_select = _calculate_n_select_work_wires(
num_bitstrings, num_control_wires, num_target_wires, num_work_wires
)
num_work_wires_swap = num_work_wires - num_work_wires_select
if num_control_wires == 0:
return {resource_rep(BasisEmbedding, num_wires=num_target_wires): num_bitstrings}
num_swap_wires = num_target_wires + num_work_wires_swap
# number of operators we store per column (power of 2)
depth = num_swap_wires // num_target_wires
depth = int(2 ** np.floor(np.log2(depth)))
depth = min(depth, num_bitstrings)
ops = [resource_rep(BasisEmbedding, num_wires=num_target_wires) for _ in range(num_bitstrings)]
ops_identity = ops + [qp_ops.I] * int(2**num_control_wires - num_bitstrings)
n_columns = (
num_bitstrings // depth if num_bitstrings % depth == 0 else num_bitstrings // depth + 1
)
# New ops block
new_ops = Counter()
for i in range(n_columns):
column_ops = Counter()
for j in range(depth):
column_ops[ops_identity[i * depth + j]] += 1
if len(column_ops) == 1 and list(column_ops.values())[0] == 1:
new_ops[list(column_ops.keys())[0]] += 1
else:
new_ops[resource_rep(qp_ops.op_math.Prod, resources=dict(column_ops))] += 1
# Select block
num_control_select_wires = ceil_log2(2**num_control_wires / depth)
new_ops_reps = reduce(
lambda acc, lst: acc + lst, [[key for _ in range(val)] for key, val in new_ops.items()]
)
if num_control_select_wires > 0:
select_ops = {
resource_rep(
Select,
num_control_wires=num_control_select_wires,
op_reps=tuple(new_ops_reps),
partial=False,
num_work_wires=num_work_wires_select,
): 1
}
else:
select_ops = new_ops
# Swap block
num_control_swap_wires = num_control_wires - num_control_select_wires
swap_resources = Counter()
for ind in range(num_control_swap_wires):
for j in range(2**ind):
num_swaps = min(
(j + 1) * num_target_wires - (j) * num_target_wires,
(j + 2 ** (ind + 1)) * num_target_wires - (j + 2**ind) * num_target_wires,
)
if num_swaps > 1:
swap_resources[resource_rep(qp_ops.CSWAP)] += num_swaps
else:
swap_resources[resource_rep(qp_ops.CSWAP)] += 1
if not clean or depth == 1:
resources = swap_resources
resources.update(select_ops)
return resources
resources = {}
hadamard_ops = {qp_ops.Hadamard: num_target_wires}
for key, val in swap_resources.items():
swap_resources[key] = val * 2
resources.update(hadamard_ops)
resources.update(swap_resources)
resources.update(select_ops)
for key, val in resources.items():
resources[key] = val * 2
return resources
@register_resources(_qrom_decomposition_resources)
def _qrom_decomposition(
data, control_wires, target_wires, work_wires, clean, **__
): # pylint: disable=unused-argument, too-many-arguments
if len(control_wires) == 0:
BasisEmbedding(data[0, :], wires=target_wires)
n_select_work_wires = _calculate_n_select_work_wires(
len(data), len(control_wires), len(target_wires), len(work_wires)
)
select_work_wires = work_wires[:n_select_work_wires]
swap_work_wires = work_wires[n_select_work_wires:]
swap_wires = target_wires + swap_work_wires
# number of operators we store per column (power of 2)
depth = len(swap_wires) // len(target_wires)
depth = int(2 ** np.floor(np.log2(depth)))
depth = min(depth, data.shape[0])
if not clean or depth == 1:
_select_ops(control_wires, depth, target_wires, swap_wires, data, select_work_wires)
_swap_ops(control_wires, depth, swap_wires, target_wires)
else:
for _ in range(2):
for w in target_wires:
qp_ops.Hadamard(wires=w)
qp_ops.adjoint(_swap_ops, lazy=False)(control_wires, depth, swap_wires, target_wires)
_select_ops(control_wires, depth, target_wires, swap_wires, data, select_work_wires)
_swap_ops(control_wires, depth, swap_wires, target_wires)
add_decomps(QROM, _qrom_decomposition)
_modules/pennylane/templates/subroutines/qrom
Download Python script
Download Notebook
View on GitHub