Compare commits

..

3 Commits

Author SHA1 Message Date
0167003bbc handle no diffs 2023-07-28 01:59:10 +03:00
99ed9b22a1 latest documentation suggest get_all not all
https://python-gitlab.readthedocs.io/en/stable/api-usage.html#pagination
2023-07-27 15:39:19 +03:00
eee6d51b40 issue #145
get all diffs in merge request and not only gitlab default 20
2023-07-27 14:41:36 +03:00
26 changed files with 117 additions and 760 deletions

4
.gitignore vendored
View File

@ -1,6 +1,4 @@
.idea/ .idea/
venv/ venv/
pr_agent/settings/.secrets.toml pr_agent/settings/.secrets.toml
__pycache__ __pycache__
dist/
*.egg-info/

View File

@ -1,6 +0,0 @@
## 2023-07-26
### Added
- New feature for updating the CHANGELOG.md based on the contents of a PR.
- Added support for this feature for the Github provider.
- New configuration settings and prompts for the changelog update feature.

View File

@ -1,8 +1,8 @@
FROM python:3.10 as base FROM python:3.10 as base
WORKDIR /app WORKDIR /app
ADD pyproject.toml . ADD requirements.txt .
RUN pip install . && rm pyproject.toml RUN pip install -r requirements.txt && rm requirements.txt
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ADD pr_agent pr_agent ADD pr_agent pr_agent
ADD github_action/entrypoint.sh / ADD github_action/entrypoint.sh /

View File

@ -66,7 +66,6 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull
- [Usage and tools](#usage-and-tools) - [Usage and tools](#usage-and-tools)
- [Configuration](./CONFIGURATION.md) - [Configuration](./CONFIGURATION.md)
- [How it works](#how-it-works) - [How it works](#how-it-works)
- [Why use PR-Agent](#why-use-pr-agent)
- [Roadmap](#roadmap) - [Roadmap](#roadmap)
- [Similar projects](#similar-projects) - [Similar projects](#similar-projects)
</div> </div>
@ -82,7 +81,6 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull
| | Auto-Description | :white_check_mark: | :white_check_mark: | | | | Auto-Description | :white_check_mark: | :white_check_mark: | |
| | Improve Code | :white_check_mark: | :white_check_mark: | | | | Improve Code | :white_check_mark: | :white_check_mark: | |
| | Reflect and Review | :white_check_mark: | | | | | Reflect and Review | :white_check_mark: | | |
| | Update CHANGELOG.md | :white_check_mark: | | |
| | | | | | | | | | | |
| USAGE | CLI | :white_check_mark: | :white_check_mark: | :white_check_mark: | | USAGE | CLI | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| | App / webhook | :white_check_mark: | :white_check_mark: | | | | App / webhook | :white_check_mark: | :white_check_mark: | |
@ -100,7 +98,6 @@ Examples for invoking the different tools via the CLI:
- **Improve**: python cli.py --pr-url=<pr_url> improve - **Improve**: python cli.py --pr-url=<pr_url> improve
- **Ask**: python cli.py --pr-url=<pr_url> ask "Write me a poem about this PR" - **Ask**: python cli.py --pr-url=<pr_url> ask "Write me a poem about this PR"
- **Reflect**: python cli.py --pr-url=<pr_url> reflect - **Reflect**: python cli.py --pr-url=<pr_url> reflect
- **Update changelog**: python cli.py --pr-url=<pr_url> update_changelog
"<pr_url>" is the url of the relevant PR (for example: https://github.com/Codium-ai/pr-agent/pull/50). "<pr_url>" is the url of the relevant PR (for example: https://github.com/Codium-ai/pr-agent/pull/50).
@ -149,19 +146,6 @@ There are several ways to use PR-Agent:
Check out the [PR Compression strategy](./PR_COMPRESSION.md) page for more details on how we convert a code diff to a manageable LLM prompt Check out the [PR Compression strategy](./PR_COMPRESSION.md) page for more details on how we convert a code diff to a manageable LLM prompt
## Why use PR-Agent?
A reasonable question that can be asked is: `"Why use PR-Agent? What make it stand out from existing tools?"`
Here are some of the reasons why:
- We emphasize **real-life practical usage**. Each tool (review, improve, ask, ...) has a single GPT-4 call, no more. We feel that this is critical for realistic team usage - obtaining an answer quickly (~30 seconds) and affordably.
- Our [PR Compression strategy](./PR_COMPRESSION.md) is a core ability that enables to effectively tackle both short and long PRs.
- Our JSON prompting strategy enables to have **modular, customizable tools**. For example, the '/review' tool categories can be controlled via the configuration file. Adding additional categories is easy and accessible.
- We support **multiple git providers** (GitHub, Gitlab, Bitbucket), and multiple ways to use the tool (CLI, GitHub Action, Docker, ...).
- We are open-source, and welcome contributions from the community.
## Roadmap ## Roadmap
- [ ] Support open-source models, as a replacement for OpenAI models. (Note - a minimal requirement for each open-source model is to have 8k+ context, and good support for generating JSON as an output) - [ ] Support open-source models, as a replacement for OpenAI models. (Note - a minimal requirement for each open-source model is to have 8k+ context, and good support for generating JSON as an output)

View File

@ -1,8 +1,8 @@
FROM python:3.10 as base FROM python:3.10 as base
WORKDIR /app WORKDIR /app
ADD pyproject.toml . ADD requirements.txt .
RUN pip install . && rm pyproject.toml RUN pip install -r requirements.txt && rm requirements.txt
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ADD pr_agent pr_agent ADD pr_agent pr_agent

View File

@ -4,9 +4,9 @@ RUN yum update -y && \
yum install -y gcc python3-devel && \ yum install -y gcc python3-devel && \
yum clean all yum clean all
ADD pyproject.toml . ADD requirements.txt .
RUN pip install . && rm pyproject.toml RUN pip install -r requirements.txt && rm requirements.txt
RUN pip install mangum==0.17.0 RUN pip install mangum==16.0.0
COPY pr_agent/ ${LAMBDA_TASK_ROOT}/pr_agent/ COPY pr_agent/ ${LAMBDA_TASK_ROOT}/pr_agent/
CMD ["pr_agent.servers.serverless.serverless"] CMD ["pr_agent.servers.serverless.serverless"]

View File

@ -6,7 +6,6 @@ from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_information_from_user import PRInformationFromUser
from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
class PRAgent: class PRAgent:
@ -27,9 +26,7 @@ class PRAgent:
elif any(cmd == action for cmd in ["/improve", "/improve_code"]): elif any(cmd == action for cmd in ["/improve", "/improve_code"]):
await PRCodeSuggestions(pr_url).suggest() await PRCodeSuggestions(pr_url).suggest()
elif any(cmd == action for cmd in ["/ask", "/ask_question"]): elif any(cmd == action for cmd in ["/ask", "/ask_question"]):
await PRQuestions(pr_url, args=args).answer() await PRQuestions(pr_url, args).answer()
elif any(cmd == action for cmd in ["/update_changelog"]):
await PRUpdateChangelog(pr_url, args=args).update_changelog()
else: else:
return False return False

View File

@ -13,7 +13,7 @@ if settings.config.use_extra_bad_extensions:
def filter_bad_extensions(files): def filter_bad_extensions(files):
return [f for f in files if f.filename is not None and is_valid_file(f.filename)] return [f for f in files if is_valid_file(f.filename)]
def is_valid_file(filename): def is_valid_file(filename):

View File

@ -3,8 +3,6 @@ from __future__ import annotations
import logging import logging
from typing import Tuple, Union, Callable, List from typing import Tuple, Union, Callable, List
from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
@ -21,6 +19,7 @@ OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3 PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str: add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
""" """
@ -41,11 +40,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
global PATCH_EXTRA_LINES global PATCH_EXTRA_LINES
PATCH_EXTRA_LINES = 0 PATCH_EXTRA_LINES = 0
try: diff_files = list(git_provider.get_diff_files())
diff_files = list(git_provider.get_diff_files())
except RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
raise
# get pr languages # get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)

View File

@ -8,7 +8,6 @@ from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_information_from_user import PRInformationFromUser
from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
def run(args=None): def run(args=None):
@ -28,15 +27,13 @@ ask / ask_question [question] - Ask a question about the PR.
describe / describe_pr - Modify the PR title and description based on the PR's contents. describe / describe_pr - Modify the PR title and description based on the PR's contents.
improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit. improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
reflect - Ask the PR author questions about the PR. reflect - Ask the PR author questions about the PR.
update_changelog - Update the changelog based on the PR's contents.
""") """)
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True) parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr', parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr',
'ask', 'ask_question', 'ask', 'ask_question',
'describe', 'describe_pr', 'describe', 'describe_pr',
'improve', 'improve_code', 'improve', 'improve_code',
'reflect', 'review_after_reflect', 'reflect', 'review_after_reflect'],
'update_changelog'],
default='review') default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
args = parser.parse_args(args) args = parser.parse_args(args)
@ -52,8 +49,7 @@ update_changelog - Update the changelog based on the PR's contents.
'review': _handle_review_command, 'review': _handle_review_command,
'review_pr': _handle_review_command, 'review_pr': _handle_review_command,
'reflect': _handle_reflect_command, 'reflect': _handle_reflect_command,
'review_after_reflect': _handle_review_after_reflect_command, 'review_after_reflect': _handle_review_after_reflect_command
'update_changelog': _handle_update_changelog,
} }
if command in commands: if command in commands:
commands[command](args.pr_url, args.rest) commands[command](args.pr_url, args.rest)
@ -100,10 +96,6 @@ def _handle_review_after_reflect_command(pr_url: str, rest: list):
reviewer = PRReviewer(pr_url, cli_mode=True, is_answer=True) reviewer = PRReviewer(pr_url, cli_mode=True, is_answer=True)
asyncio.run(reviewer.review()) asyncio.run(reviewer.review())
def _handle_update_changelog(pr_url: str, rest: list):
print(f"Updating changlog for: {pr_url}")
reviewer = PRUpdateChangelog(pr_url, cli_mode=True, args=rest)
asyncio.run(reviewer.update_changelog())
if __name__ == '__main__': if __name__ == '__main__':
run() run()

View File

@ -1,11 +1,7 @@
from os.path import abspath, dirname, join from os.path import abspath, dirname, join
from pathlib import Path
from typing import Optional
from dynaconf import Dynaconf from dynaconf import Dynaconf
PR_AGENT_TOML_KEY = 'pr-agent'
current_dir = dirname(abspath(__file__)) current_dir = dirname(abspath(__file__))
settings = Dynaconf( settings = Dynaconf(
envvar_prefix=False, envvar_prefix=False,
@ -19,36 +15,6 @@ settings = Dynaconf(
"settings/pr_description_prompts.toml", "settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml", "settings/pr_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml", "settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog.toml",
"settings_prod/.secrets.toml" "settings_prod/.secrets.toml"
]] ]]
) )
# Add local configuration from pyproject.toml of the project being reviewed
def _find_repository_root() -> Path:
"""
Identify project root directory by recursively searching for the .git directory in the parent directories.
"""
cwd = Path.cwd().resolve()
no_way_up = False
while not no_way_up:
no_way_up = cwd == cwd.parent
if (cwd / ".git").is_dir():
return cwd
cwd = cwd.parent
return None
def _find_pyproject() -> Optional[Path]:
"""
Search for file pyproject.toml in the repository root.
"""
repo_root = _find_repository_root()
if repo_root:
pyproject = _find_repository_root() / "pyproject.toml"
return pyproject if pyproject.is_file() else None
return None
pyproject_path = _find_pyproject()
if pyproject_path is not None:
settings.load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')

View File

@ -2,13 +2,11 @@ from pr_agent.config_loader import settings
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
from pr_agent.git_providers.github_provider import GithubProvider from pr_agent.git_providers.github_provider import GithubProvider
from pr_agent.git_providers.gitlab_provider import GitLabProvider from pr_agent.git_providers.gitlab_provider import GitLabProvider
from pr_agent.git_providers.local_git_provider import LocalGitProvider
_GIT_PROVIDERS = { _GIT_PROVIDERS = {
'github': GithubProvider, 'github': GithubProvider,
'gitlab': GitLabProvider, 'gitlab': GitLabProvider,
'bitbucket': BitbucketProvider, 'bitbucket': BitbucketProvider,
'local' : LocalGitProvider
} }
def get_git_provider(): def get_git_provider():

View File

@ -136,4 +136,3 @@ class IncrementalPR:
self.commits_range = None self.commits_range = None
self.first_new_commit_sha = None self.first_new_commit_sha = None
self.last_seen_commit_sha = None self.last_seen_commit_sha = None

View File

@ -3,15 +3,13 @@ from datetime import datetime
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from github import AppAuthentication, Auth, Github, GithubException from github import AppAuthentication, Github, Auth
from retry import retry
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..servers.utils import RateLimitExceeded
class GithubProvider(GitProvider): class GithubProvider(GitProvider):
@ -80,34 +78,27 @@ class GithubProvider(GitProvider):
return self.file_set.values() return self.file_set.values()
return self.pr.get_files() return self.pr.get_files()
@retry(exceptions=RateLimitExceeded,
tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
try: files = self.get_files()
files = self.get_files() diff_files = []
diff_files = [] for file in files:
for file in files: if is_valid_file(file.filename):
if is_valid_file(file.filename): new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) patch = file.patch
patch = file.patch if self.incremental.is_incremental and self.file_set:
if self.incremental.is_incremental and self.file_set: original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha)
original_file_content_str = self._get_pr_file_content(file, patch = load_large_diff(file,
self.incremental.last_seen_commit_sha) new_file_content_str,
patch = load_large_diff(file, original_file_content_str,
new_file_content_str, None)
original_file_content_str, self.file_set[file.filename] = patch
None) else:
self.file_set[file.filename] = patch original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
diff_files.append( diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename))
self.diff_files = diff_files self.diff_files = diff_files
return diff_files return diff_files
except GithubException.RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for GitHub API. Original message: {e}")
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
def publish_description(self, pr_title: str, pr_body: str): def publish_description(self, pr_title: str, pr_body: str):
self.pr.edit(title=pr_title, body=pr_body) self.pr.edit(title=pr_title, body=pr_body)

View File

@ -11,9 +11,12 @@ from pr_agent.config_loader import settings
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
logger = logging.getLogger()
class GitLabProvider(GitProvider): class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = settings.get("GITLAB.URL", None) gitlab_url = settings.get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
@ -22,8 +25,8 @@ class GitLabProvider(GitProvider):
if not gitlab_access_token: if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file") raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab( self.gl = gitlab.Gitlab(
url=gitlab_url, gitlab_url,
oauth_token=gitlab_access_token gitlab_access_token
) )
self.id_project = None self.id_project = None
self.id_mr = None self.id_mr = None
@ -48,7 +51,12 @@ class GitLabProvider(GitProvider):
def _set_merge_request(self, merge_request_url: str): def _set_merge_request(self, merge_request_url: str):
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url) self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
self.mr = self._get_merge_request() self.mr = self._get_merge_request()
self.last_diff = self.mr.diffs.list()[-1] try:
self.last_diff = self.mr.diffs.list(get_all=True)[-1]
except IndexError as e:
logger.error(f"Could not get diff for merge request {self.id_mr}")
raise ValueError(f"Could not get diff for merge request {self.id_mr}") from e
def _get_pr_file_content(self, file_path: str, branch: str) -> str: def _get_pr_file_content(self, file_path: str, branch: str) -> str:
try: try:
@ -237,30 +245,20 @@ class GitLabProvider(GitProvider):
def get_issue_comments(self): def get_issue_comments(self):
raise NotImplementedError("GitLab provider does not support issue comments yet") raise NotImplementedError("GitLab provider does not support issue comments yet")
def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[str, int]: def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[int, int]:
parsed_url = urlparse(merge_request_url) parsed_url = urlparse(merge_request_url)
path_parts = parsed_url.path.strip('/').split('/') path_parts = parsed_url.path.strip('/').split('/')
if 'merge_requests' not in path_parts: if path_parts[-2] != 'merge_requests':
raise ValueError("The provided URL does not appear to be a GitLab merge request URL") raise ValueError("The provided URL does not appear to be a GitLab merge request URL")
mr_index = path_parts.index('merge_requests')
# Ensure there is an ID after 'merge_requests'
if len(path_parts) <= mr_index + 1:
raise ValueError("The provided URL does not contain a merge request ID")
try: try:
mr_id = int(path_parts[mr_index + 1]) mr_id = int(path_parts[-1])
except ValueError as e: except ValueError as e:
raise ValueError("Unable to convert merge request ID to integer") from e raise ValueError("Unable to convert merge request ID to integer") from e
# Handle special delimiter (-) # Gitlab supports access by both project numeric ID as well as 'namespace/project_name'
project_path = "/".join(path_parts[:mr_index]) return "/".join(path_parts[:2]), mr_id
if project_path.endswith('/-'):
project_path = project_path[:-2]
# Return the path before 'merge_requests' and the ID
return project_path, mr_id
def _get_merge_request(self): def _get_merge_request(self):
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr) mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)

View File

@ -1,178 +0,0 @@
import logging
from collections import Counter
from pathlib import Path
from typing import List
from git import Repo
from pr_agent.config_loader import _find_repository_root, settings
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
class PullRequestMimic:
"""
This class mimics the PullRequest class from the PyGithub library for the LocalGitProvider.
"""
def __init__(self, title: str, diff_files: List[FilePatchInfo]):
self.title = title
self.diff_files = diff_files
class LocalGitProvider(GitProvider):
"""
This class implements the GitProvider interface for local git repositories.
It mimics the PR functionality of the GitProvider interface,
but does not require a hosted git repository.
Instead of providing a PR url, the user provides a local branch path to generate a diff-patch.
For the MVP it only supports the /review and /describe capabilities.
"""
def __init__(self, target_branch_name, incremental=False):
self.repo_path = _find_repository_root()
if self.repo_path is None:
raise ValueError('Could not find repository root')
self.repo = Repo(self.repo_path)
self.head_branch_name = self.repo.head.ref.name
self.target_branch_name = target_branch_name
self._prepare_repo()
self.diff_files = None
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
self.description_path = settings.get('local.description_path') \
if settings.get('local.description_path') is not None else self.repo_path / 'description.md'
self.review_path = settings.get('local.review_path') \
if settings.get('local.review_path') is not None else self.repo_path / 'review.md'
# inline code comments are not supported for local git repositories
settings.pr_reviewer.inline_code_comments = False
def _prepare_repo(self):
"""
Prepare the repository for PR-mimic generation.
"""
logging.debug('Preparing repository for PR-mimic generation...')
if self.repo.is_dirty():
raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.')
if self.target_branch_name not in self.repo.heads:
raise KeyError(f'Branch: {self.target_branch_name} does not exist')
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels']:
return False
return True
def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
create_patch=True,
R=True
)
diff_files = []
for diff_item in diffs:
if diff_item.a_blob is not None:
original_file_content_str = diff_item.a_blob.data_stream.read().decode('utf-8')
else:
original_file_content_str = "" # empty file
if diff_item.b_blob is not None:
new_file_content_str = diff_item.b_blob.data_stream.read().decode('utf-8')
else:
new_file_content_str = "" # empty file
edit_type = EDIT_TYPE.MODIFIED
if diff_item.new_file:
edit_type = EDIT_TYPE.ADDED
elif diff_item.deleted_file:
edit_type = EDIT_TYPE.DELETED
elif diff_item.renamed_file:
edit_type = EDIT_TYPE.RENAMED
diff_files.append(
FilePatchInfo(original_file_content_str,
new_file_content_str,
diff_item.diff.decode('utf-8'),
diff_item.b_path,
edit_type=edit_type,
old_filename=None if diff_item.a_path == diff_item.b_path else diff_item.a_path
)
)
self.diff_files = diff_files
return diff_files
def get_files(self) -> List[str]:
"""
Returns a list of files with changes in the diff.
"""
diff_index = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
R=True
)
# Get the list of changed files
diff_files = [item.a_path for item in diff_index]
return diff_files
def publish_description(self, pr_title: str, pr_body: str):
with open(self.description_path, "w") as file:
# Write the string to the file
file.write(pr_title + '\n' + pr_body)
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
with open(self.review_path, "w") as file:
# Write the string to the file
file.write(pr_comment)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
raise NotImplementedError('Creating inline comments is not implemented for the local git provider')
def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
def publish_code_suggestion(self, body: str, relevant_file: str,
relevant_lines_start: int, relevant_lines_end: int):
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
def publish_code_suggestions(self, code_suggestions: list):
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
def publish_labels(self, labels):
pass # Not applicable to the local git provider, but required by the interface
def remove_initial_comment(self):
pass # Not applicable to the local git provider, but required by the interface
def get_languages(self):
"""
Calculate percentage of languages in repository. Used for hunk prioritisation.
"""
# Get all files in repository
filepaths = [Path(item.path) for item in self.repo.tree().traverse() if item.type == 'blob']
# Identify language by file extension and count
lang_count = Counter(ext.lstrip('.') for filepath in filepaths for ext in [filepath.suffix.lower()])
# Convert counts to percentages
total_files = len(filepaths)
lang_percentage = {lang: count / total_files * 100 for lang, count in lang_count.items()}
return lang_percentage
def get_pr_branch(self):
return self.repo.head
def get_user_id(self):
return -1 # Not used anywhere for the local provider, but required by the interface
def get_pr_description(self):
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
# Get the commit messages and concatenate
commit_messages = " ".join([commit.message for commit in commits_diff])
# TODO Handle the description better - maybe use gpt-3.5 summarisation here?
return commit_messages[:200] # Use max 200 characters
def get_pr_title(self):
"""
Substitutes the branch-name as the PR-mimic title.
"""
return self.head_branch_name
def get_issue_comments(self):
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
def get_labels(self):
raise NotImplementedError('Getting labels is not implemented for the local git provider')

View File

@ -1,4 +1,3 @@
from typing import Dict, Any
import logging import logging
import sys import sys
@ -15,66 +14,51 @@ router = APIRouter()
@router.post("/api/v1/github_webhooks") @router.post("/api/v1/github_webhooks")
async def handle_github_webhooks(request: Request, response: Response): async def handle_github_webhooks(request: Request, response: Response):
""" logging.debug("Received a github webhook")
Receives and processes incoming GitHub webhook requests.
Verifies the request signature, parses the request body, and passes it to the handle_request function for further processing.
"""
logging.debug("Received a GitHub webhook")
try: try:
body = await request.json() body = await request.json()
except Exception as e: except Exception as e:
logging.error("Error parsing request body", e) logging.error("Error parsing request body", e)
raise HTTPException(status_code=400, detail="Error parsing request body") from e raise HTTPException(status_code=400, detail="Error parsing request body") from e
body_bytes = await request.body() body_bytes = await request.body()
signature_header = request.headers.get('x-hub-signature-256', None) signature_header = request.headers.get('x-hub-signature-256', None)
try:
webhook_secret = getattr(settings.github, 'webhook_secret', None) webhook_secret = settings.github.webhook_secret
except AttributeError:
webhook_secret = None
if webhook_secret: if webhook_secret:
verify_signature(body_bytes, webhook_secret, signature_header) verify_signature(body_bytes, webhook_secret, signature_header)
logging.debug(f'Request body:\n{body}') logging.debug(f'Request body:\n{body}')
return await handle_request(body) return await handle_request(body)
async def handle_request(body: Dict[str, Any]): async def handle_request(body):
""" action = body.get("action", None)
Handle incoming GitHub webhook requests. installation_id = body.get("installation", {}).get("id", None)
Args:
body: The request body.
"""
action = body.get("action")
installation_id = body.get("installation", {}).get("id")
settings.set("GITHUB.INSTALLATION_ID", installation_id) settings.set("GITHUB.INSTALLATION_ID", installation_id)
agent = PRAgent() agent = PRAgent()
if action == 'created': if action == 'created':
if "comment" not in body: if "comment" not in body:
return {} return {}
comment_body = body.get("comment", {}).get("body") comment_body = body.get("comment", {}).get("body", None)
sender = body.get("sender", {}).get("login") if 'sender' in body and 'login' in body['sender'] and 'bot' in body['sender']['login']:
if sender and 'bot' in sender:
return {} return {}
if "issue" not in body or "pull_request" not in body["issue"]: if "issue" not in body and "pull_request" not in body["issue"]:
return {} return {}
pull_request = body["issue"]["pull_request"] pull_request = body["issue"]["pull_request"]
api_url = pull_request.get("url") api_url = pull_request.get("url", None)
await agent.handle_request(api_url, comment_body) await agent.handle_request(api_url, comment_body)
elif action in ["opened"] or 'reopened' in action: elif action in ["opened"] or 'reopened' in action:
pull_request = body.get("pull_request") pull_request = body.get("pull_request", None)
if not pull_request: if not pull_request:
return {} return {}
api_url = pull_request.get("url") api_url = pull_request.get("url", None)
if not api_url: if api_url is None:
return {} return {}
await agent.handle_request(api_url, "/review") await agent.handle_request(api_url, "/review")
else:
return {} return {}
@router.get("/") @router.get("/")
@ -92,4 +76,4 @@ def start():
if __name__ == '__main__': if __name__ == '__main__':
start() start()

View File

@ -15,40 +15,28 @@ NOTIFICATION_URL = "https://api.github.com/notifications"
def now() -> str: def now() -> str:
"""
Get the current UTC time in ISO 8601 format.
Returns:
str: The current UTC time in ISO 8601 format.
"""
now_utc = datetime.now(timezone.utc).isoformat() now_utc = datetime.now(timezone.utc).isoformat()
now_utc = now_utc.replace("+00:00", "Z") now_utc = now_utc.replace("+00:00", "Z")
return now_utc return now_utc
async def polling_loop(): async def polling_loop():
"""
Polls for notifications and handles them accordingly.
"""
handled_ids = set() handled_ids = set()
since = [now()] since = [now()]
last_modified = [None] last_modified = [None]
git_provider = get_git_provider()() git_provider = get_git_provider()()
user_id = git_provider.get_user_id() user_id = git_provider.get_user_id()
agent = PRAgent() agent = PRAgent()
try: try:
deployment_type = settings.github.deployment_type deployment_type = settings.github.deployment_type
token = settings.github.user_token token = settings.github.user_token
except AttributeError: except AttributeError:
deployment_type = 'none' deployment_type = 'none'
token = None token = None
if deployment_type != 'user': if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications") raise ValueError("Deployment mode must be set to 'user' to get notifications")
if not token: if not token:
raise ValueError("User token must be set to get notifications") raise ValueError("User token must be set to get notifications")
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
while True: while True:
try: try:
@ -64,7 +52,6 @@ async def polling_loop():
params["since"] = since[0] params["since"] = since[0]
if last_modified[0]: if last_modified[0]:
headers["If-Modified-Since"] = last_modified[0] headers["If-Modified-Since"] = last_modified[0]
async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response: async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response:
if response.status == 200: if response.status == 200:
if 'Last-Modified' in response.headers: if 'Last-Modified' in response.headers:
@ -113,4 +100,4 @@ async def polling_loop():
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(polling_loop()) asyncio.run(polling_loop())

View File

@ -21,7 +21,3 @@ def verify_signature(payload_body, secret_token, signature_header):
if not hmac.compare_digest(expected_signature, signature_header): if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!") raise HTTPException(status_code=403, detail="Request signatures didn't match!")
class RateLimitExceeded(Exception):
"""Raised when the git provider API rate limit has been exceeded."""
pass

View File

@ -1,6 +1,6 @@
[config] [config]
model="gpt-4" model="gpt-4"
fallback_models=["gpt-3.5-turbo-16k"] fallback-models=["gpt-3.5-turbo-16k", "gpt-3.5-turbo"]
git_provider="github" git_provider="github"
publish_output=true publish_output=true
publish_output_progress=true publish_output_progress=true
@ -24,13 +24,9 @@ publish_description_as_comment=false
[pr_code_suggestions] [pr_code_suggestions]
num_code_suggestions=4 num_code_suggestions=4
[pr_update_changelog]
push_changelog_changes=false
[github] [github]
# The type of deployment to create. Valid values are 'app' or 'user'. # The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user" deployment_type = "user"
ratelimit_retries = 5
[gitlab] [gitlab]
# URL to the gitlab service # URL to the gitlab service
@ -44,8 +40,3 @@ magic_word = "AutoReview"
# Polling interval # Polling interval
polling_interval_seconds = 30 polling_interval_seconds = 30
[local]
# LocalGitProvider settings - uncomment to use paths other than default
# description_path= "path/to/description.md"
# review_path= "path/to/review.md"

View File

@ -1,34 +0,0 @@
[pr_update_changelog_prompt]
system="""You are a language model called CodiumAI-PR-Changlog-summarizer.
Your task is to update the CHANGELOG.md file of the project, to shortly summarize important changes introduced in this PR (the '+' lines).
- The output should match the existing CHANGELOG.md format, style and conventions, so it will look like a natural part of the file. For example, if previous changes were summarized in a single line, you should do the same.
- Don't repeat previous changes. Generate only new content, that is not already in the CHANGELOG.md file.
- Be general, and avoid specific details, files, etc. The output should be minimal, no more than 3-4 short lines. Ignore non-relevant subsections.
"""
user="""PR Info:
Title: '{{title}}'
Branch: '{{branch}}'
Description: '{{description}}'
{%- if language %}
Main language: {{language}}
{%- endif %}
The PR Diff:
```
{{diff}}
```
Current date:
```
{{today}}
```
The current CHANGELOG.md:
```
{{ changelog_file_str }}
```
Response:
"""

View File

@ -2,7 +2,6 @@ import copy
import json import json
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -17,19 +16,7 @@ from pr_agent.servers.help import actions_help_text, bot_help_text
class PRReviewer: class PRReviewer:
""" def __init__(self, pr_url: str, cli_mode=False, is_answer: bool = False, args=None):
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, cli_mode: bool = False, is_answer: bool = False, args: list = None):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
Args:
pr_url (str): The URL of the pull request to be reviewed.
cli_mode (bool, optional): Indicates whether the review is being done in command-line interface mode. Defaults to False.
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
"""
self.parse_args(args) self.parse_args(args)
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
@ -38,15 +25,13 @@ class PRReviewer:
) )
self.pr_url = pr_url self.pr_url = pr_url
self.is_answer = is_answer self.is_answer = is_answer
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now") raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now")
answer_str, question_str = self._get_user_answers()
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode
answer_str, question_str = self._get_user_answers()
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(), "branch": self.git_provider.get_pr_branch(),
@ -58,27 +43,16 @@ class PRReviewer:
"require_security": settings.pr_reviewer.require_security_review, "require_security": settings.pr_reviewer.require_security_review,
"require_focused": settings.pr_reviewer.require_focused_review, "require_focused": settings.pr_reviewer.require_focused_review,
'num_code_suggestions': settings.pr_reviewer.num_code_suggestions, 'num_code_suggestions': settings.pr_reviewer.num_code_suggestions,
#
'question_str': question_str, 'question_str': question_str,
'answer_str': answer_str, 'answer_str': answer_str,
} }
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_review_prompt.system,
settings.pr_review_prompt.user)
self.token_handler = TokenHandler( def parse_args(self, args):
self.git_provider.pr,
self.vars,
settings.pr_review_prompt.system,
settings.pr_review_prompt.user
)
def parse_args(self, args: List[str]) -> None:
"""
Parse the arguments passed to the PRReviewer class and set the 'incremental' attribute accordingly.
Args:
args: A list of arguments passed to the PRReviewer class.
Returns:
None
"""
is_incremental = False is_incremental = False
if args and len(args) >= 1: if args and len(args) >= 1:
arg = args[0] arg = args[0]
@ -86,93 +60,60 @@ class PRReviewer:
is_incremental = True is_incremental = True
self.incremental = IncrementalPR(is_incremental) self.incremental = IncrementalPR(is_incremental)
async def review(self) -> None: async def review(self):
"""
Review the pull request and generate feedback.
"""
logging.info('Reviewing PR...') logging.info('Reviewing PR...')
if settings.config.publish_output: if settings.config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True) self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR review...') logging.info('Preparing PR review...')
pr_comment = self._prepare_pr_review() pr_comment = self._prepare_pr_review()
if settings.config.publish_output: if settings.config.publish_output:
logging.info('Pushing PR review...') logging.info('Pushing PR review...')
self.git_provider.publish_comment(pr_comment) self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
if settings.pr_reviewer.inline_code_comments: if settings.pr_reviewer.inline_code_comments:
logging.info('Pushing inline code comments...') logging.info('Pushing inline code comments...')
self._publish_inline_code_comments() self._publish_inline_code_comments()
return ""
async def _prepare_prediction(self, model: str) -> None: async def _prepare_prediction(self, model: str):
"""
Prepare the AI prediction for the pull request review.
Args:
model: A string representing the AI model to be used for the prediction.
Returns:
None
"""
logging.info('Getting PR diff...') logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
logging.info('Getting AI prediction...') logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction(model) self.prediction = await self._get_prediction(model)
async def _get_prediction(self, model: str) -> str: async def _get_prediction(self, model: str):
"""
Generate an AI prediction for the pull request review.
Args:
model: A string representing the AI model to be used for the prediction.
Returns:
A string representing the AI prediction for the pull request review.
"""
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables) system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables) user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
response, finish_reason = await self.ai_handler.chat_completion( system=system_prompt, user=user_prompt)
model=model,
temperature=0.2,
system=system_prompt,
user=user_prompt
)
return response return response
def _prepare_pr_review(self) -> str: def _prepare_pr_review(self) -> str:
"""
Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback.
"""
review = self.prediction.strip() review = self.prediction.strip()
try: try:
data = json.loads(review) data = json.loads(review)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
data = try_fix_json(review) data = try_fix_json(review)
# Move 'Security concerns' key to 'PR Analysis' section for better display # reordering for nicer display
if 'PR Feedback' in data and 'Security concerns' in data['PR Feedback']: if 'PR Feedback' in data:
val = data['PR Feedback']['Security concerns'] if 'Security concerns' in data['PR Feedback']:
del data['PR Feedback']['Security concerns'] val = data['PR Feedback']['Security concerns']
data['PR Analysis']['Security concerns'] = val del data['PR Feedback']['Security concerns']
data['PR Analysis']['Security concerns'] = val
# Filter out code suggestions that can be submitted as inline comments if settings.config.git_provider != 'bitbucket' and \
if settings.config.git_provider != 'bitbucket' and settings.pr_reviewer.inline_code_comments and 'Code suggestions' in data['PR Feedback']: settings.pr_reviewer.inline_code_comments and \
'Code suggestions' in data['PR Feedback']:
# keeping only code suggestions that can't be submitted as inline comments
data['PR Feedback']['Code suggestions'] = [ data['PR Feedback']['Code suggestions'] = [
d for d in data['PR Feedback']['Code suggestions'] d for d in data['PR Feedback']['Code suggestions']
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content')) if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
@ -180,8 +121,8 @@ class PRReviewer:
if not data['PR Feedback']['Code suggestions']: if not data['PR Feedback']['Code suggestions']:
del data['PR Feedback']['Code suggestions'] del data['PR Feedback']['Code suggestions']
# Add incremental review section
if self.incremental.is_incremental: if self.incremental.is_incremental:
# Rename title when incremental review - Add to the beginning of the dict
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}" last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}"
data = OrderedDict(data) data = OrderedDict(data)
data.update({'Incremental PR Review': { data.update({'Incremental PR Review': {
@ -191,7 +132,6 @@ class PRReviewer:
markdown_text = convert_to_markdown(data) markdown_text = convert_to_markdown(data)
user = self.git_provider.get_user_id() user = self.git_provider.get_user_id()
# Add help text if not in CLI mode
if not self.cli_mode: if not self.cli_mode:
markdown_text += "\n### How to use\n" markdown_text += "\n### How to use\n"
if user and '[bot]' not in user: if user and '[bot]' not in user:
@ -199,16 +139,11 @@ class PRReviewer:
else: else:
markdown_text += actions_help_text markdown_text += actions_help_text
# Log markdown response if verbosity level is high
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}") logging.info(f"Markdown response:\n{markdown_text}")
return markdown_text return markdown_text
def _publish_inline_code_comments(self) -> None: def _publish_inline_code_comments(self):
"""
Publishes inline comments on a pull request with code suggestions generated by the AI model.
"""
if settings.pr_reviewer.num_code_suggestions == 0: if settings.pr_reviewer.num_code_suggestions == 0:
return return
@ -218,11 +153,11 @@ class PRReviewer:
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
data = try_fix_json(review) data = try_fix_json(review)
comments: List[str] = [] comments = []
for suggestion in data.get('PR Feedback', {}).get('Code suggestions', []): for d in data['PR Feedback']['Code suggestions']:
relevant_file = suggestion.get('relevant file', '').strip() relevant_file = d.get('relevant file', '').strip()
relevant_line_in_file = suggestion.get('relevant line in file', '').strip() relevant_line_in_file = d.get('relevant line in file', '').strip()
content = suggestion.get('suggestion content', '') content = d.get('suggestion content', '')
if not relevant_file or not relevant_line_in_file or not content: if not relevant_file or not relevant_line_in_file or not content:
logging.info("Skipping inline comment with missing file/line/content") logging.info("Skipping inline comment with missing file/line/content")
continue continue
@ -237,26 +172,15 @@ class PRReviewer:
if comments: if comments:
self.git_provider.publish_inline_comments(comments) self.git_provider.publish_inline_comments(comments)
def _get_user_answers(self) -> Tuple[str, str]: def _get_user_answers(self):
""" answer_str = question_str = ""
Retrieves the question and answer strings from the discussion messages related to a pull request.
Returns:
A tuple containing the question and answer strings.
"""
question_str = ""
answer_str = ""
if self.is_answer: if self.is_answer:
discussion_messages = self.git_provider.get_issue_comments() discussion_messages = self.git_provider.get_issue_comments()
for message in discussion_messages.reversed:
for message in reversed(discussion_messages):
if "Questions to better understand the PR:" in message.body: if "Questions to better understand the PR:" in message.body:
question_str = message.body question_str = message.body
elif '/answer' in message.body: elif '/answer' in message.body:
answer_str = message.body answer_str = message.body
if answer_str and question_str: if answer_str and question_str:
break break
return question_str, answer_str return question_str, answer_str

View File

@ -1,171 +0,0 @@
import copy
import logging
from datetime import date
from time import sleep
from typing import Tuple
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider, GithubProvider
from pr_agent.git_providers.git_provider import get_main_pr_language
CHANGELOG_LINES = 50
class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.commit_changelog = self._parse_args(args, settings)
self._get_changlog_file() # self.changelog_file_str
self.ai_handler = AiHandler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"changelog_file_str": self.changelog_file_str,
"today": date.today(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_update_changelog_prompt.system,
settings.pr_update_changelog_prompt.user)
async def update_changelog(self):
assert type(self.git_provider) == GithubProvider, "Currently only Github is supported"
logging.info('Updating the changelog...')
if settings.config.publish_output:
self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR changelog updates...')
new_file_content, answer = self._prepare_changelog_update()
if settings.config.publish_output:
self.git_provider.remove_initial_comment()
logging.info('Publishing changelog updates...')
if self.commit_changelog:
logging.info('Pushing PR changelog updates to repo...')
self._push_changelog_update(new_file_content, answer)
else:
logging.info('Publishing PR changelog as comment...')
self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}")
async def _prepare_prediction(self, model: str):
logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)
async def _get_prediction(self, model: str):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_update_changelog_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_update_changelog_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
def _prepare_changelog_update(self) -> Tuple[str, str]:
answer = self.prediction.strip().strip("```").strip()
if hasattr(self, "changelog_file"):
existing_content = self.changelog_file.decoded_content.decode()
else:
existing_content = ""
if existing_content:
new_file_content = answer + "\n\n" + self.changelog_file.decoded_content.decode()
else:
new_file_content = answer
if not self.commit_changelog:
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
"\n>'/update_changelog -commit'\n"
if settings.config.verbosity_level >= 2:
logging.info(f"answer:\n{answer}")
return new_file_content, answer
def _push_changelog_update(self, new_file_content, answer):
self.git_provider.repo_obj.update_file(path=self.changelog_file.path,
message="Update CHANGELOG.md",
content=new_file_content,
sha=self.changelog_file.sha,
branch=self.git_provider.get_pr_branch())
d = dict(body="CHANGELOG.md update",
path=self.changelog_file.path,
line=max(2, len(answer.splitlines())),
start_line=1)
sleep(5) # wait for the file to be updated
last_commit_id = list(self.git_provider.pr.get_commits())[-1]
try:
self.git_provider.pr.create_review(commit=last_commit_id, comments=[d])
except:
# we can't create a review for some reason, let's just publish a comment
self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}")
def _get_default_changelog(self):
example_changelog = \
"""
Example:
## <current_date>
### Added
...
### Changed
...
### Fixed
...
"""
return example_changelog
def _parse_args(self, args, setting):
commit_changelog = False
if args and len(args) >= 1:
try:
if args[0] == "-commit":
commit_changelog = True
except:
pass
else:
commit_changelog = setting.pr_update_changelog.push_changelog_changes
return commit_changelog
def _get_changlog_file(self):
try:
self.changelog_file = self.git_provider.repo_obj.get_contents("CHANGELOG.md",
ref=self.git_provider.get_pr_branch())
changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines()
changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES]
self.changelog_file_str = "\n".join(changelog_file_lines)
except:
self.changelog_file_str = ""
if self.commit_changelog:
logging.info("No CHANGELOG.md file found in the repository. Creating one...")
changelog_file = self.git_provider.repo_obj.create_file(path="CHANGELOG.md",
message='add CHANGELOG.md',
content="",
branch=self.git_provider.get_pr_branch())
self.changelog_file = changelog_file['content']
if not self.changelog_file_str:
self.changelog_file_str = self._get_default_changelog()

View File

@ -1,63 +1,3 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "pr_agent"
version = "0.0.1"
authors = [
{name = "Itamar Friedman", email = "itamar.f@codium.ai"},
]
maintainers = [
{name = "Ori Kotek", email = "ori.k@codium.ai"},
{name = "Tal Ridnik", email = "tal.r@codium.ai"},
{name = "Hussam Lawen", email = "hussam.l@codium.ai"},
{name = "Sagi Medina", email = "sagi.m@codium.ai"}
]
description = "CodiumAI PR-Agent is an open-source tool to automatically analyze a pull request and provide several types of feedback"
readme = "README.md"
requires-python = ">=3.9"
keywords = ["ai", "tool", "developer", "review", "agent"]
license = {file = "LICENSE", name = "Apache 2.0 License"}
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Operating System :: Independent",
"Programming Language :: Python :: 3",
]
dependencies = [
"dynaconf==3.1.12",
"fastapi==0.99.0",
"PyGithub==1.59.*",
"retry==0.9.2",
"openai==0.27.8",
"Jinja2==3.1.2",
"tiktoken==0.4.0",
"uvicorn==0.22.0",
"python-gitlab==3.15.0",
"pytest~=7.4.0",
"aiohttp~=3.8.4",
"atlassian-python-api==3.39.0",
"GitPython~=3.1.32",
]
[project.urls]
"Homepage" = "https://github.com/Codium-ai/pr-agent"
[tool.setuptools]
include-package-data = false
license-files = ["LICENSE"]
[tool.setuptools.packages.find]
where = ["."]
include = ["pr_agent"]
[project.scripts]
pr-agent = "pr_agent.cli:run"
[tool.ruff] [tool.ruff]
line-length = 120 line-length = 120

View File

@ -1 +1,12 @@
-e . dynaconf==3.1.12
fastapi==0.99.0
PyGithub==1.59.*
retry==0.9.2
openai==0.27.8
Jinja2==3.1.2
tiktoken==0.4.0
uvicorn==0.22.0
python-gitlab==3.15.0
pytest~=7.4.0
aiohttp~=3.8.4
atlassian-python-api==3.39.0

View File

@ -1,5 +0,0 @@
# for compatibility with legacy tools
# see: https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html
from setuptools import setup
setup()