Implementation

To extend our class QuantumDynamics1D to multiple states, we have to, as described previously, modify the operators for kinetic and potential energy. The application of both operators on the wave function, which is now a 2D array, also has to be modified. The exponential potential energy matrix is computed by diagonalization.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.animation import FuncAnimation

class QuantumDynamics1D:
    def __init__(self,m):
        self.mass = m

    def set_potential(self, v):
        self.vgrid = v
    
    def setup_grid(self, xmin, xmax, n):
        self.xmin, self.xmax = xmin, xmax
        self.n = n
        self.xgrid, self.dx = np.linspace(self.xmin, self.xmax, self.n, retstep=True)
        self.eo = np.zeros(n)
        self.eo[::2] = 1
        self.eo[1::2] = -1
        self.pk = (np.arange(n) - n/2)*(2*np.pi/(n*self.dx))
        self.tk = self.pk**2 / (2.0*self.mass)

    def get_v_psi(self, psi):
        return np.einsum('ijk,ik->ij', self.vgrid, psi)
    
    def get_t_psi(self, psi):
        psi_eo = np.einsum('ij,i->ij', psi, self.eo)
        ft = np.fft.fft(psi_eo, axis=0)
        ft = np.einsum('ij,i->ij', ft, self.tk)
        return np.einsum('ij,i->ij', np.fft.ifft(ft, axis=0), self.eo)

    def get_h_psi(self, psi):
        tpsi = self.get_t_psi(psi)
        vpsi = self.get_v_psi(psi)
        return tpsi + vpsi

    def propagate(self, dt, nstep, psi, method="Forward Euler" ):
        psi = np.asarray(psi, dtype=complex)
        self.wavepackets = np.zeros((nstep, psi.shape[0], psi.shape[1]), dtype=complex)

        if method == "Forward Euler":
            for i in range(nstep):
                self.wavepackets[i, :, :] = psi
                psi = psi - 1.0j*dt*self.get_h_psi(psi)

        elif method == "SOD":
            psi_half = psi + 0.5j*(dt)*self.get_h_psi(psi)
            psi_old = psi + 1.0j*dt*self.get_h_psi(psi_half)
            for i in range(nstep):
                self.wavepackets[i, :, :] = psi
                psi_new = psi_old - 2.0j*(dt)*self.get_h_psi(psi)
                psi_old = psi
                psi = psi_new
        
        elif method == "Split Operator":
            eigv, u = np.linalg.eigh(self.vgrid)
            expv = np.exp(-0.5j*dt*eigv)
            uh = np.conjugate(np.transpose(u, axes=(0, 2, 1)))
            expt = np.exp(-1.0j*dt*self.tk)
            
            for i in range(nstep):
                self.wavepackets[i, :, :] = psi
                
                psi = np.einsum('ijk,ik->ij', uh, psi)
                psi = np.einsum('ij,ij->ij', expv, psi)
                psi = np.einsum('ijk,ik->ij', u, psi)
                
                psi = np.einsum('ij,i->ij', psi, self.eo)
                psift = np.fft.fft(psi, axis=0)
                psift = np.einsum('ij,i->ij', psift, expt) 
                psi = np.einsum('ij,i->ij', np.fft.ifft(psift, axis=0), self.eo)
                
                psi = np.einsum('ijk,ik->ij', uh, psi)
                psi = np.einsum('ij,ij->ij', expv, psi)
                psi = np.einsum('ijk,ik->ij', u, psi)
                
        else:
            print("Method {} not implemented!".format(method))

    def animate(self, nstep=100, scalewf=1.0, delay=50, plot_potential=False,
                colormap='jet', xlimits=(-10,10), ylimits=(-5, 50)):
        cmap = get_cmap(colormap)
        nstate = self.vgrid.shape[1]
        colors = cmap(np.linspace(0.0, 1.0, nstate))

        fig, ax = plt.subplots()
        lines1 = []
        for state in np.arange(0, nstate):
            minstate = np.min(self.vgrid[:, state, state])
            y = minstate + scalewf*np.abs(self.wavepackets[0, :, state])**2
            line1, = ax.plot(self.xgrid, y, color=colors[state])
            line2 = ax.fill_between(self.xgrid, y, linewidth=3.0, 
                                    color=colors[state], alpha=0.2)
            lines1.append(line1)
        
        ax.set_xlabel('$x$')
        ax.set_xlim(xlimits)
        ax.set_ylim(ylimits)
        
        if plot_potential:
            for state in np.arange(0, nstate):
                ax.plot(self.xgrid, self.vgrid[:, state, state], lw=3.0, 
                        color=colors[state])

        def anim_func(frame_num):
            ax.collections.clear()
            for state in np.arange(0, nstate):
                minstate = np.min(self.vgrid[:, state, state])
                y = minstate + scalewf*np.abs(self.wavepackets[frame_num, :, state])**2
                lines1[state].set_ydata(y)
                ax.fill_between(self.xgrid, y,linewidth=3.0, color=colors[state], 
                                alpha=0.2)
        
        anim = FuncAnimation(fig, anim_func, frames=nstep, interval=delay, 
                             cache_frame_data=False)
        return anim