-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnodes.py
164 lines (149 loc) · 5.3 KB
/
nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from collections import KeysView,ValuesView,ItemsView,MutableMapping
from networkx.exception import NetworkXError
class NodeKeys(KeysView):
def __repr__(self):
# If we remove this def, the result is:
# return '{0.__class__.__name__}({1})'.format(self,self._mapping)
return '{}'.format(list(self._mapping))
class NodeData(ValuesView):
def __repr__(self):
return '{}'.format(list(self._mapping.values()))
class NodeItems(ItemsView):
def __repr__(self):
return '{}'.format(list(self._mapping.items()))
class Nodes(MutableMapping):
__slots__ = ('_nodes','_adj')
def __init__(self, nodes, adj=None):
self._nodes = nodes
self._adj = adj
# both set and dict methods
def __iter__(self):
for n in self._nodes:
yield n
def __contains__(self, key):
return key in self._nodes
def __repr__(self):
return '{0.__class__.__name__}({1})'.format(self,self._nodes)
def __len__(self):
return len(self._nodes)
def clear(self):
self._nodes.clear()
self._adj.clear()
# set methods
def __and__(self, other):
return set(self._nodes) & set(other)
def __or__(self, other):
return set(self._nodes) | set(other)
def __xor__(self, other):
return set(self._nodes) ^ set(other)
def __sub__(self, other):
return set(self._nodes) - set(other)
# reverse set methods (so set | me works same as me | set)
def __rand__(self, other):
return set(self._nodes) & set(other)
def __ror__(self, other):
return set(self._nodes) | set(other)
def __rxor__(self, other):
return set(self._nodes) ^ set(other)
def __rsub__(self, other):
return set(self._nodes) - set(other)
# inplace mass adds and removes
def update(self, nodes, **attr):
for n in nodes:
try:
nn, ndict = n
newdict = attr.copy()
newdict.update(ndict)
self.add(nn, **newdict)
except TypeError:
self.add(n, attr_dict=None, **attr)
return self
def intersection_update(self, nodes):
for n in self - set(nodes):
self.discard(n)
return self
def symmetric_difference_update(self, nodes):
for n in nodes:
if n in self:
self.discard(n)
else:
self.add(n)
return self
def difference_update(self, nodes):
for n in nodes:
self.discard(n)
return self
__ior__ = update # |=
__iand__ = intersection_update # &=
__isub__ = difference_update # -=
__ixor__ = symmetric_difference_update # ^=
def discard(self, n):
adj = self._adj
try:
# list handles self-loops (allow mutation later)
nbrs = list(adj[n].keys())
del self._nodes[n]
except KeyError: # silently ignore if n not in self
return
for u in nbrs:
del adj[u][n] # remove all edges n-u in graph
del adj[n] # now remove node
def remove(self, n):
adj = self._adj
try:
# keys handles self-loops (allow mutation later)
nbrs = list(adj[n].keys())
del self._nodes[n]
except KeyError: # NetworkXError if n not in self
raise NetworkXError("The node %s is not in the graph." % (n,))
for u in nbrs:
del adj[u][n] # remove all edges n-u in graph
del adj[n] # now remove node
# dictionary methods
def __delitem__(self, key):
self.remove(key)
def __setitem__(self, key, value):
# self.add(key, value) # probably a bad idea
raise NetworkXError('Use the add() method')
def __getitem__(self, key):
return self._nodes[key]
def keys(self):
# return self._nodes.keys()
return NodeKeys(self._nodes)
def items(self):
# return self._nodes.items()
return NodeItems(self._nodes)
def values(self):
# return self._nodes.values()
return NodeData(self._nodes)
def add(self, n, attr_dict=None, **attr):
if attr_dict is None:
attr_dict = attr
else:
try:
attr_dict.update(attr)
except AttributeError:
raise NetworkXError(
"The attr_dict argument must be a dictionary.")
if n not in self._nodes:
self._adj[n] = {} # FIXME factory
self._nodes[n] = attr_dict
else: # update attr even if node already exists
self._nodes[n].update(attr_dict)
# extra methods: neither set or dict
def data(self):
# return self._nodes.values()
return NodeData(self._nodes)
# this could be a view
def degree(self, weight=None):
if weight is None:
for n, nbrs in self._adj.items():
# return tuple (n,degree)
yield (n, len(nbrs) + (1 if n in nbrs else 0))
else:
for n, nbrs in self._adj.items():
yield (n, sum((nbrs[nbr].get(weight, 1) for nbr in nbrs)) +
(nbrs[n].get(weight, 1) if n in nbrs else 0))
# return Degree(dict(d_iter()))
def selfloops(self):
return (n for n, nbrs in self._adj.items() if n in nbrs)