Experimental noise model and simulation module#1848
Conversation
mhodson-rigetti
left a comment
There was a problem hiding this comment.
I have not yet reviewed simulation/*, transform.py, or the tests (besides one thing I saw while cross-checking the code). Partial review submitted!
| [tool.poetry.dependencies] | ||
| # TODO(#1816): Loosen this bound once we've resolved support for Python 3.13+. | ||
| python = "^3.9,<3.13" | ||
| python = ">=3.11, <3.13" |
There was a problem hiding this comment.
Nominally this is a breaking change but the versioning strategy here has been that retiring unsupported Python versions is allowed without a major version increment. Python 3.8 support was removed in 4.11 (June 2024). Python 3.9 is already EOL, so it should be dropped now. However, Python 3.10 is still in maintenance until October 2026. Can you do some due diligence to see if 3.10 can be accepted as a minimum version here?
The reason to upgrade should also be captured in a separate issue, even if its implementation is linked to this PR.
There was a problem hiding this comment.
- The python bump is because quax is >=3.11
- quax is >= 3.11 because it requires jax >= 0.8.2
- Jax 0.8.2 is >=3.11 because it follows SPEC-0
There was a problem hiding this comment.
That seems reasonable to me.
| CustomGateMap = Dict[str, Union[qx.Unitary, Callable[..., qx.Unitary]]] | ||
|
|
||
|
|
||
| def _parse_quil_instruction(quil_str: str) -> Gate | Measurement | Reset: |
There was a problem hiding this comment.
I was wondering what code paths would require string parsing for individual textual Quil instructions. I can't find any call sites for this method. Since it is private, it is also not meant to be used outside. Remove?
There was a problem hiding this comment.
Left from a previous iteration, will remove
There was a problem hiding this comment.
Correction it's used in the serialization of the noise model.
| if len(pauli) != num_qubits: | ||
| raise ValueError(f"Pauli term '{pauli}' has length {len(pauli)}, expected {num_qubits}.") | ||
|
|
||
| all_pauli_terms = list(map(lambda term: "".join(term), itertools.product(*["IXYZ" for _ in range(num_qubits)]))) |
There was a problem hiding this comment.
Why force a list here? I think it can be left as a generator and use O(1) storage given the single usage site.
| return False | ||
| return bool(jnp.isclose(float(qx.process_fidelity(self.process, other.process)), 1.0, atol=1e-9)) | ||
|
|
||
| def __hash__(self) -> int: |
There was a problem hiding this comment.
Well, spent at least an hour here.
This is inconsistent with __eq__. The hash and the equality check should agree on the semantics.
But I know what it really happening. You want the frozen dataclass. But, Jax arrays are not hashable. If you didn't override the default behavior, any attempt to use the channel as a key would fail with an unhashable type exception. It seems like it should, since Jax arrays are immutable, but they are not hashable, which is a design decision they made related to the JIT compiler. And equality checks on channels (including their process) does make general sense.
Still, the non-standard treatment is buried in this method and not documented to the user. If the user came along with some Python set() of CZ(0, 1) channels which they had prepared to sweep over some wide range of noise processes, the Python set would have hash collisions on every single insert and lookup, degrading the container performance from O(log N) insert and O(N.log N) full traversal to O(N.log N) and O(N^2.log N) respectively. One day they might notice and get quite grumpy.
The codified way to document that you are not hashing the process is to use field(hash=False), and probably say why in the property documentation. However, they still need to read the docs to not fall foul of the aforementioned scenario.
I would personally hash the Jax array. This means you do need to override this dunder method, but you should take inspiration maybe from the below to get some (one-time) binary blob as a key.
You should also include the target process (unitary) which should be resolved prior. Mimic the expected behavior of a dataclass but overcome the non-hashability issue minimally. I don't think it will impact you as you are not keying on this anyway.
There was a problem hiding this comment.
I decided to remove hash, it's not necessary.
| """A depolarizing channel on a qutrit gate mixes the state.""" | ||
| inst = Gate("TX", [], [0]) | ||
| channel = Channel.from_gate_fidelity(inst=inst, fidelity=0.8) | ||
| noise_model = NoiseModel(channels=frozenset([channel])) |
There was a problem hiding this comment.
I think this is the wrong type. How is it not picked up in type checking?
| """Trajectory simulation with qutrit depolarizing noise.""" | ||
| inst = Gate("TX", [], [0]) | ||
| channel = Channel.from_gate_fidelity(inst=inst, fidelity=0.9) | ||
| noise_model = NoiseModel(channels=frozenset([channel])) |
There was a problem hiding this comment.
I think this is the wrong type. How is it not picked up in type checking?
| # Use a Pauli channel: p_I = 1-p, p_X = p, p_Y = 0, p_Z = 0 | ||
| pauli_probs = {"X": p_error} | ||
| channel = Channel.from_pauli_noise(inst=inst, pauli_noise=pauli_probs) | ||
| return NoiseModel(channels=frozenset([channel])) |
There was a problem hiding this comment.
I think this is the wrong type. How is it not picked up in type checking?
| """Create a noise model with depolarizing noise on X gate.""" | ||
| inst = X(qubit) | ||
| channel = Channel.from_gate_fidelity(inst=inst, fidelity=fidelity) | ||
| return NoiseModel(channels=frozenset([channel])) |
There was a problem hiding this comment.
I think this is the wrong type. How is it not picked up in type checking?
| """With noise_model provided, runs a single trajectory.""" | ||
| inst = X(0) | ||
| channel = Channel.from_gate_fidelity(inst=inst, fidelity=1.0) | ||
| noise_model = NoiseModel(channels=frozenset([channel])) |
There was a problem hiding this comment.
I think this is the wrong type. How is it not picked up in type checking?
|
Additional notes on the PR description (which overall was really punchy and well thought out!): Regarding the Jax dependency, Jax could be specified as an extra, and then the parts of the new functionality (principally, the simulators) that require Jax could self-enable if the dependency is detected as fulfilled. "First, the noise.py file is promoted to a module." I think you mean you promoted it to a package? Regarding "by replacing the reference numpy simulators with a jax-accelerated, highly flexible simulation framework", did you really replace them? I thought it was non-breaking wrt the previous existing simulators? Regarding the mapping Regarding "the compressor an the calculator", think there's a typo -- "and"? |
The noise model also depends on jax. We could keep jax as an experimental extra for the time being and only add it to the dependencies when the noise model and simulator becomes the default. However, it may be preferable to add the dependency now to identify potential issues earlier.
It used to be a file, now it's a folder.
Yes, non-breaking that statement is a bit forward-looking. Regarding the mapping
Regarding the Jax dependency, Jax could be specified as an extra, and then the parts of the new functionality (principally, the simulators) that require Jax could self-enable if the dependency is detected as fulfilled.
Regarding "by replacing the reference numpy simulators with a jax-accelerated, highly flexible simulation framework", did you really replace them? I thought it was non-breaking wrt the previous existing simulators?
Yes, its a replacement operation now. |
mhodson-rigetti
left a comment
There was a problem hiding this comment.
Review complete. I'd be interested to check on the test coverage for your new modules once the format/style/type checks and unit tests are all green in CI.
| matplotlib = {version = "^3.9.0", optional = true} | ||
| matplotlib-inline = {version = "^0.1.7", optional = true} | ||
| seaborn = {version = "^0.13.2", optional = true} | ||
| rigetti-quax = ">=0.5.3" |
There was a problem hiding this comment.
In this PR you are referring directly to jax and plotly. These should be stated as explicit dependencies; I think you are currently relying on a transitive dependency chain here.
|
|
||
|
|
||
| # ────────────────────────────────────────────────────────── | ||
| # Re-export quax-based noise model classes (lazy to avoid circular imports) |
There was a problem hiding this comment.
I see the issue, or part of it -- this noise.py module is unreachable code because it is shadowed by noise/__init__.py. Looks like you have uplifted the content of this module to noise/_legacy_noise.py?
Can you delete this file?
| CycleChannel, | ||
| MeasurementChannel, | ||
| ResetChannel, | ||
| get_custom_gates_from_program, |
There was a problem hiding this comment.
Not referenced in this module. Overall, I see that poetry run make check-all is failing all three checks. Still some work to do there?
| DensityMatrixOp = Tuple[qx.SuperOp, Tuple[int, ...]] | ||
|
|
||
| # Custom gate definitions. | ||
| CustomGateMap = dict |
There was a problem hiding this comment.
This is defined as CustomGateMap = Dict[str, Union[qx.Unitary, Callable[..., qx.Unitary]]] in _channels.py. Does your module dependency hierarchy allow you to import?
| def linearize(memory_map: MemoryMap) -> Array: | ||
| if not param_refs: | ||
| return jnp.array([], dtype=float) | ||
| values = [float(memory_map[name][offset]) for name, offset in param_refs] |
There was a problem hiding this comment.
You don't seem to check the declared type of the memory regions, but assume here it's convertible to float. Should you defend against non-real types? I guess a BIT or INTEGER will, in most cases, convert. Not sure about OCTET.
|
The PR now touches many files due to upgrading the style to python 3.11 |
Description
This PR adds new noise modeling and simulation utilities. It is intended to replace the existing reference simulators and the existing NoiseModel. The proposed migration path is to add these utilities as experimental private modules, while marking the existing simulation and noise utilities as deprecated. In the next major version of pyquil, the deprecated utilities will be replaced by the new utilities.
Dependencies
This PR adds the new dependency
rigetti-quax. This package provides a wide set of utilities for representing quantum objects including states, gates, measurements and noise channels. It's based onjax, which delivers high performance but addsjaxas a transitive dependency to pyquil.jaxis a relatively mature package but somewhat larger than existing dependencies (Around 80MB compared to scipy at 40MB). So this is worth some discussion.Noise
Below, I discuss the new noise modeling classes.
First, the
noise.pyfile is promoted to a module. This is non-breaking as the same functions continue to be exported frompyquil.noise. Inside the module are several new private files. When these are promoted to the public api, they will be exported frompyquil.noise.New Noise Model System (
pyquil/noise/_noise_model.py)We introduce a frozen-dataclass-based
NoiseModelcontainer that collects per-instruction noise channels. The role of the noise model is to store a collection of channels, which together make up a device noise model. Channels are associated with instructions in the program and we can get the channel for a particular instruction viaNoiseModel.get_channel(inst).NoiseModelLike— Protocol defining theget_channelinterface for custom implementations.NoiseModel— The canonical use case. Accepts anIterableof channels (list, tuple, set, generator), stores them as an immutableTuple. Supports+for composition andget_channel(inst)for lookup.DepolarizingNoiseModel— Convenience model returning a depolarizing channel for any gate.CompositeNoiseModel— Chains multipleNoiseModelLikeobjects, returning the first non-None channel.We can also construct noise models from the instruction set architecture, giving users a straightforward path to a device-realistic noise model.
Quax-Backed Noise Channels (
pyquil/noise/_channels.py)Four frozen dataclasses representing physical noise processes, each backed by a
quaxoperator (SuperOp,KrausMap, orQuantumInstrument): The main role of the channel is to associate a superoperator with a particular instruction. For example, the instructionCZ 0 1indicates that qubits 0 and 1 undergo the unitary CZ operation (via the quil spec). The channel associates that instruction to a superoperator, which is a higher dimensional representation of the operation which includes the effects of noise.ChannelMeasurementChannelQuantumInstrument)ResetChannelCycleChannelSimulation
A noise model is not very interesting on it's own, we want a simulator which can simulate it's effects on programs. The existing pyquil simulators are limited in various ways. The numpy reference simulators have poor performance, while the QVM has limited options for representing noise. Here we attempt to solve both problems by replacing the reference numpy simulators with a jax-accelerated, highly flexible simulation framework.
Three New Simulators (
pyquil/simulation/_simulator.py)A simulator object is constructed from a program. We have a simulation hierarchy based on the capabilities and scale.
PureStateVectorSimulatorjit+gradDensityMatrixSimulatorjit+gradTrajectorySimulatorjit(per-batch)Linearizer / DAG / Resolver / Compressor Pipeline (
pyquil/simulation/_resolver.py)The reason that the simulator is an object rather than a simple function is because efficient simulation requires the construction of several related closures. We call these the linearizer, the resolver, the compressor an the calculator.
Linearizer— ConvertsMemoryMap→ flat JAX parameter vector.Resolver— Converts parameter vector →(operator, subsystem)pairs, consulting the noise model.Compressor— Greedy edge contraction on the program DAG, merging adjacent operators up tomax_subsystem_sizequbits.Calculator— Compute the final state from the resolved and compressed operators.The functions are all tightly coupled, each closure is constructed based on the program structure. They can also all be compiled with
jax.jit. This is important to achieving good performance. The role of the class is therefore to combine all these objects into a logical construct.Usage