OpenCV Recipes:增强现实(AR)

In this post, you are going to learn about augmented reality and how you can use it to build cool applications.

AR 的前提是什么?

你可能会看到 AR 一词在各种环境中被使用,因此,在开始讨论实现细节之前,我们应该了解 AR 的前提。AR 是指在实际内容中叠加计算机生成的输入,如图像、声音、图形和文本。

AR 试图通过无缝融合信息来模糊真实和计算机生成之间的界限,以增强我们的感官体验。它实际上与一个被称为媒介现实(mediated reality)的概念密切相关,这项技术通过增强我们目前对现实的感知而发挥作用。现在,这里的挑战是让它对用户来说看起来无缝。在输入视频上叠加一些东西很容易,但是我们需要让它看起来像是视频的一部分。用户应该感觉到计算机生成的输入紧密反映了实际内容。这就是我们在构建 AR 系统时想要达到的目标。在这种背景下,计算机视觉研究探索了如何将计算机生成的图像应用于实时视频流,从而增强对现实世界的感知。

AR系统是什么样子的?

让我们考虑下图:

正如我们在这里看到的,摄像机捕捉实际视频以获取参考点。图形系统生成虚拟对象,视频合并模块是所有魔法发生的地方。这个块应该足够聪明,能够理解如何以最好的方式将虚拟对象覆盖在实际内容之上。

AR 的几何变换

AR 的结果是惊人的,但是有很多数学蕴含其中。当谈论 AR 实时视频时,我们需要在实际内容上精确地注册虚拟物体。为了更好地理解这一点,让我们把它想象成两台摄像机:一台是真实的,通过它我们可以看到世界;另一台是虚拟的,它投射了计算机生成的图形对象。

为了构建 AR 系统,需要建立以下几何变换:

  • 对象到场景(object-to-scene): 这种变换指的是转换虚拟对象的 3D 坐标,并在实际场景的坐标中表达它们。这确保我们能够将虚拟对象放置在正确的位置。
  • 场景到相机(scene-to-camera):这种变换是指相机在现实世界中的姿态(pose)。所谓姿势,指的是摄像机的方向和位置。我们需要估计摄像机的视角,以便知道如何覆盖虚拟物体。
  • 相机到图像(camera-to-image):这是指相机的校准参数,定义了我们如何将 3D 对象投影到 2D 图像平面上,形成我们最终会看到的图像。

一旦我们进行了这些转换,我们就可以构建完整的系统。

什么是姿态估计?

在继续之前,我们需要了解如何估计相机姿态。这是 AR 系统中非常关键的一步,如果我们想让我们的体验变得无缝,我们需要把它做好。在 AR 世界中,我们实时地将图形叠加在物体上。为了做到这一点,我们需要知道相机的位置和方向,并且我们需要迅速做到这一点。这是姿态估计变得非常重要的地方。如果你没有正确跟踪姿势,叠加的图形看起来会不自然。

如何跟踪平面目标?

既然你已经理解了什么是姿态估计,让我们看看如何使用它来跟踪平面目标。

当我们讨论几何变换以及全景成像时,我们详细讨论了透视变换。我们只需要使用两组点并提取单应矩阵(homography matrix)。这个单应矩阵将告诉我们平面目标是如何转动的。

首先我们使用 ROISelector 类选择兴趣区域,一旦完成,我们将把这些坐标传递给 PoseEstimator

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ROISelector(object):
def __init__(self, win_name, init_frame, callback_func):
self.callback_func = callback_func
self.selected_rect = None
self.drag_start = None
self.tracking_state = 0
event_params = {"frame": init_frame}
cv2.namedWindow(win_name)
cv2.setMouseCallback(win_name, self.mouse_event, event_params)

def mouse_event(self, event, x, y, flags, param):
x, y = np.int16([x, y])

# Detecting the mouse button down event
if event == cv2.EVENT_LBUTTONDOWN:
self.drag_start = (x, y)
self.tracking_state = 0

if self.drag_start:
if event == cv2.EVENT_MOUSEMOVE:
h, w = param["frame"].shape[:2]
xo, yo = self.drag_start
x0, y0 = np.maximum(0, np.minimum([xo, yo], [x, y]))
x1, y1 = np.minimum([w, h], np.maximum([xo, yo], [x, y]))
self.selected_rect = None

if x1-x0 > 0 and y1-y0 > 0:
self.selected_rect = (x0, y0, x1, y1)

elif event == cv2.EVENT_LBUTTONUP:
self.drag_start = None
if self.selected_rect is not None:
self.callback_func(self.selected_rect)
self.selected_rect = None
self.tracking_state = 1

def draw_rect(self, img, rect):
if not rect: return False
x_start, y_start, x_end, y_end = rect
cv2.rectangle(img, (x_start, y_start), (x_end, y_end), (0, 255, 0), 2)
return True

然后,我们从这个兴趣区域提取特征点。因为我们跟踪平面目标,所以算法假设这个兴趣区域是一个平面。所以,当你选择这个兴趣区域时,确保你手中有一个纸板盒。此外,如果纸板盒有一堆图案和独特的点会更好,这样很容易检测和跟踪其上的特征点。

PoseEstimator 类的方法 add_target() 中接收兴趣区域,并将从这些区域提取特征点,我们就可以跟踪目标的运动:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class PoseEstimator(object):
def __init__(self):
# Use locality sensitive hashing algorithm
flann_params = dict(algorithm = 6, table_number = 6, key_size = 12, multi_probe_level = 1)

self.min_matches = 10
self.cur_target = namedtuple('Current', 'image, rect, keypoints, descriptors, data')
self.tracked_target = namedtuple('Tracked', 'target, points_prev, points_cur, H, quad')

self.feature_detector = cv2.ORB_create()
self.feature_detector.setMaxFeatures(1000)
self.feature_matcher = cv2.FlannBasedMatcher(flann_params, {})
self.tracking_targets = []

# Function to add a new target for tracking
def add_target(self, image, rect, data=None):
x_start, y_start, x_end, y_end = rect
keypoints, descriptors = [], []
for keypoint, descriptor in zip(*self.detect_features(image)):
x, y = keypoint.pt
if x_start <= x <= x_end and y_start <= y <= y_end:
keypoints.append(keypoint)
descriptors.append(descriptor)

descriptors = np.array(descriptors, dtype='uint8')
self.feature_matcher.add([descriptors])
target = self.cur_target(image=image, rect=rect, keypoints=keypoints, descriptors=descriptors, data=None)
self.tracking_targets.append(target)

# To get a list of detected objects
def track_target(self, frame):
self.cur_keypoints, self.cur_descriptors = self.detect_features(frame)

if len(self.cur_keypoints) < self.min_matches: return []
try: matches = self.feature_matcher.knnMatch(self.cur_descriptors, k=2)
except Exception as e:
print('Invalid target, please select another with features to extract')
return []
matches = [match[0] for match in matches if len(match) == 2 and match[0].distance < match[1].distance * 0.75]
if len(matches) < self.min_matches: return []

matches_using_index = [[] for _ in range(len(self.tracking_targets))]
for match in matches:
matches_using_index[match.imgIdx].append(match)

tracked = []
for image_index, matches in enumerate(matches_using_index):
if len(matches) < self.min_matches: continue

target = self.tracking_targets[image_index]
points_prev = [target.keypoints[m.trainIdx].pt for m in matches]
points_cur = [self.cur_keypoints[m.queryIdx].pt for m in matches]
points_prev, points_cur = np.float32((points_prev, points_cur))
H, status = cv2.findHomography(points_prev, points_cur, cv2.RANSAC, 3.0)
status = status.ravel() != 0

if status.sum() < self.min_matches: continue

points_prev, points_cur = points_prev[status], points_cur[status]

x_start, y_start, x_end, y_end = target.rect
quad = np.float32([[x_start, y_start], [x_end, y_start], [x_end, y_end], [x_start, y_end]])
quad = cv2.perspectiveTransform(quad.reshape(1, -1, 2), H).reshape(-1, 2)
track = self.tracked_target(target=target, points_prev=points_prev, points_cur=points_cur, H=H, quad=quad)
tracked.append(track)

tracked.sort(key = lambda x: len(x.points_prev), reverse=True)
return tracked

# Detect features in the selected ROIs and return the keypoints and descriptors
def detect_features(self, frame):
keypoints, descriptors = self.feature_detector.detectAndCompute(frame, None)
if descriptors is None: descriptors = []
return keypoints, descriptors

# Function to clear all the existing targets
def clear_targets(self):
self.feature_matcher.clear()
self.tracking_targets = []

下面是剩下的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sys
from collections import namedtuple

import cv2
import numpy as np


class VideoHandler(object):
def __init__(self, capId, scaling_factor, win_name):
self.cap = cv2.VideoCapture(capId)
self.pose_tracker = PoseEstimator()
self.win_name = win_name
self.scaling_factor = scaling_factor

ret, frame = self.cap.read()
self.rect = None
self.frame = cv2.resize(frame, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_AREA)
self.roi_selector = ROISelector(win_name, self.frame, self.set_rect)

def set_rect(self, rect):
self.rect = rect
self.pose_tracker.add_target(self.frame, rect)

def start(self):
paused = False
while True:
if not paused or self.frame is None:
ret, frame = self.cap.read()
scaling_factor = self.scaling_factor
frame = cv2.resize(frame, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_AREA)
if not ret: break
self.frame = frame.copy()

img = self.frame.copy()
if not paused and self.rect is not None:
tracked = self.pose_tracker.track_target(self.frame)
for item in tracked:
cv2.polylines(img, [np.int32(item.quad)], True, (255, 255, 255), 2)
for (x, y) in np.int32(item.points_cur):
cv2.circle(img, (x, y), 2, (255, 255, 255))

self.roi_selector.draw_rect(img, self.rect)
cv2.imshow(self.win_name, img)
ch = cv2.waitKey(1)
if ch == ord(' '): paused = not paused
if ch == ord('c'): self.pose_tracker.clear_targets()
if ch == 27: break

if __name__ == '__main__':
VideoHandler(0, 0.8, 'Tracker').start()

背后的细节

首先,我们有一个 PoseEstimator 类,在这里完成所有繁重的工作。我们需要检测图像中的特征,以及匹配连续图像之间的特征。因此,我们使用 ORB 特征检测器和 Flann 特征匹配器在提取的特征中进行快速最近邻搜索。

每当我们选择兴趣区域时,我们都会调用 add_target 方法将它添加到我们的跟踪目标列表中。这个方法从兴趣区域提取特征,并将其存储在一个类变量中。

track_target 方法处理所有跟踪。我们获取当前帧并提取所有关键点。然而,我们并不是对视频当前帧中的所有关键点都感兴趣,我们只是想要属于我们目标的关键点。所以现在我们的工作是在当前帧中找到最近的关键点。

现在,我们在当前帧中有一组关键点,在前一帧中,我们有另一组来自目标对象的关键点。下一步是从这些匹配点中提取单应矩阵。这个单应矩阵告诉我们如何变换重叠的矩形,使其与纸板盒的表面对齐。我们只需要把这个单应矩阵应用到覆盖的矩形上,就可以得到所有纸板盒点的新位置。

如何增强现实?

既然我们知道如何跟踪平面对象,让我们看看如何将 3D 对象叠加到实际内容。物体是 3D 的,但是我们的视频流是 2D 的。因此,这里的第一步是了解如何将这些 3D 对象映射到 2D 平面,使它们看起来更真实。我们只需要将这些 3D 点投影到平面上。

将坐标从 3D 映射到 2D

一旦我们估计了姿态,我们就将点从 3D 投影到 2D。

我们需要一种机制来将 3D 物体映射到 2D 平面上。这是 3D 到 2D 投影非常重要的地方。我们需要估计最初的相机姿态。现在,假设摄像机的固有参数已经知道了。所以,我们可以使用 OpenCV 中的 solvePnP 函数来估计相机的姿态。此函数使用一组点来估计物体的姿态,你可以参考此文档了解更多。

一旦我们这样做,我们需要将这些点投影到 2D 平面上。我们使用OpenCV projectPoints 函数来实现这一点。此函数计算这些 3D 点在 2D 平面上的投影。

如何在视频上叠加 3D 对象?

既然我们有了所有不同的模块,我们就准备构建最终的系统。

假设我们想在纸板盒上覆盖一座金字塔,让我们看看如何使用 OpenCV 实现这一点。确保将前一个文件保存为 pose_estimation.py,因为我们将从那里导入几个类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import cv2
import numpy as np

from pose_estimation import PoseEstimator, ROISelector

class Tracker(object):
def __init__(self, capId, scaling_factor, win_name):
self.cap = cv2.VideoCapture(capId)
self.rect = None
self.win_name = win_name
self.scaling_factor = scaling_factor
self.tracker = PoseEstimator()

ret, frame = self.cap.read()
self.rect = None
self.frame = cv2.resize(frame, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_AREA)

self.roi_selector = ROISelector(win_name, self.frame, self.set_rect)
self.overlay_vertices = np.float32([[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0], [0.5, 0.5, 4]])
self.overlay_edges = [(0, 1), (1, 2), (2, 3), (3, 0), (0,4), (1,4), (2,4), (3,4)]
self.color_base = (0, 255, 0)
self.color_lines = (0, 0, 0)

def set_rect(self, rect):
self.rect = rect
self.tracker.add_target(self.frame, rect)

def start(self):
paused = False
while True:
if not paused or self.frame is None:
ret, frame = self.cap.read()
scaling_factor = self.scaling_factor
frame = cv2.resize(frame, None, fx=scaling_factor, fy=scaling_factor,\
interpolation=cv2.INTER_AREA)
if not ret: break

self.frame = frame.copy()

img = self.frame.copy()
if not paused:
tracked = self.tracker.track_target(self.frame)
for item in tracked:
cv2.polylines(img, [np.int32(item.quad)],
True, self.color_lines, 2)
for (x, y) in np.int32(item.points_cur):
cv2.circle(img, (x, y), 2,
self.color_lines)

self.overlay_graphics(img, item)

self.roi_selector.draw_rect(img, self.rect)
cv2.imshow(self.win_name, img)
ch = cv2.waitKey(1)
if ch == ord(' '): self.paused = not self.paused
if ch == ord('c'): self.tracker.clear_targets()
if ch == 27: break

def overlay_graphics(self, img, tracked):
x_start, y_start, x_end, y_end = tracked.target.rect
quad_3d = np.float32([[x_start, y_start, 0], [x_end,
y_start, 0],
[x_end, y_end, 0], [x_start, y_end, 0]])
h, w = img.shape[:2]
K = np.float64([[w, 0, 0.5*(w-1)],
[0, w, 0.5*(h-1)],
[0, 0, 1.0]])
dist_coef = np.zeros(4)
ret, rvec, tvec = cv2.solvePnP(objectPoints=quad_3d, imagePoints=tracked.quad,
cameraMatrix=K, distCoeffs=dist_coef)
verts = self.overlay_vertices * \
[(x_end-x_start), (y_end-y_start), -(x_end-x_start)*0.3] + (x_start, y_start, 0)
verts = cv2.projectPoints(verts, rvec, tvec, cameraMatrix=K,
distCoeffs=dist_coef)[0].reshape(-1, 2)

verts_floor = np.int32(verts).reshape(-1,2)
cv2.drawContours(img, contours=[verts_floor[:4]],
contourIdx=-1, color=self.color_base, thickness=-3)
cv2.drawContours(img, contours=[np.vstack((verts_floor[:2],
verts_floor[4:5]))], contourIdx=-1, color=(0,255,0), thickness=-3)
cv2.drawContours(img, contours=[np.vstack((verts_floor[1:3],
verts_floor[4:5]))], contourIdx=-1, color=(255,0,0), thickness=-3)
cv2.drawContours(img, contours=[np.vstack((verts_floor[2:4],
verts_floor[4:5]))], contourIdx=-1, color=(0,0,150), thickness=-3)
cv2.drawContours(img, contours=[np.vstack((verts_floor[3:4],
verts_floor[0:1], verts_floor[4:5]))], contourIdx=-1, color=(255,255,0), thickness=-3)

for i, j in self.overlay_edges:
(x_start, y_start), (x_end, y_end) = verts[i], verts[j]
cv2.line(img, (int(x_start), int(y_start)), (int(x_end), int(y_end)), self.color_lines, 2)

if __name__ == '__main__':
Tracker(0, 0.8, 'Augmented Reality').start()

背后细节

Tracker 类用于执行这里的所有计算,我们在类初始化时定义金字塔的结构。我们用来跟踪平面的逻辑与我们前面讨论的相同,因为我们使用的是同一个类。我们只需要使用 solvePnPprojectPoints将 3D 金字塔映射到 2D 平面。

GreatX wechat
Subscribe to my blog by scanning my public wechat account