I have trained a UNET for Semantic Segmentation based on this library.
I am currently trying to run the network in real-time in a ROS Node that takes the raw image from a topic, produces the segmentation mask and publishes it in another topic.
The code I have written so far is working semi-well, meaning it works in principle but it is very slow - like 0.5 FPS. I am aiming at least at 15FPS.
I am not sure if the callback()
function takes so long because the inference simply takes too long on my computer (i5-6500 with tensorflow-cpu since I don't have an NVIDIA GPU), or what I think is more likely, the set_session(sess)
in the callback function slows it down.
How can I define this session outside the callback function, so that it doesn't slow my code down? Note that the code below the model inference is just for further image processing, the function is still slow without it.
TL;DR: How do I define set_session() outside the callback()-function so that it does not slow down my code?
#!/usr/bin/env python
#This code is partly based on an example found at:
#https://github.com/isarlab-department-engineering/ros_dt_lane_follower/blob/master/src/lane_detection.py
import rospy
import numpy as np
import cv2
import math
import os
import tensorflow as tf
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from keras_segmentation.predict import predict
from keras_segmentation.models.unet import vgg_unet
from keras_segmentation.predict import model_from_checkpoint_path
from tensorflow.python.keras.backend import set_session
from tensorflow.python.keras.models import load_model
sess = tf.Session()
graph = tf.get_default_graph()
# IMPORTANT: models have to be loaded AFTER SETTING THE SESSION for keras!
# Otherwise, their weights will be unavailable in the threads after the session there has been set
set_session(sess)
import tensorflow as tf
graph = tf.get_default_graph()
seg = model_from_checkpoint_path("vgg_unet_1")
seg._make_predict_function()
#definitions and declarations
bridge = CvBridge()
pub_image = rospy.Publisher('/Segmentation_image',Image,queue_size=1)
with graph.as_default():
set_session(sess)
#callback is executed once for each frame
def callback(data):
#make OpenCV able to process image
image = bridge.imgmsg_to_cv2(data)
image = cv2.cvtColor(image,cv2.COLOR_RGB2BGR )
image = cv2.resize(image, (608, 416))
global sess
global graph
with graph.as_default():
set_session(sess)
'''erg = predict(
model=seg,
inp=image,'''
#out_fname=None
erg = seg.predict(np.array([image]))[0]
#)
print (erg.shape)
erg = erg.astype(np.uint8)
def lane_detect():
rospy.init_node('Segmentation',anonymous=True)
#rospy.Subscriber("/cv_camera/image_raw",Image,callback,queue_size=1,buff_size=2**24)
rospy.Subscriber("/movie_raw",Image,callback,queue_size=1,buff_size=2**24)
try:
rospy.loginfo("Entering ROS Spin")
rospy.spin()
except KeyboardInterrupt:
print("Shutting down")
if __name__ == '__main__':
try:
lane_detect()
except rospy.ROSInterruptException:
pass
'''