171 lines
6.0 KiB
Python
171 lines
6.0 KiB
Python
# utils_vlm.py
|
||
# 同济子豪兄 2024-5-22
|
||
# 多模态大模型、可视化
|
||
|
||
print('导入视觉大模型模块')
|
||
import time
|
||
import cv2
|
||
import numpy as np
|
||
from PIL import Image
|
||
from PIL import ImageFont, ImageDraw
|
||
# 导入中文字体,指定字号
|
||
font = ImageFont.truetype('asset/SimHei.ttf', 26)
|
||
|
||
from API_KEY import *
|
||
|
||
# 系统提示词
|
||
SYSTEM_PROMPT = '''
|
||
我即将说一句给机械臂的指令,你帮我从这句话中提取出起始物体和终止物体,并从这张图中分别找到这两个物体左上角和右下角的像素坐标,输出json数据结构。
|
||
|
||
接着,根据物体的样式估计夹爪应该到达的起始高度和终止高度,默认为0.2m,也就是夹爪的长度。
|
||
|
||
你可以调整移动的开始高度HEIGHT_START和结束高度HEIGHT_END以抓取不同高度的物品,默认是0.2m
|
||
因此任何高度不应该小于0.2m,否则夹爪会与平面发生碰撞
|
||
|
||
例如,如果我的指令是:请帮我把红色方块放纸巾上。
|
||
估计红色方块本身高度为0.05m,所以开始高度设为0.2m;估计纸巾有10cm高,所以结束高度设为0.3m
|
||
你输出这样的格式:
|
||
{
|
||
"start": "红色方块",
|
||
"start_xyxy": [[102, 505], [324, 860]],
|
||
"start_height": 0.2,
|
||
"end": "纸巾",
|
||
"end_xyxy": [[300, 150], [476, 310]],
|
||
"end_height": 0.3
|
||
}
|
||
|
||
只回复json本身即可,不要回复其它内容
|
||
|
||
我现在的指令是:
|
||
'''
|
||
|
||
# Yi-Vision调用函数
|
||
import openai
|
||
from openai import OpenAI
|
||
import base64
|
||
def yi_vision_api(PROMPT='帮我把红色方块放在钢笔上', img_path='temp/vl_now.jpg'):
|
||
|
||
'''
|
||
零一万物大模型开放平台,yi-vision视觉语言多模态大模型API
|
||
'''
|
||
|
||
client = OpenAI(
|
||
api_key=YI_KEY,
|
||
base_url="https://api.lingyiwanwu.com/v1"
|
||
)
|
||
|
||
# 编码为base64数据
|
||
with open(img_path, 'rb') as image_file:
|
||
image = 'data:image/jpeg;base64,' + base64.b64encode(image_file.read()).decode('utf-8')
|
||
|
||
# 向大模型发起请求
|
||
completion = client.chat.completions.create(
|
||
model="yi-vision",
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": SYSTEM_PROMPT + PROMPT
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": image
|
||
}
|
||
}
|
||
]
|
||
},
|
||
]
|
||
)
|
||
|
||
# 解析大模型返回结果
|
||
result = eval(completion.choices[0].message.content.strip())
|
||
print(' 大模型调用成功!')
|
||
|
||
return result
|
||
|
||
def post_processing_viz(result, img_path, check=False):
|
||
|
||
'''
|
||
视觉大模型输出结果后处理和可视化
|
||
check:是否需要人工看屏幕确认可视化成功,按键继续或退出
|
||
'''
|
||
|
||
# 后处理
|
||
img_bgr = cv2.imread(img_path)
|
||
img_h = img_bgr.shape[0]
|
||
img_w = img_bgr.shape[1]
|
||
# 缩放因子
|
||
FACTOR = 999
|
||
# 起点物体名称
|
||
START_NAME = result['start']
|
||
# 终点物体名称
|
||
END_NAME = result['end']
|
||
# 起点,左上角像素坐标
|
||
START_X_MIN = int(result['start_xyxy'][0][0] * img_w / FACTOR)
|
||
START_Y_MIN = int(result['start_xyxy'][0][1] * img_h / FACTOR)
|
||
# 起点,右下角像素坐标
|
||
START_X_MAX = int(result['start_xyxy'][1][0] * img_w / FACTOR)
|
||
START_Y_MAX = int(result['start_xyxy'][1][1] * img_h / FACTOR)
|
||
# 起点,中心点像素坐标
|
||
START_X_CENTER = int((START_X_MIN + START_X_MAX) / 2)
|
||
START_Y_CENTER = int((START_Y_MIN + START_Y_MAX) / 2)
|
||
# 起点,高度
|
||
HEIGHT_START = result.get('start_height', 0.2)
|
||
# 终点,左上角像素坐标
|
||
END_X_MIN = int(result['end_xyxy'][0][0] * img_w / FACTOR)
|
||
END_Y_MIN = int(result['end_xyxy'][0][1] * img_h / FACTOR)
|
||
# 终点,右下角像素坐标
|
||
END_X_MAX = int(result['end_xyxy'][1][0] * img_w / FACTOR)
|
||
END_Y_MAX = int(result['end_xyxy'][1][1] * img_h / FACTOR)
|
||
# 终点,中心点像素坐标
|
||
END_X_CENTER = int((END_X_MIN + END_X_MAX) / 2)
|
||
END_Y_CENTER = int((END_Y_MIN + END_Y_MAX) / 2)
|
||
# 终点,高度
|
||
HEIGHT_END = result.get('end_height', 0.2)
|
||
|
||
# 可视化
|
||
# 画起点物体框
|
||
img_bgr = cv2.rectangle(img_bgr, (START_X_MIN, START_Y_MIN), (START_X_MAX, START_Y_MAX), [0, 0, 255], thickness=3)
|
||
# 画起点中心点
|
||
img_bgr = cv2.circle(img_bgr, [START_X_CENTER, START_Y_CENTER], 6, [0, 0, 255], thickness=-1)
|
||
# 画终点物体框
|
||
img_bgr = cv2.rectangle(img_bgr, (END_X_MIN, END_Y_MIN), (END_X_MAX, END_Y_MAX), [255, 0, 0], thickness=3)
|
||
# 画终点中心点
|
||
img_bgr = cv2.circle(img_bgr, [END_X_CENTER, END_Y_CENTER], 6, [255, 0, 0], thickness=-1)
|
||
# 写中文物体名称
|
||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # BGR 转 RGB
|
||
img_pil = Image.fromarray(img_rgb) # array 转 pil
|
||
draw = ImageDraw.Draw(img_pil)
|
||
# 写起点物体中文名称
|
||
draw.text((START_X_MIN, START_Y_MIN-32), START_NAME, font=font, fill=(255, 0, 0, 1)) # 文字坐标,中文字符串,字体,rgba颜色
|
||
# 写终点物体中文名称
|
||
draw.text((END_X_MIN, END_Y_MIN-32), END_NAME, font=font, fill=(0, 0, 255, 1)) # 文字坐标,中文字符串,字体,rgba颜色
|
||
img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) # RGB转BGR
|
||
# 保存可视化效果图
|
||
cv2.imwrite('temp/vl_now_viz.jpg', img_bgr)
|
||
|
||
formatted_time = time.strftime("%Y%m%d%H%M", time.localtime())
|
||
cv2.imwrite('visualizations/{}.jpg'.format(formatted_time), img_bgr)
|
||
|
||
# 在屏幕上展示可视化效果图
|
||
cv2.imshow('vlm_agent', img_bgr)
|
||
|
||
if check:
|
||
print(' 请确认可视化成功,按c键继续,按q键退出')
|
||
while(True):
|
||
key = cv2.waitKey(10) & 0xFF
|
||
if key == ord('c'): # 按c键继续
|
||
break
|
||
if key == ord('q'): # 按q键退出
|
||
# exit()
|
||
cv2.destroyAllWindows() # 关闭所有opencv窗口
|
||
# raise NameError('按q退出')
|
||
else:
|
||
if cv2.waitKey(1) & 0xFF == None:
|
||
pass
|
||
|
||
return START_X_CENTER, START_Y_CENTER, HEIGHT_START, END_X_CENTER, END_Y_CENTER, HEIGHT_END
|