-
Notifications
You must be signed in to change notification settings - Fork 130
/
Copy pathplot_tree.py
166 lines (133 loc) · 5.39 KB
/
plot_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import argparse
import json
import matplotlib as mpl
mpl.use("Agg")
from matplotlib import gridspec
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import numpy as np
import sys
sys.path.append('..')
from base.io_util import json_to_tree
def plot_tree(tree, figure_name, color_by_trait, initial_branch_width, tip_size):
"""Plot a BioPython Phylo tree in the BALTIC-style.
"""
# Plot H3N2 tree in BALTIC style from Bio.Phylo tree.
mpl.rcParams['savefig.dpi'] = 120
mpl.rcParams['figure.dpi'] = 100
mpl.rcParams['font.weight']=300
mpl.rcParams['axes.labelweight']=300
mpl.rcParams['font.size']=14
yvalues = [node.yvalue for node in tree.find_clades()]
y_span = max(yvalues)
y_unit = y_span / float(len(yvalues))
# Setup colors.
trait_name = color_by_trait
traits = [k.attr[trait_name] for k in tree.find_clades()]
norm = mpl.colors.Normalize(min(traits), max(traits))
cmap = mpl.cm.viridis
#
# Setup the figure grid.
#
fig = plt.figure(figsize=(8, 6), facecolor='w')
gs = gridspec.GridSpec(2, 1, height_ratios=[14, 1], width_ratios=[1], hspace=0.1, wspace=0.1)
ax = fig.add_subplot(gs[0])
colorbar_ax = fig.add_subplot(gs[1])
L=len([k for k in tree.find_clades() if k.is_terminal()])
# Setup arrays for tip and internal node coordinates.
tip_circles_x = []
tip_circles_y = []
tip_circles_color = []
tip_circle_sizes = []
node_circles_x = []
node_circles_y = []
node_circles_color = []
node_line_widths = []
node_line_segments = []
node_line_colors = []
branch_line_segments = []
branch_line_widths = []
branch_line_colors = []
branch_line_labels = []
for k in tree.find_clades(): ## iterate over objects in tree
x=k.attr["num_date"] ## or from x position determined earlier
y=k.yvalue ## get y position from .drawTree that was run earlier, but could be anything else
if k.up is None:
xp = None
else:
xp=k.up.attr["num_date"] ## get x position of current object's parent
if x==None: ## matplotlib won't plot Nones, like root
x=0.0
if xp==None:
xp=x
c = 'k'
if k.attr.has_key(trait_name):
c = cmap(norm(k.attr[trait_name]))
branchWidth=2
if k.is_terminal(): ## if leaf...
s = tip_size ## tip size can be fixed
tip_circle_sizes.append(s)
tip_circles_x.append(x)
tip_circles_y.append(y)
tip_circles_color.append(c)
else: ## if node...
k_leaves = [child
for child in k.find_clades()
if child.is_terminal()]
# Scale branch widths by the number of tips.
branchWidth += initial_branch_width * len(k_leaves) / float(L)
if len(k.clades)==1:
node_circles_x.append(x)
node_circles_y.append(y)
node_circles_color.append(c)
ax.plot([x,x],[k.clades[-1].yvalue, k.clades[0].yvalue], lw=branchWidth, color=c, ls='-', zorder=9, solid_capstyle='round')
branch_line_segments.append([(xp, y), (x, y)])
branch_line_widths.append(branchWidth)
branch_line_colors.append(c)
branch_lc = LineCollection(branch_line_segments, zorder=9)
branch_lc.set_color(branch_line_colors)
branch_lc.set_linewidth(branch_line_widths)
branch_lc.set_label(branch_line_labels)
branch_lc.set_linestyle("-")
ax.add_collection(branch_lc)
# Add circles for tips and internal nodes.
tip_circle_sizes = np.array(tip_circle_sizes)
ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes, facecolor=tip_circles_color, edgecolor='none',zorder=11) ## plot circle for every tip
ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes*2, facecolor='k', edgecolor='none', zorder=10) ## plot black circle underneath
ax.scatter(node_circles_x, node_circles_y, facecolor=node_circles_color, s=50, edgecolor='none', zorder=10, lw=2, marker='|') ## mark every node in the tree to highlight that it's a multitype tree
#ax.set_ylim(-10, y_span - 300)
ax.spines['top'].set_visible(False) ## no axes
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.grid(axis='x',ls='-',color='grey')
ax.tick_params(axis='y',size=0)
ax.set_yticklabels([])
cb1 = mpl.colorbar.ColorbarBase(
colorbar_ax,
cmap=cmap,
norm=norm,
orientation='horizontal'
)
cb1.set_label(color_by_trait)
gs.tight_layout(fig)
plt.savefig(figure_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("tree", help="auspice tree JSON")
parser.add_argument("output", help="plotted tree figure")
parser.add_argument("--colorby", help="trait in tree to color by", default="num_date")
parser.add_argument("--branch_width", help="initial branch width", type=int, default=10)
parser.add_argument("--tip_size", help="tip size", type=int, default=10)
args = parser.parse_args()
with open(args.tree, "r") as json_fh:
json_tree = json.load(json_fh)
# Convert JSON tree layout to a Biopython Clade instance.
tree = json_to_tree(json_tree)
# Plot the tree.
plot_tree(
tree,
args.output,
args.colorby,
args.branch_width,
args.tip_size
)