孪生神经网络实现的点选验证码

孪生网络训练

孪生神经网络(Siamese network),该网络常常用于检测输入进来的两张图片的相似性。

之前我们已经使用过yolov5对目标识别坐标了,

具体查看http://kjol.cc/yolov5-win10-gpu.html

标了100个样本坐训练。体力活啊

但是点选还差先点哪个后点哪个。

我的思路是,识别回来的0类,根据X坐标可以判断先后关系

网上思路我感觉过于复杂。还得分开训练大小图片。直接一古脑扔进去我感觉完事

然后把把识别出来的题目小图片,通过孪生神经网络跟2类对比相似度。接着排序好返回具体要点击的3个坐标。

由于博主菜鸟一个,基本功不够扎实,,加上时间有限,一时没弄出来,提前占坑。有空待续。。。

孪生网络样本分类

这2天空余时间人工分类,为后面的孪生网络做训练使用

使用源码:https://github.com/bubbliiiing/siamese-pytorch/tree/bilibili

最后经过一顿操作猛乱改

终于YOLOV5识别目标坐标+孪生网络对比然后按循序输出了坐标。

目测成功率灰常高,比我用的打码平台都还高。毕竟自己亲自训练的

主要代码如下:

  1. import argparse
  2. import csv
  3. import os
  4. import platform
  5. import sys
  6. from pathlib import Path
  7. import numpy as np
  8. import torch
  9.  
  10. FILE = Path(__file__).resolve()
  11. ROOT = FILE.parents[0] # YOLOv5 root directory
  12. if str(ROOT) not in sys.path:
  13. sys.path.append(str(ROOT)) # add ROOT to PATH
  14. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  15.  
  16. from ultralytics.utils.plotting import Annotator, colors, save_one_box
  17.  
  18. from models.common import DetectMultiBackend
  19. from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
  20. from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
  21. increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
  22. from utils.torch_utils import select_device, smart_inference_mode
  23.  
  24. from models.experimental import attempt_load
  25. import base64
  26.  
  27. from flask import Flask, request, Response,render_template
  28. import json
  29. import cv2
  30. import time
  31.  
  32. from PIL import Image
  33. app = Flask(__name__)
  34.  
  35. output_folder = 'runs/output' # 保存结果文件夹路径
  36.  
  37. # 确保输出文件夹存在
  38. if not os.path.exists(output_folder):
  39. os.makedirs(output_folder)
  40.  
  41.  
  42. UPLOAD_FOLDER = r'./uploads'
  43. app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
  44. def base64_to_image(base64_code):
  45. # base64解码
  46. img_data = base64.b64decode(base64_code)
  47. # 转换为np数组
  48. img_array = np.fromstring(img_data, np.uint8)
  49. # 转换成opencv可用格式
  50. image_base64_dec = cv2.imdecode(img_array, cv2.COLOR_RGB2BGR)
  51. return image_base64_dec
  52.  
  53.  
  54. def redirect(source_path):
  55. stride = int(model.stride.max())
  56. imgsz = check_img_size(640, s=stride)
  57. detections = [] # 用于存储检测结果的列表
  58.  
  59. im0s = cv2.imread(source_path)
  60. assert im0s is not None, "Failed to load image at path {}".format(source_path)
  61.  
  62. # 转换颜色空间从BGR到RGB
  63. im0s = cv2.cvtColor(im0s, cv2.COLOR_BGR2RGB)
  64.  
  65. # 调整图像尺寸以匹配模型期望的输入尺寸
  66. im0s = cv2.resize(im0s, (imgsz, imgsz))
  67. im = torch.from_numpy(im0s).to(device)
  68. im = im.half() if model.half else im.float()
  69. im /= 255
  70. im = im.permute(2, 0, 1)[None] # 重排维度并添加批次维度
  71.  
  72. pred = model(im, augment=False, visualize=False)
  73. pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000)
  74. below_50_detections = []
  75. above_50_detections = []
  76. for i, det in enumerate(pred):
  77. if len(det):
  78. det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0s.shape).round()
  79. for *xyxy, conf, cls in reversed(det):
  80. c1, c2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
  81. if c1[1] < 50:
  82. below_50_detections.append((c1, c2))
  83. else:
  84. above_50_detections.append((c1, c2))
  85.  
  86. # 根据左上角 x 坐标对检测结果进行排序
  87. below_50_detections.sort(key=lambda x: x[0][0])
  88. above_50_detections.sort(key=lambda x: x[0][1])
  89. seen=0
  90. max_similarity = 0
  91. most_similar_img = None
  92. top_similarities = [] # 用于存储所有相关图片的相似度和坐标
  93. for (c1, c2) in below_50_detections:
  94. seen += 1
  95. im0_pil = Image.fromarray(im0s)
  96. cropped_img = im0_pil.crop((c1[0], c1[1], c2[0], c2[1]))
  97. output_filename = f'output_{seen}.bmp'
  98. cropped_img.save(os.path.join(output_folder, output_filename))
  99. max_similarity = 0
  100. most_similar_above_c1 = None
  101. most_similar_above_c2 = None
  102.  
  103. # 使用Siamese网络与大于50像素的图片进行相关性判断
  104. for above_c1, above_c2 in above_50_detections:
  105. above_im0_pil = Image.fromarray(im0s)
  106. above_cropped_img = above_im0_pil.crop((above_c1[0], above_c1[1], above_c2[0], above_c2[1]))
  107. # 确保above_cropped_img是PIL图像
  108. if not isinstance(above_cropped_img, Image.Image):
  109. above_cropped_img = Image.fromarray(above_cropped_img)
  110. similarity = siamese_model.detect_image(cropped_img, above_cropped_img)
  111. if similarity > max_similarity:
  112. max_similarity = similarity
  113. most_similar_above_c1 = above_c1
  114. most_similar_above_c2 = above_c2
  115. print(f"坐标: ({most_similar_above_c1[0]}, {most_similar_above_c1[1]}) - ({most_similar_above_c1[0]}, {most_similar_above_c1[1]})")
  116. # 保存图片
  117. above_im0_pil = Image.fromarray(im0s)
  118. above_cropped_img = above_im0_pil.crop((most_similar_above_c1[0], most_similar_above_c1[1], most_similar_above_c2[0], most_similar_above_c2[1]))
  119. output_filename = f'top_similar_{seen}.bmp'
  120. above_cropped_img.save(os.path.join(output_folder, output_filename))
  121. seen += 1
  122.  
  123.  
  124. # 输出大于50像素的坐标
  125. #for (c1, c2) in above_50_detections:
  126. # print(f"大于50像素的坐标: {c2[0]-c1[0]+c1[0]},{c2[1]-c2[1]+c2[1]}")
  127.  
  128. return ''
  129.  
  130. from siamese import Siamese
  131.  
  132. def allowed_file(filename):
  133. return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
  134.  
  135.  
  136. import tempfile
  137. @app.route('/detect', methods=['POST'])
  138. def detect():
  139. file = request.files['file']
  140. if file:
  141. with tempfile.NamedTemporaryFile(delete=False) as tmp:
  142. tmp.write(file.read())
  143. tmp_path = tmp.name
  144. results = redirect(tmp_path)
  145. os.remove(tmp_path) # 删除临时文件
  146. return results
  147. return 'No file uploaded.', 400
  148. if __name__ == "__main__":
  149. #opt=parse_opt()
  150. #run(**vars(opt))
  151. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  152. model = attempt_load(ROOT / 'sx.pt')
  153. model.to(device).eval()
  154. model.half()
  155. conf_thres = 0.65 # NMS置信度
  156. iou_thres = 0.45 # IOU阈值
  157. siamese_model = Siamese()
  158. #print(opt)
  159. app.run(host='0.0.0.0', port=5002, debug=True)

代码搓的难看。各位大神看看就好

还有最后图片需要归一化处理。不然坐标是不准的。

参考链接

https://www.sohu.com/a/751469426_120818776

https://zhuanlan.zhihu.com/p/678161348

感谢工具:https://kimi.moonshot.cn/

发表评论

邮箱地址不会被公开。 必填项已用*标注