|
| 1 | +# This file defines the postprocess module |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import sys, os, re, shutil |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import matplotlib.colors as col |
| 7 | +import matplotlib.cm as cm |
| 8 | + |
| 9 | +class dftpp: |
| 10 | + ''' |
| 11 | + **** ABSTRACT **** |
| 12 | + This module defines the core class of postprocessing the calculated |
| 13 | + results of DFT such as band plot, fatband plot and spectral plot, etc. |
| 14 | + This module can be used either independent or inherited by other |
| 15 | + package specified classes. |
| 16 | + **** INITLALIZE **** |
| 17 | + N/A |
| 18 | + **** METHODS **** |
| 19 | + bicolormap(self,colorcode): |
| 20 | + * generate three default colormaps for fatband plot. |
| 21 | + => [colorcode]: str, 'r'/'g'/'b' |
| 22 | + three default bicolormaps can be generated by simply input a |
| 23 | + single keyword: 'r': red, 'g': green, 'b': blue |
| 24 | +
|
| 25 | + band_plot(self,spin,Ek,Ef,kdiv,klabel='default',Ebound='default',lw=2,fontsize=18) |
| 26 | + * this method plots the band structures |
| 27 | + => [spin]: int, 1 / 2 |
| 28 | + if two spin are separate, spin=2, else spin=1 |
| 29 | + => [Ek]: numpy array, tot_k x tot_band / tot_k x 2*tot_band |
| 30 | + if spin=1, tot_k x tot_band. if spin=2, tot_k x 2*tot_band |
| 31 | + => [Ef]: float |
| 32 | + Fermi energy |
| 33 | + => [kdiv]: list, int, e.g. [0,25,42] |
| 34 | + location of the k-point to label high symmetry point. the first one |
| 35 | + must be 0 and the last one must be tot_k-1 |
| 36 | + => [klabel]: list, string, e.g. ['$\Gamma','X','M'] |
| 37 | + name of the high symmetry point. must the same length as kdiv |
| 38 | + default=['k1','k2',...,'kn'] |
| 39 | + => [Ebound]: list, two float values, e.g. [-3.5, 3.5] |
| 40 | + upper and lower energy bounds |
| 41 | + default=min(Ek) ~ max(Ek) |
| 42 | + => [lw]: int, default=2 |
| 43 | + the line width |
| 44 | + [fontsize]: int, default=18 |
| 45 | + the fontsize of the plot |
| 46 | + |
| 47 | + fatband_plot(self,Ek,Ek_weight,Ef,state_info,state_grp,kdiv,klabel='default'\ |
| 48 | + ,Ebound='default',marker_size=30,colorcode='b'): |
| 49 | + * this method plots the fatband sturctures |
| 50 | + => [Ek]: numpy array, tot_k x tot_band (all spin counts) |
| 51 | + the energy bands, tot_band are spin included |
| 52 | + => [Ek_weight]: numpy array, tot_k x tot_band x tot_state |
| 53 | + the weight of each state at each eigenvalues |
| 54 | + => [Ef]: float |
| 55 | + the Fermi level. Ek will be shift based on it, so the output |
| 56 | + Ef is always zero. |
| 57 | + => [state_info]: list, string, size: tot_state |
| 58 | + a list of strings that describes each state. each element corresponds |
| 59 | + to each state. |
| 60 | + => [state_grp]: list, e.g:[[1,2,3],[7,8,10],[12,17]] |
| 61 | + a list which shows the states to be grouped into a fatband plot. |
| 62 | + therefore, the above example will output three fatband plots where |
| 63 | + the first plot shows the wieght of sum over state 1 2 and 3. |
| 64 | + => [kdiv], [klabel], [Ebound] |
| 65 | + the same as band_plot |
| 66 | + => [ini_fig_num]: int, default=1 |
| 67 | + in case you plot many plots in your code, you don't want it conflict with |
| 68 | + others. so you can set the ini_fig_num to tell the code the figure number |
| 69 | + of each fatband plot. e.g. if ini_fig_num=3 and you have 3 state_grp, then |
| 70 | + your fatband plot will be fig.3 ~ fig.5 |
| 71 | + => [marker_size]: int, default=30 |
| 72 | + the size of your fatband markers |
| 73 | + => [colorcode]: str, 'r' / 'g' / 'b' , default: 'b' |
| 74 | + the default colormap for fatband plot. 'r': red, 'g': green, 'b': blue |
| 75 | + |
| 76 | + spectral_plot(self,E,PDOS,state_grp='default',xlabel='Energy',ylabel='PDOS'\ |
| 77 | + ,llabel='default',Ebound='default',lw=2,fontsize=18): |
| 78 | + * this method plots all kinds of spectral functions such as DOS, PDOS, |
| 79 | + DMFT spectral, etc. |
| 80 | + =>[E]: Nx1 numpy array, |
| 81 | + the x-axis of your plot |
| 82 | + =>[PDOS]: numpy array, len(E) X tot_state |
| 83 | + the PDOS of different states. each column corresponds to each state |
| 84 | + => [state_grp]: a list of list |
| 85 | + combine the data to plot a line in the spectral, e.g, [[1,2,3],[4,5]] |
| 86 | + means combine the data of state [1,2,3] to plot a line and [4,5] for |
| 87 | + another line. |
| 88 | + |
| 89 | + **** Version **** |
| 90 | + 02/01/2016: first built |
| 91 | + **** Comment **** |
| 92 | + 1. Run /test/postprocess_test.py to test this module |
| 93 | + ''' |
| 94 | + def __str__(self): |
| 95 | + return 'DFTtooolbox: Postprocess Object' |
| 96 | + |
| 97 | + def grep(self,txtlines,kws,reg=False): |
| 98 | + # search for line numbers from a textlines where contains the keywords |
| 99 | + if reg: |
| 100 | + ln=[ n for n,txt in enumerate(txtlines) if (re.search(kws,txt)!=None)] |
| 101 | + else: |
| 102 | + ln=[ n for n,txt in enumerate(txtlines) if (txt.find(kws)!=-1)] |
| 103 | + |
| 104 | + return ln |
| 105 | + |
| 106 | + def bicolormap(self,colorcode): |
| 107 | + if colorcode=='r': |
| 108 | + colorlist=['#e6e6e6','#cc0000'] |
| 109 | + elif colorcode=='g': |
| 110 | + colorlist=['#e6e6e6','#009900'] |
| 111 | + elif colorcode=='b': |
| 112 | + colorlist=['#e6e6e6','#000099'] |
| 113 | + |
| 114 | + # generate a colormap |
| 115 | + my_cmap = col.LinearSegmentedColormap.from_list('cmap_tmp',colorlist,N=32) |
| 116 | + |
| 117 | + # return the colormap |
| 118 | + return my_cmap |
| 119 | + |
| 120 | + def state_grp_trans(self,state_info,state_grp): |
| 121 | + # this method turn the string format of states in DFTtoolbox into the |
| 122 | + # correct state label. To use it, one must orgainze the state info |
| 123 | + # as ' 1 => xxx ( atn / l / ml / ms)\n'. In this format, each state |
| 124 | + # is labeled by 4 numbers within the the bracket only. The first one |
| 125 | + # is atom number the other three are quantum numbers. |
| 126 | + # if some of them are not available, use 'a' instead. |
| 127 | + # xxx is extra information about the state which will not be used for filter. |
| 128 | + state_grp_new=[] |
| 129 | + if type(state_grp[0][0])==str: |
| 130 | + state_info=np.array(\ |
| 131 | + [ state.replace(')',' ').replace('(',' ').replace('/',' ').split()[-4:] for state in state_info]) |
| 132 | + def state_translator(state_str): |
| 133 | + state=state_str.replace(':',' ').replace('/',' ').split() |
| 134 | + |
| 135 | + # pick atoms |
| 136 | + s=set(range(0,len(state_info))) |
| 137 | + s=s & set(np.nonzero(\ |
| 138 | + (state_info[:,0].astype(np.int) >= int(state[0]))\ |
| 139 | + & (state_info[:,0].astype(np.int) <= int(state[1]))\ |
| 140 | + )[0].tolist()) |
| 141 | + |
| 142 | + # pick index (other than atom label) |
| 143 | + for n in range(0,3): |
| 144 | + if (state[n+2] is not 'a') & (state_info[:,n+1] is not 'a'): |
| 145 | + s=s & set(np.nonzero(state_info[:,n+1].astype(np.float)==float(state[n+2]))[0].tolist()) |
| 146 | + |
| 147 | + s=list(s) |
| 148 | + if s==[]: |
| 149 | + print('Error: assigned state {0} not exist!'.format(state_str)) |
| 150 | + sys.exit() |
| 151 | + return s |
| 152 | + |
| 153 | + for grp in state_grp: |
| 154 | + for subgrp in grp: |
| 155 | + state_grp_new.append(state_translator(subgrp)) |
| 156 | + |
| 157 | + elif type(state_grp[0][0])==int: |
| 158 | + state_grp_new=state_grp |
| 159 | + |
| 160 | + return state_grp_new |
| 161 | + |
| 162 | + def band_plot(self,spin,Ek,Ef,kdiv='default',klabel='default',Ebound='default',\ |
| 163 | + lw=2,fontsize=18,savefig_dir='default'): |
| 164 | + |
| 165 | + # give input parameter default values |
| 166 | + if (kdiv is 'default'): |
| 167 | + kdiv=[0,Ek.shape[0]-1] |
| 168 | + if (klabel is 'default'): |
| 169 | + klabel=['k'+str(n) for n, val in enumerate(kdiv)] |
| 170 | + if (Ebound is 'default'): |
| 171 | + Ebound=[np.min(Ek),np.max(Ek)] |
| 172 | + |
| 173 | + tot_k=Ek.shape[0] |
| 174 | + tot_ban=Ek.shape[1] |
| 175 | + Ek=Ek-Ef |
| 176 | + |
| 177 | + # plot normal band structure |
| 178 | + if spin==1: |
| 179 | + plt.plot(range(0,tot_k),Ek,'b',lw=lw) |
| 180 | + elif spin==2: |
| 181 | + plt.plot(range(0,tot_k),Ek[:,0:int(tot_ban/2)],'b'\ |
| 182 | + ,range(0,tot_k),Ek[:,int(tot_ban/2):],'g',lw=lw) |
| 183 | + |
| 184 | + # plot k divider |
| 185 | + if len(kdiv)>=3: |
| 186 | + for val in kdiv[1:-1]: |
| 187 | + plt.plot(val*np.ones(10),np.linspace(np.min(Ek),np.max(Ek),10),'k--') |
| 188 | + |
| 189 | + # plot Fermi level |
| 190 | + plt.plot(np.linspace(0,tot_k-1,10),np.zeros(10),'r--') |
| 191 | + |
| 192 | + # tweak figure |
| 193 | + plt.xticks(kdiv,klabel,fontsize=fontsize) |
| 194 | + plt.yticks(fontsize=fontsize) |
| 195 | + plt.ylabel('Energy (eV)',fontsize=fontsize) |
| 196 | + plt.title('bands',fontsize=fontsize) |
| 197 | + plt.xlim(0,tot_k-1) |
| 198 | + plt.ylim(Ebound[0],Ebound[1]) |
| 199 | + |
| 200 | + # save figure |
| 201 | + if savefig_dir is not 'default': |
| 202 | + plt.savefig(savefig_dir+'band.png',bbox_inches='tight') |
| 203 | + |
| 204 | + # display figure |
| 205 | + plt.show() |
| 206 | + |
| 207 | + def fatband_plot(self,Ek,Ek_weight,Ef,state_info,state_grp,kdiv='default',klabel='default'\ |
| 208 | + ,Ebound='default',ini_fig_num=1,marker_size=30,colorcode='b',fontsize=18,savefig_dir='default'): |
| 209 | + |
| 210 | + # give input parameter default values |
| 211 | + if (kdiv is 'default'): |
| 212 | + kdiv=[0,Ek.shape[0]-1] |
| 213 | + if (klabel is 'default'): |
| 214 | + klabel=['k'+str(n) for n, val in enumerate(kdiv)] |
| 215 | + if (Ebound is 'default'): |
| 216 | + Ebound=[np.min(Ek),np.max(Ek)] |
| 217 | + |
| 218 | + tot_k=Ek.shape[0] |
| 219 | + tot_ban=Ek.shape[1] |
| 220 | + |
| 221 | + #organize data to scatter form |
| 222 | + x=np.tile(np.array(range(0,tot_k)),[tot_ban,1]).transpose().flatten() |
| 223 | + y=Ek.flatten()-Ef |
| 224 | + |
| 225 | + |
| 226 | + # screen out data out of E_bound |
| 227 | + my_cmap=self.bicolormap(colorcode) |
| 228 | + for n, state_list in enumerate(state_grp): |
| 229 | + w=np.zeros(len(Ek_weight[:,:,0].flatten())) |
| 230 | + |
| 231 | + print('\n==== fatband-'+str(n)+' ====') |
| 232 | + # sum over projected states |
| 233 | + for state in state_list: |
| 234 | + print(state_info[state]) |
| 235 | + w=w+Ek_weight[:,:,state].flatten() |
| 236 | + |
| 237 | + |
| 238 | + # plot colorbar |
| 239 | + plt.figure(n+ini_fig_num) |
| 240 | + cbar_obj = plt.contourf([[0,0],[0,0]],\ |
| 241 | + np.linspace(min(w),max(w),32), cmap=my_cmap) |
| 242 | + plt.clf() |
| 243 | + plt.colorbar(cbar_obj) |
| 244 | + |
| 245 | + # lexsort data based on w, so low intensity will be plot first |
| 246 | + w_sort=np.argsort(w) |
| 247 | + plt.scatter(x=x[w_sort], y=y[w_sort],s=marker_size, c=w[w_sort], cmap=my_cmap,edgecolor='face') |
| 248 | + |
| 249 | + # plot k divider |
| 250 | + for val in kdiv[1:-1]: |
| 251 | + plt.plot(val*np.ones(10),np.linspace(np.min(y),\ |
| 252 | + np.max(y),10),'k--') |
| 253 | + |
| 254 | + # plot Fermi level |
| 255 | + plt.plot(np.linspace(0,np.max(x),10),np.zeros(10),'r--') |
| 256 | + |
| 257 | + # tweak figure |
| 258 | + plt.xticks(kdiv,klabel,fontsize=fontsize) |
| 259 | + plt.yticks(fontsize=fontsize) |
| 260 | + plt.ylabel('Energy (eV)',fontsize=fontsize) |
| 261 | + plt.title('fatband-'+str(n),fontsize=fontsize) |
| 262 | + plt.xlim(0,np.max(x)) |
| 263 | + plt.ylim( Ebound[0],Ebound[1]) |
| 264 | + |
| 265 | + # save figure |
| 266 | + if savefig_dir is not 'default': |
| 267 | + plt.savefig(savefig_dir+'fatband-'+str(n)+'.png',bbox_inches='tight') |
| 268 | + |
| 269 | + # display figure |
| 270 | + plt.show() |
| 271 | + |
| 272 | + def spectral_plot(self,E,PDOS,state_info,state_grp='default',xlabel='Energy',ylabel='PDOS'\ |
| 273 | + ,llabel='default',Ebound='default',savefig_path='default',lw=2,fontsize=18): |
| 274 | + if Ebound=='default': |
| 275 | + Ebound=[np.min(E),np.max(E)] |
| 276 | + |
| 277 | + if state_grp is 'default': |
| 278 | + state_grp=[n for n in range(0,PDOS.shape[1])] |
| 279 | + |
| 280 | + PDOS_grp=np.zeros((E.shape[0],len(state_grp))) |
| 281 | + for n, state_list in enumerate(state_grp): |
| 282 | + print('spectral data-{0}'.format(n)) |
| 283 | + for state in state_list: |
| 284 | + print(' {0}'.format(state_info[state])) |
| 285 | + PDOS_grp[:,n]+=PDOS[:,state] |
| 286 | + |
| 287 | + if llabel is 'default': |
| 288 | + plt.plot(E,PDOS_grp[:,n],lw=lw,label='data-'+str(n)) |
| 289 | + else: |
| 290 | + plt.plot(E,PDOS_grp[:,n],lw=lw,label=llabel[n]) |
| 291 | + |
| 292 | + |
| 293 | + # tweak figure |
| 294 | + plt.plot(np.zeros(10),np.linspace(1.01*np.min(PDOS_grp),1.01*np.max(PDOS_grp),10),'r--') |
| 295 | + plt.legend(loc='upper right', shadow=True, fontsize='large',framealpha=0.0) |
| 296 | + plt.xlabel(xlabel,fontsize=fontsize) |
| 297 | + plt.ylabel(ylabel,fontsize=fontsize) |
| 298 | + plt.title('spectral',fontsize=fontsize) |
| 299 | + plt.xlim(Ebound[0],Ebound[1]) |
| 300 | + plt.xticks(fontsize=fontsize) |
| 301 | + plt.yticks(fontsize=fontsize) |
| 302 | + if savefig_path!='default': |
| 303 | + plt.savefig(savefig_path,bbox_inches='tight') |
| 304 | + print('spectral has been saved in: '+savefig_path) |
| 305 | + plt.show() |
| 306 | + |
| 307 | + |
| 308 | + |
| 309 | + |
| 310 | + |
| 311 | + |
0 commit comments