__init__.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import cv2
  4. import numpy as np
  5. import time
  6. import requests
  7. import websocket
  8. import threading
  9. __version__ = "0.0.1"
  10. class CVHandler(object):
  11. template_threshold = 0.95 # 模板匹配的阈值
  12. def show(self, img):
  13. ''' 显示一个图片 '''
  14. cv2.imshow('image', img)
  15. cv2.waitKey(0)
  16. cv2.destroyAllWindows()
  17. def imread(self, filename):
  18. '''
  19. Like cv2.imread
  20. This function will make sure filename exists
  21. '''
  22. im = cv2.imread(filename)
  23. if im is None:
  24. raise RuntimeError("file: '%s' not exists" % filename)
  25. return im
  26. def imdecode(self, img_data):
  27. '''
  28. Like cv2.imdecode
  29. This function will make sure filename exists
  30. 直接读取从网络下载的图片数据
  31. '''
  32. im = np.asarray(bytearray(img_data), dtype="uint8")
  33. im = cv2.imdecode(im, cv2.IMREAD_COLOR)
  34. if im is None:
  35. raise RuntimeError("img_data is can not decode")
  36. return im
  37. def find_template(self, im_source, im_search, threshold=template_threshold, rgb=False, bgremove=False):
  38. '''
  39. @return find location
  40. if not found; return None
  41. '''
  42. result = self.find_all_template(im_source, im_search, threshold, 1, rgb, bgremove)
  43. return result[0] if result else None
  44. def find_all_template(self, im_source, im_search, threshold=template_threshold, maxcnt=0, rgb=False,
  45. bgremove=False):
  46. '''
  47. Locate image position with cv2.templateFind
  48. Use pixel match to find pictures.
  49. Args:
  50. im_source(string): 图像、素材
  51. im_search(string): 需要查找的图片
  52. threshold: 阈值,当相识度小于该阈值的时候,就忽略掉
  53. Returns:
  54. A tuple of found [(point, score), ...]
  55. Raises:
  56. IOError: when file read error
  57. '''
  58. # method = cv2.TM_CCORR_NORMED
  59. # method = cv2.TM_SQDIFF_NORMED
  60. method = cv2.TM_CCOEFF_NORMED
  61. if rgb:
  62. s_bgr = cv2.split(im_search) # Blue Green Red
  63. i_bgr = cv2.split(im_source)
  64. weight = (0.3, 0.3, 0.4)
  65. resbgr = [0, 0, 0]
  66. for i in range(3): # bgr
  67. resbgr[i] = cv2.matchTemplate(i_bgr[i], s_bgr[i], method)
  68. res = resbgr[0] * weight[0] + resbgr[1] * weight[1] + resbgr[2] * weight[2]
  69. else:
  70. s_gray = cv2.cvtColor(im_search, cv2.COLOR_BGR2GRAY)
  71. i_gray = cv2.cvtColor(im_source, cv2.COLOR_BGR2GRAY)
  72. # 边界提取(来实现背景去除的功能)
  73. if bgremove:
  74. s_gray = cv2.Canny(s_gray, 100, 200)
  75. i_gray = cv2.Canny(i_gray, 100, 200)
  76. res = cv2.matchTemplate(i_gray, s_gray, method)
  77. w, h = im_search.shape[1], im_search.shape[0]
  78. result = []
  79. while True:
  80. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
  81. if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
  82. top_left = min_loc
  83. else:
  84. top_left = max_loc
  85. if max_val < threshold:
  86. break
  87. # calculator middle point
  88. middle_point = (top_left[0] + w / 2, top_left[1] + h / 2)
  89. result.append(dict(
  90. result=middle_point,
  91. rectangle=(top_left, (top_left[0], top_left[1] + h), (top_left[0] + w, top_left[1]),
  92. (top_left[0] + w, top_left[1] + h)),
  93. confidence=max_val
  94. ))
  95. if maxcnt and len(result) >= maxcnt:
  96. break
  97. # floodfill the already found area
  98. cv2.floodFill(res, None, max_loc, (-1000,), max_val - threshold + 0.1, 1, flags=cv2.FLOODFILL_FIXED_RANGE)
  99. return result
  100. def _sift_instance(self, edge_threshold=100):
  101. if hasattr(cv2, 'SIFT'):
  102. return cv2.SIFT(edgeThreshold=edge_threshold)
  103. return cv2.xfeatures2d.SIFT_create(edgeThreshold=edge_threshold)
  104. def sift_count(self, img):
  105. sift = self._sift_instance()
  106. kp, des = sift.detectAndCompute(img, None)
  107. return len(kp)
  108. def find_sift(self, im_source, im_search, min_match_count=4):
  109. '''
  110. SIFT特征点匹配
  111. '''
  112. res = self.find_all_sift(im_source, im_search, min_match_count, maxcnt=1)
  113. if not res:
  114. return None
  115. return res[0]
  116. def find_all_sift(self, im_source, im_search, min_match_count=4, maxcnt=0):
  117. '''
  118. 使用sift算法进行多个相同元素的查找
  119. Args:
  120. im_source(string): 图像、素材
  121. im_search(string): 需要查找的图片
  122. threshold: 阈值,当相识度小于该阈值的时候,就忽略掉
  123. maxcnt: 限制匹配的数量
  124. Returns:
  125. A tuple of found [(point, rectangle), ...]
  126. A tuple of found [{"point": point, "rectangle": rectangle, "confidence": 0.76}, ...]
  127. rectangle is a 4 points list
  128. '''
  129. sift = self._sift_instance()
  130. flann = cv2.FlannBasedMatcher({'algorithm': self.FLANN_INDEX_KDTREE, 'trees': 5}, dict(checks=50))
  131. kp_sch, des_sch = sift.detectAndCompute(im_search, None)
  132. if len(kp_sch) < min_match_count:
  133. return None
  134. kp_src, des_src = sift.detectAndCompute(im_source, None)
  135. if len(kp_src) < min_match_count:
  136. return None
  137. h, w = im_search.shape[1:]
  138. result = []
  139. while True:
  140. # 匹配两个图片中的特征点,k=2表示每个特征点取2个最匹配的点
  141. matches = flann.knnMatch(des_sch, des_src, k=2)
  142. good = []
  143. for m, n in matches:
  144. # 剔除掉跟第二匹配太接近的特征点
  145. if m.distance < 0.9 * n.distance:
  146. good.append(m)
  147. if len(good) < min_match_count:
  148. break
  149. sch_pts = np.float32([kp_sch[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
  150. img_pts = np.float32([kp_src[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
  151. # M是转化矩阵
  152. M, mask = cv2.findHomography(sch_pts, img_pts, cv2.RANSAC, 5.0)
  153. matches_mask = mask.ravel().tolist()
  154. # 计算四个角矩阵变换后的坐标,也就是在大图中的坐标
  155. h, w = im_search.shape[:2]
  156. pts = np.float32([[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0]]).reshape(-1, 1, 2)
  157. dst = cv2.perspectiveTransform(pts, M)
  158. # trans numpy arrary to python list
  159. # [(a, b), (a1, b1), ...]
  160. pypts = []
  161. for npt in dst.astype(int).tolist():
  162. pypts.append(tuple(npt[0]))
  163. lt, br = pypts[0], pypts[2]
  164. middle_point = (lt[0] + br[0]) / 2, (lt[1] + br[1]) / 2
  165. result.append(dict(
  166. result=middle_point,
  167. rectangle=pypts,
  168. confidence=(matches_mask.count(1), len(good)) # min(1.0 * matches_mask.count(1) / 10, 1.0)
  169. ))
  170. if maxcnt and len(result) >= maxcnt:
  171. break
  172. # 从特征点中删掉那些已经匹配过的, 用于寻找多个目标
  173. qindexes, tindexes = [], []
  174. for m in good:
  175. qindexes.append(m.queryIdx) # need to remove from kp_sch
  176. tindexes.append(m.trainIdx) # need to remove from kp_img
  177. def filter_index(indexes, arr):
  178. r = np.ndarray(0, np.float32)
  179. for i, item in enumerate(arr):
  180. if i not in qindexes:
  181. r = np.append(r, item)
  182. return r
  183. kp_src = filter_index(tindexes, kp_src)
  184. des_src = filter_index(tindexes, des_src)
  185. return result
  186. def find_all(self, im_source, im_search, maxcnt=0):
  187. '''
  188. 优先Template,之后Sift
  189. @ return [(x,y), ...]
  190. '''
  191. result = self.find_all_template(im_source, im_search, maxcnt=maxcnt)
  192. if not result:
  193. result = self.find_all_sift(im_source, im_search, maxcnt=maxcnt)
  194. if not result:
  195. return []
  196. return [match["result"] for match in result]
  197. def find(self, im_source, im_search):
  198. '''
  199. Only find maximum one object
  200. '''
  201. r = self.find_all(im_source, im_search, maxcnt=1)
  202. return r[0] if r else None
  203. def brightness(self, im):
  204. '''
  205. Return the brightness of an image
  206. Args:
  207. im(numpy): image
  208. Returns:
  209. float, average brightness of an image
  210. '''
  211. im_hsv = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)
  212. h, s, v = cv2.split(im_hsv)
  213. height, weight = v.shape[:2]
  214. total_bright = 0
  215. for i in v:
  216. total_bright = total_bright + sum(i)
  217. return float(total_bright) / (height * weight)
  218. class Aircv(object):
  219. timeout = 30
  220. wait_before_operation = 1 # 操作前等待时间 秒
  221. rcv_interval = 2 # 接收图片的间隔时间 秒
  222. # temporary_directory = "./" # 临时保存截图的目录路径
  223. support_network = False # 是否启用网络下载图片
  224. url = ""
  225. host = "127.0.0.1:8000"
  226. path = "/image_service/download/"
  227. def __init__(self, d):
  228. self.__rcv_interva_time_cache = 0
  229. self.d = d
  230. self.cvHandler = CVHandler()
  231. self.FLANN_INDEX_KDTREE = 0
  232. # self.aircv_cache_image_name = Aircv.temporary_directory + self.d._host + "_aircv_cache_image.jpg"
  233. self.debug = True
  234. self.aircv_cache_image = None
  235. self.ws_screen = None
  236. self.zoom_out = None
  237. # 下面三个函数放在最后,而且顺序不能变
  238. self.detection_screen()
  239. self.start_get_screen()
  240. self.get_scaling_ratio()
  241. def detection_screen(self):
  242. """检测设备屏幕比例,必须为 16:9"""
  243. display_height = self.d.info['displayHeight']
  244. display_width = self.d.info['displayWidth']
  245. if display_height / display_width != 16 / 9 and display_width / display_height != 16 / 9:
  246. raise RuntimeError("Does not support current mobile phones, The screen ratio is not 16:9")
  247. def get_scaling_ratio(self):
  248. """计算缩放比"""
  249. while True:
  250. if self.aircv_cache_image is not None:
  251. self.zoom_out = 1.0 * self.d.info['displayHeight'] / self.aircv_cache_image.shape[0]
  252. break
  253. def start_get_screen(self):
  254. def on_message(ws, message):
  255. this = self
  256. if isinstance(message, bytes):
  257. if int(time.time()) - this.__rcv_interva_time_cache >= Aircv.rcv_interval:
  258. # with open(this.aircv_cache_image_name, 'wb') as f:
  259. # f.write(message)
  260. # this.aircv_cache_image = this.cvHandler.imread(self.aircv_cache_image_name)
  261. this.aircv_cache_image = this.cvHandler.imdecode(message)
  262. this.__rcv_interva_time_cache = int(time.time())
  263. def on_error(ws, error):
  264. raise RuntimeError(error)
  265. def on_close(ws):
  266. print("### ws_screen closed ###")
  267. def on_open(ws):
  268. print("### ws_screen on_open ###")
  269. if not self.ws_screen or not self.ws_screen.keep_running:
  270. self.ws_screen = websocket.WebSocketApp("ws://" + self.d._host + ":" + str(self.d._port) + "/minicap",
  271. on_open=on_open,
  272. on_message=on_message,
  273. on_error=on_error,
  274. on_close=on_close)
  275. ws_thread = threading.Thread(target=self.ws_screen.run_forever)
  276. ws_thread.daemon = True
  277. ws_thread.start()
  278. def stop_get_scren(self):
  279. if self.ws_screen and self.ws_screen.keep_running:
  280. self.ws_screen.close()
  281. # operating
  282. def find_template_by_crop(self, img, area=None):
  283. if Aircv.support_network:
  284. img_url = "".join(["http://", Aircv.host, Aircv.path, img])
  285. data = requests.get(img_url)
  286. img_serch = self.cvHandler.imdecode(data.content)
  287. else:
  288. img_serch = self.cvHandler.imread(img)
  289. if area:
  290. crop_img = self.aircv_cache_image[area[1]:area[3], area[0]:area[2]]
  291. result = self.cvHandler.find_template(crop_img, img_serch)
  292. point = result['result'] if result else None
  293. if point:
  294. point = (point[0] + area[0], point[1] + area[1])
  295. else:
  296. crop_img = self.aircv_cache_image
  297. result = self.cvHandler.find_template(crop_img, img_serch)
  298. point = result['result'] if result else None
  299. return (int(point[0] * self.zoom_out), int(point[1] * self.zoom_out)) if point else None
  300. def exists(self, img, timeout=timeout, area=None):
  301. point = None
  302. is_exists = False
  303. while timeout:
  304. if self.debug:
  305. print(timeout)
  306. if self.aircv_cache_image is not None:
  307. point = self.find_template_by_crop(img, area)
  308. if point:
  309. is_exists = True
  310. break
  311. else:
  312. timeout -= 1
  313. time.sleep(1)
  314. return is_exists
  315. def click(self, img, timeout=timeout, area=None):
  316. point = None
  317. while timeout:
  318. if self.debug:
  319. print(timeout)
  320. if self.aircv_cache_image is not None:
  321. point = self.find_template_by_crop(img, area)
  322. if point:
  323. time.sleep(Aircv.wait_before_operation)
  324. self.d.click(point[0], point[1])
  325. break
  326. else:
  327. timeout -= 1
  328. time.sleep(1)
  329. if not timeout:
  330. raise RuntimeError('No image found')
  331. def click_index(self, img, index=1, maxcnt=20, timeout=timeout):
  332. point = None
  333. img_serch = self.cvHandler.imread(img)
  334. while timeout:
  335. if self.debug:
  336. print(timeout)
  337. if self.aircv_cache_image is not None:
  338. result_list = self.cvHandler.find_all_template(self.aircv_cache_image, img_serch, maxcnt=maxcnt)
  339. point = result_list[index - 1]['result'] if result_list else None
  340. if point:
  341. time.sleep(Aircv.wait_before_operation)
  342. self.d.click(point[0], point[1])
  343. break
  344. else:
  345. timeout -= 1
  346. time.sleep(1)
  347. if not timeout:
  348. raise RuntimeError('No image found')
  349. def long_click(self, img, duration=None, timeout=timeout, area=None):
  350. point = None
  351. while timeout:
  352. if self.debug:
  353. print(timeout)
  354. if self.aircv_cache_image is not None:
  355. point = self.find_template_by_crop(img, area)
  356. if point:
  357. time.sleep(Aircv.wait_before_operation)
  358. self.d.long_click(point[0], point[1], duration)
  359. break
  360. else:
  361. timeout -= 1
  362. time.sleep(1)
  363. if not timeout:
  364. raise RuntimeError('No image found')
  365. def swipe(self, img_from, img_to, duration=0.1, steps=None, timeout=timeout, area=None):
  366. point_from = None
  367. point_to = None
  368. while timeout:
  369. if self.debug:
  370. print(timeout)
  371. if self.aircv_cache_image is not None:
  372. point_from = self.find_template_by_crop(img_from, area)
  373. point_to = self.find_template_by_crop(img_to, area)
  374. if point_from and point_to:
  375. time.sleep(Aircv.wait_before_operation)
  376. self.d.swipe(point_from[0], point_from[1], point_to[0], point_to[1], duration, steps)
  377. break
  378. else:
  379. timeout -= 1
  380. time.sleep(1)
  381. if not timeout:
  382. raise RuntimeError('No image found')
  383. def swipe_points(self, img_list, duration=0.5, timeout=timeout, area=None):
  384. point_list = []
  385. while timeout:
  386. if self.debug:
  387. print(timeout)
  388. if self.aircv_cache_image is not None:
  389. for img in img_list:
  390. point = self.find_template_by_crop(img, area)
  391. if not point:
  392. break
  393. point_list.append(point)
  394. if len(point_list) == len(img_list):
  395. time.sleep(Aircv.wait_before_operation)
  396. self.d.swipe_points(point_list, duration)
  397. break
  398. else:
  399. timeout -= 1
  400. time.sleep(1)
  401. if not timeout:
  402. raise RuntimeError('No image found')
  403. def drag(self, img_from, img_to, duration=0.1, steps=None, timeout=timeout, area=None):
  404. point_from = None
  405. point_to = None
  406. while timeout:
  407. if self.debug:
  408. print(timeout)
  409. if self.aircv_cache_image is not None:
  410. point_from = self.find_template_by_crop(img_from, area)
  411. point_to = self.find_template_by_crop(img_to, area)
  412. if point_from and point_to:
  413. time.sleep(Aircv.wait_before_operation)
  414. self.d.drag(point_from[0], point_from[1], point_to[0], point_to[1], duration, steps)
  415. break
  416. else:
  417. timeout -= 1
  418. time.sleep(1)
  419. if not timeout:
  420. raise RuntimeError('No image found')
  421. def get_point(self, img, timeout=timeout, area=None):
  422. point = None
  423. while timeout:
  424. if self.debug:
  425. print(timeout)
  426. if self.aircv_cache_image is not None:
  427. point = self.find_template_by_crop(img, area)
  428. if point:
  429. break
  430. else:
  431. timeout -= 1
  432. time.sleep(1)
  433. if not timeout:
  434. raise RuntimeError('No image found')
  435. return point