qp.capture.register_custom_staging_rule

register_custom_staging_rule(primitive, get_jaxpr_from_params, setup_env=<function _default_setup_env>)[source]

Register a custom staging rule for a higher order primitive that can handle dynamic shapes.

Parameters:
  • primitive (jax.extend.core.Primitive) – a jax primitive we want to register a custom staging rule for

  • get_jaxpr_from_params (Callable[[dict], "jax.extend.core.Jaxpr"]) – A function that takes in the equation’s params and returns a target jaxpr

  • setup_env (Callable) – A function that setups a dictionary for mapping from the inner jaxpr variables to the tracers that are inputs to the equation. The inputs are the tracers that are inputs to the equation and the params for the equation. By default, returns an empty dictionary.

For example, the cond_prim will request its custom staging rule like:

register_custom_staging_rule(cond_prim, lambda params: params['jaxpr_branches'][0])

cond cannot support setup_env, because different branches may have different dynamic shapes.

Compare this to while_loop_prim:

def setup_env(tracers, params):
    tracers = tracers[slice(*params['args_slice'])] + tracers[slice(*params['consts_slice'])]
    vars = params['jaxpr_body_fn'].invars + params['jaxpr_body_fn'].constvars
    return dict(zip(vars, tracers), strict=True)

register_custom_staging_rule(
    while_loop_prim,
    get_jaxpr_from_params=lambda params: params["jaxpr_body_fn"],
    matching_eqn_inputs=matching_eqn_inputs,
)

for_loop_prim gets more complicated, as we have to slice out the start, stop, step from the tracers, and the loop index for the jaxpr_invars.

Contents

Using PennyLane

Release news

Development

API

Internals