6
\$\begingroup\$

I'd like to plot an animation of Lissajous curves using Python and matplotlib's animate library. I really do not have a lot of experience with Python, so rather than performance increasements, I'm looking for best practices to improve (and/or shorten) my code.

The following code produces a .gif file when evaluated as a jupyter-lab cell:

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter

x_data = []
y_data = []
max_range = 1.2

f1 = 3        # sets the frequency for the horizontal motion
f2 = 5        # sets the frequency for the vertical motion
d1 = 0.0      # sets the phase shift for the horizontal motion
d2 = 0.5      # sets the phase shift for the vertical motion
delta1 = d1 * np.pi # I define the phase shift like this in order to use 
delta2 = d2 * np.pi # ...d1 and d2 in the export file name

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(4,4), gridspec_kw={'width_ratios': [6, 1], 'height_ratios': [1, 6]})

for i in [ax1, ax2, ax3, ax4]: 
    i.set_yticklabels([])         # in order to remove the ticks and tick labels
    i.set_xticklabels([])
    i.set_xticks([])
    i.set_yticks([])
    i.set_xlim(-max_range, max_range)
    i.set_ylim(-max_range, max_range)
ax2.set_visible(False)

line, = ax3.plot(0, 0)        # line plot in the lower left
line2 = ax3.scatter(0, 0)     # moving dot in the lower left
linex = ax1.scatter(0, 0)     # moving dot on top
liney = ax4.scatter(0, 0)     # moving dot on the right

def animation_frame(i):
    ax1.clear()              # I tried to put this in a loop like this: 
    ax1.set_yticklabels([])  # for i in [ax1,ax3,ax4]: 
    ax1.set_xticklabels([])  #     i.clear() 
    ax1.set_xticks([])       #     (... etc.)
    ax1.set_yticks([])       # but this didn't work
    ax3.clear()
    ax3.set_yticklabels([])
    ax3.set_xticklabels([])
    ax3.set_xticks([])
    ax3.set_yticks([])
    ax4.clear()
    ax4.set_yticklabels([])
    ax4.set_xticklabels([])
    ax4.set_xticks([])
    ax4.set_yticks([])
    ax1.set_xlim(-max_range, max_range)     # after ax.clear() I apparently have to re-set these
    ax3.set_xlim(-max_range, max_range)
    ax3.set_ylim(-max_range, max_range)
    ax4.set_ylim(-max_range, max_range)
    x_data.append(np.sin(i * f1 + delta1)) # for the line plot
    y_data.append(np.sin(i * f2 + delta2))
    x_inst = np.sin(i * f1 + delta1)       # for the scatter plot
    y_inst = np.sin(i * f2 + delta2)
    line, = ax3.plot(x_data, y_data)
    line2 = ax3.scatter(x_inst, y_inst)
    linex = ax1.scatter(x_inst, 0)
    liney = ax4.scatter(0, y_inst)
    fig.canvas.draw()
    transFigure = fig.transFigure.inverted()       # in order to draw over 2 subplots
    coord1 = transFigure.transform(ax1.transData.transform([x_inst, 0]))
    coord2 = transFigure.transform(ax3.transData.transform([x_inst, y_inst]))
    my_line1 = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]), transform=fig.transFigure, linewidth=1, c='gray', alpha=0.5)
    coord1 = transFigure.transform(ax3.transData.transform([x_inst, y_inst]))
    coord2 = transFigure.transform(ax4.transData.transform([0, y_inst]))
    my_line2 = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]), transform=fig.transFigure, linewidth=1, c='gray', alpha=0.5)
    fig.lines = my_line1, my_line2,         # moving vertical and horizontal lines
    return line, line2, linex, liney

animation = FuncAnimation(fig, func=animation_frame, frames=np.linspace(0, 4*np.pi, num=800, endpoint=True), interval=1000)
animation.save('lissajous_{0}_{1}_{2:.2g}_{3:.2g}.gif'.format(f1,f2,d1,d2), writer='pillow', fps=50, dpi=200)
# This takes quite long, but since I'd like to have a smooth, slow animation, 
# ...I'm willing to accept a longer execution time. 

\$\endgroup\$

1 Answer 1

3
\$\begingroup\$

For demonstration purposes I disregard the .gif generation, and especially disregard Jupyter. Notebooks tend to produce a scope swamp. System parameters? Throw them in the swamp. Loops, matplotlib artists, initialization code? In the swamp.

plt.subplots() is a convenience wrapper, and not a helpful one here. It produces a useless ax2 that you then need to hide. Instead, just use a GridSpec so that you can pick the locations you want while still controlling ratios.

animation_frame recreates your artists from scratch every time! All of those calls to plot() and scatter() need to be removed, and instead you need to perform data updates.

You're missing my_line1, my_line2 from your tuple of returned updated artists.

Your axis and artist names are very difficult to understand; those need better names.

A one-second (1000 ms) update rate is quite slow; for the interactive plot let's speed it up a lot.

All together,

import typing

import numpy as np
import matplotlib
import matplotlib.animation
import matplotlib.pyplot as plt


class Lissajous(typing.NamedTuple):
    fig: plt.Figure
    ax_top: plt.Axes
    ax_mid: plt.Axes
    ax_right: plt.Axes
    mid_curve: matplotlib.lines.Line2D
    mid_scatter: matplotlib.collections.PathCollection
    x_scatter: matplotlib.collections.PathCollection
    y_scatter: matplotlib.collections.PathCollection
    vert_line: matplotlib.lines.Line2D
    horz_line: matplotlib.lines.Line2D

    max_range: float
    f1: float  # frequency for the horizontal motion
    f2: float  # frequency for the vertical motion
    d1: float  # phase shift for the horizontal motion
    d2: float  # phase shift for the vertical motion
    delta1: float  # horizontal phase (radians)
    delta2: float  # vertical phase (radians)
    x_data: list[float] = []
    y_data: list[float] = []

    @classmethod
    def new(
        cls,
        max_range: float = 1.2,
        f1: float = 3.,
        f2: float = 5.,
        d1: float = 0.,
        d2: float = 0.5,
    ) -> typing.Self:
        fig = plt.figure()
        grid = matplotlib.gridspec.GridSpec(figure=fig, nrows=2, ncols=2, width_ratios=(6, 1), height_ratios=(1, 6))
        ax_top = fig.add_subplot(grid[0,0])
        ax_mid = fig.add_subplot(grid[1,0])
        ax_right = fig.add_subplot(grid[1,1])

        for ax in (ax_top, ax_mid, ax_right):
            ax.set_yticklabels(())
            ax.set_xticklabels(())
            ax.set_xticks(())
            ax.set_yticks(())
            ax.set_xlim(-max_range, max_range)
            ax.set_ylim(-max_range, max_range)

        vert_line = matplotlib.lines.Line2D((),(), transform=fig.transFigure, linewidth=1, c='gray', alpha=0.5)
        horz_line = matplotlib.lines.Line2D((),(), transform=fig.transFigure, linewidth=1, c='gray', alpha=0.5)
        fig.add_artist(vert_line)
        fig.add_artist(horz_line)

        return cls(
            fig=fig, ax_top=ax_top, ax_mid=ax_mid, ax_right=ax_right,
            mid_curve=ax_mid.plot((), ())[0],
            mid_scatter=ax_mid.scatter((), ()),
            x_scatter=ax_top.scatter((), ()),
            y_scatter=ax_right.scatter((), ()),
            vert_line=vert_line, horz_line=horz_line, max_range=max_range,
            f1=f1, f2=f2, d1=d1, d2=d2, delta1=d1*np.pi, delta2=d2*np.pi,
        )

    def make_animation(self, n_frames: int = 800, interval: int = 1_000) -> matplotlib.animation.FuncAnimation:
        return matplotlib.animation.FuncAnimation(
            fig=self.fig, func=self.update, interval=interval,
            frames=np.linspace(start=0, stop=4*np.pi, num=n_frames, endpoint=True),
        )

    def update(self, t: float) -> tuple[matplotlib.artist.Artist, ...]:
        x_inst = np.sin(t*self.f1 + self.delta1)
        y_inst = np.sin(t*self.f2 + self.delta2)
        self.x_data.append(x_inst)
        self.y_data.append(y_inst)

        self.mid_curve.set_xdata(self.x_data)
        self.mid_curve.set_ydata(self.y_data)
        self.mid_scatter.set_offsets((x_inst, y_inst))
        self.x_scatter.set_offsets((x_inst, 0))
        self.y_scatter.set_offsets((0, y_inst))

        # This cannot be cached; it becomes invalid on resize
        trans_figure = self.fig.transFigure.inverted()

        ax, ay = trans_figure.transform(self.ax_top.transData.transform((x_inst, 0)))
        bx, by = trans_figure.transform(self.ax_mid.transData.transform((x_inst, y_inst)))
        self.vert_line.set_xdata((ax, bx))
        self.vert_line.set_ydata((ay, by))

        ax, ay = bx, by
        bx, by = trans_figure.transform(self.ax_right.transData.transform((0, y_inst)))
        self.horz_line.set_xdata((ax, bx))
        self.horz_line.set_ydata((ay, by))

        return self.mid_curve, self.mid_scatter, self.x_scatter, self.y_scatter, self.vert_line, self.horz_line


def main() -> None:
    lissajous = Lissajous.new()
    anim = lissajous.make_animation(interval=30)
    plt.show()


if __name__ == '__main__':
    main()

lissajous plot

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.