Implementation

In the following we implement the SOD scheme into our class.

class QuantumDynamics1D:
    def __init__(self,m):
        self.mass = m

    def set_potential(self, v):
        self.vgrid = v

    def setup_grid(self, xmin, xmax, n):
        self.xmin, self.xmax = xmin, xmax
        self.n = n
        self.xgrid, self.dx = np.linspace(self.xmin, self.xmax, self.n, retstep=True)
        self.eo = np.zeros(n)
        self.eo[::2] = 1
        self.eo[1::2] = -1
        self.pk = (np.arange(n) - n/2)*(2*np.pi/(n*self.dx))
        self.tk = self.pk**2 / (2.0*self.mass)

    def get_v_psi(self, psi):
        return self.vgrid * psi

    def get_t_psi(self, psi):
        psi_eo = psi * self.eo
        ft = np.fft.fft(psi_eo)
        ft = ft * self.tk
        return self.eo * np.fft.ifft(ft)

    def get_h_psi(self, psi):
        psi = np.asarray(psi, dtype=complex)
        tpsi = self.get_t_psi(psi)
        vpsi = self.get_v_psi(psi)
        return tpsi + vpsi

    def propagate(self, dt, nstep, psi, method="Forward Euler" ):
        self.wavepackets = np.zeros((nstep, psi.shape[0]), dtype=complex)
        self.wavepackets[0,:] = psi

        if method == "Forward Euler":
            for i in range(nstep):
                self.wavepackets[i, :] = psi
                psi = psi - 1.0j*dt*self.get_h_psi(psi)

        elif method == "SOD":
            psi_half = psi + 0.5j*(dt)*self.get_h_psi(psi)
            psi_old = psi + 1.0j*dt*self.get_h_psi(psi_half)
            for i in range(nstep):
                self.wavepackets[i, :] = psi
                psi_new = psi_old - 2.0j*(dt)*self.get_h_psi(psi)
                psi_old = psi
                psi = psi_new

    def animate(self, nstep=100, scalewf=1.0, delay=50):
        fig, ax = plt.subplots()
        line1, = ax.plot(self.xgrid, scalewf*abs(self.wavepackets[0, :])**2)
        line2 = ax.fill_between(
            self.xgrid, scalewf*abs(self.wavepackets[0, :])**2, 
            linewidth=3.0, color="blue", alpha=0.2,
        )
        ax.set_xlabel('$x$')

        def anim_func(frame_num):
            y = self.wavepackets[frame_num, :]
            line1.set_ydata(scalewf*abs(y)**2)
            ax.collections.clear()
            ax.fill_between(
                self.xgrid, scalewf*abs(y)**2, 
                linewidth=3.0, color="blue", alpha=0.2
            )
        anim = FuncAnimation(fig, anim_func, frames=nstep, interval=delay, 
                             cache_frame_data=False)
        return anim