Source code for pysit.vis.vis

import itertools
import time

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import animation

__all__ = ['animate', 'plot']

[docs]def animate(data, mesh, display_rate=30,pause=1, scale=None, show=True, **kwargs): if mesh.dim == 1: fig = plt.figure() plt.clf() line = plot(data[0], mesh, **kwargs) title = plt.title('{0:4}/{1}'.format(0,len(data)-1)) if scale == 'fixed': pltscale = 1.1*min([d.min() for d in data]), 1.1*max([d.max() for d in data]) else: pltscale = scale def _animate(i, line, data, mesh, title, pltscale): line = plot(data[i], mesh, update=line, **kwargs) if pltscale is None: plt.gca().set_ylim(data[i].min(),data[i].max()) else: plt.gca().set_ylim(*pltscale) title.set_text('{0:4}/{1}'.format(i,len(data)-1)) plt.draw() # return line, _animate_args = (line, data, mesh, title, pltscale) if mesh.dim == 2: fig = plt.figure() plt.clf() im = plot(data[0], mesh, **kwargs) cbar = plt.colorbar() title = plt.title('{0:4}/{1}'.format(0,len(data)-1)) if scale == 'fixed': pltclim = 1.1*min([d.min() for d in data]), 1.1*max([d.max() for d in data]) else: pltclim = None def _animate(i, im, data, mesh, title, cbar, pltclim): im = plot(data[i], mesh, update=im, **kwargs) if pltclim is None: clim = (data[i].min(),data[i].max()) im.set_clim(clim) cbar.set_clim(clim) else: im.set_clim(pltclim) cbar.set_clim(pltclim) title.set_text('{0:4}/{1}'.format(i,len(data)-1)) plt.draw() # return im, _animate_args = (im, data, mesh, title, cbar, pltclim) if mesh.dim == 3: fig = plt.figure() plt.clf() ims = plot(data[0], mesh, **kwargs) cax=plt.subplot(2,2,4) cbar = plt.colorbar(cax=cax) title = plt.title('{0:4}/{1}'.format(0,len(data)-1)) if scale == 'fixed': pltclim = 1.1*min([d.min() for d in data]), 1.1*max([d.max() for d in data]) else: pltclim = None def _animate(i, ims, data, mesh, title, cbar, pltclim): ims = plot(data[i], mesh, update=ims, **kwargs) if pltclim is None: clim = (data[i].min(),data[i].max()) for im in ims: im.set_clim(clim) cbar.set_clim(clim) else: for im in ims: im.set_clim(pltclim) cbar.set_clim(pltclim) # plt.sca(cax) title.set_text('{0:4}/{1}'.format(i,len(data)-1)) plt.draw() # return ims, _animate_args = (ims, data, mesh, title, cbar, pltclim) anim = animation.FuncAnimation(fig, _animate, fargs=_animate_args, #init_func=init, interval=pause, frames=range(0,len(data),display_rate), blit=False, repeat=False) if show: plt.show() return anim
[docs]def plot(data, mesh, shade_pml=False, axis=None, ticks=True, update=None, slice3d=(0,0,0), **kwargs): """ Assumes that data has no ghost padding.""" data.shape = -1,1 sh_bc = mesh.shape(include_bc=True) sh_primary = mesh.shape() if data.shape == sh_bc: has_bc = True plot_shape = mesh.shape(include_bc=True, as_grid=True) elif data.shape == sh_primary: has_bc = False plot_shape = mesh.shape(as_grid=True) else: raise ValueError('Shape mismatch between domain and data.') if axis is None: ax = plt.gca() else: ax = axis data = data.reshape(plot_shape) if mesh.dim == 1: if update is None: zs, = mesh.mesh_coords() ret, = ax.plot(zs,data) if has_bc: ax.axvline(mesh.domain.parameters['z']['lbound'], color='r') ax.axvline(mesh.domain.parameters['z']['rbound'], color='r') else: update.set_ydata(data) ret = update if mesh.dim == 2: data = data.T if update is None: im = ax.imshow(data, interpolation='nearest', aspect='auto', **kwargs) if ticks: mesh_tickers(mesh) else: ax.xaxis.set_ticks([]) ax.yaxis.set_ticks([]) if has_bc: draw_pml(mesh, 'x', 'LR', shade_pml=shade_pml) draw_pml(mesh, 'z', 'TB', shade_pml=shade_pml) else: update.set_data(data) im = update # Update current image for the colorbar plt.sci(im) ret = im if mesh.dim == 3: if update is None: # X-Y plot ax = plt.subplot(2,2,1) imslice = int(slice3d[2]) # z slice pltdata = data[:,:,imslice:(imslice+1)].squeeze() imxy = ax.imshow(pltdata, interpolation='nearest', aspect='equal', **kwargs) if ticks: mesh_tickers(mesh, ('y', 'x')) else: ax.xaxis.set_ticks([]) ax.yaxis.set_ticks([]) if has_bc: draw_pml(mesh, 'y', 'LR', shade_pml=shade_pml) draw_pml(mesh, 'x', 'TB', shade_pml=shade_pml) # X-Z plot ax = plt.subplot(2,2,2) imslice = int(slice3d[1]) # y slice pltdata = data[:,imslice:(imslice+1),:].squeeze().T imxz = ax.imshow(pltdata, interpolation='nearest', aspect='equal', **kwargs) if ticks: mesh_tickers(mesh, ('x', 'z')) else: ax.xaxis.set_ticks([]) ax.yaxis.set_ticks([]) if has_bc: draw_pml(mesh, 'x', 'LR', shade_pml=shade_pml) draw_pml(mesh, 'z', 'TB', shade_pml=shade_pml) # Y-Z plot ax = plt.subplot(2,2,3) imslice = int(slice3d[0]) # x slice pltdata = data[imslice:(imslice+1),:,:].squeeze().T imyz = ax.imshow(pltdata, interpolation='nearest', aspect='equal', **kwargs) if ticks: mesh_tickers(mesh, ('y', 'z')) else: ax.xaxis.set_ticks([]) ax.yaxis.set_ticks([]) if has_bc: draw_pml(mesh, 'y', 'LR', shade_pml=shade_pml) draw_pml(mesh, 'z', 'TB', shade_pml=shade_pml) update = [imxy, imxz, imyz] plt.sci(imyz) else: imslice = int(slice3d[2]) # z slice pltdata = data[:,:,imslice:(imslice+1)].squeeze() update[0].set_data(pltdata) imslice = int(slice3d[1]) # y slice pltdata = data[:,imslice:(imslice+1),:].squeeze().T update[1].set_data(pltdata) imslice = int(slice3d[0]) # x slice pltdata = data[imslice:(imslice+1),:,:].squeeze().T update[2].set_data(pltdata) ret = update return ret
def mesh_tickers(mesh, dims=('x','z')): ax = plt.gca() xformatter = mpl.ticker.FuncFormatter(MeshFormatterHelper(mesh, dims[0])) ax.xaxis.set_major_formatter(xformatter) zformatter = mpl.ticker.FuncFormatter(MeshFormatterHelper(mesh, dims[1])) ax.yaxis.set_major_formatter(zformatter) def draw_pml(mesh, axis, orientation='TB', shade_pml=False): ax = plt.gca() dim = mesh.parameters[axis] S = 0 #start left L = 0 #end left if dim.lbc.type is 'pml': L = dim.lbc.n #end left R = dim.n #start right N = dim.n #end right if dim.rbc.type is 'pml': N += dim.rbc.n #end right if orientation == 'LR': if shade_pml: ax.axvspan(S,L, hatch='\\', fill=False, color='r') ax.axvspan(R,N, hatch='\\', fill=False, color='r') else: ax.axvline(L, color='r') ax.axvline(R, color='r') plt.xlim(0,dim['n']-1) if orientation == 'TB': if shade_pml: ax.axhspan(S,L, hatch='/', fill=False, color='r') ax.axhspan(R,N, hatch='/', fill=False, color='r') else: ax.axhline(L, color='r') ax.axhline(R, color='r') plt.ylim(dim['n']-1,0) class MeshFormatterHelper(object): def __init__(self, mesh, axis): self.lbound = mesh.domain.parameters[axis]['lbound'] self.delta = mesh.parameters[axis]['delta'] def __call__(self, grid_point, pos): return '{0:.3}'.format(self.lbound + self.delta*grid_point) if __name__ == '__main__': from pysit import Domain, PML pmlx = PML(0.1, 100) pmlz = PML(0.1, 100) x_config = (0.1, 1.0, 90, pmlx, pmlx) z_config = (0.1, 0.8, 70, pmlz, pmlz) d = Domain( (x_config, z_config) ) # Generate true wave speed # (M = C^-2 - C0^-2) M, C0, C = horizontal_reflector(d) XX, ZZ = d.generate_grid() display_on_grid(XX, domain=d) display_on_grid(XX, domain=d, shade_pml=True)