Skip to content

Commit

Permalink
Add buffon graph generation (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon authored Sep 19, 2022
1 parent ace3942 commit c1fa32c
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 1 deletion.
15 changes: 15 additions & 0 deletions examples/make_buffon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import matplotlib.pyplot as plt
import networkx as nx
from netsalt.utils import make_buffon_graph

import numpy as np
if __name__ == "__main__":
np.random.seed(42)
buffon, pos = make_buffon_graph(n_lines=20, size=(-100.0, 100.0), resolution=1.0)

plt.figure()
nx.draw(buffon, pos=pos, node_size=0.00, width=0.2)
ax = plt.gca()
ax.set_axis_on()
ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
plt.savefig("buffon_graph.pdf")
2 changes: 1 addition & 1 deletion netsalt/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def _plot_single_mode(graph, mode, ax=None, colorbar=True, edge_vmin=None, edge_
if colorbar:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=edge_vmin, vmax=edge_vmax))
sm.set_array([])
plt.colorbar(sm, label=r"$|E|^2$ (a.u)", shrink=0.5)
plt.colorbar(sm, ax=ax, label=r"$|E|^2$ (a.u)", shrink=0.5)
return ax


Expand Down
124 changes: 124 additions & 0 deletions netsalt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,127 @@ def remove_pixel(graph, center, size):
graph = nx.convert_node_labels_to_integers(graph)
pump = [graph[e[0]][e[1]]["pump"] for e in graph.edges]
return graph, pump


def _to_points(point, angle, t, size):
"""Convert a point/angle to set of points."""
points = point + np.array([np.cos(angle) * t, np.sin(angle) * t]).T
points = points[(points[:, 0] > size[0]) & (points[:, 0] < size[1])]
return points[(points[:, 1] > size[0]) & (points[:, 1] < size[1])].tolist()


def _get_line_points(points, angles, t, size):
"""For each line, we create the points are edge list.
We return a dict with keys are line index.
"""
all_points = {}
edge_list = {}
for i, (point, angle) in enumerate(zip(points, angles)):
_points = _to_points(point, angle, t, size)
edge_list[i] = [(i, i + 1) for i in range(len(_points) - 1)]
all_points[i] = _points
return edge_list, all_points


def _get_intersection_points(points, angles, size):
"""Find the intersection points between intersecting lines.
For each point, we return a 2-tuple with the point and indices of the intersecting lines.
"""
intersection_points = []
for i, (point1, angle1) in enumerate(zip(points, angles)):
for j, (point2, angle2) in enumerate(zip(points[i:], angles[i:])):
x = (
point1[1] - point2[1] - np.tan(angle1) * point1[0] + np.tan(angle2) * point2[0]
) / (np.tan(angle2) - np.tan(angle1))
y = point1[1] + np.tan(angle1) * (x - point1[0])
if size[0] < x < size[1] and size[0] < y < size[1]:
intersection_points.append([[x, y], (i, i + j)])

return intersection_points


def _add_intersection_points(edge_list, all_points, intersection_points):
"""We add intersections point to each line by adding a new point and updating edge_list."""
edges = []
for intersection_point in intersection_points:
inter_id = {}
for i in intersection_point[1]:
edges = edge_list[i]
points = np.array(all_points[i])

# search for correct segment (where intersection is in the middle)
index = None
for j, edge in enumerate(edges):
x = intersection_point[0] - points[edge[0]]
y = intersection_point[0] - points[edge[1]]
z = points[edge[1]] - points[edge[0]]

if abs(np.linalg.norm(x) + np.linalg.norm(y) - np.linalg.norm(z)) < 1e-10:
index = j

if index is not None and inter_id is not None:
e = edge_list[i].pop(index)
edge_list[i].append([e[0], len(points)])
edge_list[i].append([len(points), e[1]])
inter_id[i] = len(points)
all_points[i].append(intersection_point[0])
else:
inter_id = None

if inter_id is not None:
intersection_point.append(inter_id)


def _get_graph(edge_list, all_points, intersection_points):
"""We create the buffon graph by making line subgraph, and merging each intersection point.
We return the graph and list of node positions.
"""
graph = nx.Graph()
shift = 0
pos = []
last_ids = {}
# create the graph
for i in edge_list:
edges, points = edge_list[i], all_points[i]
for edge in edges:
graph.add_edge(edge[0] + shift, edge[1] + shift)

last_ids[i] = shift
shift += len(points)
pos += points

# merge intersection points
for intersection_point in intersection_points:
edge_i = intersection_point[1][0]
edge_j = intersection_point[1][1]
if len(intersection_point) == 3:
i = last_ids[edge_i] + intersection_point[2][edge_i]
j = last_ids[edge_j] + intersection_point[2][edge_j]
graph = nx.contracted_nodes(graph, i, j)
return graph, pos


def make_buffon_graph(n_lines, size, resolution=1.0):
"""Make a buffon graph.
Args:
n_lines (int): number of lines to draw randomly
size (2-tuple): min and max extent of the graph (it will be square only)
resolution (float): distance between each points along lines
Warning: it is not exactly the same graph as in the Nat. Comm. Paper, which was done with
a matlab code.
"""
diag = np.sqrt(2) * (size[1] - size[0])
t = np.arange(-diag, diag, resolution)
points = np.random.uniform(size[0], size[1], size=(n_lines, 2))
angles = np.random.uniform(0, np.pi, n_lines)

edge_list, all_points = _get_line_points(points, angles, t, size)
intersection_points = _get_intersection_points(points, angles, size)
_add_intersection_points(edge_list, all_points, intersection_points)
graph, pos = _get_graph(edge_list, all_points, intersection_points)
return graph, pos
Binary file modified tests/data/run_simple/out/quantum_graph.gpickle
Binary file not shown.

0 comments on commit c1fa32c

Please sign in to comment.