2

I am currently working on a project where I am using the mediapipe's body pose estimation library, I know that using plotly we can create animated 3D scatter plots. I was wondering if it's possible to create a animated plot of the body pose landmarks and connections provided by mediapipe using plotly ?

It would be same like the mediapipe's in-built functionality :

mp_draw.plot_landmarks()

But with plotly's interactivity and animated timeline.

Atharva Katre
  • 457
  • 1
  • 6
  • 17

1 Answers1

5
  • take a look at code for mp_draw.plot_landmarks() it is using matplotlib. Hence a plotly version can be coded
  • have done this, this version returns a plotly figure. code I'm sure can be further cleaned up
  • animations could be built by building up frames across multiple results

setup

import cv2
import math
import numpy as np
import plotly.express as px
from pathlib import Path
import mediapipe as mp

uploaded = {
    Path.home()
    .joinpath("Downloads")
    .joinpath("thao-le-hoang-v4zceVZ5HK8-unsplash.jpg"): ""
}
uploaded = {str(k): i for k, i in uploaded.items()}

# Read images with OpenCV.
images = {name: cv2.imread(name) for name in uploaded.keys()}

mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles

# Run MediaPipe Pose and plot 3d pose world landmarks.
with mp_pose.Pose(
    static_image_mode=True, min_detection_confidence=0.5, model_complexity=2
) as pose:
    for name, image in images.items():
        results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

mp_drawing.plot_landmarks(
        results.pose_world_landmarks,  mp_pose.POSE_CONNECTIONS)

enter image description here

plotly plot_landmarks()

import pandas as pd
import plotly.graph_objects as go

_PRESENCE_THRESHOLD = 0.5
_VISIBILITY_THRESHOLD = 0.5


def plot_landmarks(
    landmark_list,
    connections=None,
):
    if not landmark_list:
        return
    plotted_landmarks = {}
    for idx, landmark in enumerate(landmark_list.landmark):
        if (
            landmark.HasField("visibility")
            and landmark.visibility < _VISIBILITY_THRESHOLD
        ) or (
            landmark.HasField("presence") and landmark.presence < _PRESENCE_THRESHOLD
        ):
            continue
        plotted_landmarks[idx] = (-landmark.z, landmark.x, -landmark.y)
    if connections:
        out_cn = []
        num_landmarks = len(landmark_list.landmark)
        # Draws the connections if the start and end landmarks are both visible.
        for connection in connections:
            start_idx = connection[0]
            end_idx = connection[1]
            if not (0 <= start_idx < num_landmarks and 0 <= end_idx < num_landmarks):
                raise ValueError(
                    f"Landmark index is out of range. Invalid connection "
                    f"from landmark #{start_idx} to landmark #{end_idx}."
                )
            if start_idx in plotted_landmarks and end_idx in plotted_landmarks:
                landmark_pair = [
                    plotted_landmarks[start_idx],
                    plotted_landmarks[end_idx],
                ]
                out_cn.append(
                    dict(
                        xs=[landmark_pair[0][0], landmark_pair[1][0]],
                        ys=[landmark_pair[0][1], landmark_pair[1][1]],
                        zs=[landmark_pair[0][2], landmark_pair[1][2]],
                    )
                )
        cn2 = {"xs": [], "ys": [], "zs": []}
        for pair in out_cn:
            for k in pair.keys():
                cn2[k].append(pair[k][0])
                cn2[k].append(pair[k][1])
                cn2[k].append(None)

    df = pd.DataFrame(plotted_landmarks).T.rename(columns={0: "z", 1: "x", 2: "y"})
    df["lm"] = df.index.map(lambda s: mp_pose.PoseLandmark(s).name).values
    fig = (
        px.scatter_3d(df, x="z", y="x", z="y", hover_name="lm")
        .update_traces(marker={"color": "red"})
        .update_layout(
            margin={"l": 0, "r": 0, "t": 0, "b": 0},
            scene={"camera": {"eye": {"x": 2.1, "y": 0, "z": 0}}},
        )
    )
    fig.add_traces(
        [
            go.Scatter3d(
                x=cn2["xs"],
                y=cn2["ys"],
                z=cn2["zs"],
                mode="lines",
                line={"color": "black", "width": 5},
                name="connections",
            )
        ]
    )

    return fig

try it

plot_landmarks(
        results.pose_world_landmarks,  mp_pose.POSE_CONNECTIONS)

enter image description here

Rob Raymond
  • 29,118
  • 3
  • 14
  • 30
  • 1
    That's some neat piece of code you have written it'll take a while for my puny brain to comprehend this. Thanks for your help :) – Atharva Katre Sep 23 '21 at 08:32
  • Do you have a proper solution to get this `mp_drawing.plot_landmarks(results.pose_world_landmarks, mp_holistic.POSE_CONNECTIONS)` in real time to webcam video feed ? – Daniel Jun 02 '22 at 10:31