Source code for Perception.Grasp_Pose_Estimation.AnyGrasp.Anygrasp

# !/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
# @FileName       : Anygrasp.py
# @Time           : 2024-08-03 15:07:35
# @Author         : yk
# @Email          : yangkui1127@gmail.com
# @Description:   : AnyGrasp: Grasp pose estimation algorithm
"""

import os
import sys

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import argparse
import copy
import math

import cv2
import matplotlib.pyplot as plt
import numpy as np
from graspnetAPI import GraspGroup
from gsnet import AnyGrasp
from PIL import Image

from Perception.Object_detection import Lang_SAM
from RoboticsToolBox import Pose
from Visualization import Camera

from .utils import Bbox, draw_rectangle, sample_points, visualize_cloud_geometries


[docs]class Anygrasp: """A class for grasp pose estimation using AnyGrasp model. Attributes: cfgs: Configuration settings. grasping_model: Instance of AnyGrasp model. cam: Camera parameters. """
[docs] def __init__(self, cfgs): """Initialize Anygrasp class with configurations. Args: cfgs: Configuration settings. """ self.cfgs = cfgs self.cfgs.output_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), self.cfgs.output_dir ) self.cfgs.checkpoint_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), self.cfgs.checkpoint_path ) self.grasping_model = AnyGrasp(self.cfgs) self.grasping_model.load_net()
[docs] def Grasp_Pose_Estimation( self, camera: Camera, seg_mask: np.ndarray, bbox: Bbox, crop_flag: bool = False, ): """Calculate the optimal grasping pose. Args: cam (CameraParameters): Camera internal parameters. points (np.ndarray): 3D point cloud data. seg_mask (np.ndarray): Object Mask. bbox (Bbox): Object bounding box. crop_flag (bool, optional): If point cloud has been crop by seg mask. Defaults to False. Returns: Tuple[bool, Pose]: Status and best grasp pose. """ # get 3d points points, colors = camera.sim_get_3d_points() # If the sampling rate is less than 1, the point cloud is downsampled if self.cfgs.sampling_rate < 1: points, indices = sample_points(points, self.cfgs.sampling_rate) colors = colors[indices] # gg is a list of grasps of type graspgroup in graspnetAPI xmin = points[:, 0].min() xmax = points[:, 0].max() ymin = points[:, 1].min() ymax = points[:, 1].max() zmin = points[:, 2].min() zmax = points[:, 2].max() lims = [xmin, xmax, ymin, ymax, zmin, zmax] # Grasp prediction, return grasp group and point cloud gg, cloud = self.grasping_model.get_grasp( points, colors, lims, apply_object_mask=True, dense_grasp=False, collision_detection=True, ) trans_mat = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) cloud.transform(trans_mat) if len(gg) == 0: print( "[AnyGrasp] \033[33mwarning\033[0m: No Grasp detected after collision detection!" ) return None # The grasped groups are subjected to non-maximum suppression (NMS) and sorted by scores. gg = gg.nms().sort_by_score() print( "[AnyGrasp] \033[34mInfo\033[0m: Grasp point number of all objects:", len(gg), ) if self.cfgs.debug: grippers = gg.to_open3d_geometry_list() for gripper in grippers: gripper.transform(trans_mat) visualize_cloud_geometries( cloud, grippers, save_file=os.path.join(self.cfgs.output_dir, "poses.png"), ) filter_gg = GraspGroup() # Reference direction vector, indicating the ideal grasping direction. # ref_vec = np.array([0, math.cos(camera.head_tilt), -math.sin(camera.head_tilt)]) ref_vec = np.array([0, 0, 1]) # visualize the grip points associated with the given object image = copy.deepcopy(camera.image) img_drw = draw_rectangle(image, bbox) # Filtering the grasps by penalising the vertical grasps as they are not robust to calibration errors. for g in gg: grasp_center = g.translation # Convert the coordinates of a grasp center in three-dimensional space to two-dimensional coordinates on the image plane ix = max( 0, min( camera.width - 1, int(((grasp_center[0] * camera.fx) / grasp_center[2]) + camera.cx), ), ) iy = max( 0, min( camera.height - 1, int(((-grasp_center[1] * camera.fy) / grasp_center[2]) + camera.cy), ), ) if crop_flag: filter_gg.add(g) else: if seg_mask[ iy, ix ]: # Check if the grasp point is within the grasp object area img_drw.ellipse([(ix - 2, iy - 2), (ix + 2, iy + 2)], fill="green") # # 3 * 3 rotation matrix # rotation_matrix = g.rotation_matrix # # The angle between ref vec and grasp z-axis direction # cur_vec = rotation_matrix[:, 2] # angle = math.acos(np.dot(ref_vec, cur_vec) / (np.linalg.norm(cur_vec))) # score = g.score - 0.1 * (angle) ** 4 # if g.score >= 0.095: # g.score = score filter_gg.add(g) else: img_drw.ellipse([(ix - 2, iy - 2), (ix + 2, iy + 2)], fill="red") if len(filter_gg) == 0: print( "[AnyGrasp] \033[33mwarning\033[0m: No grasp poses detected for this object try to move the object a little and try again" ) return None # show and save grasp projections result projections_file_name = os.path.join( self.cfgs.output_dir, "grasp_projections.png" ) if self.cfgs.debug: plt.imshow(image) plt.title("visualize grasp projections") plt.axis("off") plt.show() image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) cv2.imwrite(projections_file_name, image) print( f"[AnyGrasp] \033[34mInfo\033[0m: Saved projections of grasps at {projections_file_name}" ) filter_gg = filter_gg.nms().sort_by_score() print( "[AnyGrasp] \033[34mInfo\033[0m: Filter grasp point number of grasp object:", len(filter_gg), ) self.print_filter_gg(filter_gg) if self.cfgs.debug: filter_grippers = filter_gg.to_open3d_geometry_list() for gripper in filter_grippers: gripper.transform(trans_mat) visualize_cloud_geometries( cloud, [filter_grippers[0].paint_uniform_color([1.0, 0.0, 0.0])], save_file=os.path.join(self.cfgs.output_dir, "best_pose.png"), ) # best_pose = [ # filter_gg[0].translation, # grasp position # filter_gg[0].rotation_matrix, # grasp orientation # ] best_pose = Pose(filter_gg[0].translation, filter_gg[0].rotation_matrix) return best_pose
[docs] def print_filter_gg(self, filter_gg): """print grasp pose and score, Descending. Args: filter_gg (): anygrasp grasp pose info """ print(f"[AnyGrasp] \033[34mInfo\033[0m: AnyGrasp output pose about object:") for g in filter_gg: print( f"[AnyGrasp] \033[34mInfo\033[0m: translation: {g.translation}, z_vec: {g.rotation_matrix[:, 2]}, score: {g.score}" )
if __name__ == "__main__": # set work dir to AnyGrasp os.chdir(os.path.dirname(os.path.abspath(__file__))) # cfgs parser = argparse.ArgumentParser() parser.add_argument( "--checkpoint_path", default="./checkpoints/checkpoint_detection.tar", help="Model checkpoint path", ) parser.add_argument( "--max_gripper_width", type=float, default=0.1, help="Maximum gripper width (<=0.1m)", ) parser.add_argument( "--gripper_height", type=float, default=0.03, help="Gripper height" ) # parser.add_argument( # "--top_down_grasp", action="store_true", help="Output top-down grasps" # ) parser.add_argument("--debug", action="store_true", help="Enable visualization") parser.add_argument( "--max_depth", type=float, default=1.0, help="Maximum depth of point cloud" ) parser.add_argument( "--min_depth", type=float, default=0.0, help="Maximum depth of point cloud" ) parser.add_argument( "--sampling_rate", type=float, default=1.0, help="Sampling rate of points [<= 1]", ) parser.add_argument( "--input_dir", default="./test_data/example2", help="Input data dir" ) parser.add_argument( "--output_dir", default="./output/example2", help="Output data dir" ) parser.add_argument( "--fx", type=float, default=306, help="Focal length in the x direction" ) parser.add_argument( "--fy", type=float, default=306, help="Focal length in the y direction" ) parser.add_argument( "--cx", type=float, default=118, help="Center of the optical axis in the x-direction", ) parser.add_argument( "--cy", type=float, default=211, help="Center of the optical axis in the y-direction", ) parser.add_argument( "--scale", type=float, default=0.001, help="Convert the original depth value in the depth image to the actual depth value", ) parser.add_argument( "--head_tilt", type=float, default=-0.45, help="Tilt angle for ideal gripping pose", ) cfgs = parser.parse_args() cfgs.max_gripper_width = max(0, min(0.2, cfgs.max_gripper_width)) if not os.path.exists(cfgs.output_dir): os.mkdir(cfgs.output_dir) # camera parameters colors = np.array(Image.open(os.path.join(cfgs.input_dir, "color.png"))) image = Image.open(os.path.join(cfgs.input_dir, "color.png")) depths = np.array(Image.open(os.path.join(cfgs.input_dir, "depth.png"))) # cam = CameraParameters( # cfgs.fx, cfgs.fy, cfgs.cx, cfgs.cy, cfgs.head_tilt, image, colors, depths # ) # object detection lang_sam = Lang_SAM() image = Image.open(os.path.join(cfgs.input_dir, "color.png")) query = str(input("Enter a Object name in the image: ")) seg_mask, bbox = lang_sam.detect_obj(image, query, save_box=False, save_mask=False) anygrasp = Anygrasp(cfgs) # best_pose = anygrasp.Grasp_Pose_Estimation(cam, points, seg_mask, bbox)