123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- import cv2
- import numpy as np
- import time
- import requests
- import websocket
- import threading
- __version__ = "0.0.1"
- class CVHandler(object):
- template_threshold = 0.95 # 模板匹配的阈值
- def show(self, img):
- ''' 显示一个图片 '''
- cv2.imshow('image', img)
- cv2.waitKey(0)
- cv2.destroyAllWindows()
- def imread(self, filename):
- '''
- Like cv2.imread
- This function will make sure filename exists
- '''
- im = cv2.imread(filename)
- if im is None:
- raise RuntimeError("file: '%s' not exists" % filename)
- return im
- def imdecode(self, img_data):
- '''
- Like cv2.imdecode
- This function will make sure filename exists
- 直接读取从网络下载的图片数据
- '''
- im = np.asarray(bytearray(img_data), dtype="uint8")
- im = cv2.imdecode(im, cv2.IMREAD_COLOR)
- if im is None:
- raise RuntimeError("img_data is can not decode")
- return im
- def find_template(self, im_source, im_search, threshold=template_threshold, rgb=False, bgremove=False):
- '''
- @return find location
- if not found; return None
- '''
- result = self.find_all_template(im_source, im_search, threshold, 1, rgb, bgremove)
- return result[0] if result else None
- def find_all_template(self, im_source, im_search, threshold=template_threshold, maxcnt=0, rgb=False,
- bgremove=False):
- '''
- Locate image position with cv2.templateFind
- Use pixel match to find pictures.
- Args:
- im_source(string): 图像、素材
- im_search(string): 需要查找的图片
- threshold: 阈值,当相识度小于该阈值的时候,就忽略掉
- Returns:
- A tuple of found [(point, score), ...]
- Raises:
- IOError: when file read error
- '''
- # method = cv2.TM_CCORR_NORMED
- # method = cv2.TM_SQDIFF_NORMED
- method = cv2.TM_CCOEFF_NORMED
- if rgb:
- s_bgr = cv2.split(im_search) # Blue Green Red
- i_bgr = cv2.split(im_source)
- weight = (0.3, 0.3, 0.4)
- resbgr = [0, 0, 0]
- for i in range(3): # bgr
- resbgr[i] = cv2.matchTemplate(i_bgr[i], s_bgr[i], method)
- res = resbgr[0] * weight[0] + resbgr[1] * weight[1] + resbgr[2] * weight[2]
- else:
- s_gray = cv2.cvtColor(im_search, cv2.COLOR_BGR2GRAY)
- i_gray = cv2.cvtColor(im_source, cv2.COLOR_BGR2GRAY)
- # 边界提取(来实现背景去除的功能)
- if bgremove:
- s_gray = cv2.Canny(s_gray, 100, 200)
- i_gray = cv2.Canny(i_gray, 100, 200)
- res = cv2.matchTemplate(i_gray, s_gray, method)
- w, h = im_search.shape[1], im_search.shape[0]
- result = []
- while True:
- min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
- if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
- top_left = min_loc
- else:
- top_left = max_loc
- if max_val < threshold:
- break
- # calculator middle point
- middle_point = (top_left[0] + w / 2, top_left[1] + h / 2)
- result.append(dict(
- result=middle_point,
- rectangle=(top_left, (top_left[0], top_left[1] + h), (top_left[0] + w, top_left[1]),
- (top_left[0] + w, top_left[1] + h)),
- confidence=max_val
- ))
- if maxcnt and len(result) >= maxcnt:
- break
- # floodfill the already found area
- cv2.floodFill(res, None, max_loc, (-1000,), max_val - threshold + 0.1, 1, flags=cv2.FLOODFILL_FIXED_RANGE)
- return result
- def _sift_instance(self, edge_threshold=100):
- if hasattr(cv2, 'SIFT'):
- return cv2.SIFT(edgeThreshold=edge_threshold)
- return cv2.xfeatures2d.SIFT_create(edgeThreshold=edge_threshold)
- def sift_count(self, img):
- sift = self._sift_instance()
- kp, des = sift.detectAndCompute(img, None)
- return len(kp)
- def find_sift(self, im_source, im_search, min_match_count=4):
- '''
- SIFT特征点匹配
- '''
- res = self.find_all_sift(im_source, im_search, min_match_count, maxcnt=1)
- if not res:
- return None
- return res[0]
- def find_all_sift(self, im_source, im_search, min_match_count=4, maxcnt=0):
- '''
- 使用sift算法进行多个相同元素的查找
- Args:
- im_source(string): 图像、素材
- im_search(string): 需要查找的图片
- threshold: 阈值,当相识度小于该阈值的时候,就忽略掉
- maxcnt: 限制匹配的数量
- Returns:
- A tuple of found [(point, rectangle), ...]
- A tuple of found [{"point": point, "rectangle": rectangle, "confidence": 0.76}, ...]
- rectangle is a 4 points list
- '''
- sift = self._sift_instance()
- flann = cv2.FlannBasedMatcher({'algorithm': self.FLANN_INDEX_KDTREE, 'trees': 5}, dict(checks=50))
- kp_sch, des_sch = sift.detectAndCompute(im_search, None)
- if len(kp_sch) < min_match_count:
- return None
- kp_src, des_src = sift.detectAndCompute(im_source, None)
- if len(kp_src) < min_match_count:
- return None
- h, w = im_search.shape[1:]
- result = []
- while True:
- # 匹配两个图片中的特征点,k=2表示每个特征点取2个最匹配的点
- matches = flann.knnMatch(des_sch, des_src, k=2)
- good = []
- for m, n in matches:
- # 剔除掉跟第二匹配太接近的特征点
- if m.distance < 0.9 * n.distance:
- good.append(m)
- if len(good) < min_match_count:
- break
- sch_pts = np.float32([kp_sch[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
- img_pts = np.float32([kp_src[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
- # M是转化矩阵
- M, mask = cv2.findHomography(sch_pts, img_pts, cv2.RANSAC, 5.0)
- matches_mask = mask.ravel().tolist()
- # 计算四个角矩阵变换后的坐标,也就是在大图中的坐标
- h, w = im_search.shape[:2]
- pts = np.float32([[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0]]).reshape(-1, 1, 2)
- dst = cv2.perspectiveTransform(pts, M)
- # trans numpy arrary to python list
- # [(a, b), (a1, b1), ...]
- pypts = []
- for npt in dst.astype(int).tolist():
- pypts.append(tuple(npt[0]))
- lt, br = pypts[0], pypts[2]
- middle_point = (lt[0] + br[0]) / 2, (lt[1] + br[1]) / 2
- result.append(dict(
- result=middle_point,
- rectangle=pypts,
- confidence=(matches_mask.count(1), len(good)) # min(1.0 * matches_mask.count(1) / 10, 1.0)
- ))
- if maxcnt and len(result) >= maxcnt:
- break
- # 从特征点中删掉那些已经匹配过的, 用于寻找多个目标
- qindexes, tindexes = [], []
- for m in good:
- qindexes.append(m.queryIdx) # need to remove from kp_sch
- tindexes.append(m.trainIdx) # need to remove from kp_img
- def filter_index(indexes, arr):
- r = np.ndarray(0, np.float32)
- for i, item in enumerate(arr):
- if i not in qindexes:
- r = np.append(r, item)
- return r
- kp_src = filter_index(tindexes, kp_src)
- des_src = filter_index(tindexes, des_src)
- return result
- def find_all(self, im_source, im_search, maxcnt=0):
- '''
- 优先Template,之后Sift
- @ return [(x,y), ...]
- '''
- result = self.find_all_template(im_source, im_search, maxcnt=maxcnt)
- if not result:
- result = self.find_all_sift(im_source, im_search, maxcnt=maxcnt)
- if not result:
- return []
- return [match["result"] for match in result]
- def find(self, im_source, im_search):
- '''
- Only find maximum one object
- '''
- r = self.find_all(im_source, im_search, maxcnt=1)
- return r[0] if r else None
- def brightness(self, im):
- '''
- Return the brightness of an image
- Args:
- im(numpy): image
- Returns:
- float, average brightness of an image
- '''
- im_hsv = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)
- h, s, v = cv2.split(im_hsv)
- height, weight = v.shape[:2]
- total_bright = 0
- for i in v:
- total_bright = total_bright + sum(i)
- return float(total_bright) / (height * weight)
- class Aircv(object):
- timeout = 30
- wait_before_operation = 1 # 操作前等待时间 秒
- rcv_interval = 2 # 接收图片的间隔时间 秒
- # temporary_directory = "./" # 临时保存截图的目录路径
- support_network = False # 是否启用网络下载图片
- url = ""
- host = "127.0.0.1:8000"
- path = "/image_service/download/"
- def __init__(self, d):
- self.__rcv_interva_time_cache = 0
- self.d = d
- self.cvHandler = CVHandler()
- self.FLANN_INDEX_KDTREE = 0
- # self.aircv_cache_image_name = Aircv.temporary_directory + self.d._host + "_aircv_cache_image.jpg"
- self.debug = True
- self.aircv_cache_image = None
- self.ws_screen = None
- self.zoom_out = None
- # 下面三个函数放在最后,而且顺序不能变
- self.detection_screen()
- self.start_get_screen()
- self.get_scaling_ratio()
- def detection_screen(self):
- """检测设备屏幕比例,必须为 16:9"""
- display_height = self.d.info['displayHeight']
- display_width = self.d.info['displayWidth']
- if display_height / display_width != 16 / 9 and display_width / display_height != 16 / 9:
- raise RuntimeError("Does not support current mobile phones, The screen ratio is not 16:9")
- def get_scaling_ratio(self):
- """计算缩放比"""
- while True:
- if self.aircv_cache_image is not None:
- self.zoom_out = 1.0 * self.d.info['displayHeight'] / self.aircv_cache_image.shape[0]
- break
- def start_get_screen(self):
- def on_message(ws, message):
- this = self
- if isinstance(message, bytes):
- if int(time.time()) - this.__rcv_interva_time_cache >= Aircv.rcv_interval:
- # with open(this.aircv_cache_image_name, 'wb') as f:
- # f.write(message)
- # this.aircv_cache_image = this.cvHandler.imread(self.aircv_cache_image_name)
- this.aircv_cache_image = this.cvHandler.imdecode(message)
- this.__rcv_interva_time_cache = int(time.time())
- def on_error(ws, error):
- raise RuntimeError(error)
- def on_close(ws):
- print("### ws_screen closed ###")
- def on_open(ws):
- print("### ws_screen on_open ###")
- if not self.ws_screen or not self.ws_screen.keep_running:
- self.ws_screen = websocket.WebSocketApp("ws://" + self.d._host + ":" + str(self.d._port) + "/minicap",
- on_open=on_open,
- on_message=on_message,
- on_error=on_error,
- on_close=on_close)
- ws_thread = threading.Thread(target=self.ws_screen.run_forever)
- ws_thread.daemon = True
- ws_thread.start()
- def stop_get_scren(self):
- if self.ws_screen and self.ws_screen.keep_running:
- self.ws_screen.close()
- # operating
- def find_template_by_crop(self, img, area=None):
- if Aircv.support_network:
- img_url = "".join(["http://", Aircv.host, Aircv.path, img])
- data = requests.get(img_url)
- img_serch = self.cvHandler.imdecode(data.content)
- else:
- img_serch = self.cvHandler.imread(img)
- if area:
- crop_img = self.aircv_cache_image[area[1]:area[3], area[0]:area[2]]
- result = self.cvHandler.find_template(crop_img, img_serch)
- point = result['result'] if result else None
- if point:
- point = (point[0] + area[0], point[1] + area[1])
- else:
- crop_img = self.aircv_cache_image
- result = self.cvHandler.find_template(crop_img, img_serch)
- point = result['result'] if result else None
- return (int(point[0] * self.zoom_out), int(point[1] * self.zoom_out)) if point else None
- def exists(self, img, timeout=timeout, area=None):
- point = None
- is_exists = False
- while timeout:
- if self.debug:
- print(timeout)
- if self.aircv_cache_image is not None:
- point = self.find_template_by_crop(img, area)
- if point:
- is_exists = True
- break
- else:
- timeout -= 1
- time.sleep(1)
- return is_exists
- def click(self, img, timeout=timeout, area=None):
- point = None
- while timeout:
- if self.debug:
- print(timeout)
- if self.aircv_cache_image is not None:
- point = self.find_template_by_crop(img, area)
- if point:
- time.sleep(Aircv.wait_before_operation)
- self.d.click(point[0], point[1])
- break
- else:
- timeout -= 1
- time.sleep(1)
- if not timeout:
- raise RuntimeError('No image found')
- def click_index(self, img, index=1, maxcnt=20, timeout=timeout):
- point = None
- img_serch = self.cvHandler.imread(img)
- while timeout:
- if self.debug:
- print(timeout)
- if self.aircv_cache_image is not None:
- result_list = self.cvHandler.find_all_template(self.aircv_cache_image, img_serch, maxcnt=maxcnt)
- point = result_list[index - 1]['result'] if result_list else None
- if point:
- time.sleep(Aircv.wait_before_operation)
- self.d.click(point[0], point[1])
- break
- else:
- timeout -= 1
- time.sleep(1)
- if not timeout:
- raise RuntimeError('No image found')
- def long_click(self, img, duration=None, timeout=timeout, area=None):
- point = None
- while timeout:
- if self.debug:
- print(timeout)
- if self.aircv_cache_image is not None:
- point = self.find_template_by_crop(img, area)
- if point:
- time.sleep(Aircv.wait_before_operation)
- self.d.long_click(point[0], point[1], duration)
- break
- else:
- timeout -= 1
- time.sleep(1)
- if not timeout:
- raise RuntimeError('No image found')
- def swipe(self, img_from, img_to, duration=0.1, steps=None, timeout=timeout, area=None):
- point_from = None
- point_to = None
- while timeout:
- if self.debug:
- print(timeout)
- if self.aircv_cache_image is not None:
- point_from = self.find_template_by_crop(img_from, area)
- point_to = self.find_template_by_crop(img_to, area)
- if point_from and point_to:
- time.sleep(Aircv.wait_before_operation)
- self.d.swipe(point_from[0], point_from[1], point_to[0], point_to[1], duration, steps)
- break
- else:
- timeout -= 1
- time.sleep(1)
- if not timeout:
- raise RuntimeError('No image found')
- def swipe_points(self, img_list, duration=0.5, timeout=timeout, area=None):
- point_list = []
- while timeout:
- if self.debug:
- print(timeout)
- if self.aircv_cache_image is not None:
- for img in img_list:
- point = self.find_template_by_crop(img, area)
- if not point:
- break
- point_list.append(point)
- if len(point_list) == len(img_list):
- time.sleep(Aircv.wait_before_operation)
- self.d.swipe_points(point_list, duration)
- break
- else:
- timeout -= 1
- time.sleep(1)
- if not timeout:
- raise RuntimeError('No image found')
- def drag(self, img_from, img_to, duration=0.1, steps=None, timeout=timeout, area=None):
- point_from = None
- point_to = None
- while timeout:
- if self.debug:
- print(timeout)
- if self.aircv_cache_image is not None:
- point_from = self.find_template_by_crop(img_from, area)
- point_to = self.find_template_by_crop(img_to, area)
- if point_from and point_to:
- time.sleep(Aircv.wait_before_operation)
- self.d.drag(point_from[0], point_from[1], point_to[0], point_to[1], duration, steps)
- break
- else:
- timeout -= 1
- time.sleep(1)
- if not timeout:
- raise RuntimeError('No image found')
- def get_point(self, img, timeout=timeout, area=None):
- point = None
- while timeout:
- if self.debug:
- print(timeout)
- if self.aircv_cache_image is not None:
- point = self.find_template_by_crop(img, area)
- if point:
- break
- else:
- timeout -= 1
- time.sleep(1)
- if not timeout:
- raise RuntimeError('No image found')
- return point
|