Skip to content

Commit

Permalink
fix: close #155, close #157
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhaoyilunnn committed Apr 7, 2024
1 parent 2f8c0d7 commit 4881570
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 93 deletions.
60 changes: 34 additions & 26 deletions quafu/results/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from collections import OrderedDict

import matplotlib.pyplot as plt

from ..algorithms.hamiltonian import Hamiltonian
from ..utils.basis import *


class Result(object):
"""Basis class for quantum results"""

Expand Down Expand Up @@ -79,12 +81,12 @@ def plot_probabilities(self):


class SimuResult(Result):
def __init__(self, res_info : dict):
def __init__(self, res_info: dict):
"""
Args:
res_info: data from simulator
"""
self._meta_data = res_info
self._meta_data = res_info
counts = res_info["counts"]
if counts:
if isinstance(list(counts.keys())[0], int):
Expand All @@ -97,7 +99,7 @@ def __init__(self, res_info : dict):

self._probabilities = []

def __getitem__(self, key:str):
def __getitem__(self, key: str):
"""
Get meta_data of simulate results.
Args:
Expand All @@ -111,65 +113,70 @@ def get_statevector(self):
try:
return self["statevector"]
except KeyError:
raise KeyError("no statevector saved from %s simulator" %self["simulator"])
raise KeyError("no statevector saved from %s simulator" % self["simulator"])

@property
def probabilities(self):
if not self._probabilities:
if len(self._probabilities) == 0:
self.calc_probabilities()
return self._probabilities

@property
def counts(self):
return self["counts"]
def calc_probabilities(self):
psi = self.get_statevector()
return self["counts"]

def calc_probabilities(self):
psi = self.get_statevector()
num = self["qbitnum"]
measures = list(self["measures"].keys())
values_tmp = list(self["measures"].values())
values = np.argsort(values_tmp)


from ..simulators.default_simulator import permutebits, ptrace

psi = permutebits(psi, range(num)[::-1])
if measures:
self._probabilities = ptrace(psi, measures)
self._probabilities = permutebits(self._probabilities, values)
else:
self._probabilities = np.abs(psi)**2


def plot_probabilities(self, full: bool=False, reverse_basis: bool=False, sort:bool=None, from_counts=False):
self._probabilities = np.abs(psi) ** 2

def plot_probabilities(
self,
full: bool = False,
reverse_basis: bool = False,
sort: bool = None,
from_counts=False,
):
"""
Plot the probabilites of measured qubits
"""
import matplotlib.pyplot as plt

if from_counts:
counts = self._meta_data["counts"]
total_counts = sum(counts.values())
probabilities = {}
for key in self._meta_data["counts"]:
probabilities[key] = counts[key]/total_counts
probabilities[key] = counts[key] / total_counts


bitstrs = list(probabilities.keys())
probs = list(probabilities.values())
plt.figure()
plt.bar(range(len(probs)), probs, tick_label = bitstrs)
plt.bar(range(len(probs)), probs, tick_label=bitstrs)
plt.xticks(rotation=70)
plt.ylabel("probabilities")

elif len(self.get_statevector()) > 0:
if not full:
inds = np.where(self.probabilities > 1e-14)[0]
probs = self.probabilities[inds]

measures = self._meta_data["measures"]
num = len(measures) if measures else self["qbitnum"]
basis=np.array([bin(i)[2:].zfill(num) for i in inds])
basis = np.array([bin(i)[2:].zfill(num) for i in inds])
if reverse_basis:
basis=np.array([bin(i)[2:].zfill(num)[::-1] for i in inds])
basis = np.array([bin(i)[2:].zfill(num)[::-1] for i in inds])

if sort == "ascend":
orders = np.argsort(probs)
Expand All @@ -187,7 +194,6 @@ def plot_probabilities(self, full: bool=False, reverse_basis: bool=False, sort:b
else:
raise ValueError("No data for ploting")


def calc_density_matrix(self):
psi = self.get_statevector()
num = self["qbitnum"]
Expand All @@ -198,12 +204,14 @@ def calc_density_matrix(self):
measures = list(range(num))
values = list(range(num))
from ..simulators.default_simulator import permutebits, ptrace

psi = permutebits(psi, range(num)[::-1])
rho = ptrace(psi, measures, diag=False)
rho = permutebits(rho, values)
return rho

#TODO:These should merge to paulis

# TODO:These should merge to paulis
def intersec(a, b):
inter = []
aind = []
Expand Down Expand Up @@ -264,4 +272,4 @@ def merge_measure(obslist):
measure_basis.append(obs)
targ_basis.append(len(measure_basis) - 1)

return measure_basis, targ_basis
return measure_basis, targ_basis
Loading

0 comments on commit 4881570

Please sign in to comment.