Skip to content

Commit

Permalink
Merge pull request #27 from ScQ-Cloud/dev
Browse files Browse the repository at this point in the history
merge dev into master
  • Loading branch information
Zhaoyilunnn authored Jul 21, 2023
2 parents 8d41d23 + c6403b5 commit ba66449
Show file tree
Hide file tree
Showing 71 changed files with 6,779 additions and 991 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
pull_request_review:
types: [submitted, edited]
workflow_dispatch:
release:



jobs:
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ cmake
*.egg-info
test
.vscode
thirdparty
thirdparty
.pyd
MANIFEST.in
56 changes: 50 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

cmake_minimum_required(VERSION 3.14...3.22)

project(qfvm LANGUAGES CXX C)
project(qfvm LANGUAGES CXX C)

set (CMAKE_BUILD_TYPE Release)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_ARCHITECTURES 70;75;80;90)
if(SKBUILD)

execute_process(
Expand Down Expand Up @@ -42,9 +43,6 @@ ExternalProject_Add(Eigen3
PREFIX ${EIGEN3_ROOT}
GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
GIT_TAG 3.3.9
# CONFIGURE_COMMAND cd ${EIGEN3_ROOT}/src/Eigen3 && cmake -B build -DCMAKE_INSTALL_PREFIX=${EIGEN3_ROOT}
# BUILD_COMMAND ""
# INSTALL_COMMAND cd ${EIGEN3_ROOT}/src/Eigen3 && cmake --build build --target install

CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down Expand Up @@ -72,6 +70,7 @@ if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_HOST_SYSTEM_PROCESSOR
endif()
endif()

list (APPEND PRJ_INCLUDE_DIRS src/qfvm)
pybind11_add_module(${PROJECT_NAME} MODULE src/${PROJECT_NAME}/${PROJECT_NAME}.cpp)
add_dependencies(${PROJECT_NAME} Eigen3) #must add dependence for ninja
target_compile_options(${PROJECT_NAME} PUBLIC ${PRJ_COMPILE_OPTIONS})
Expand All @@ -80,4 +79,49 @@ target_link_libraries(${PROJECT_NAME} PUBLIC ${PRJ_LIBRARIES})
set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ${PYTHON_MODULE_EXTENSION})
target_compile_definitions(${PROJECT_NAME} PRIVATE VERSION_INFO=${PROJECT_VERSION})

install(TARGETS ${PROJECT_NAME} DESTINATION .)
#GPU version
if (USE_GPU)
add_compile_definitions(_USE_GPU)
enable_language(CUDA)
set_source_files_properties(src/${PROJECT_NAME}/${PROJECT_NAME}.cpp PROPERTIES LANGUAGE CUDA)
target_link_libraries(${PROJECT_NAME} PUBLIC cudart)
target_compile_options(${PROJECT_NAME} PUBLIC $<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda> )
target_include_directories(${PROJECT_NAME} PUBLIC src/qfvm_gpu)
target_include_directories(${PROJECT_NAME} PUBLIC src/qfvm_gpu/cuda_utils)
target_include_directories(${PROJECT_NAME} PUBLIC ${CUDA_INCLUDE_DIRS})
message("cuda_include" ${CUDA_INCLUDE_DIRS})
if (USE_CUQUANTUM)
add_compile_definitions(_USE_CUQUANTUM)
function(set_with_fallback VARIABLE FALLBACK)
if (NOT DEFINED ${VARIABLE} OR ${VARIABLE} STREQUAL "")
set(${VARIABLE} $ENV{${VARIABLE}} CACHE INTERNAL ${VARIABLE})
if (${VARIABLE} STREQUAL "")
if (NOT ${FALLBACK} STREQUAL "")
set(${VARIABLE} $ENV{${FALLBACK}} CACHE INTERNAL ${VARIABLE})
endif ()
endif ()
endif ()
endfunction()

set_with_fallback(CUSTATEVEC_ROOT CUQUANTUM_ROOT)

if (CUSTATEVEC_ROOT STREQUAL "")
message(FATAL_ERROR "Please set the environment variables CUSTATEVEC_ROOT or CUQUANTUM_ROOT to the path of the cuQuantum installation.")
endif ()

message(STATUS "Using CUSTATEVEC_ROOT = ${CUSTATEVEC_ROOT}")

set(CMAKE_CUDA_FLAGS_ARCH_SM70 "-gencode arch=compute_70,code=sm_70")
set(CMAKE_CUDA_FLAGS_ARCH_SM75 "-gencode arch=compute_75,code=sm_75")
set(CMAKE_CUDA_FLAGS_ARCH_SM80 "-gencode arch=compute_80,code=sm_80 -gencode arch=compute_80,code=compute_80")
set(CMAKE_CUDA_FLAGS_ARCH_SM90 "-gencode arch=compute_90,code=sm_90 -gencode arch=compute_90,code=compute_90")
set(CMAKE_CUDA_FLAGS_ARCH "${CMAKE_CUDA_FLAGS_ARCH_SM70} ${CMAKE_CUDA_FLAGS_ARCH_SM75} ${CMAKE_CUDA_FLAGS_ARCH_SM80} ${CMAKE_CUDA_FLAGS_ARCH_SM90}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${CMAKE_CUDA_FLAGS_ARCH}")

target_include_directories(${PROJECT_NAME} PUBLIC ${CUDA_INCLUDE_DIRS} ${CUSTATEVEC_ROOT}/include)
target_link_directories(${PROJECT_NAME} PUBLIC ${CUSTATEVEC_ROOT}/lib ${CUSTATEVEC_ROOT}/lib64)
target_link_libraries(${PROJECT_NAME} PUBLIC -lcustatevec_static -lcublas )
endif()
endif()

install(TARGETS ${PROJECT_NAME} DESTINATION .)
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ pip install -r requirements.txt
python setup.py install
```

## GPU support
To install PyQuafu with GPU-based circuit simulator, you need build from the source and make sure that [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) is installed. You can run

```
python setup.py install -DUSE_GPU=ON
```
to install the GPU version. If you further have [cuQuantum](https://developer.nvidia.com/cuquantum-sdk) installed, you can install PyQuafu with cuQuantum support.
```
python setup.py install -DUSE_GPU=ON -DUSE_CUQUANTUM=ON
```


## Document
Please see the website [docs](https://scq-cloud.github.io/).

Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ scipy>=1.8.1
setuptools>=58.0.4
sparse>=0.13.0
scikit-build>=0.16.1
pybind11>=2.10.3
pybind11>=2.10.3
ply~=3.11
Pillow~=10.0.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

setup(
name="pyquafu",
version="0.2.11",
version="0.3.0",
author="ssli",
author_email="[email protected]",
url="https://github.com/ScQ-Cloud/pyquafu",
Expand All @@ -45,4 +45,4 @@
zip_safe=False,
setup_cfg=True,
license="Apache-2.0 License"
)
)
12 changes: 10 additions & 2 deletions src/qfvm/circuit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class Circuit{
private:
uint qubit_num_;
vector<QuantumOperator> gates_{0};

uint max_targe_num_;
public:
Circuit();
explicit Circuit(uint qubit_num);
Expand All @@ -99,6 +99,7 @@ class Circuit{
void add_gate(QuantumOperator &gate);
void compress_gates();
uint qubit_num() const { return qubit_num_; }
uint max_targe_num() const {return max_targe_num_;}
vector<QuantumOperator>gates() const { return gates_; }

};
Expand All @@ -121,16 +122,21 @@ void Circuit::add_gate(QuantumOperator &gate){

Circuit::Circuit(vector<QuantumOperator> &gates)
:
gates_(gates){
gates_(gates),
max_targe_num_(0){
qubit_num_ = 0;
for (auto gate : gates){
for (pos_t pos : gate.positions()){
if (gate.targe_num() > max_targe_num_)
max_targe_num_ = gate.targe_num();
if (pos+1 > qubit_num_){ qubit_num_ = pos+1; }
}
}
}

Circuit::Circuit(py::object const&pycircuit)
:
max_targe_num_(0)
{
auto pygates = pycircuit.attr("gates");
auto used_qubits = pycircuit.attr("used_qubits").cast<vector<pos_t>>();
Expand All @@ -139,6 +145,8 @@ Circuit::Circuit(py::object const&pycircuit)
py::object pygate = py::reinterpret_borrow<py::object>(pygate_h);
QuantumOperator gate = from_pyops(pygate);
if (gate){
if (gate.targe_num() > max_targe_num_)
max_targe_num_ = gate.targe_num();
gates_.push_back(std::move(gate));
}
}
Expand Down
101 changes: 85 additions & 16 deletions src/qfvm/qfvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,111 @@
#include <pybind11/numpy.h>
#include "simulator.hpp"

#ifdef _USE_GPU
#include <cuda_simulator.cuh>
#endif

#ifdef _USE_CUQUANTUM
#include <custate_simu.cuh>
#endif

namespace py = pybind11;

template <typename T>
py::array_t<T> to_numpy(std::vector<T> &&src) {
vector<T>* src_ptr = new std::vector<T>(std::move(src));
auto capsule = py::capsule(src_ptr, [](void* p) { delete reinterpret_cast<std::vector<T>*>(p); });
return py::array_t<T>(
src_ptr->size(), // shape of array
src_ptr->data(), // c-style contiguous strides for vector
capsule // numpy array references this parent
);
}
py::array_t<T> to_numpy(const std::tuple<T*, size_t> &src) {
auto src_ptr = std::get<0>(src);
auto src_size = std::get<1>(src);

auto capsule = py::capsule(src_ptr, [](void* p) {
delete [] reinterpret_cast<T*>(p);
});
return py::array_t<T>(
src_size,
src_ptr,
capsule
);
}

py::object execute(string qasm){
return to_numpy(simulate(qasm).move_data());
return to_numpy(simulate(qasm).move_data_to_python());
}

py::object simulate_circuit(py::object const&pycircuit, vector<complex<double>> const&inputstate){
auto circuit = Circuit(pycircuit);
if (inputstate.size() == 0){
py::object simulate_circuit(py::object const&pycircuit, py::array_t<complex<double>> &np_inputstate){
auto circuit = Circuit(pycircuit);
py::buffer_info buf = np_inputstate.request();
auto* data_ptr = reinterpret_cast<std::complex<double>*>(buf.ptr);
size_t data_size = buf.size;

if (data_size == 0){
StateVector<double> state;
simulate(circuit, state);
return to_numpy(state.move_data());
return to_numpy(state.move_data_to_python());
}
else{
StateVector<double> state{inputstate};
StateVector<double> state(data_ptr, buf.size);
simulate(circuit, state);
return to_numpy(state.move_data());
state.move_data_to_python();
return np_inputstate;
}
}

#ifdef _USE_GPU
py::object simulate_circuit_gpu(py::object const&pycircuit, py::array_t<complex<double>> &np_inputstate){
auto circuit = Circuit(pycircuit);
py::buffer_info buf = np_inputstate.request();
auto* data_ptr = reinterpret_cast<std::complex<double>*>(buf.ptr);
size_t data_size = buf.size;


if (data_size == 0){
StateVector<double> state;
simulate_gpu(circuit, state);
return to_numpy(state.move_data_to_python());
}
else{
StateVector<double> state(data_ptr, buf.size);
simulate_gpu(circuit, state);
state.move_data_to_python();
return np_inputstate;
}
}
#endif

#ifdef _USE_CUQUANTUM
py::object simulate_circuit_custate(py::object const&pycircuit, py::array_t<complex<double>> &np_inputstate){
auto circuit = Circuit(pycircuit);
py::buffer_info buf = np_inputstate.request();
auto* data_ptr = reinterpret_cast<std::complex<double>*>(buf.ptr);
size_t data_size = buf.size;


if (data_size == 0){
StateVector<double> state;
simulate_custate(circuit, state);
return to_numpy(state.move_data_to_python());
}
else{
StateVector<double> state(data_ptr, buf.size);
simulate_custate(circuit, state);
state.move_data_to_python();
return np_inputstate;
}
}
#endif



PYBIND11_MODULE(qfvm, m) {
m.doc() = "Qfvm simulator";
m.def("execute", &execute, "Simulate with qasm");
m.def("simulate_circuit", &simulate_circuit, "Simulate with circuit", py::arg("circuit"), py::arg("inputstate")= py::array_t<complex<double>>(0));

#ifdef _USE_GPU
m.def("simulate_circuit_gpu", &simulate_circuit_gpu, "Simulate with circuit", py::arg("circuit"), py::arg("inputstate")= py::array_t<complex<double>>(0));
#endif

#ifdef _USE_CUQUANTUM
m.def("simulate_circuit_custate", &simulate_circuit_custate, "Simulate with circuit", py::arg("circuit"), py::arg("inputstate")= py::array_t<complex<double>>(0));
#endif
}

8 changes: 4 additions & 4 deletions src/qfvm/simulator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ void simulate(Circuit const& circuit, StateVector<data_t> & state){
state.apply_rz(op.positions()[1], op.paras()[0]);
state.apply_cnot(op.positions()[0], op.positions()[1]);
break;


//Other general gate
default:
{
default:
{
if (op.targe_num() == 1){
auto mat_temp = op.mat();
complex<double> *mat = mat_temp.data();
Expand Down Expand Up @@ -176,7 +176,7 @@ void simulate(string qasm, StateVector<double> & state){
StateVector<double> simulate(string qasm){
StateVector<double>state;
simulate(qasm, state);
return std::move(state);
return std::move(state);
}


Loading

0 comments on commit ba66449

Please sign in to comment.