mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-06 05:40:38 +08:00
Allow keeping the original user description
This commit is contained in:
@ -6,12 +6,11 @@ from urllib.parse import urlparse
|
|||||||
import requests
|
import requests
|
||||||
from atlassian.bitbucket import Cloud
|
from atlassian.bitbucket import Cloud
|
||||||
|
|
||||||
from ..algo.pr_processing import clip_tokens
|
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from .git_provider import FilePatchInfo
|
from .git_provider import FilePatchInfo, GitProvider
|
||||||
|
|
||||||
|
|
||||||
class BitbucketProvider:
|
class BitbucketProvider(GitProvider):
|
||||||
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||||
s = requests.Session()
|
s = requests.Session()
|
||||||
s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
|
s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
|
||||||
@ -156,10 +155,7 @@ class BitbucketProvider:
|
|||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.pr.source_branch
|
return self.pr.source_branch
|
||||||
|
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
|
||||||
if max_tokens:
|
|
||||||
return clip_tokens(self.pr.description, max_tokens)
|
|
||||||
return self.pr.description
|
return self.pr.description
|
||||||
|
|
||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
|
@ -82,9 +82,26 @@ class GitProvider(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_pr_description(self):
|
||||||
|
from pr_agent.config_loader import get_settings
|
||||||
|
from pr_agent.algo.pr_processing import clip_tokens
|
||||||
|
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
||||||
|
description = self.get_pr_description_full()
|
||||||
|
if max_tokens:
|
||||||
|
return clip_tokens(description, max_tokens)
|
||||||
|
return description
|
||||||
|
|
||||||
|
def get_user_description(self):
|
||||||
|
description = (self.get_pr_description_full() or "").strip()
|
||||||
|
if not description.startswith("## PR Type"):
|
||||||
|
return description
|
||||||
|
if "## User Description:" not in description:
|
||||||
|
return ""
|
||||||
|
return description.split("## User Description:", 1)[1].strip()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_issue_comments(self):
|
def get_issue_comments(self):
|
||||||
pass
|
pass
|
||||||
|
@ -233,10 +233,7 @@ class GithubProvider(GitProvider):
|
|||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.pr.head.ref
|
return self.pr.head.ref
|
||||||
|
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
|
||||||
if max_tokens:
|
|
||||||
return clip_tokens(self.pr.body, max_tokens)
|
|
||||||
return self.pr.body
|
return self.pr.body
|
||||||
|
|
||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
|
@ -299,10 +299,7 @@ class GitLabProvider(GitProvider):
|
|||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.mr.source_branch
|
return self.mr.source_branch
|
||||||
|
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
|
||||||
if max_tokens:
|
|
||||||
return clip_tokens(self.mr.description, max_tokens)
|
|
||||||
return self.mr.description
|
return self.mr.description
|
||||||
|
|
||||||
def get_issue_comments(self):
|
def get_issue_comments(self):
|
||||||
|
@ -158,7 +158,7 @@ class LocalGitProvider(GitProvider):
|
|||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
return -1 # Not used anywhere for the local provider, but required by the interface
|
return -1 # Not used anywhere for the local provider, but required by the interface
|
||||||
|
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
|
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
|
||||||
# Get the commit messages and concatenate
|
# Get the commit messages and concatenate
|
||||||
commit_messages = " ".join([commit.message for commit in commits_diff])
|
commit_messages = " ".join([commit.message for commit in commits_diff])
|
||||||
|
@ -24,6 +24,7 @@ extra_instructions = ""
|
|||||||
|
|
||||||
[pr_description] # /describe #
|
[pr_description] # /describe #
|
||||||
publish_description_as_comment=false
|
publish_description_as_comment=false
|
||||||
|
keep_user_description=true
|
||||||
extra_instructions = ""
|
extra_instructions = ""
|
||||||
|
|
||||||
[pr_questions] # /ask #
|
[pr_questions] # /ask #
|
||||||
|
@ -43,6 +43,8 @@ class PRDescription:
|
|||||||
"commit_messages_str": self.git_provider.get_commit_messages()
|
"commit_messages_str": self.git_provider.get_commit_messages()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.user_description = self.git_provider.get_user_description()
|
||||||
|
|
||||||
# Initialize the token handler
|
# Initialize the token handler
|
||||||
self.token_handler = TokenHandler(
|
self.token_handler = TokenHandler(
|
||||||
self.git_provider.pr,
|
self.git_provider.pr,
|
||||||
@ -145,6 +147,9 @@ class PRDescription:
|
|||||||
# Load the AI prediction data into a dictionary
|
# Load the AI prediction data into a dictionary
|
||||||
data = load_yaml(self.prediction.strip())
|
data = load_yaml(self.prediction.strip())
|
||||||
|
|
||||||
|
if get_settings().pr_description.keep_user_description and self.user_description:
|
||||||
|
data["User Description"] = self.user_description
|
||||||
|
|
||||||
# Initialization
|
# Initialization
|
||||||
pr_types = []
|
pr_types = []
|
||||||
|
|
||||||
@ -167,7 +172,7 @@ class PRDescription:
|
|||||||
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
|
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
|
||||||
# except for the items containing the word 'walkthrough'
|
# except for the items containing the word 'walkthrough'
|
||||||
pr_body = ""
|
pr_body = ""
|
||||||
for key, value in data.items():
|
for idx, (key, value) in enumerate(data.items()):
|
||||||
pr_body += f"## {key}:\n"
|
pr_body += f"## {key}:\n"
|
||||||
if 'walkthrough' in key.lower():
|
if 'walkthrough' in key.lower():
|
||||||
# for filename, description in value.items():
|
# for filename, description in value.items():
|
||||||
@ -179,7 +184,9 @@ class PRDescription:
|
|||||||
# if the value is a list, join its items by comma
|
# if the value is a list, join its items by comma
|
||||||
if type(value) == list:
|
if type(value) == list:
|
||||||
value = ', '.join(v for v in value)
|
value = ', '.join(v for v in value)
|
||||||
pr_body += f"{value}\n\n___\n"
|
pr_body += f"{value}\n"
|
||||||
|
if idx < len(data) - 1:
|
||||||
|
pr_body += "\n___\n"
|
||||||
|
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"title:\n{title}\n{pr_body}")
|
logging.info(f"title:\n{title}\n{pr_body}")
|
||||||
|
Reference in New Issue
Block a user