vlm_agent/utils_vlm.py

171 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# utils_vlm.py
# 同济子豪兄 2024-5-22
# 多模态大模型、可视化
print('导入视觉大模型模块')
import time
import cv2
import numpy as np
from PIL import Image
from PIL import ImageFont, ImageDraw
# 导入中文字体,指定字号
font = ImageFont.truetype('asset/SimHei.ttf', 26)
from API_KEY import *
# 系统提示词
SYSTEM_PROMPT = '''
我即将说一句给机械臂的指令你帮我从这句话中提取出起始物体和终止物体并从这张图中分别找到这两个物体左上角和右下角的像素坐标输出json数据结构。
接着根据物体的样式估计夹爪应该到达的起始高度和终止高度默认为0.2m,也就是夹爪的长度。
你可以调整移动的开始高度HEIGHT_START和结束高度HEIGHT_END以抓取不同高度的物品默认是0.2m
因此任何高度不应该小于0.2m,否则夹爪会与平面发生碰撞
例如,如果我的指令是:请帮我把红色方块放纸巾上。
估计红色方块本身高度为0.05m所以开始高度设为0.2m估计纸巾有10cm高所以结束高度设为0.3m
你输出这样的格式:
{
"start": "红色方块",
"start_xyxy": [[102, 505], [324, 860]],
"start_height": 0.2,
"end": "纸巾",
"end_xyxy": [[300, 150], [476, 310]],
"end_height": 0.3
}
只回复json本身即可不要回复其它内容
我现在的指令是:
'''
# Yi-Vision调用函数
import openai
from openai import OpenAI
import base64
def yi_vision_api(PROMPT='帮我把红色方块放在钢笔上', img_path='temp/vl_now.jpg'):
'''
零一万物大模型开放平台yi-vision视觉语言多模态大模型API
'''
client = OpenAI(
api_key=YI_KEY,
base_url="https://api.lingyiwanwu.com/v1"
)
# 编码为base64数据
with open(img_path, 'rb') as image_file:
image = 'data:image/jpeg;base64,' + base64.b64encode(image_file.read()).decode('utf-8')
# 向大模型发起请求
completion = client.chat.completions.create(
model="yi-vision",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": SYSTEM_PROMPT + PROMPT
},
{
"type": "image_url",
"image_url": {
"url": image
}
}
]
},
]
)
# 解析大模型返回结果
result = eval(completion.choices[0].message.content.strip())
print(' 大模型调用成功!')
return result
def post_processing_viz(result, img_path, check=False):
'''
视觉大模型输出结果后处理和可视化
check是否需要人工看屏幕确认可视化成功按键继续或退出
'''
# 后处理
img_bgr = cv2.imread(img_path)
img_h = img_bgr.shape[0]
img_w = img_bgr.shape[1]
# 缩放因子
FACTOR = 999
# 起点物体名称
START_NAME = result['start']
# 终点物体名称
END_NAME = result['end']
# 起点,左上角像素坐标
START_X_MIN = int(result['start_xyxy'][0][0] * img_w / FACTOR)
START_Y_MIN = int(result['start_xyxy'][0][1] * img_h / FACTOR)
# 起点,右下角像素坐标
START_X_MAX = int(result['start_xyxy'][1][0] * img_w / FACTOR)
START_Y_MAX = int(result['start_xyxy'][1][1] * img_h / FACTOR)
# 起点,中心点像素坐标
START_X_CENTER = int((START_X_MIN + START_X_MAX) / 2)
START_Y_CENTER = int((START_Y_MIN + START_Y_MAX) / 2)
# 起点,高度
HEIGHT_START = result.get('start_height', 0.2)
# 终点,左上角像素坐标
END_X_MIN = int(result['end_xyxy'][0][0] * img_w / FACTOR)
END_Y_MIN = int(result['end_xyxy'][0][1] * img_h / FACTOR)
# 终点,右下角像素坐标
END_X_MAX = int(result['end_xyxy'][1][0] * img_w / FACTOR)
END_Y_MAX = int(result['end_xyxy'][1][1] * img_h / FACTOR)
# 终点,中心点像素坐标
END_X_CENTER = int((END_X_MIN + END_X_MAX) / 2)
END_Y_CENTER = int((END_Y_MIN + END_Y_MAX) / 2)
# 终点,高度
HEIGHT_END = result.get('end_height', 0.2)
# 可视化
# 画起点物体框
img_bgr = cv2.rectangle(img_bgr, (START_X_MIN, START_Y_MIN), (START_X_MAX, START_Y_MAX), [0, 0, 255], thickness=3)
# 画起点中心点
img_bgr = cv2.circle(img_bgr, [START_X_CENTER, START_Y_CENTER], 6, [0, 0, 255], thickness=-1)
# 画终点物体框
img_bgr = cv2.rectangle(img_bgr, (END_X_MIN, END_Y_MIN), (END_X_MAX, END_Y_MAX), [255, 0, 0], thickness=3)
# 画终点中心点
img_bgr = cv2.circle(img_bgr, [END_X_CENTER, END_Y_CENTER], 6, [255, 0, 0], thickness=-1)
# 写中文物体名称
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # BGR 转 RGB
img_pil = Image.fromarray(img_rgb) # array 转 pil
draw = ImageDraw.Draw(img_pil)
# 写起点物体中文名称
draw.text((START_X_MIN, START_Y_MIN-32), START_NAME, font=font, fill=(255, 0, 0, 1)) # 文字坐标中文字符串字体rgba颜色
# 写终点物体中文名称
draw.text((END_X_MIN, END_Y_MIN-32), END_NAME, font=font, fill=(0, 0, 255, 1)) # 文字坐标中文字符串字体rgba颜色
img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) # RGB转BGR
# 保存可视化效果图
cv2.imwrite('temp/vl_now_viz.jpg', img_bgr)
formatted_time = time.strftime("%Y%m%d%H%M", time.localtime())
cv2.imwrite('visualizations/{}.jpg'.format(formatted_time), img_bgr)
# 在屏幕上展示可视化效果图
cv2.imshow('vlm_agent', img_bgr)
if check:
print(' 请确认可视化成功按c键继续按q键退出')
while(True):
key = cv2.waitKey(10) & 0xFF
if key == ord('c'): # 按c键继续
break
if key == ord('q'): # 按q键退出
# exit()
cv2.destroyAllWindows() # 关闭所有opencv窗口
# raise NameError('按q退出')
else:
if cv2.waitKey(1) & 0xFF == None:
pass
return START_X_CENTER, START_Y_CENTER, HEIGHT_START, END_X_CENTER, END_Y_CENTER, HEIGHT_END