Skip to content

Commit

Permalink
Fix Python->Rust Param conversion
Browse files Browse the repository at this point in the history
This commit adds a custom implementation of the FromPyObject trait for
the Param enum. Previously, the Param trait derived it's impl of the
trait, but this logic wasn't perfect. In cases whern a
ParameterExpression was effectively a constant (such as `0 * x`) the
trait's attempt to coerce to a float first would result in those
ParameterExpressions being dropped from the circuit at insertion time.
This was a change in behavior from before having gates in Rust as the
parameters would disappear from the circuit at insertion time instead of
at bind time. This commit fixes this by having a custom impl for
FromPyObject that first tries to figure out if the parameter is a
ParameterExpression (or a QuantumCircuit) by using a Python isinstance()
check, then tries to extract it as a float, and finally stores a
non-parameter object; which is a new variant in the Param enum. This
new variant also lets us simplify the logic around adding gates to the
parameter table as we're able to know ahead of time which gate
parameters are `ParameterExpression`s and which are other objects (and
don't need to be tracked in the parameter table.

Additionally this commit tweaks two tests, the first is
test.python.circuit.library.test_nlocal.TestNLocal.test_parameters_setter
which was adjusted in the previous commit to workaround the bug fixed
by this commit. The second is test.python.circuit.test_parameters which
was testing that a bound ParameterExpression with a value of 0 defaults
to an int which was a side effect of passing an int input to symengine
for the bind value and not part of the api and didn't need to be
checked. This assertion was removed from the test because the rust
representation is only storing f64 values for the numeric parameters
and it is never an int after binding from the Python perspective it
isn't any different to have float(0) and int(0) unless you explicit
isinstance check like the test previously was.
  • Loading branch information
mtreinish committed May 25, 2024
1 parent ad3e3c5 commit 37c0780
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 44 deletions.
43 changes: 3 additions & 40 deletions crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,33 +280,11 @@ impl CircuitData {
let mut new_param = false;
let inst_params = &self.data[inst_index].params;
if let Some(raw_params) = inst_params {
let param_mod =
PyModule::import_bound(py, intern!(py, "qiskit.circuit.parameterexpression"))?;
let param_class = param_mod.getattr(intern!(py, "ParameterExpression"))?;
let circuit_mod =
PyModule::import_bound(py, intern!(py, "qiskit.circuit.quantumcircuit"))?;
let circuit_class = circuit_mod.getattr(intern!(py, "QuantumCircuit"))?;
let params: Vec<(usize, PyObject)> = raw_params
.iter()
.enumerate()
.filter_map(|(idx, x)| match x {
Param::ParameterExpression(param_obj) => {
if param_obj
.clone_ref(py)
.into_bound(py)
.is_instance(&param_class)
.unwrap()
|| param_obj
.clone_ref(py)
.into_bound(py)
.is_instance(&circuit_class)
.unwrap()
{
Some((idx, param_obj.clone_ref(py)))
} else {
None
}
}
Param::ParameterExpression(param_obj) => Some((idx, param_obj.clone_ref(py))),
_ => None,
})
.collect();
Expand Down Expand Up @@ -370,23 +348,7 @@ impl CircuitData {
.iter()
.enumerate()
.filter_map(|(idx, x)| match x {
Param::ParameterExpression(param_obj) => {
let param_mod =
PyModule::import_bound(py, "qiskit.circuit.parameterexpression")
.ok()?;
let param_class =
param_mod.getattr(intern!(py, "ParameterExpression")).ok()?;
if param_obj
.clone_ref(py)
.into_bound(py)
.is_instance(&param_class)
.unwrap()
{
Some((idx, param_obj.clone_ref(py)))
} else {
None
}
}
Param::ParameterExpression(param_obj) => Some((idx, param_obj.clone_ref(py))),
_ => None,
})
.collect();
Expand Down Expand Up @@ -1131,6 +1093,7 @@ impl CircuitData {
}
self.global_phase = Param::ParameterExpression(angle);
}
Param::Obj(_) => return Err(PyValueError::new_err("Invalid type for global phase")),
};
Ok(())
}
Expand Down
11 changes: 11 additions & 0 deletions crates/circuit/src/circuit_instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,17 @@ impl CircuitInstruction {
break;
}
}
Param::Obj(val_a) => {
if let Param::Obj(val_b) = param_b {
if !val_a.bind(py).eq(val_b.bind(py))? {
out = false;
break;
}
} else {
out = false;
break;
}
}
}
}
out
Expand Down
33 changes: 31 additions & 2 deletions crates/circuit/src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,41 @@ pub trait Operation {
fn standard_gate(&self) -> Option<StandardGate>;
}

#[derive(FromPyObject, Clone, Debug)]
#[derive(Clone, Debug)]
pub enum Param {
Float(f64),
ParameterExpression(PyObject),
Float(f64),
Obj(PyObject),
}

impl<'py> FromPyObject<'py> for Param {
fn extract_bound(b: &Bound<'py, PyAny>) -> Result<Self, PyErr> {
let param_mod = PyModule::import_bound(
b.py(),
intern!(b.py(), "qiskit.circuit.parameterexpression"),
)?;
let param_class = param_mod.getattr(intern!(b.py(), "ParameterExpression"))?;
let circuit_mod =
PyModule::import_bound(b.py(), intern!(b.py(), "qiskit.circuit.quantumcircuit"))?;
let circuit_class = circuit_mod.getattr(intern!(b.py(), "QuantumCircuit"))?;
Ok(
if b.is_instance(&param_class)? || b.is_instance(&circuit_class)? {
Param::ParameterExpression(b.clone().unbind())
} else if let Ok(val) = b.extract::<f64>() {
Param::Float(val)
} else {
Param::Obj(b.clone().unbind())
},
)
}
}

impl IntoPy<PyObject> for Param {
fn into_py(self, py: Python) -> PyObject {
match &self {
Self::Float(val) => val.to_object(py),
Self::ParameterExpression(val) => val.clone_ref(py),
Self::Obj(val) => val.clone_ref(py),
}
}
}
Expand All @@ -136,6 +160,7 @@ impl ToPyObject for Param {
match self {
Self::Float(val) => val.to_object(py),
Self::ParameterExpression(val) => val.clone_ref(py),
Self::Obj(val) => val.clone_ref(py),
}
}
}
Expand Down Expand Up @@ -328,14 +353,17 @@ impl Operation for StandardGate {
let theta: Option<f64> = match params[0] {
Param::Float(val) => Some(val),
Param::ParameterExpression(_) => None,
Param::Obj(_) => None,
};
let phi: Option<f64> = match params[1] {
Param::Float(val) => Some(val),
Param::ParameterExpression(_) => None,
Param::Obj(_) => None,
};
let lam: Option<f64> = match params[2] {
Param::Float(val) => Some(val),
Param::ParameterExpression(_) => None,
Param::Obj(_) => None,
};
// If let chains as needed here are unstable ignore clippy to
// workaround. Upstream rust tracking issue:
Expand Down Expand Up @@ -471,6 +499,7 @@ impl Operation for StandardGate {
)
.expect("Unexpected Qiskit python bug"),
),
Param::Obj(_) => unreachable!(),
}
}),
Self::ECRGate => todo!("Add when we have RZX"),
Expand Down
2 changes: 1 addition & 1 deletion test/python/circuit/library/test_nlocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_parameters_setter(self, params):
initial_params = ParameterVector("p", length=6)
circuit = QuantumCircuit(1)
for i, initial_param in enumerate(initial_params):
circuit.ry((i + 1) * initial_param, 0)
circuit.ry(i * initial_param, 0)

# create an NLocal from the circuit and set the new parameters
nlocal = NLocal(1, entanglement_blocks=circuit, reps=1)
Expand Down
1 change: 0 additions & 1 deletion test/python/circuit/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,6 @@ def test_expression_partial_binding_zero(self):
fbqc = pqc.assign_parameters({phi: 1})

self.assertEqual(fbqc.parameters, set())
self.assertIsInstance(fbqc.data[0].operation.params[0], int)
self.assertEqual(float(fbqc.data[0].operation.params[0]), 0)

def test_raise_if_assigning_params_not_in_circuit(self):
Expand Down

0 comments on commit 37c0780

Please sign in to comment.