1

I'm using the following code to produce an animation with matplotlib that is intended to visualize my experiments.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation, PillowWriter

plt.rcParams['animation.html'] = 'jshtml'

def make_grid(X, description=None, labels=None, title_fmt="label: {}", cmap='gray', ncols=3, colors=None):
    L = len(X)
    nrows = -(-L // ncols)
    frame_plot = []
    for i in range(L):
        plt.subplot(nrows, ncols, i + 1)
        im = plt.imshow(X[i].squeeze(), cmap=cmap, interpolation='none')
        if labels is not None:
            color = 'k' if colors is None else colors[i]
            plt.title(title_fmt.format(labels[i]), color=color)
        plt.xticks([])
        plt.yticks([])
        frame_plot.append(im)
    return frame_plot


def animate_step(X):
    return X ** 2

n_splots = 6
X = np.random.random((n_splots,32,32,3))

Y = X
X_t = []

for i in range(10):
    Y = animate_step(Y)
    X_t.append((Y, i))

frames = []
for X, step in X_t:
    frame = make_grid(X,
                    description="step={}".format(step),
                    labels=range(n_splots),
                    title_fmt="target: {}")
    frames.append(frame)

anim = ArtistAnimation(plt.gcf(), frames,
                        interval=300, repeat_delay=8000, blit=True)
plt.close()                               
anim.save("test.gif", writer=PillowWriter())
anim

The result can be seen here: https://i.stack.imgur.com/OaOsf.gif

It works fine so far, but I'm having trouble getting a shared xlabel to add a description for all of the 6 subplots in the animation. It is supposed to show what step the image is on, i.e. "step=5". Since it is an animation, I cannot use xlabel or set_title (since it would be constant over the whole animation) and have to draw the text myself. I've tried something along the lines of..

def make_grid(X, description=None, labels=None, title_fmt="label: {}", cmap='gray', ncols=3, colors=None):
    L = len(X)
    nrows = -(-L // ncols)
    frame_plot = []
    desc = plt.text(0.5, .04, description,
                    size=plt.rcparams["axes.titlesize"],
                    ha="center",
                    transform=plt.gca().transAxes
                    )
    frame_plot.append(desc)
...

This, of course, won't work, because the axes are not yet created. I tried using the axis of another subplot(nrows, 1, nrows), but then the existing images are drawn over..

Does anyone have a solution to this?

Edit:

unclean, hacky solution for now: Wait for the axes of the middle image of the last row to be created and use that for plotting the text. In the for loop:

...
        if i == int((nrows - 0.5) * ncols):
            title = ax.text(0.25, -.3, description,
                            size=plt.rcParams["axes.titlesize"],
                            # ha="center",
                            transform=ax.transAxes
                            )
            frame_plot.append(title)
...

1 Answers1

0

To me, your case is easier to solve with FuncAnimation instead of ArtistAnimation, even if you already have access to the full list of data you want to show animated (see this thread for a discussion about the difference between the two functions).

Inspired from this FuncAnimation example, I wrote the code below that does what you needed (using the same code with ArtistAnimation and correct list of arguments does not work).

The main idea is to initialize all elements to be animated at the beginning, and to update them over the animation frames. This can be done for the text object (step_txt = fig.text(...)) in charge of displaying the current step, and for the images out from ax.imshow. You can then update whatever object you would like to see animated with this recipe.

Note that the technique works if you want the text to be an x_label or any text you choose to show. See the commented line in the code.

#!/Users/seydoux/anaconda3/envs/jupyter/bin/python
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.animation import FuncAnimation, PillowWriter

# parameters
n_frames = 10
n_splots = 6
n_cols = 3
n_rows = n_splots // n_cols


def update_data(x):
    return x ** 2


# create all snapshots
snapshots = [np.random.rand(n_splots, 32, 32, 3)]
for _ in range(n_frames):
    snapshots.append(update_data(snapshots[-1]))

# initialize figure and static elements
fig, axes = plt.subplots(2, 3)
axes = axes.ravel()  # so we can access all axes with a single index
for i, ax in enumerate(axes):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title("target: {}".format(i))

# initialize elements to be animated
step_txt = fig.text(0.5, 0.95, "step: 0", ha="center", weight="bold")
# step_txt = axes[4].set_xlabel("step: 0")  # also works with x_label
imgs = list()
for a, s in zip(axes, snapshots[0]):
    imgs.append(a.imshow(s, interpolation="none", cmap="gray"))


# animation function
def animate(i):

    # update images
    for img, s in zip(imgs, snapshots[i]):
        img.set_data(s)

    # update text
    step_txt.set_text("step: {}".format(i))

    # etc


anim = FuncAnimation(fig, animate, frames=n_frames, interval=300)
anim.save("test.gif", writer=PillowWriter())

Here is the output I got from the above code:

animated with step display

Leonard
  • 2,510
  • 18
  • 37
  • This was one of those problems where I thought that it would be cool to have this small feature; it should only take me 15 minutes at most.. And 2 days later I'm about to smash my computer. Thanks a lot. – Kurt Willis Sep 30 '20 at 15:58