Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@ authors = ["Br0kenSmi1e"]
version = "1.0.0-DEV"

[deps]
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SCIP = "82193955-e24f-5292-bf16-6f2c5261a85f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
HiGHS = "1"
JuMP = "1"
LinearAlgebra = "1"
SCIP = "0.12"
SpecialFunctions = "2"
StaticArrays = "1"
julia = "1"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Documenter", "Test"]
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# CrystalStructurePrediction

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://Br0kenSmi1e.github.io/CrystalStructurePrediction.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://Br0kenSmi1e.github.io/CrystalStructurePrediction.jl/dev/)
[![Build Status](https://github.com/Br0kenSmi1e/CrystalStructurePrediction.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/Br0kenSmi1e/CrystalStructurePrediction.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/Br0kenSmi1e/CrystalStructurePrediction.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Br0kenSmi1e/CrystalStructurePrediction.jl)

Expand Down
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CrystalStructurePrediction = "4140a7b2-00d5-4ecf-8adb-c63b9db3f1fd"
SCIP = "82193955-e24f-5292-bf16-6f2c5261a85f"
Binary file modified examples/SrTiO3-structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 24 additions & 28 deletions examples/SrTiO3.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using CrystalStructurePrediction
using CairoMakie, SCIP

"""
setup_crystal_parameters()
Expand All @@ -10,9 +11,7 @@ function setup_crystal_parameters()
# Crystal structure parameters
grid_size = (2, 2, 2)
population_list = [1, 1, 3] # 1 Sr, 1 Ti, 3 O atoms
species_list = [:Sr, :Ti, :O]
charge_list = [+2, +4, -2]
radii_list = [1.18, 0.42, 1.35]
ion_types = [IonType(:Sr, +2, 1.18), IonType(:Ti, +4, 0.42), IonType(:O, -2, 1.35)]

# Lattice parameters
lattice_constant = 3.899 # Å
Expand All @@ -23,7 +22,7 @@ function setup_crystal_parameters()
depth = (4, 4, 4)
alpha = 2 / lattice_constant

return grid_size, population_list, species_list, charge_list, radii_list, lattice, depth, alpha
return grid_size, population_list, ion_types, lattice, depth, alpha
end

"""
Expand All @@ -36,16 +35,15 @@ Run the crystal structure prediction for SrTiO3.
"""
function run_crystal_structure_prediction(; use_quadratic_problem::Bool=false)
# Setup parameters
grid_size, population_list, species_list, charge_list, radii_list, lattice, depth, alpha = setup_crystal_parameters()
grid_size, population_list, ion_types, lattice, depth, alpha = setup_crystal_parameters()

@info "Setting up crystal structure prediction for SrTiO3"
@info "Grid size: $grid_size"
@info "Population: $population_list $species_list"
@info "Charges: $charge_list"
@info "Ionic radii: $radii_list"
@info "Population: $population_list"
@info "Ion types: $ion_types"

# Build ion list and proximal pairs
ion_list = build_ion_list(grid_size, species_list, charge_list, radii_list)
ion_list = build_ion_list(grid_size, ion_types)
@info "Created ion list with $(length(ion_list)) possible ion positions"

proximal_pairs = build_proximal_pairs(ion_list, lattice, 0.75)
Expand All @@ -58,15 +56,15 @@ function run_crystal_structure_prediction(; use_quadratic_problem::Bool=false)

# Solve the quadratic problem
@info "Solving quadratic optimization problem..."
energy, solution_x, csp = build_quadratic_problem(grid_size, population_list, matrix, proximal_pairs)
energy, solution_x, csp = build_quadratic_problem(grid_size, population_list, matrix; optimizer=SCIP.Optimizer)
else
# Build interaction vector
@info "Building interaction energy vector..."
vector = build_vector(ion_list, lattice, interaction_energy, (alpha, depth, depth, depth))

# Solve the linear problem
@info "Solving linear optimization problem..."
energy, solution_x, csp = build_linear_problem(grid_size, population_list, vector, proximal_pairs)
energy, solution_x, csp = build_linear_problem(grid_size, population_list, vector, proximal_pairs; optimizer=SCIP.Optimizer)
end

# Display results
Expand All @@ -87,14 +85,7 @@ function run_crystal_structure_prediction(; use_quadratic_problem::Bool=false)
return energy, selected_ions, csp
end

# Run the prediction
energy, selected_ions, csp = run_crystal_structure_prediction()

# (-6.061349350569213, Any[Ion{3, Float64}(:Sr, 2, 1.18, [0.5, 0.5, 0.0]), Ion{3, Float64}(:Ti, 4, 0.42, [0.0, 0.0, 0.5]), Ion{3, Float64}(:O, -2, 1.35, [0.0, 0.0, 0.0]), Ion{3, Float64}(:O, -2, 1.35, [0.5, 0.0, 0.5]), Ion{3, Float64}(:O, -2, 1.35, [0.0, 0.5, 0.5])], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

# Visualize the crystal structure
using CairoMakie

function visualize_crystal_structure(selected_ions, lattice, shift)
fig = Figure(; size = (300, 250))
ax = Axis3(fig[1, 1],
Expand Down Expand Up @@ -128,11 +119,12 @@ function visualize_crystal_structure(selected_ions, lattice, shift)
end

# Plot the ions
properties = Dict(:Sr => (color = :green, size = 30), :Ti => (color = :aqua, size = 25), :O => (color = :red, size = 10))
properties = Dict(:Sr => (color = :green,), :Ti => (color = :aqua,), :O => (color = :red,))

# Plot each ion
for ion in selected_ions
# Plot the ion at its position and all periodic images within the unit cell
coordinates = []
for dx in -1:1, dy in -1:1, dz in -1:1
# Add periodic image shift vector
offset = [dx, dy, dz] .+ shift
Expand All @@ -141,18 +133,19 @@ function visualize_crystal_structure(selected_ions, lattice, shift)
if all(0 .<= shifted_pos .<= 1)
# Convert to Cartesian coordinates
shifted_cart_pos = lattice.vectors * shifted_pos
scatter!(ax, [shifted_cart_pos[1]], [shifted_cart_pos[2]], [shifted_cart_pos[3]],
color = properties[ion.species].color,
markersize = properties[ion.species].size,
label = string(ion.species))
push!(coordinates, shifted_cart_pos)
end
end
scatter!(ax, coordinates,
color = properties[ion.type.species].color,
markersize = ion.type.radii * 20,
label = string(ion.type.species))
end

# Add legend with unique entries
unique_species = unique([ion.species for ion in selected_ions])
legend_elements = [MarkerElement(color = properties[sp].color, marker = :circle, markersize = properties[sp].size) for sp in unique_species]
legend_labels = [string(sp) for sp in unique_species]
unique_species = unique([ion.type for ion in selected_ions])
legend_elements = [MarkerElement(color = properties[sp.species].color, marker = :circle, markersize = sp.radii * 20) for sp in unique_species]
legend_labels = [string(sp.species) for sp in unique_species]

Legend(fig[1, 2], legend_elements, legend_labels, "Species", patchsize = (30, 30))
# Remove decorations and axis
Expand All @@ -161,9 +154,12 @@ function visualize_crystal_structure(selected_ions, lattice, shift)
return fig
end

# Run the prediction
energy, selected_ions, csp = run_crystal_structure_prediction(; use_quadratic_problem=false)

# Generate and save the visualization
lattice = setup_crystal_parameters()[6]
fig = visualize_crystal_structure(selected_ions, lattice, [0.5, 0.5, 0])
lattice = setup_crystal_parameters()[4]
fig = visualize_crystal_structure(selected_ions, lattice, [0.0, 0.0, 0.5])

filename = joinpath(@__DIR__, "SrTiO3-structure.png")
save(filename, fig, dpi=20)
Expand Down
6 changes: 3 additions & 3 deletions src/CrystalStructurePrediction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ module CrystalStructurePrediction
using SpecialFunctions
using LinearAlgebra
using JuMP
using HiGHS
using SCIP
using StaticArrays

export Lattice, Ion
export Lattice, Ion, IonType
export build_ion_list, build_vector
export build_matrix, interaction_energy
export build_linear_problem, build_quadratic_problem
export build_linear_problem, build_quadratic_problem, build_proximal_pairs

include("struct.jl")
include("interaction.jl")
Expand Down
17 changes: 4 additions & 13 deletions src/build_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,9 @@ return a list of ions in the format of:
`[ion(t1, p1), ion(t1, p2), ..., ion(t1,pn), ion(t2, p1), ..., ion(tm, pn)]`
"""
function build_ion_list(grid_size::NTuple{N, Int},
species_list::AbstractVector{Symbol},
charge_list::AbstractVector{Int},
radii_list::AbstractVector{Float64}
) where N
return [Ion(species_list[t], charge_list[t], radii_list[t], (ci.I .- 1) ./ grid_size) for t in range(1, length(species_list)) for ci in vec(CartesianIndices(grid_size))]
end

function interaction_energy(
ion_a::Ion{D, T}, ion_b::Ion{D, T}, lattice::Lattice{D, T},
alpha::T, real_depth::NTuple{D, Int}, reciprocal_depth::NTuple{D, Int}, buckingham_depth::NTuple{D, Int}
) where {D, T}
return real_space_sum(ion_a, ion_b, lattice, alpha, real_depth) + reciprocal_space_sum(ion_a, ion_b, lattice, alpha, reciprocal_depth) + buckingham_sum(ion_a, ion_b, lattice, buckingham_depth)# + radii_penalty(ion_a, ion_b, lattice, 0.75)
type_list::AbstractVector{IonType{T}},
) where {N, T}
return [Ion(type_list[t], (ci.I .- 1) ./ grid_size) for t in range(1, length(type_list)) for ci in vec(CartesianIndices(grid_size))]
end

function build_matrix(
Expand Down Expand Up @@ -48,6 +39,6 @@ function build_vector(
end

function build_proximal_pairs(ion_list::AbstractVector{Ion{D, T}}, lattice::Lattice{D, T}, c::Float64) where {D, T}
isProximal = (i, j) -> CrystalStructurePrediction.minimum_distance(ion_list[i], ion_list[j], lattice) < c * (ion_list[i].radii + ion_list[j].radii)
isProximal = (i, j) -> CrystalStructurePrediction.minimum_distance(ion_list[i].frac_pos, ion_list[j].frac_pos, lattice) < c * (radii(ion_list[i]) + radii(ion_list[j]))
return [(i,j) for i in range(1, length(ion_list)) for j in range(i+1, length(ion_list)) if isProximal(i, j)]
end
42 changes: 29 additions & 13 deletions src/build_problem.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# TODO: remove this function
function build_problem(
grid_size::AbstractVector{Int},
population_list::AbstractVector{Int},
interaction_matrix::AbstractMatrix{T};
optimizer = HiGHS.Optimizer
optimizer = SCIP.Optimizer
) where T<:Real
csp = Model(optimizer)
num_grid_points = prod(grid_size)
Expand All @@ -21,12 +22,26 @@ function build_problem(
return value.(x)
end

"""
build_linear_problem(grid_size, population_list, interaction_vector, proximal_pairs; optimizer = SCIP.Optimizer, optimizer_options = Dict())

Build a linear problem for crystal structure prediction. Suited for solvers not supporting quadratic constraints.

# Arguments
- `grid_size::NTuple{N, Int}`: The size of the grid.
- `population_list::AbstractVector{Int}`: The number of atoms of each species.
- `interaction_vector::AbstractVector{T}`: The interaction vector.
- `proximal_pairs::AbstractVector{Tuple{Int, Int}}`: The proximal pairs.
- `optimizer`: The optimizer.
- `optimizer_options`: The options for the optimizer, e.g. `optimizer_options = Dict("NodefileSave" => 1)` for Gurobi.
"""
function build_linear_problem(
grid_size::NTuple{N, Int},
population_list::AbstractVector{Int},
interaction_vector::AbstractVector{T},
proximal_pairs::AbstractVector{Tuple{Int, Int}};
optimizer = HiGHS.Optimizer
optimizer = SCIP.Optimizer,
optimizer_options = Dict()
) where {N, T<:Real}
csp = Model(optimizer)
num_grid_points = prod(grid_size)
Expand All @@ -48,21 +63,19 @@ function build_linear_problem(
@constraint(csp, s[i + (j-1)*(j-2)÷2] <= x[i])
@constraint(csp, s[i + (j-1)*(j-2)÷2] <= x[j])
@constraint(csp, s[i + (j-1)*(j-2)÷2] >= x[i] + x[j] - 1)
end
end
end
@objective(csp, Min, dot(interaction_vector, s))
for (key, value) in optimizer_options
set_optimizer_attribute(csp, key, value)
end
optimize!(csp)
assert_is_solved_and_feasible(csp)
return objective_value(csp), value.(x), value.(s)
end

"""
build_quadratic_problem(
grid_size::NTuple{N, Int},
population_list::AbstractVector{Int},
interaction_matrix::AbstractMatrix{T},
optimizer
) where {N, T<:Real}
build_quadratic_problem(grid_size, population_list, interaction_matrix; optimizer = SCIP.Optimizer, optimizer_options = Dict())

Build a quadratic problem for crystal structure prediction.

Expand All @@ -71,12 +84,14 @@ Build a quadratic problem for crystal structure prediction.
- `population_list::AbstractVector{Int}`: The number of atoms of each species.
- `interaction_matrix::AbstractMatrix{T}`: The interaction matrix.
- `optimizer`: The optimizer.
- `optimizer_options`: The options for the optimizer, e.g. `optimizer_options = Dict("NodefileSave" => 1)` for Gurobi.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this NodefileSave for?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
function build_quadratic_problem(
grid_size::NTuple{N, Int},
population_list::AbstractVector{Int},
interaction_matrix::AbstractMatrix{T},
optimizer
interaction_matrix::AbstractMatrix{T};
optimizer = SCIP.Optimizer,
optimizer_options = Dict()
) where {N, T<:Real}

csp = Model(optimizer)
Expand All @@ -91,8 +106,9 @@ function build_quadratic_problem(
@constraint(csp, sum(x[num_grid_points*(t-1)+p] for t in range(1, num_species)) <= 1)
end
@objective(csp, Min, sum(interaction_matrix[i,j]*x[i]*x[j] for i in range(1, num_species*num_grid_points) for j in range(i+1, num_species*num_grid_points) if (j-i)%num_grid_points != 0))
set_optimizer_attribute(csp, "NodefileStart", 1)

for (key, value) in optimizer_options
set_optimizer_attribute(csp, key, value)
end
optimize!(csp)
assert_is_solved_and_feasible(csp)
return objective_value(csp), value.(x), csp
Expand Down
Loading