Merge remote-tracking branch 'origin/main' into ok/settings_refactor

This commit is contained in:
Ori Kotek
2023-08-01 16:01:04 +03:00
6 changed files with 49 additions and 2 deletions

View File

@ -1,3 +1,10 @@
## 2023-08-01
### Enhanced
- Introduced the ability to retrieve commit messages from pull requests across different git providers.
- Implemented commit messages retrieval for GitHub and GitLab providers.
- Updated the PR description template to include a section for commit messages if they exist.
## 2023-07-30 ## 2023-07-30
### Enhanced ### Enhanced

View File

@ -120,3 +120,6 @@ class BitbucketProvider:
def _get_pr_file_content(self, remote_link: str): def _get_pr_file_content(self, remote_link: str):
return "" return ""
def get_commit_messages(self):
return "" # not implemented yet

View File

@ -344,4 +344,19 @@ class GithubProvider(GitProvider):
return [label.name for label in self.pr.labels] return [label.name for label in self.pr.labels]
except Exception as e: except Exception as e:
logging.exception(f"Failed to get labels, error: {e}") logging.exception(f"Failed to get labels, error: {e}")
return [] return []
def get_commit_messages(self) -> str:
"""
Retrieves the commit messages of a pull request.
Returns:
str: A string containing the commit messages of the pull request.
"""
try:
commit_list = self.pr.get_commits()
commit_messages = [commit.commit.message for commit in commit_list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)])
except:
commit_messages_str = ""
return commit_messages_str

View File

@ -297,3 +297,17 @@ class GitLabProvider(GitProvider):
def get_labels(self): def get_labels(self):
return self.mr.labels return self.mr.labels
def get_commit_messages(self) -> str:
"""
Retrieves the commit messages of a pull request.
Returns:
str: A string containing the commit messages of the pull request.
"""
try:
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
except:
commit_messages_str = ""
return commit_messages_str

View File

@ -36,8 +36,14 @@ Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'desc
user="""PR Info: user="""PR Info:
Branch: '{{branch}}' Branch: '{{branch}}'
{%- if language %} {%- if language %}
Main language: {{language}} Main language: {{language}}
{%- endif %} {%- endif %}
{%- if commit_messages_str %}
Commit messages:
{{commit_messages_str}}
{%- endif %}
The PR Git Diff: The PR Git Diff:

View File

@ -27,7 +27,8 @@ class PRDescription:
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
commit_messages_str = self.git_provider.get_commit_messages()
# Initialize the AI handler # Initialize the AI handler
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
@ -39,6 +40,7 @@ class PRDescription:
"language": self.main_pr_language, "language": self.main_pr_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"extra_instructions": get_settings().pr_description.extra_instructions, "extra_instructions": get_settings().pr_description.extra_instructions,
"commit_messages_str": commit_messages_str
} }
# Initialize the token handler # Initialize the token handler