2

I am trying to animate multiple lines at once in matplotlib. To do this I am following the tutorial from the matplotlib.animation docs:

https://matplotlib.org/stable/api/animation_api.html

The idea in this tutorial is to create a line ln, = plt.plot([], []) and update the data of the line using ln.set_data in order to produce the animation. Whilst this all works fine when the line data is a 1 dimensional array (shape = (n,)) of n data points, I am having trouble when the line data is a 2 dimensional array (shape = (n,k)) of k lines to plot.

To be more precise, plt.plot accepts arrays as inputs, with each column corresponding to a new line to plot. Here is a simple example with 3 lines plotted with a single plt.plot call:

import matplotlib.pyplot as plt
import numpy as np


x = np.linspace(0, 2 * np.pi, 100).reshape(-1, 1)
x = np.concatenate([x] * 3, axis=1)

# generate 3 curves
y = np.copy(x)
y[:, 0] = np.cos(y[:, 0])
y[:, 1] = np.sin(y[:, 1] )
y[:, 2] = np.sin(y[:, 2] ) + np.cos(y[:, 2])

fig, ax = plt.subplots()
plt.plot(x,y)
plt.show()

plt.plot using arrays

However if I try to set the data using .set_data as required for generating animations I have a problem:

import matplotlib.pyplot as plt
import numpy as np


x = np.linspace(0, 2 * np.pi, 100).reshape(-1, 1)
x = np.concatenate([x] * 3, axis=1)

# generate 3 curves
y = np.copy(x)
y[:, 0] = np.cos(y[:, 0])
y[:, 1] = np.sin(y[:, 1] )
y[:, 2] = np.sin(y[:, 2] ) + np.cos(y[:, 2])

fig, ax = plt.subplots()
p, = plt.plot([], [], color='b')
p.set_data(x, y)
plt.show()

problem

Is there a way to set_data for 2 dimensional arrays? Whilst I am aware that I could just create three plots p1, p2, p3 and call set_data on each of them in a loop, my real data consists of 1000-10,000 lines to plot, and this makes the animation too slow.

Many thanks for any help.

lmms
  • 45
  • 1
  • 1
  • 6

2 Answers2

6

An approach could be to create a list of Line2D objects and use set_data in a loop. Note that ax.plot() always returns a list of lines, even when only one line is plotted.

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

x = np.linspace(0, 2 * np.pi, 100)

# generate 10 curves
y = np.sin(x.reshape(-1, 1) + np.random.uniform(0, 2 * np.pi, (1, 10)))

fig, ax = plt.subplots()
ax.set(xlim=(0, 2 * np.pi), ylim=(-1.5, 1.5))
# lines = [ax.plot([], [], lw=2)[0] for _ in range(y.shape[1])]
lines = ax.plot(np.empty((0, y.shape[1])), np.empty((0, y.shape[1])), lw=2)

def animate(i):
    for line_k, y_k in zip(lines, y.T):
        line_k.set_data(x[:i], y_k[:i])
    return lines

anim = FuncAnimation(fig, animate, frames=x.size, interval=200, repeat=False)
plt.show()
JohanC
  • 71,591
  • 8
  • 33
  • 66
5

The array given by set_data() will be two one-dimensional arrays, so in this case three set_data() will be needed.

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

x = np.linspace(0, 2 * np.pi, 100).reshape(-1, 1)
x = np.concatenate([x] * 3, axis=1)

# generate 3 curves
y = np.copy(x)
y[:, 0] = np.cos(y[:, 0])
y[:, 1] = np.sin(y[:, 1] )
y[:, 2] = np.sin(y[:, 2] ) + np.cos(y[:, 2])

fig, ax = plt.subplots()
ax = plt.axes(xlim=(0,6), ylim=(-1.5, 1.5))
line1, = ax.plot([], [], lw=2)
line2, = ax.plot([], [], lw=2)
line3, = ax.plot([], [], lw=2)


def animate(i):
    line1.set_data(x[:i, 0], y[:i, 0])
    line2.set_data(x[:i, 1], y[:i, 1])
    line3.set_data(x[:i, 2], y[:i, 2])
    return line1,line2,line3

anim = FuncAnimation(fig, animate, frames=100, interval=200, repeat=False)
plt.show()

enter image description here

r-beginners
  • 31,170
  • 3
  • 14
  • 32
  • Animation is easy once you understand how it works. If my answer has helped you with your question, please accept this as the correct answer. – r-beginners Jun 07 '21 at 13:43
  • Hi, thanks for the response. As I said in the question: ""Whilst I am aware that I could just create three plots p1, p2, p3 and call set_data on each of them in a loop, my real data consists of 1000-10,000 lines to plot..."" I tried to make a list of all plots and then update their data in a loop during the `update` function. However, this does not work as the outputs of `update` need to be plots and not a list of plots. – lmms Jun 07 '21 at 14:40
  • 1
    Wouldn't it be sufficient to put all lines in a list? And then return that list? Something like `lines = [ax.plot([], [], lw=2)[0] for _ in range(y.shape[1])]`, and use `set_data` in a loop. – JohanC Jun 07 '21 at 14:57
  • 1
    @JohanC lines doesn't know how to write set_data() for a list of Line2D objects. After all, `lines[0].set_data()` would require 3 lines. Please advise. – r-beginners Jun 08 '21 at 04:10
  • 1
    I added an answer with my comment translated into code. It's just your approach extended to a loop. (Note that `ax = plt.axes(...)` creates a second subplot on top of the subplot created by `plt.subplots()`. You might want to remove one of them) – JohanC Jun 08 '21 at 06:58
  • @JohanC Your response helped me to understand your comment. And the point about `ax=plt.axes()` was very meaningful to me. Thanks. – r-beginners Jun 08 '21 at 07:28
  • Great, that makes a lot of sense. Thank you both for your time and help, it is very kind. Have a nice day ! – lmms Jun 08 '21 at 08:37