Implementation

In order to incorporate light into our code, we add the functions set_dipole_coupling and set_efield to set \(\mu\) and \(E(t)\), respectively. We should also add get_v_ext to apply the external potential.

Since the potential energy changes with time, the potential energy propatator \(\exp(\mathrm{i} \hat{V} t)\) also becomes time-dependent. Therefore, we have to compute it anew at every time step.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.animation import FuncAnimation

class QuantumDynamics1D:
    def __init__(self,m):
        self.mass = m

    def set_potential(self, v):
        self.vgrid = v
    
    def set_dipole_coupling(self, mu):
        self.mu = mu
    
    def set_efield(self, e):
        self.efield = e    
    
    def setup_grid(self, xmin, xmax, n):
        self.xmin, self.xmax = xmin, xmax
        self.n = n
        self.xgrid, self.dx = np.linspace(self.xmin, self.xmax, self.n, retstep=True)
        self.eo = np.zeros(n)
        self.eo[::2] = 1
        self.eo[1::2] = -1
        self.pk = (np.arange(n) - n/2)*(2*np.pi/(n*self.dx))
        self.tk = self.pk**2 / (2.0*self.mass)

    def get_v_psi(self, psi):
        return np.einsum('ijk,ik->ij', self.vgrid, psi)
    
    def get_v_ext(self, t):
        v_ext_grid = self.vgrid + self.mu*self.efield(t)
        return v_ext_grid

    def get_t_psi(self, psi):
        psi_eo = np.einsum('ij,i->ij', psi, self.eo)
        ft = np.fft.fft(psi_eo, axis=0)
        ft = np.einsum('ij,i->ij', ft, self.tk)
        return np.einsum('ij,i->ij', np.fft.ifft(ft, axis=0), self.eo)

    def get_h_psi(self, psi):
        tpsi = self.get_t_psi(psi)
        vpsi = self.get_v_psi(psi)
        return tpsi + vpsi

    def propagate(self, dt, nstep, psi, method="Forward Euler" ):
        psi = np.asarray(psi, dtype=complex)
        self.wavepackets = np.zeros((nstep, psi.shape[0], psi.shape[1]), dtype=complex)

        if method == "Forward Euler":
            for i in range(nstep):
                self.wavepackets[i, :, :] = psi
                psi = psi - 1.0j*dt*self.get_h_psi(psi)

        elif method == "SOD":
            psi_half = psi + 0.5j*(dt)*self.get_h_psi(psi)
            psi_old = psi + 1.0j*dt*self.get_h_psi(psi_half)
            for i in range(nstep):
                self.wavepackets[i, :, :] = psi
                psi_new = psi_old - 2.0j*(dt)*self.get_h_psi(psi)
                psi_old = psi
                psi = psi_new
        
        elif method == "Split Operator":
            eigv, u = np.linalg.eigh(self.vgrid)
            expv = np.exp(-0.5j*dt*eigv)
            uh = np.conjugate(np.transpose(u, axes=(0, 2, 1)))
            expt = np.exp(-1.0j*dt*self.tk)
            
            for i in range(nstep):
                self.wavepackets[i, :, :] = psi
                
                psi = np.einsum('ijk,ik->ij', uh, psi)
                psi = np.einsum('ij,ij->ij', expv, psi)
                psi = np.einsum('ijk,ik->ij', u, psi)
                
                psi = np.einsum('ij,i->ij', psi, self.eo)
                psift = np.fft.fft(psi, axis=0)
                psift = np.einsum('ij,i->ij', psift, expt) 
                psi = np.einsum('ij,i->ij', np.fft.ifft(psift, axis=0), self.eo)
                
                psi = np.einsum('ijk,ik->ij', uh, psi)
                psi = np.einsum('ij,ij->ij', expv, psi)
                psi = np.einsum('ijk,ik->ij', u, psi)
        
        elif method == "TD Split Operator":
            expt = np.exp(-1.0j*dt*self.tk)
            
            for i in range(nstep):
                self.wavepackets[i, :, :] = psi

                time = i * dt
                eigv, u = np.linalg.eigh(self.get_v_ext(time))
                expv = np.exp(-0.5j*dt*eigv)
                uh = np.conjugate(np.transpose(u, axes=(0, 2, 1)))
                
                psi = np.einsum('ijk,ik->ij', uh, psi)
                psi = np.einsum('ij,ij->ij', expv, psi)
                psi = np.einsum('ijk,ik->ij', u, psi)
                
                psi = np.einsum('ij,i->ij', psi, self.eo)
                psift = np.fft.fft(psi, axis=0)
                psift = np.einsum('ij,i->ij', psift, expt) 
                psi = np.einsum('ij,i->ij', np.fft.ifft(psift, axis=0), self.eo)

                eigv, u = np.linalg.eigh(self.get_v_ext(time + dt))
                # eigv, u = np.linalg.eigh(self.get_v_ext(time))
                expv = np.exp(-0.5j*dt*eigv)
                uh = np.conjugate(np.transpose(u, axes=(0, 2, 1)))

                psi = np.einsum('ijk,ik->ij', uh, psi)
                psi = np.einsum('ij,ij->ij', expv, psi)
                psi = np.einsum('ijk,ik->ij', u, psi)

        else:
            print("Method {} not implemented!".format(method))

    def animate(self, nstep=100, scalewf=1.0, delay=50, plot_potential=False,
                colormap='jet', xlimits=(-10,10), ylimits=(-5, 50)):
        cmap = get_cmap(colormap)
        nstate = self.vgrid.shape[1]
        colors = cmap(np.linspace(0.0, 1.0, nstate))

        fig, ax = plt.subplots()
        lines1 = []
        for state in np.arange(0, nstate):
            minstate = np.min(self.vgrid[:, state, state])
            y = minstate + scalewf*np.abs(self.wavepackets[0, :, state])**2
            line1, = ax.plot(self.xgrid, y, color=colors[state])
            line2 = ax.fill_between(self.xgrid, y, linewidth=3.0, 
                                    color=colors[state], alpha=0.2)
            lines1.append(line1)
        
        ax.set_xlabel('$x$')
        ax.set_xlim(xlimits)
        ax.set_ylim(ylimits)
        
        if plot_potential:
            for state in np.arange(0, nstate):
                ax.plot(self.xgrid, self.vgrid[:, state, state], lw=3.0, 
                        color=colors[state])

        def anim_func(frame_num):
            ax.collections.clear()
            for state in np.arange(0, nstate):
                minstate = np.min(self.vgrid[:, state, state])
                y = minstate + scalewf*np.abs(self.wavepackets[frame_num, :, state])**2
                lines1[state].set_ydata(y)
                ax.fill_between(self.xgrid, y,linewidth=3.0, color=colors[state], 
                                alpha=0.2)
        
        anim = FuncAnimation(fig, anim_func, frames=nstep, interval=delay, 
                             cache_frame_data=False)
        return anim