improve ask_line tool(add conversation history context)

This commit is contained in:
benedict.lee
2025-04-09 23:45:04 +09:00
parent edaab4b6b1
commit b53d2773a9
5 changed files with 218 additions and 0 deletions

View File

@ -285,6 +285,15 @@ class GitProvider(ABC):
def get_comment_url(self, comment) -> str:
return ""
def get_review_comment_by_id(self, comment_id: int):
pass
def get_review_id_by_comment_id(self, comment_id: int):
pass
def get_review_thread_comments(self, comment_id: int):
pass
#### labels operations ####
@abstractmethod

View File

@ -428,6 +428,130 @@ class GithubProvider(GitProvider):
except Exception as e:
get_logger().error(f"Failed to publish inline code comments fallback, error: {e}")
raise e
def get_review_comment_by_id(self, comment_id: int):
"""
Retrieves a review comment by its ID.
Args:
comment_id: Review comment ID
Returns:
Review comment object or None (if not found)
"""
try:
# Using PyGitHub library
# There's no direct way to get PR comment by ID, so we fetch all comments and filter
all_comments = list(self.pr.get_comments())
for comment in all_comments:
if comment.id == comment_id:
return comment
return None
except Exception as e:
get_logger().warning(f"Failed to get review comment {comment_id}, error: {e}")
return None
def get_review_id_by_comment_id(self, comment_id: int):
"""
Finds the review ID that a comment belongs to based on its comment ID.
Args:
comment_id: Review comment ID
Returns:
Review ID or None (if not found)
"""
try:
comment = self.get_review_comment_by_id(comment_id)
if comment:
return getattr(comment, 'pull_request_review_id', None)
return None
except Exception as e:
get_logger().warning(f"Failed to get review ID for comment {comment_id}, error: {e}")
return None
def get_review_thread_comments(self, comment_id: int):
"""
Retrieves all comments in the thread that a specific comment belongs to.
Args:
comment_id: Review comment ID
Returns:
List of comments in the same thread
"""
try:
# Get comment information
comment = self.get_review_comment_by_id(comment_id)
if not comment:
return []
# get all comments
all_comments = list(self.pr.get_comments())
# Filter comments in the same thread
thread_comments = []
in_reply_to_map = {}
# First build the in_reply_to relationship map
for c in all_comments:
in_reply_to_id = getattr(c, 'in_reply_to_id', None)
if in_reply_to_id:
in_reply_to_map[c.id] = in_reply_to_id
# Recursively find all ancestor comments (collect comment's ancestors)
def find_ancestors(cid):
ancestors = []
current = cid
while current in in_reply_to_map:
parent_id = in_reply_to_map[current]
ancestors.append(parent_id)
current = parent_id
return ancestors
# Recursively find all descendant comments (collect all replies to the comment)
def find_descendants(cid):
descendants = []
for c in all_comments:
if getattr(c, 'in_reply_to_id', None) == cid:
descendants.append(c.id)
descendants.extend(find_descendants(c.id))
return descendants
# Find all descendants of a specific ancestor (including sibling comments)
def find_all_descendants_of_ancestor(ancestor_id):
all_descendants = []
for c in all_comments:
if getattr(c, 'in_reply_to_id', None) == ancestor_id:
all_descendants.append(c.id)
all_descendants.extend(find_descendants(c.id))
return all_descendants
# Collect both ancestor and descendant IDs of the comment
ancestors = find_ancestors(comment_id)
descendants = find_descendants(comment_id)
# Create thread ID set (self, ancestors, descendants)
thread_ids = set([comment_id] + ancestors + descendants)
# For each ancestor, include all conversation branches (sibling comments with the same ancestor)
for ancestor_id in ancestors:
sibling_ids = find_all_descendants_of_ancestor(ancestor_id)
thread_ids.update(sibling_ids)
# Filter to only get comments belonging to the thread
for c in all_comments:
if c.id in thread_ids:
thread_comments.append(c)
# Sort chronologically (by creation date)
thread_comments.sort(key=lambda c: c.created_at)
return thread_comments
except Exception as e:
get_logger().warning(f"Failed to get review thread comments for comment {comment_id}, error: {e}")
return []
def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]):
"""

View File

@ -119,6 +119,7 @@ async_ai_calls=true
[pr_questions] # /ask #
enable_help_text=false
use_conversation_history=true
[pr_code_suggestions] # /improve #

View File

@ -43,6 +43,15 @@ Now focus on the selected lines from the hunk:
======
Note that lines in the diff body are prefixed with a symbol that represents the type of change: '-' for deletions, '+' for additions, and ' ' (a space) for unchanged lines
{%- if conversation_history %}
Previous discussion on this code:
======
{{ conversation_history|trim }}
======
Consider both the previous review comments from authors and reviewers, as well as any previous questions and answers about this code. The "Previous Question" and "Previous AI Answer" show earlier interactions about the same code. Use this context to provide a more informed and consistent answer.
{%- endif %}
A question about the selected lines:
======

View File

@ -35,6 +35,7 @@ class PR_LineQuestions:
"question": self.question_str,
"full_hunk": "",
"selected_lines": "",
"conversation_history": "",
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
@ -42,6 +43,9 @@ class PR_LineQuestions:
get_settings().pr_line_questions_prompt.user)
self.patches_diff = None
self.prediction = None
# get settings for use conversation history
self.use_conversation_history = get_settings().pr_questions.use_conversation_history
def parse_args(self, args):
if args and len(args) > 0:
@ -56,6 +60,10 @@ class PR_LineQuestions:
# if get_settings().config.publish_output:
# self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
# set conversation history if enabled
if self.use_conversation_history:
self._load_conversation_history()
self.patch_with_lines = ""
ask_diff = get_settings().get('ask_diff_hunk', "")
line_start = get_settings().get('line_start', '')
@ -92,6 +100,73 @@ class PR_LineQuestions:
self.git_provider.publish_comment(model_answer_sanitized)
return ""
def _load_conversation_history(self):
"""generate conversation history from the code review thread"""
try:
comment_id = get_settings().get('comment_id', '')
file_path = get_settings().get('file_name', '')
line_number = get_settings().get('line_end', '')
# return if no comment id or file path and line number
if not (comment_id or (file_path and line_number)):
return
# initialize conversation history
conversation_history = []
if hasattr(self.git_provider, 'get_review_thread_comments') and comment_id:
try:
# get review thread comments
thread_comments = self.git_provider.get_review_thread_comments(comment_id)
# current question id (this question is excluded from the context)
current_question_id = comment_id
# generate conversation history from the comments
for comment in thread_comments:
# skip empty comments
body = getattr(comment, 'body', '')
if not body or not body.strip():
continue
# except for current question
if current_question_id and str(comment.id) == str(current_question_id):
continue
# remove the AI command (/ask etc) from the beginning of the comment (optional)
clean_body = body
if clean_body.startswith('/'):
clean_body = clean_body.split('\n', 1)[-1] if '\n' in clean_body else ''
if not clean_body.strip():
continue
# author info
user = comment.user
author = user.login if hasattr(user, 'login') else 'Unknown'
# confirm if the author is the current user (AI vs user)
is_ai = 'bot' in author.lower() or '[bot]' in author.lower()
role = 'AI' if is_ai else 'User'
# append to the conversation history
conversation_history.append(f"{role} ({author}): {clean_body}")
# transform the conversation history to a string
if conversation_history:
self.vars["conversation_history"] = "\n\n".join(conversation_history)
get_logger().info(f"Loaded {len(conversation_history)} comments from the code review thread")
else:
self.vars["conversation_history"] = ""
except Exception as e:
get_logger().warning(f"Failed to get review thread comments: {e}")
self.vars["conversation_history"] = ""
except Exception as e:
get_logger().error(f"Error loading conversation history: {e}")
self.vars["conversation_history"] = ""
async def _get_prediction(self, model: str):
variables = copy.deepcopy(self.vars)