Skip to content

Commit bdce30e

Browse files
improved loss function, added dictionaries for fast searching
1 parent 1fe985c commit bdce30e

7 files changed

+3035
-2448
lines changed

Diff for: algo_improved.c

+2,595-2,414
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: algo_improved.cp39-win_amd64.pyd

7.5 KB
Binary file not shown.

Diff for: algo_improved.pyx

+29-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from libc.math cimport sin , cos, sqrt, fabs
32
import numpy as np
43
cimport numpy as cnp
@@ -14,30 +13,32 @@ cdef:
1413

1514
@cython.boundscheck(False)
1615
@cython.wraparound(False)
17-
def findDepth(dict specificNode , list parsedLinkData):
18-
cdef int temp = 0
16+
def findDepth(dict specificNode , list parsedLinkData, dict nodeIDMapToParsedLinkDataIndexTarget):
17+
cdef int temp = 1
1918
cdef float length = 0
2019
cdef list linkWithTarget
2120
cdef dict currentNode = specificNode
2221
while True:
2322
## this can be imporved using hashmap
24-
linkWithTarget = [x for x in parsedLinkData if x['target']['id'] == currentNode["id"]]
23+
## linkWithTarget = [x for x in parsedLinkData if x['target']['id'] == currentNode["id"]]
24+
linkWithTarget = [parsedLinkData[nodeIDMapToParsedLinkDataIndexTarget[currentNode["id"]]]] if nodeIDMapToParsedLinkDataIndexTarget.get(currentNode["id"], None) is not None else []
2525
if len(linkWithTarget)==0:
2626
return (temp, length)
2727
else:
28-
temp = temp +1
28+
temp += 1
2929
length = length + linkWithTarget[0]['len']
3030
currentNode = linkWithTarget[0]['source']
3131

3232

3333
@cython.boundscheck(False)
3434
@cython.wraparound(False)
35-
def findRootNode(dict randomNode, list parsedLinkData):
35+
def findRootNode(dict randomNode, list parsedLinkData, dict nodeIDMapToParsedLinkDataIndexTarget):
3636
cdef dict currentNode = randomNode
3737
cdef list linkWithTarget
3838
while True:
3939
## this can be imporved using hashmap
40-
linkWithTarget = [x for x in parsedLinkData if x['target']['id'] == currentNode["id"]]
40+
## linkWithTarget = [x for x in parsedLinkData if x['target']['id'] == currentNode["id"]]
41+
linkWithTarget = [parsedLinkData[nodeIDMapToParsedLinkDataIndexTarget[currentNode["id"]]]] if nodeIDMapToParsedLinkDataIndexTarget.get(currentNode["id"], None) is not None else []
4142
if len(linkWithTarget)==0:
4243
return currentNode
4344
else:
@@ -46,12 +47,13 @@ def findRootNode(dict randomNode, list parsedLinkData):
4647
## find the nodes from "endNode" to node "startNode" without the startNode itself
4748
@cython.boundscheck(False)
4849
@cython.wraparound(False)
49-
def reversebfs(dict startNode, list parsedLinkData):
50+
def reversebfs(dict startNode, list parsedLinkData, dict nodeIDMapToParsedLinkDataIndexTarget):
5051
cdef list path = []
5152
cdef dict currentNode = startNode , currentLink
5253
while True:
5354
## this can be imporved using hashmap
54-
linkWithTarget = [x for x in parsedLinkData if x['target']['id'] == currentNode["id"]]
55+
## linkWithTarget = [x for x in parsedLinkData if x['target']['id'] == currentNode["id"]]
56+
linkWithTarget = [parsedLinkData[nodeIDMapToParsedLinkDataIndexTarget[currentNode["id"]]]] if nodeIDMapToParsedLinkDataIndexTarget.get(currentNode["id"], None) is not None else []
5557
if len(linkWithTarget)==0:
5658
break
5759
else:
@@ -74,6 +76,7 @@ def updateCoordinatesX(double[:] x , double[:] y, float cX ,float cY, float rad,
7476
return new_x
7577

7678

79+
7780
@cython.boundscheck(False)
7881
@cython.wraparound(False)
7982
def updateCoordinatesY(double[:] x, double[:] y, float cX, float cY, float rad, int l):
@@ -126,6 +129,10 @@ def lossFunction(dict node,float realTheta,list parsedLinkData, dict nodeIDMapTo
126129
deltaY = matchedLinkObject['target']['y'] - matchedLinkObject['source']['y']
127130
return ((middleX - cX)* (deltaX) + (middleY - cY)* (deltaY)) / sqrt(a*b)
128131

132+
@cython.boundscheck(False)
133+
@cython.wraparound(False)
134+
def lossFunctionWithDepth(dict node,float realTheta, int depth, list parsedLinkData, dict nodeIDMapToParsedLinkDataIndexTarget):
135+
pass
129136

130137
## can be destructured to function(hashTable, requiredNode)
131138
@cython.boundscheck(False)
@@ -225,7 +232,6 @@ cdef bint checkInfiniteCollision(float sourceX, float sourceY, float targetX, fl
225232
dist = fabs(m*centreX - centreY + c) / sqrt(m*m +1)
226233
return (dist <= R)
227234

228-
229235
@cython.boundscheck(False)
230236
@cython.wraparound(False)
231237
def calIntersectionNum(list links, list requiredLinkIDs, list filteredLinksIDs ):
@@ -288,15 +294,17 @@ def findIntersection(list links):
288294

289295
## all nodes inputed should not be root node here
290296
@cython.boundscheck(False)
297+
@cython.wraparound(False)
291298
def mainAlgo(dict node, dict link, list parsedNodeData, list parsedLinkData, float THETA, float LAMBDA,
292299
dict nodeIDMapToParsedLinkDataIndexSource, dict nodeIDMapToParsedLinkDataIndexTarget, dict nodeIDMapToParsedNodeDataIndex,
293-
dict linkIDMapToParsedNodeDataIndexSource, dict linkIDMapToParsedNodeDataIndexTarget, dict linkIDMapToParsedLinkDataIndex
300+
dict linkIDMapToParsedNodeDataIndexSource, dict linkIDMapToParsedNodeDataIndexTarget, dict linkIDMapToParsedLinkDataIndex,
301+
dict depth
294302
):
295303

296304
cdef:
297305
int numOfIntersections, iterations = 0, currentNodeIndex, nodeIndex, breakFromWhile = 0, tempIndex = 1
298306
list requiredUpdateNode = search(node['id'], nodeIDMapToParsedLinkDataIndexSource, nodeIDMapToParsedLinkDataIndexTarget, parsedLinkData)
299-
int l = len(requiredUpdateNode), N = len(parsedLinkData), recordLength
307+
int l = len(requiredUpdateNode), N = len(parsedLinkData), recordLength = 0, nodeDepth = depth[node["id"]]
300308
unsigned int indexing, index
301309
long[:] orderedNodeIndex = np.zeros(l, dtype=long)
302310
## a numpy float is a C double.
@@ -312,10 +320,12 @@ def mainAlgo(dict node, dict link, list parsedNodeData, list parsedLinkData, flo
312320
for index in range(l):
313321
linkObjects[index] = nodeIDMapToParsedLinkDataIndexTarget[requiredUpdateNode[index]]
314322
orderedNodeIndex[index] = nodeIDMapToParsedNodeDataIndex[requiredUpdateNode[index]]
315-
orderedNodeX[index] = float(parsedNodeData[orderedNodeIndex[index]]["x"])
316-
orderedNodeY[index] = float(parsedNodeData[orderedNodeIndex[index]]["y"])
323+
## orderedNodeX[index] = float(parsedNodeData[orderedNodeIndex[index]]["x"])
324+
## orderedNodeY[index] = float(parsedNodeData[orderedNodeIndex[index]]["y"])
325+
orderedNodeX[index] = parsedNodeData[orderedNodeIndex[index]]["x"]
326+
orderedNodeY[index] = parsedNodeData[orderedNodeIndex[index]]["y"]
317327

318-
lastLinkObject = parsedLinkData[linkObjects[-1]]
328+
lastLinkObject = parsedLinkData[linkObjects[l-1]]
319329
R = sqrt((link["source"]["x"] - lastLinkObject["target"]["x"])**2 + (link["source"]["y"] - lastLinkObject["target"]["y"])**2)
320330

321331
for index in range(N):
@@ -344,16 +354,17 @@ def mainAlgo(dict node, dict link, list parsedNodeData, list parsedLinkData, flo
344354
if indexing == 0:
345355
records.append({'root' : { 'id' : requiredUpdateNode[indexing] ,'pos' :[new_x, new_y]} , 'childNodes' : {}})
346356
parsedNodeData[currentNodeIndex].update([("x", new_x), ("y", new_y)])
357+
recordLength += 1
347358
else:
348-
records[-1]['childNodes'].update([(requiredUpdateNode[indexing],[new_x, new_y])])
359+
records[recordLength -1]['childNodes'].update([(requiredUpdateNode[indexing],[new_x, new_y])])
349360
parsedNodeData[currentNodeIndex].update([("x",new_x), ("y",new_y)])
350-
361+
351362
numOfIntersections = calIntersectionNum(parsedLinkData, linkObjects, linkFilteredList)
352363
if numOfIntersections == 0:
353364
breakFromWhile = 1
354365
break
355366
"""The loss function is defined as : #Intersections - lambda * dotproduct (which is for measuring the degree of parallelism)"""
356-
records[-1]["loss"] = numOfIntersections - LAMBDA*lossFunction(node, binaryVariable, parsedLinkData, nodeIDMapToParsedLinkDataIndexTarget)
367+
records[recordLength - 1]["loss"] = numOfIntersections - LAMBDA*lossFunction(node, binaryVariable, parsedLinkData, nodeIDMapToParsedLinkDataIndexTarget) / nodeDepth
357368
iterations += 1
358369
realTheta = iterations* THETA
359370

0 commit comments

Comments
 (0)