Compare commits

..

43 Commits

Author SHA1 Message Date
c5a79ceedd Merge remote-tracking branch 'origin/main' into ok/settings_refactor 2023-08-01 16:01:04 +03:00
13547afc58 Merge pull request #163 from Codium-ai/tr/commit_messages
Adding Commit Messages Retrieval Functionality
2023-08-01 15:59:26 +03:00
8ae936e504 Bug fixes 2023-08-01 15:58:23 +03:00
e577d27f9b Update CHANGELOG.md 2023-08-01 12:38:31 +00:00
dfb73c963a get_commit_messages for gitlab 2023-08-01 15:30:14 +03:00
8c0370a166 Commit messages in pr-description 2023-08-01 15:15:59 +03:00
d7b77764c3 Support context aware settings (for each incoming request), support override of settings, refactor CLI to use pr_agent.py 2023-08-01 14:43:26 +03:00
6605f9c444 typos in 'commands_text' 2023-07-31 11:02:30 +03:00
2a8adcbbd6 update README.md 2023-07-30 22:16:56 +03:00
0b22c8d427 update README.md 2023-07-30 22:04:59 +03:00
dfa0d9fd43 update README.md 2023-07-30 22:01:14 +03:00
c8470645e2 add tests and update README.md 2023-07-30 21:54:07 +03:00
5a181e52d5 Merge pull request #159 from Codium-ai/tr/edit_any_config_setting
The Configurator Strikes Back
2023-07-30 15:19:07 +03:00
0ad8dcd2aa Merge remote-tracking branch 'origin/tr/edit_any_config_setting' into tr/edit_any_config_setting 2023-07-30 12:27:40 +03:00
e2d015a20c final 2023-07-30 12:27:32 +03:00
a0cfe4b48a Update CHANGELOG.md 2023-07-30 12:26:53 +03:00
a6ba8b614a Example args 2023-07-30 12:16:43 +03:00
4f0fabd2ca update_settings_from_args refactor 2023-07-30 12:14:26 +03:00
42b047a14e update_settings_from_args 2023-07-30 12:04:57 +03:00
3daf94954a update_settings_from_args 2023-07-30 11:43:44 +03:00
b564d8ac32 Merge pull request #147 from zmeir/zmeir-align_describe_styling
Minor improvements to describe command
2023-07-28 20:55:15 +03:00
d8e6da74db Update .dockerignore 2023-07-28 12:15:17 +03:00
278f1883fd Merge pull request #153 from marshally/fix_iteration_error_in_reflect_tmp
fix TypeError when iterating discussion_messages
2023-07-28 12:12:12 +03:00
ef71a7049e fix TypeError when iterating discussion_messages
When `pr-agent` is reviewing a long list of messages, a TypeError is thrown on the line

```python
for message in reversed(discussion_messages):
```

When reviewing the PyGithub library, the recommend an alternate syntax for iterating a paginated list in reverse.

https://github.com/PyGithub/PyGithub/blob/v1.59.0/github/PaginatedList.py#L122-L125

```
    If you want to iterate in reversed order, just do::

        for repo in user.get_repos().reversed:
            print(repo.name)
```

And here's a copy of the actual traceback

```
Traceback (most recent call last):
  File "/app/pr_agent/servers/github_action_runner.py", line 68, in <module>
    asyncio.run(run_action())
  File "/usr/local/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/app/pr_agent/servers/github_action_runner.py", line 64, in run_action
    await PRAgent().handle_request(pr_url, body)
  File "/app/pr_agent/agent/pr_agent.py", line 19, in handle_request
    await PRReviewer(pr_url, is_answer=True).review()
  File "/app/pr_agent/tools/pr_reviewer.py", line 49, in __init__
    answer_str, question_str = self._get_user_answers()
  File "/app/pr_agent/tools/pr_reviewer.py", line 253, in _get_user_answers
    for message in reversed(discussion_messages):
TypeError: object of type 'PaginatedList' has no len()
```
2023-07-28 11:04:46 +02:00
6fde87b3bd Merge pull request #152 from Codium-ai/tr/gitlab_fixes
Improvements and Error Handling for GitLab Provider
2023-07-28 11:40:53 +03:00
07fe91e57b Update CHANGELOG.md 2023-07-28 08:39:42 +00:00
01e2f3f0cd Merge pull request #150 from Codium-ai/ok/handle_installation_id_properly
Github App: handle concurrent requests from multiple installations of app
2023-07-28 11:38:14 +03:00
63a703c000 Handle marketplace hook 2023-07-28 11:30:51 +03:00
4664d91844 bug fixes in gitlab code suggestion 2023-07-28 11:24:14 +03:00
8f16c46012 try-except 2023-07-28 10:52:49 +03:00
a8780f722d Handle marketplace hook 2023-07-28 03:22:25 +03:00
1a8fce1505 Updated handling of installation id 2023-07-28 02:44:28 +03:00
8519b106f9 Updated .gitignore 2023-07-28 02:28:50 +03:00
d375dd62fe Merge pull request #141 from patryk-kowalski-ds/pg/pip_package
Transition to pip package with pyproject.toml
2023-07-28 02:23:06 +03:00
3770bf8031 Update setup.py 2023-07-28 02:22:38 +03:00
42388b1f8d Merge pull request #146 from idavidov/idsvidov/gitlabpaginator_fix
Fix for GitLab Paginator in GitLab Provider
2023-07-28 02:01:04 +03:00
0167003bbc handle no diffs 2023-07-28 01:59:10 +03:00
2ce91fbdf5 Merge pull request #148 from eltociear/patch-1
Fix typo in PR_COMPRESSION.md
2023-07-28 01:50:30 +03:00
aa7659d6bf Fix typo in PR_COMPRESSION.md
Withing -> Within
2023-07-28 00:18:58 +09:00
4aa54b9bd4 Add /describe -c option 2023-07-27 17:42:50 +03:00
c6d0bacc08 Match styling of both /describe modes 2023-07-27 17:31:31 +03:00
99ed9b22a1 latest documentation suggest get_all not all
https://python-gitlab.readthedocs.io/en/stable/api-usage.html#pagination
2023-07-27 15:39:19 +03:00
eee6d51b40 issue #145
get all diffs in merge request and not only gitlab default 20
2023-07-27 14:41:36 +03:00
45 changed files with 553 additions and 401 deletions

View File

@ -1,3 +1,5 @@
venv/ venv/
pr_agent/settings/.secrets.toml pr_agent/settings/.secrets.toml
pics/ pics/
pr_agent.egg-info/
build/

2
.gitignore vendored
View File

@ -4,3 +4,5 @@ pr_agent/settings/.secrets.toml
__pycache__ __pycache__
dist/ dist/
*.egg-info/ *.egg-info/
build/
review.md

View File

@ -1,3 +1,24 @@
## 2023-08-01
### Enhanced
- Introduced the ability to retrieve commit messages from pull requests across different git providers.
- Implemented commit messages retrieval for GitHub and GitLab providers.
- Updated the PR description template to include a section for commit messages if they exist.
## 2023-07-30
### Enhanced
- Added the ability to modify any configuration parameter from 'configuration.toml' on-the-fly.
- Updated the command line interface and bot commands to accept configuration changes as arguments.
- Improved the PR agent to handle additional arguments for each action.
## 2023-07-28
### Improved
- Enhanced error handling and logging in the GitLab provider.
- Improved handling of inline comments and code suggestions in GitLab.
- Fixed a bug where an additional unneeded line was added to code suggestions in GitLab.
## 2023-07-26 ## 2023-07-26
### Added ### Added

View File

@ -1,19 +1,12 @@
## Configuration ## Configuration
The different tools and sub-tools used by CodiumAI pr-agent are easily configurable via the configuration file: `/pr-agent/settings/configuration.toml`. The different tools and sub-tools used by CodiumAI pr-agent are adjustable via the configuration file: `/pr-agent/settings/configuration.toml`.
##### Git Provider:
You can select your git_provider with the flag `git_provider` in the `config` section
##### PR Reviewer: To edit the configuration of any tool, just add `--config_path=<value>` to you command.
For example if you want to edit online the `pr_reviewer` configurations, you can run:
```
/review --pr_reviewer.extra_instructions="focus on the file xyz" --pr_reviewer.require_score_review=false ...
```
Any configuration value in `configuration.toml` file can be similarly edited.
You can enable/disable the different PR Reviewer abilities with the following flags (`pr_reviewer` section):
```
require_focused_review=true
require_score_review=true
require_tests_review=true
require_security_review=true
```
You can contol the number of suggestions returned by the PR Reviewer with the following flag:
```inline_code_comments=3```
And enable/disable the inline code suggestions with the following flag:
```inline_code_comments=true```

View File

@ -31,7 +31,7 @@ We prioritize additions over deletions:
- File patches are a list of hunks, remove all hunks of type deletion-only from the hunks in the file patch - File patches are a list of hunks, remove all hunks of type deletion-only from the hunks in the file patch
#### Adaptive and token-aware file patch fitting #### Adaptive and token-aware file patch fitting
We use [tiktoken](https://github.com/openai/tiktoken) to tokenize the patches after the modifications described above, and we use the following strategy to fit the patches into the prompt: We use [tiktoken](https://github.com/openai/tiktoken) to tokenize the patches after the modifications described above, and we use the following strategy to fit the patches into the prompt:
1. Withing each language we sort the files by the number of tokens in the file (in descending order): 1. Within each language we sort the files by the number of tokens in the file (in descending order):
* ```[[file2.py, file.py],[file4.jsx, file3.js],[readme.md]]``` * ```[[file2.py, file.py],[file4.jsx, file3.js],[readme.md]]```
2. Iterate through the patches in the order described above 2. Iterate through the patches in the order described above
2. Add the patches to the prompt until the prompt reaches a certain buffer from the max token length 2. Add the patches to the prompt until the prompt reaches a certain buffer from the max token length

View File

@ -23,7 +23,9 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull
\ \
**Question Answering**: Answering free-text questions about the PR. **Question Answering**: Answering free-text questions about the PR.
\ \
**Code Suggestion**: Committable code suggestions for improving the PR. **Code Suggestions**: Committable code suggestions for improving the PR.
\
**Update Changelog**: Automatically updating the CHANGELOG.md file with the PR changes.
<h3>Example results:</h2> <h3>Example results:</h2>
</div> </div>
@ -100,7 +102,7 @@ Examples for invoking the different tools via the CLI:
- **Improve**: python cli.py --pr-url=<pr_url> improve - **Improve**: python cli.py --pr-url=<pr_url> improve
- **Ask**: python cli.py --pr-url=<pr_url> ask "Write me a poem about this PR" - **Ask**: python cli.py --pr-url=<pr_url> ask "Write me a poem about this PR"
- **Reflect**: python cli.py --pr-url=<pr_url> reflect - **Reflect**: python cli.py --pr-url=<pr_url> reflect
- **Update changelog**: python cli.py --pr-url=<pr_url> update_changelog - **Update Changelog**: python cli.py --pr-url=<pr_url> update_changelog
"<pr_url>" is the url of the relevant PR (for example: https://github.com/Codium-ai/pr-agent/pull/50). "<pr_url>" is the url of the relevant PR (for example: https://github.com/Codium-ai/pr-agent/pull/50).
@ -135,13 +137,14 @@ There are several ways to use PR-Agent:
## Usage and Tools ## Usage and Tools
**PR-Agent** provides five types of interactions ("tools"): `"PR Reviewer"`, `"PR Q&A"`, `"PR Description"`, `"PR Code Sueggestions"` and `"PR Reflect and Review"`. **PR-Agent** provides six types of interactions ("tools"): `"PR Reviewer"`, `"PR Q&A"`, `"PR Description"`, `"PR Code Sueggestions"`, `"PR Reflect and Review"` and `"PR Update Changlog"`.
- The "PR Reviewer" tool automatically analyzes PRs, and provides various types of feedback. - The "PR Reviewer" tool automatically analyzes PRs, and provides various types of feedback.
- The "PR Q&A" tool answers free-text questions about the PR. - The "PR Q&A" tool answers free-text questions about the PR.
- The "PR Description" tool automatically sets the PR Title and body. - The "PR Description" tool automatically sets the PR Title and body.
- The "PR Code Suggestion" tool provide inline code suggestions for the PR that can be applied and committed. - The "PR Code Suggestion" tool provide inline code suggestions for the PR that can be applied and committed.
- The "PR Reflect and Review" tool initiates a dialog with the user, asks them to reflect on the PR, and then provides a more focused review. - The "PR Reflect and Review" tool initiates a dialog with the user, asks them to reflect on the PR, and then provides a more focused review.
- The "PR Update Changelog" tool automatically updates the CHANGELOG.md file with the PR changes.
## How it works ## How it works
@ -158,7 +161,7 @@ Here are some of the reasons why:
- We emphasize **real-life practical usage**. Each tool (review, improve, ask, ...) has a single GPT-4 call, no more. We feel that this is critical for realistic team usage - obtaining an answer quickly (~30 seconds) and affordably. - We emphasize **real-life practical usage**. Each tool (review, improve, ask, ...) has a single GPT-4 call, no more. We feel that this is critical for realistic team usage - obtaining an answer quickly (~30 seconds) and affordably.
- Our [PR Compression strategy](./PR_COMPRESSION.md) is a core ability that enables to effectively tackle both short and long PRs. - Our [PR Compression strategy](./PR_COMPRESSION.md) is a core ability that enables to effectively tackle both short and long PRs.
- Our JSON prompting strategy enables to have **modular, customizable tools**. For example, the '/review' tool categories can be controlled via the configuration file. Adding additional categories is easy and accessible. - Our JSON prompting strategy enables to have **modular, customizable tools**. For example, the '/review' tool categories can be controlled via the configuration file. Adding additional categories is easy and accessible.
- We support **multiple git providers** (GitHub, Gitlab, Bitbucket), and multiple ways to use the tool (CLI, GitHub Action, Docker, ...). - We support **multiple git providers** (GitHub, Gitlab, Bitbucket), and multiple ways to use the tool (CLI, GitHub Action, GitHub App, Docker, ...).
- We are open-source, and welcome contributions from the community. - We are open-source, and welcome contributions from the community.

View File

@ -1,6 +1,7 @@
import re import shlex
from pr_agent.config_loader import settings from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_information_from_user import PRInformationFromUser
@ -8,29 +9,40 @@ from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
command2class = {
"answer": PRReviewer,
"review": PRReviewer,
"review_pr": PRReviewer,
"reflect": PRInformationFromUser,
"reflect_and_review": PRInformationFromUser,
"describe": PRDescription,
"describe_pr": PRDescription,
"improve": PRCodeSuggestions,
"improve_code": PRCodeSuggestions,
"ask": PRQuestions,
"ask_question": PRQuestions,
"update_changelog": PRUpdateChangelog,
}
commands = list(command2class.keys())
class PRAgent: class PRAgent:
def __init__(self): def __init__(self):
pass pass
async def handle_request(self, pr_url, request) -> bool: async def handle_request(self, pr_url, request) -> bool:
action, *args = request.strip().split() request = request.replace("'", "\\'")
if any(cmd == action for cmd in ["/answer"]): lexer = shlex.shlex(request, posix=True)
await PRReviewer(pr_url, is_answer=True).review() lexer.whitespace_split = True
elif any(cmd == action for cmd in ["/review", "/review_pr", "/reflect_and_review"]): action, *args = list(lexer)
if settings.pr_reviewer.ask_and_reflect or "/reflect_and_review" in request: args = update_settings_from_args(args)
await PRInformationFromUser(pr_url).generate_questions() action = action.lstrip("/").lower()
else: if action == "reflect_and_review" and not get_settings().pr_reviewer.ask_and_reflect:
await PRReviewer(pr_url, args=args).review() action = "review"
elif any(cmd == action for cmd in ["/describe", "/describe_pr"]): if action == "answer":
await PRDescription(pr_url).describe() await PRReviewer(pr_url, is_answer=True, args=args).run()
elif any(cmd == action for cmd in ["/improve", "/improve_code"]): elif action in command2class:
await PRCodeSuggestions(pr_url).suggest() await command2class[action](pr_url, args=args).run()
elif any(cmd == action for cmd in ["/ask", "/ask_question"]):
await PRQuestions(pr_url, args=args).answer()
elif any(cmd == action for cmd in ["/update_changelog"]):
await PRUpdateChangelog(pr_url, args=args).update_changelog()
else: else:
return False return False
return True return True

View File

@ -1,10 +1,10 @@
import logging import logging
import openai import openai
from openai.error import APIError, Timeout, TryAgain, RateLimitError from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry from retry import retry
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
OPENAI_RETRIES=5 OPENAI_RETRIES=5
@ -21,16 +21,16 @@ class AiHandler:
Raises a ValueError if the OpenAI key is missing. Raises a ValueError if the OpenAI key is missing.
""" """
try: try:
openai.api_key = settings.openai.key openai.api_key = get_settings().openai.key
if settings.get("OPENAI.ORG", None): if get_settings().get("OPENAI.ORG", None):
openai.organization = settings.openai.org openai.organization = get_settings().openai.org
self.deployment_id = settings.get("OPENAI.DEPLOYMENT_ID", None) self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
if settings.get("OPENAI.API_TYPE", None): if get_settings().get("OPENAI.API_TYPE", None):
openai.api_type = settings.openai.api_type openai.api_type = get_settings().openai.api_type
if settings.get("OPENAI.API_VERSION", None): if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = settings.openai.api_version openai.api_version = get_settings().openai.api_version
if settings.get("OPENAI.API_BASE", None): if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = settings.openai.api_base openai.api_base = get_settings().openai.api_base
except AttributeError as e: except AttributeError as e:
raise ValueError("OpenAI key is required") from e raise ValueError("OpenAI key is required") from e

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import re import re
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
def extend_patch(original_file_str, patch_str, num_lines) -> str: def extend_patch(original_file_str, patch_str, num_lines) -> str:
@ -55,7 +55,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
continue continue
extended_patch_lines.append(line) extended_patch_lines.append(line)
except Exception as e: except Exception as e:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.error(f"Failed to extend patch: {e}") logging.error(f"Failed to extend patch: {e}")
return patch_str return patch_str
@ -126,14 +126,14 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
""" """
if not new_file_content_str: if not new_file_content_str:
# logic for handling deleted files - don't show patch, just show that the file was deleted # logic for handling deleted files - don't show patch, just show that the file was deleted
if settings.config.verbosity_level > 0: if get_settings().config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, minimizing deletion file") logging.info(f"Processing file: {file_name}, minimizing deletion file")
patch = None # file was deleted patch = None # file was deleted
else: else:
patch_lines = patch.splitlines() patch_lines = patch.splitlines()
patch_new = omit_deletion_hunks(patch_lines) patch_new = omit_deletion_hunks(patch_lines)
if patch != patch_new: if patch != patch_new:
if settings.config.verbosity_level > 0: if get_settings().config.verbosity_level > 0:
logging.info(f"Processing file: {file_name}, hunks were deleted") logging.info(f"Processing file: {file_name}, hunks were deleted")
patch = patch_new patch = patch_new
return patch return patch
@ -141,7 +141,8 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str: def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
""" """
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of the file. Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
the file.
Args: Args:
patch (str): The patch string to be converted. patch (str): The patch string to be converted.

View File

@ -1,15 +1,15 @@
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501 # Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
from typing import Dict from typing import Dict
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
language_extension_map_org = settings.language_extension_map_org language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501 # Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
bad_extensions = settings.bad_extensions.default bad_extensions = get_settings().bad_extensions.default
if settings.config.use_extra_bad_extensions: if get_settings().config.use_extra_bad_extensions:
bad_extensions += settings.bad_extensions.extra bad_extensions += get_settings().bad_extensions.extra
def filter_bad_extensions(files): def filter_bad_extensions(files):

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Tuple, Union, Callable, List from typing import Callable, Tuple
from github import RateLimitExceededException from github import RateLimitExceededException
@ -10,7 +10,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_large_diff from pr_agent.algo.utils import load_large_diff
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider from pr_agent.git_providers.git_provider import GitProvider
DELETED_FILES_ = "Deleted files:\n" DELETED_FILES_ = "Deleted files:\n"
@ -27,11 +27,15 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
Returns a string with the diff of the pull request, applying diff minimization techniques if needed. Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
Args: Args:
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request. git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. request.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The name of the model used for tokenization. model (str): The name of the model used for tokenization.
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False. add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False. diff. Defaults to False.
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with
extra lines of context. Defaults to False.
Returns: Returns:
str: A string with the diff of the pull request, applying diff minimization techniques if needed. str: A string with the diff of the pull request, applying diff minimization techniques if needed.
@ -76,10 +80,12 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> \ add_line_numbers_to_hunks: bool) -> \
Tuple[list, int]: Tuple[list, int]:
""" """
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff minimization techniques if needed. Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed.
Args: Args:
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding files. - pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request. - token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff. - add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
@ -119,10 +125,13 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]: convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
""" """
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used. Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
tokens used.
Args: Args:
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files. top_langs (list): A list of dictionaries representing the languages used in the pull request and their
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. corresponding files.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The model used for tokenization. model (str): The model used for tokenization.
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff. convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
Returns: Returns:
@ -181,7 +190,7 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
# Current logic is to skip the patch if it's too large # Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens # TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements # until we meet the requirements
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.warning(f"Patch too large, minimizing it, {file.filename}") logging.warning(f"Patch too large, minimizing it, {file.filename}")
if not modified_files_list: if not modified_files_list:
total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_) total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
@ -196,15 +205,15 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
patch_final = patch patch_final = patch
patches.append(patch_final) patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final) total_tokens += token_handler.count_tokens(patch_final)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}") logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
return patches, modified_files_list, deleted_files_list return patches, modified_files_list, deleted_files_list
async def retry_with_fallback_models(f: Callable): async def retry_with_fallback_models(f: Callable):
model = settings.config.model model = get_settings().config.model
fallback_models = settings.config.fallback_models fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list): if not isinstance(fallback_models, list):
fallback_models = [fallback_models] fallback_models = [fallback_models]
all_models = [model] + fallback_models all_models = [model] + fallback_models

View File

@ -1,8 +1,7 @@
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model from tiktoken import encoding_for_model
from pr_agent.algo import MAX_TOKENS from pr_agent.config_loader import get_settings
from pr_agent.config_loader import settings
class TokenHandler: class TokenHandler:
@ -10,9 +9,12 @@ class TokenHandler:
A class for handling tokens in the context of a pull request. A class for handling tokens in the context of a pull request.
Attributes: Attributes:
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the number of tokens in them. - encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the pr_agent.algo module. number of tokens in them.
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens method. - limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the
pr_agent.algo module.
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens
method.
""" """
def __init__(self, pr, vars: dict, system, user): def __init__(self, pr, vars: dict, system, user):
@ -25,7 +27,7 @@ class TokenHandler:
- system: The system string. - system: The system string.
- user: The user string. - user: The user string.
""" """
self.encoder = encoding_for_model(settings.config.model) self.encoder = encoding_for_model(get_settings().config.model)
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):

View File

@ -1,14 +1,24 @@
from __future__ import annotations from __future__ import annotations
import difflib import difflib
from datetime import datetime
import json import json
import logging import logging
import re import re
import textwrap import textwrap
from datetime import datetime
from typing import Any, List
from pr_agent.config_loader import settings from starlette_context import context
from pr_agent.config_loader import get_settings, global_settings
def get_setting(key: str) -> Any:
try:
key = key.upper()
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
except Exception:
return global_settings.get(key, None)
def convert_to_markdown(output_data: dict) -> str: def convert_to_markdown(output_data: dict) -> str:
""" """
@ -96,12 +106,16 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
- data: A dictionary containing the parsed JSON data. - data: A dictionary containing the parsed JSON data.
The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion. The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message. If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the
If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON message by parsing until the last valid code suggestion. message.
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines. If code_suggestions is True and the JSON message contains code suggestions, the function tries to fix the JSON
message by parsing until the last valid code suggestion.
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or
newlines.
It tries to parse the JSON message with the closing bracket and checks if it is valid. It tries to parse the JSON message with the closing bracket and checks if it is valid.
If the JSON message is valid, the parsed JSON data is returned. If the JSON message is valid, the parsed JSON data is returned.
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached. If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON
message is obtained or the maximum number of iterations is reached.
If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned. If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
""" """
@ -183,7 +197,8 @@ def convert_str_to_datetime(date_str):
def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str: def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
""" """
Generate a patch for a modified file by comparing the original content of the file with the new content provided as input. Generate a patch for a modified file by comparing the original content of the file with the new content provided as
input.
Args: Args:
file: The file object for which the patch needs to be generated. file: The file object for which the patch needs to be generated.
@ -198,16 +213,55 @@ def load_large_diff(file, new_file_content_str: str, original_file_content_str:
None. None.
Additional Information: Additional Information:
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it as output. - If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating that the file was modified but no patch was found, and a patch is manually created. as output.
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating
that the file was modified but no patch was found, and a patch is manually created.
""" """
if not patch: # to Do - also add condition for file extension if not patch: # to Do - also add condition for file extension
try: try:
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True)) new_file_content_str.splitlines(keepends=True))
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.") logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
patch = ''.join(diff) patch = ''.join(diff)
except Exception: except Exception:
pass pass
return patch return patch
def update_settings_from_args(args: List[str]) -> List[str]:
"""
Update the settings of the Dynaconf object based on the arguments passed to the function.
Args:
args: A list of arguments passed to the function.
Example args: ['--pr_code_suggestions.extra_instructions="be funny',
'--pr_code_suggestions.num_code_suggestions=3']
Returns:
None
Raises:
ValueError: If the argument is not in the correct format.
"""
other_args = []
if args:
for arg in args:
arg = arg.strip()
if arg.startswith('--'):
arg = arg.strip('-').strip()
vals = arg.split('=')
if len(vals) != 2:
logging.error(f'Invalid argument format: {arg}')
other_args.append(arg)
continue
key, value = vals
key = key.strip().upper()
value = value.strip()
get_settings().set(key, value)
logging.info(f'Updated setting {key} to: "{value}"')
else:
other_args.append(arg)
return other_args

View File

@ -3,15 +3,11 @@ import asyncio
import logging import logging
import os import os
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.agent.pr_agent import PRAgent, commands
from pr_agent.tools.pr_description import PRDescription from pr_agent.config_loader import get_settings
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
def run(args=None): def run(inargs=None):
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage= parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
"""\ """\
Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>]. Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>].
@ -29,81 +25,21 @@ describe / describe_pr - Modify the PR title and description based on the PR's c
improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit. improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
reflect - Ask the PR author questions about the PR. reflect - Ask the PR author questions about the PR.
update_changelog - Update the changelog based on the PR's contents. update_changelog - Update the changelog based on the PR's contents.
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
For example: '- cli.py --pr-url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""") """)
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True) parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr', parser.add_argument('command', type=str, help='The', choices=commands, default='review')
'ask', 'ask_question',
'describe', 'describe_pr',
'improve', 'improve_code',
'reflect', 'review_after_reflect',
'update_changelog'],
default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
args = parser.parse_args(args) args = parser.parse_args(inargs)
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
command = args.command.lower() command = args.command.lower()
commands = { get_settings().set("CONFIG.CLI_MODE", True)
'ask': _handle_ask_command, result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest)))
'ask_question': _handle_ask_command, if not result:
'describe': _handle_describe_command,
'describe_pr': _handle_describe_command,
'improve': _handle_improve_command,
'improve_code': _handle_improve_command,
'review': _handle_review_command,
'review_pr': _handle_review_command,
'reflect': _handle_reflect_command,
'review_after_reflect': _handle_review_after_reflect_command,
'update_changelog': _handle_update_changelog,
}
if command in commands:
commands[command](args.pr_url, args.rest)
else:
print(f"Unknown command: {command}")
parser.print_help() parser.print_help()
def _handle_ask_command(pr_url: str, rest: list):
if len(rest) == 0:
print("Please specify a question")
return
print(f"Question: {' '.join(rest)} about PR {pr_url}")
reviewer = PRQuestions(pr_url, rest)
asyncio.run(reviewer.answer())
def _handle_describe_command(pr_url: str, rest: list):
print(f"PR description: {pr_url}")
reviewer = PRDescription(pr_url)
asyncio.run(reviewer.describe())
def _handle_improve_command(pr_url: str, rest: list):
print(f"PR code suggestions: {pr_url}")
reviewer = PRCodeSuggestions(pr_url)
asyncio.run(reviewer.suggest())
def _handle_review_command(pr_url: str, rest: list):
print(f"Reviewing PR: {pr_url}")
reviewer = PRReviewer(pr_url, cli_mode=True, args=rest)
asyncio.run(reviewer.review())
def _handle_reflect_command(pr_url: str, rest: list):
print(f"Asking the PR author questions: {pr_url}")
reviewer = PRInformationFromUser(pr_url)
asyncio.run(reviewer.generate_questions())
def _handle_review_after_reflect_command(pr_url: str, rest: list):
print(f"Processing author's answers and sending review: {pr_url}")
reviewer = PRReviewer(pr_url, cli_mode=True, is_answer=True)
asyncio.run(reviewer.review())
def _handle_update_changelog(pr_url: str, rest: list):
print(f"Updating changlog for: {pr_url}")
reviewer = PRUpdateChangelog(pr_url, cli_mode=True, args=rest)
asyncio.run(reviewer.update_changelog())
if __name__ == '__main__': if __name__ == '__main__':
run() run()

View File

@ -3,28 +3,36 @@ from pathlib import Path
from typing import Optional from typing import Optional
from dynaconf import Dynaconf from dynaconf import Dynaconf
from starlette_context import context
PR_AGENT_TOML_KEY = 'pr-agent' PR_AGENT_TOML_KEY = 'pr-agent'
current_dir = dirname(abspath(__file__)) current_dir = dirname(abspath(__file__))
settings = Dynaconf( global_settings = Dynaconf(
envvar_prefix=False, envvar_prefix=False,
merge_enabled=True, merge_enabled=True,
settings_files=[join(current_dir, f) for f in [ settings_files=[join(current_dir, f) for f in [
"settings/.secrets.toml", "settings/.secrets.toml",
"settings/configuration.toml", "settings/configuration.toml",
"settings/language_extensions.toml", "settings/language_extensions.toml",
"settings/pr_reviewer_prompts.toml", "settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml", "settings/pr_questions_prompts.toml",
"settings/pr_description_prompts.toml", "settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml", "settings/pr_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml", "settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog.toml", "settings/pr_update_changelog_prompts.toml",
"settings_prod/.secrets.toml" "settings_prod/.secrets.toml"
]] ]]
) )
def get_settings():
try:
return context["settings"]
except Exception:
return global_settings
# Add local configuration from pyproject.toml of the project being reviewed # Add local configuration from pyproject.toml of the project being reviewed
def _find_repository_root() -> Path: def _find_repository_root() -> Path:
""" """
@ -39,6 +47,7 @@ def _find_repository_root() -> Path:
cwd = cwd.parent cwd = cwd.parent
return None return None
def _find_pyproject() -> Optional[Path]: def _find_pyproject() -> Optional[Path]:
""" """
Search for file pyproject.toml in the repository root. Search for file pyproject.toml in the repository root.
@ -49,6 +58,7 @@ def _find_pyproject() -> Optional[Path]:
return pyproject if pyproject.is_file() else None return pyproject if pyproject.is_file() else None
return None return None
pyproject_path = _find_pyproject() pyproject_path = _find_pyproject()
if pyproject_path is not None: if pyproject_path is not None:
settings.load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}') get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')

View File

@ -1,4 +1,4 @@
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
from pr_agent.git_providers.github_provider import GithubProvider from pr_agent.git_providers.github_provider import GithubProvider
from pr_agent.git_providers.gitlab_provider import GitLabProvider from pr_agent.git_providers.gitlab_provider import GitLabProvider
@ -13,7 +13,7 @@ _GIT_PROVIDERS = {
def get_git_provider(): def get_git_provider():
try: try:
provider_id = settings.config.git_provider provider_id = get_settings().config.git_provider
except AttributeError as e: except AttributeError as e:
raise ValueError("git_provider is a required attribute in the configuration file") from e raise ValueError("git_provider is a required attribute in the configuration file") from e
if provider_id not in _GIT_PROVIDERS: if provider_id not in _GIT_PROVIDERS:

View File

@ -5,15 +5,14 @@ from urllib.parse import urlparse
import requests import requests
from atlassian.bitbucket import Cloud from atlassian.bitbucket import Cloud
from pr_agent.config_loader import settings from ..config_loader import get_settings
from .git_provider import FilePatchInfo from .git_provider import FilePatchInfo
class BitbucketProvider: class BitbucketProvider:
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
s = requests.Session() s = requests.Session()
s.headers['Authorization'] = f'Bearer {settings.get("BITBUCKET.BEARER_TOKEN", None)}' s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
self.bitbucket_client = Cloud(session=s) self.bitbucket_client = Cloud(session=s)
self.workspace_slug = None self.workspace_slug = None
@ -121,3 +120,6 @@ class BitbucketProvider:
def _get_pr_file_content(self, remote_link: str): def _get_pr_file_content(self, remote_link: str):
return "" return ""
def get_commit_messages(self):
return "" # not implemented yet

View File

@ -5,19 +5,22 @@ from urllib.parse import urlparse
from github import AppAuthentication, Auth, Github, GithubException from github import AppAuthentication, Auth, Github, GithubException
from retry import retry from retry import retry
from starlette_context import context
from pr_agent.config_loader import settings
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..config_loader import get_settings
from ..servers.utils import RateLimitExceeded from ..servers.utils import RateLimitExceeded
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
class GithubProvider(GitProvider): class GithubProvider(GitProvider):
def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)): def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)):
self.repo_obj = None self.repo_obj = None
self.installation_id = settings.get("GITHUB.INSTALLATION_ID") try:
self.installation_id = context.get("installation_id", None)
except Exception:
self.installation_id = None
self.github_client = self._get_github_client() self.github_client = self._get_github_client()
self.repo = None self.repo = None
self.pr_num = None self.pr_num = None
@ -81,7 +84,7 @@ class GithubProvider(GitProvider):
return self.pr.get_files() return self.pr.get_files()
@retry(exceptions=RateLimitExceeded, @retry(exceptions=RateLimitExceeded,
tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
try: try:
files = self.get_files() files = self.get_files()
@ -114,7 +117,7 @@ class GithubProvider(GitProvider):
# self.pr.create_issue_comment(pr_comment) # self.pr.create_issue_comment(pr_comment)
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not settings.config.publish_output_progress: if is_temporary and not get_settings().config.publish_output_progress:
logging.debug(f"Skipping publish_comment for temporary comment: {pr_comment}") logging.debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
return return
response = self.pr.create_issue_comment(pr_comment) response = self.pr.create_issue_comment(pr_comment)
@ -145,7 +148,7 @@ class GithubProvider(GitProvider):
position = i position = i
break break
if position == -1: if position == -1:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}") logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
subject_type = "FILE" subject_type = "FILE"
else: else:
@ -170,13 +173,13 @@ class GithubProvider(GitProvider):
relevant_lines_end = suggestion['relevant_lines_end'] relevant_lines_end = suggestion['relevant_lines_end']
if not relevant_lines_start or relevant_lines_start == -1: if not relevant_lines_start or relevant_lines_start == -1:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.exception( logging.exception(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
continue continue
if relevant_lines_end < relevant_lines_start: if relevant_lines_end < relevant_lines_start:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.exception(f"Failed to publish code suggestion, " logging.exception(f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and " f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}") f"relevant_lines_start is {relevant_lines_start}")
@ -203,7 +206,7 @@ class GithubProvider(GitProvider):
self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list) self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list)
return True return True
except Exception as e: except Exception as e:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.error(f"Failed to publish code suggestion, error: {e}") logging.error(f"Failed to publish code suggestion, error: {e}")
return False return False
@ -237,7 +240,7 @@ class GithubProvider(GitProvider):
return self.github_user_id return self.github_user_id
def get_notifications(self, since: datetime): def get_notifications(self, since: datetime):
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user") deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type != 'user': if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications") raise ValueError("Deployment mode must be set to 'user' to get notifications")
@ -278,12 +281,12 @@ class GithubProvider(GitProvider):
return repo_name, pr_number return repo_name, pr_number
def _get_github_client(self): def _get_github_client(self):
deployment_type = settings.get("GITHUB.DEPLOYMENT_TYPE", "user") deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type == 'app': if deployment_type == 'app':
try: try:
private_key = settings.github.private_key private_key = get_settings().github.private_key
app_id = settings.github.app_id app_id = get_settings().github.app_id
except AttributeError as e: except AttributeError as e:
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
if not self.installation_id: if not self.installation_id:
@ -294,7 +297,7 @@ class GithubProvider(GitProvider):
if deployment_type == 'user': if deployment_type == 'user':
try: try:
token = settings.github.user_token token = get_settings().github.user_token
except AttributeError as e: except AttributeError as e:
raise ValueError( raise ValueError(
"GitHub token is required when using user deployment. See: " "GitHub token is required when using user deployment. See: "
@ -323,7 +326,9 @@ class GithubProvider(GitProvider):
def publish_labels(self, pr_types): def publish_labels(self, pr_types):
try: try:
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", "Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9", "Other": "d1bcf9"} label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5",
"Refactoring": "bfdadc", "Enhancement": "bfd4f2", "Documentation": "d4c5f9",
"Other": "d1bcf9"}
post_parameters = [] post_parameters = []
for p in pr_types: for p in pr_types:
color = label_color_map.get(p, "d1bcf9") # default to "Other" color color = label_color_map.get(p, "d1bcf9") # default to "Other" color
@ -340,3 +345,18 @@ class GithubProvider(GitProvider):
except Exception as e: except Exception as e:
logging.exception(f"Failed to get labels, error: {e}") logging.exception(f"Failed to get labels, error: {e}")
return [] return []
def get_commit_messages(self) -> str:
"""
Retrieves the commit messages of a pull request.
Returns:
str: A string containing the commit messages of the pull request.
"""
try:
commit_list = self.pr.get_commits()
commit_messages = [commit.commit.message for commit in commit_list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)])
except:
commit_messages_str = ""
return commit_messages_str

View File

@ -6,19 +6,20 @@ from urllib.parse import urlparse
import gitlab import gitlab
from gitlab import GitlabGetError from gitlab import GitlabGetError
from pr_agent.config_loader import settings
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
logger = logging.getLogger()
class GitLabProvider(GitProvider): class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = settings.get("GITLAB.URL", None) gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
raise ValueError("GitLab URL is not set in the config file") raise ValueError("GitLab URL is not set in the config file")
gitlab_access_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None) gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_access_token: if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file") raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab( self.gl = gitlab.Gitlab(
@ -48,7 +49,12 @@ class GitLabProvider(GitProvider):
def _set_merge_request(self, merge_request_url: str): def _set_merge_request(self, merge_request_url: str):
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url) self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
self.mr = self._get_merge_request() self.mr = self._get_merge_request()
self.last_diff = self.mr.diffs.list()[-1] try:
self.last_diff = self.mr.diffs.list(get_all=True)[-1]
except IndexError as e:
logger.error(f"Could not get diff for merge request {self.id_mr}")
raise ValueError(f"Could not get diff for merge request {self.id_mr}") from e
def _get_pr_file_content(self, file_path: str, branch: str) -> str: def _get_pr_file_content(self, file_path: str, branch: str) -> str:
try: try:
@ -133,32 +139,42 @@ class GitLabProvider(GitProvider):
else: else:
pos_obj['new_line'] = target_line_no - 1 pos_obj['new_line'] = target_line_no - 1
pos_obj['old_line'] = source_line_no - 1 pos_obj['old_line'] = source_line_no - 1
logging.debug(f"Creating comment in {self.id_mr} with body {body} and position {pos_obj}")
self.mr.discussions.create({'body': body, self.mr.discussions.create({'body': body,
'position': pos_obj}) 'position': pos_obj})
def publish_code_suggestions(self, code_suggestions: list): def publish_code_suggestions(self, code_suggestions: list):
for suggestion in code_suggestions: for suggestion in code_suggestions:
body = suggestion['body'] try:
relevant_file = suggestion['relevant_file'] body = suggestion['body']
relevant_lines_start = suggestion['relevant_lines_start'] relevant_file = suggestion['relevant_file']
relevant_lines_end = suggestion['relevant_lines_end'] relevant_lines_start = suggestion['relevant_lines_start']
relevant_lines_end = suggestion['relevant_lines_end']
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
target_file = None target_file = None
for file in self.diff_files: for file in self.diff_files:
if file.filename == relevant_file:
if file.filename == relevant_file: if file.filename == relevant_file:
target_file = file if file.filename == relevant_file:
break target_file = file
range = relevant_lines_end - relevant_lines_start + 1 break
body = body.replace('```suggestion', f'```suggestion:-0+{range}') range = relevant_lines_end - relevant_lines_start # no need to add 1
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
lines = target_file.head_file.splitlines()
relevant_line_in_file = lines[relevant_lines_start - 1]
lines = target_file.head_file.splitlines() # edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(target_file,
relevant_line_in_file = lines[relevant_lines_start - 1] # relevant_line_in_file)
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(target_file, # for code suggestions, we want to edit the new code
relevant_line_in_file) source_line_no = None
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, target_line_no = relevant_lines_start + 1
target_file, target_line_no) found = True
edit_type = 'addition'
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
target_file, target_line_no)
except Exception as e:
logging.exception(f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}")
def search_line(self, relevant_file, relevant_line_in_file): def search_line(self, relevant_file, relevant_line_in_file):
target_file = None target_file = None
@ -281,3 +297,17 @@ class GitLabProvider(GitProvider):
def get_labels(self): def get_labels(self):
return self.mr.labels return self.mr.labels
def get_commit_messages(self) -> str:
"""
Retrieves the commit messages of a pull request.
Returns:
str: A string containing the commit messages of the pull request.
"""
try:
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
except:
commit_messages_str = ""
return commit_messages_str

View File

@ -5,7 +5,7 @@ from typing import List
from git import Repo from git import Repo
from pr_agent.config_loader import _find_repository_root, settings from pr_agent.config_loader import _find_repository_root, get_settings
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
@ -38,12 +38,12 @@ class LocalGitProvider(GitProvider):
self._prepare_repo() self._prepare_repo()
self.diff_files = None self.diff_files = None
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files()) self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
self.description_path = settings.get('local.description_path') \ self.description_path = get_settings().get('local.description_path') \
if settings.get('local.description_path') is not None else self.repo_path / 'description.md' if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
self.review_path = settings.get('local.review_path') \ self.review_path = get_settings().get('local.review_path') \
if settings.get('local.review_path') is not None else self.repo_path / 'review.md' if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
# inline code comments are not supported for local git repositories # inline code comments are not supported for local git repositories
settings.pr_reviewer.inline_code_comments = False get_settings().pr_reviewer.inline_code_comments = False
def _prepare_repo(self): def _prepare_repo(self):
""" """

View File

@ -3,7 +3,7 @@ import json
import os import os
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
@ -30,11 +30,11 @@ async def run_action():
return return
# Set the environment variables in the settings # Set the environment variables in the settings
settings.set("OPENAI.KEY", OPENAI_KEY) get_settings().set("OPENAI.KEY", OPENAI_KEY)
if OPENAI_ORG: if OPENAI_ORG:
settings.set("OPENAI.ORG", OPENAI_ORG) get_settings().set("OPENAI.ORG", OPENAI_ORG)
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN) get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
settings.set("GITHUB.DEPLOYMENT_TYPE", "user") get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user")
# Load the event payload # Load the event payload
try: try:
@ -50,7 +50,7 @@ async def run_action():
if action in ["opened", "reopened"]: if action in ["opened", "reopened"]:
pr_url = event_payload.get("pull_request", {}).get("url") pr_url = event_payload.get("pull_request", {}).get("url")
if pr_url: if pr_url:
await PRReviewer(pr_url).review() await PRReviewer(pr_url).run()
# Handle issue comment event # Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment": elif GITHUB_EVENT_NAME == "issue_comment":

View File

@ -1,12 +1,16 @@
from typing import Dict, Any import copy
import logging import logging
import sys import sys
from typing import Any, Dict
import uvicorn import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
from starlette.middleware import Middleware
from starlette_context import context
from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.servers.utils import verify_signature from pr_agent.servers.utils import verify_signature
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@ -17,27 +21,40 @@ router = APIRouter()
async def handle_github_webhooks(request: Request, response: Response): async def handle_github_webhooks(request: Request, response: Response):
""" """
Receives and processes incoming GitHub webhook requests. Receives and processes incoming GitHub webhook requests.
Verifies the request signature, parses the request body, and passes it to the handle_request function for further processing. Verifies the request signature, parses the request body, and passes it to the handle_request function for further
processing.
""" """
logging.debug("Received a GitHub webhook") logging.debug("Received a GitHub webhook")
body = await get_body(request)
logging.debug(f'Request body:\n{body}')
installation_id = body.get("installation", {}).get("id")
context["installation_id"] = installation_id
context["settings"] = copy.deepcopy(global_settings)
return await handle_request(body)
@router.post("/api/v1/marketplace_webhooks")
async def handle_marketplace_webhooks(request: Request, response: Response):
body = await get_body(request)
logging.info(f'Request body:\n{body}')
async def get_body(request):
try: try:
body = await request.json() body = await request.json()
except Exception as e: except Exception as e:
logging.error("Error parsing request body", e) logging.error("Error parsing request body", e)
raise HTTPException(status_code=400, detail="Error parsing request body") from e raise HTTPException(status_code=400, detail="Error parsing request body") from e
body_bytes = await request.body() body_bytes = await request.body()
signature_header = request.headers.get('x-hub-signature-256', None) signature_header = request.headers.get('x-hub-signature-256', None)
webhook_secret = getattr(get_settings().github, 'webhook_secret', None)
webhook_secret = getattr(settings.github, 'webhook_secret', None)
if webhook_secret: if webhook_secret:
verify_signature(body_bytes, webhook_secret, signature_header) verify_signature(body_bytes, webhook_secret, signature_header)
return body
logging.debug(f'Request body:\n{body}')
return await handle_request(body)
async def handle_request(body: Dict[str, Any]): async def handle_request(body: Dict[str, Any]):
@ -48,8 +65,8 @@ async def handle_request(body: Dict[str, Any]):
body: The request body. body: The request body.
""" """
action = body.get("action") action = body.get("action")
installation_id = body.get("installation", {}).get("id") if not action:
settings.set("GITHUB.INSTALLATION_ID", installation_id) return {}
agent = PRAgent() agent = PRAgent()
if action == 'created': if action == 'created':
@ -65,7 +82,7 @@ async def handle_request(body: Dict[str, Any]):
api_url = pull_request.get("url") api_url = pull_request.get("url")
await agent.handle_request(api_url, comment_body) await agent.handle_request(api_url, comment_body)
elif action in ["opened"] or 'reopened' in action: elif action == "opened" or 'reopened' in action:
pull_request = body.get("pull_request") pull_request = body.get("pull_request")
if not pull_request: if not pull_request:
return {} return {}
@ -84,8 +101,9 @@ async def root():
def start(): def start():
# Override the deployment type to app # Override the deployment type to app
settings.set("GITHUB.DEPLOYMENT_TYPE", "app") get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app")
app = FastAPI() middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware)
app.include_router(router) app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=3000) uvicorn.run(app, host="0.0.0.0", port=3000)

View File

@ -6,7 +6,7 @@ from datetime import datetime, timezone
import aiohttp import aiohttp
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.servers.help import bot_help_text from pr_agent.servers.help import bot_help_text
@ -38,8 +38,8 @@ async def polling_loop():
agent = PRAgent() agent = PRAgent()
try: try:
deployment_type = settings.github.deployment_type deployment_type = get_settings().github.deployment_type
token = settings.github.user_token token = get_settings().github.user_token
except AttributeError: except AttributeError:
deployment_type = 'none' deployment_type = 'none'
token = None token = None

View File

@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse
from starlette.background import BackgroundTasks from starlette.background import BackgroundTasks
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
app = FastAPI() app = FastAPI()
router = APIRouter() router = APIRouter()
@ -29,13 +29,13 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
def start(): def start():
gitlab_url = settings.get("GITLAB.URL", None) gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
raise ValueError("GITLAB.URL is not set") raise ValueError("GITLAB.URL is not set")
gitlab_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None) gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_token: if not gitlab_token:
raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set") raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set")
settings.config.git_provider = "gitlab" get_settings().config.git_provider = "gitlab"
app = FastAPI() app = FastAPI()
app.include_router(router) app.include_router(router)

View File

@ -1,9 +1,10 @@
commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \ commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \
"considers changes since the last review, include the '-i' option.\n" \ "considers changes since the last review, include the '-i' option.\n" \
"> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \ "> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \
"> **/improve**: Suggest improvements to the code in the PR. " \ "> **/improve**: Suggest improvements to the code in the PR. \n" \
"These will be provided as pull request comments, ready to commit.\n" \ "> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n\n" \
"> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n" ">To edit any configuration parameter from 'configuration.toml', add --config_path=new_value\n" \
">For example: /review --pr_reviewer.extra_instructions=\"focus on the file: ...\" " \
def bot_help_text(user: str): def bot_help_text(user: str):

View File

@ -7,7 +7,7 @@ publish_output_progress=true
verbosity_level=0 # 0,1,2 verbosity_level=0 # 0,1,2
use_extra_bad_extensions=false use_extra_bad_extensions=false
[pr_reviewer] [pr_reviewer] # /review #
require_focused_review=true require_focused_review=true
require_score_review=false require_score_review=false
require_tests_review=true require_tests_review=true
@ -15,17 +15,21 @@ require_security_review=true
num_code_suggestions=0 num_code_suggestions=0
inline_code_comments = true inline_code_comments = true
ask_and_reflect=false ask_and_reflect=false
extra_instructions = ""
[pr_description] [pr_description] # /describe #
publish_description_as_comment=false publish_description_as_comment=false
extra_instructions = ""
[pr_questions] [pr_questions] # /ask #
[pr_code_suggestions] [pr_code_suggestions] # /improve #
num_code_suggestions=4 num_code_suggestions=4
extra_instructions = ""
[pr_update_changelog] [pr_update_changelog] # /update_changelog #
push_changelog_changes=false push_changelog_changes=false
extra_instructions = ""
[github] [github]
# The type of deployment to create. Valid values are 'app' or 'user'. # The type of deployment to create. Valid values are 'app' or 'user'.

View File

@ -9,6 +9,12 @@ Your task is to provide meaningfull non-trivial code suggestions to improve the
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines). - Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines).
- Don't output line numbers in the 'improved code' snippets. - Don't output line numbers in the 'improved code' snippets.
{%- if extra_instructions %}
Extra instructions from the user:
{{ extra_instructions }}
{% endif %}
You must use the following JSON schema to format your answer: You must use the following JSON schema to format your answer:
```json ```json
{ {

View File

@ -3,6 +3,12 @@ system="""You are CodiumAI-PR-Reviewer, a language model designed to review git
Your task is to provide full description of the PR content. Your task is to provide full description of the PR content.
- Make sure not to focus the new PR code (the '+' lines). - Make sure not to focus the new PR code (the '+' lines).
{%- if extra_instructions %}
Extra instructions from the user:
{{ extra_instructions }}
{% endif %}
You must use the following JSON schema to format your answer: You must use the following JSON schema to format your answer:
```json ```json
{ {
@ -30,8 +36,14 @@ Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'desc
user="""PR Info: user="""PR Info:
Branch: '{{branch}}' Branch: '{{branch}}'
{%- if language %} {%- if language %}
Main language: {{language}} Main language: {{language}}
{%- endif %} {%- endif %}
{%- if commit_messages_str %}
Commit messages:
{{commit_messages_str}}
{%- endif %}
The PR Git Diff: The PR Git Diff:

View File

@ -8,6 +8,12 @@ Your task is to provide constructive and concise feedback for the PR, and also p
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines). - Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines).
{%- endif %} {%- endif %}
{%- if extra_instructions %}
Extra instructions from the user:
{{ extra_instructions }}
{% endif %}
You must use the following JSON schema to format your answer: You must use the following JSON schema to format your answer:
```json ```json
{ {

View File

@ -4,6 +4,12 @@ Your task is to update the CHANGELOG.md file of the project, to shortly summariz
- The output should match the existing CHANGELOG.md format, style and conventions, so it will look like a natural part of the file. For example, if previous changes were summarized in a single line, you should do the same. - The output should match the existing CHANGELOG.md format, style and conventions, so it will look like a natural part of the file. For example, if previous changes were summarized in a single line, you should do the same.
- Don't repeat previous changes. Generate only new content, that is not already in the CHANGELOG.md file. - Don't repeat previous changes. Generate only new content, that is not already in the CHANGELOG.md file.
- Be general, and avoid specific details, files, etc. The output should be minimal, no more than 3-4 short lines. Ignore non-relevant subsections. - Be general, and avoid specific details, files, etc. The output should be minimal, no more than 3-4 short lines. Ignore non-relevant subsections.
{%- if extra_instructions %}
Extra instructions from the user:
{{ extra_instructions }}
{%- endif %}
""" """
user="""PR Info: user="""PR Info:

View File

@ -9,18 +9,19 @@ from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import try_fix_json from pr_agent.algo.utils import try_fix_json
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import BitbucketProvider, get_git_provider from pr_agent.git_providers import BitbucketProvider, get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
class PRCodeSuggestions: class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False): def __init__(self, pr_url: str, cli_mode=False, args: list = None):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
@ -31,23 +32,24 @@ class PRCodeSuggestions:
"description": self.git_provider.get_pr_description(), "description": self.git_provider.get_pr_description(),
"language": self.main_language, "language": self.main_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
'num_code_suggestions': settings.pr_code_suggestions.num_code_suggestions, "num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, self.vars,
settings.pr_code_suggestions_prompt.system, get_settings().pr_code_suggestions_prompt.system,
settings.pr_code_suggestions_prompt.user) get_settings().pr_code_suggestions_prompt.user)
async def suggest(self): async def run(self):
assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now" assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
logging.info('Generating code suggestions for PR...') logging.info('Generating code suggestions for PR...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True) self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR review...') logging.info('Preparing PR review...')
data = self._prepare_pr_code_suggestions() data = self._prepare_pr_code_suggestions()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing PR review...') logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...') logging.info('Pushing inline code comments...')
@ -68,9 +70,9 @@ class PRCodeSuggestions:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_code_suggestions_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_code_suggestions_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -83,7 +85,7 @@ class PRCodeSuggestions:
try: try:
data = json.loads(review) data = json.loads(review)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse json response: {review}") logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review, code_suggestions=True) data = try_fix_json(review, code_suggestions=True)
return data return data
@ -91,22 +93,28 @@ class PRCodeSuggestions:
def push_inline_code_suggestions(self, data): def push_inline_code_suggestions(self, data):
code_suggestions = [] code_suggestions = []
for d in data['Code suggestions']: for d in data['Code suggestions']:
if settings.config.verbosity_level >= 2: try:
logging.info(f"suggestion: {d}") if get_settings().config.verbosity_level >= 2:
relevant_file = d['relevant file'].strip() logging.info(f"suggestion: {d}")
relevant_lines_str = d['relevant lines'].strip() relevant_file = d['relevant file'].strip()
relevant_lines_start = int(relevant_lines_str.split('-')[0]) # absolute position relevant_lines_str = d['relevant lines'].strip()
relevant_lines_end = int(relevant_lines_str.split('-')[-1]) if ',' in relevant_lines_str: # handling 'relevant lines': '181, 190' or '178-184, 188-194'
content = d['suggestion content'] relevant_lines_str = relevant_lines_str.split(',')[0]
new_code_snippet = d['improved code'] relevant_lines_start = int(relevant_lines_str.split('-')[0]) # absolute position
relevant_lines_end = int(relevant_lines_str.split('-')[-1])
content = d['suggestion content']
new_code_snippet = d['improved code']
if new_code_snippet: if new_code_snippet:
new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet) new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet)
body = f"**Suggestion:** {content}\n```suggestion\n" + new_code_snippet + "\n```" body = f"**Suggestion:** {content}\n```suggestion\n" + new_code_snippet + "\n```"
code_suggestions.append({'body': body,'relevant_file': relevant_file, code_suggestions.append({'body': body, 'relevant_file': relevant_file,
'relevant_lines_start': relevant_lines_start, 'relevant_lines_start': relevant_lines_start,
'relevant_lines_end': relevant_lines_end}) 'relevant_lines_end': relevant_lines_end})
except Exception:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse suggestion: {d}")
self.git_provider.publish_code_suggestions(code_suggestions) self.git_provider.publish_code_suggestions(code_suggestions)
@ -127,7 +135,8 @@ class PRCodeSuggestions:
if delta_spaces > 0: if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n') new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
except Exception as e: except Exception as e:
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}") logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
return new_code_snippet return new_code_snippet

View File

@ -1,31 +1,33 @@
import copy import copy
import json import json
import logging import logging
from typing import Tuple, List from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
class PRDescription: class PRDescription:
def __init__(self, pr_url: str): def __init__(self, pr_url: str, args: list = None):
""" """
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
Args: Args:
pr_url (str): The URL of the pull request. pr_url (str): The URL of the pull request.
args (list, optional): List of arguments passed to the PRDescription class. Defaults to None.
""" """
# Initialize the git provider and main PR language # Initialize the git provider and main PR language
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
commit_messages_str = self.git_provider.get_commit_messages()
# Initialize the AI handler # Initialize the AI handler
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
@ -37,26 +39,28 @@ class PRDescription:
"description": self.git_provider.get_pr_description(), "description": self.git_provider.get_pr_description(),
"language": self.main_pr_language, "language": self.main_pr_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"extra_instructions": get_settings().pr_description.extra_instructions,
"commit_messages_str": commit_messages_str
} }
# Initialize the token handler # Initialize the token handler
self.token_handler = TokenHandler( self.token_handler = TokenHandler(
self.git_provider.pr, self.git_provider.pr,
self.vars, self.vars,
settings.pr_description_prompt.system, get_settings().pr_description_prompt.system,
settings.pr_description_prompt.user, get_settings().pr_description_prompt.user,
) )
# Initialize patches_diff and prediction attributes # Initialize patches_diff and prediction attributes
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
async def describe(self): async def run(self):
""" """
Generates a PR description using an AI model and publishes it to the PR. Generates a PR description using an AI model and publishes it to the PR.
""" """
logging.info('Generating a PR description...') logging.info('Generating a PR description...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True) self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
@ -64,9 +68,9 @@ class PRDescription:
logging.info('Preparing answer...') logging.info('Preparing answer...')
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer() pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing answer...') logging.info('Pushing answer...')
if settings.pr_description.publish_description_as_comment: if get_settings().pr_description.publish_description_as_comment:
self.git_provider.publish_comment(markdown_text) self.git_provider.publish_comment(markdown_text)
else: else:
self.git_provider.publish_description(pr_title, pr_body) self.git_provider.publish_description(pr_title, pr_body)
@ -112,10 +116,10 @@ class PRDescription:
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
@ -160,13 +164,13 @@ class PRDescription:
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format, # Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
# except for the items containing the word 'walkthrough' # except for the items containing the word 'walkthrough'
for key, value in data.items(): for key, value in data.items():
pr_body += f"{key}:\n" pr_body += f"## {key}:\n"
if 'walkthrough' in key.lower(): if 'walkthrough' in key.lower():
pr_body += f"{value}\n" pr_body += f"{value}\n"
else: else:
pr_body += f"**{value}**\n\n___\n" pr_body += f"{value}\n\n___\n"
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}") logging.info(f"title:\n{title}\n{pr_body}")
return title, pr_body, pr_types, markdown_text return title, pr_body, pr_types, markdown_text

View File

@ -6,15 +6,13 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
class PRInformationFromUser: class PRInformationFromUser:
def __init__(self, pr_url: str): def __init__(self, pr_url: str, args: list = None):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
@ -29,19 +27,19 @@ class PRInformationFromUser:
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, self.vars,
settings.pr_information_from_user_prompt.system, get_settings().pr_information_from_user_prompt.system,
settings.pr_information_from_user_prompt.user) get_settings().pr_information_from_user_prompt.user)
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
async def generate_questions(self): async def run(self):
logging.info('Generating question to the user...') logging.info('Generating question to the user...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing questions...", is_temporary=True) self.git_provider.publish_comment("Preparing questions...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing questions...') logging.info('Preparing questions...')
pr_comment = self._prepare_pr_answer() pr_comment = self._prepare_pr_answer()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing questions...') logging.info('Pushing questions...')
self.git_provider.publish_comment(pr_comment) self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
@ -57,9 +55,9 @@ class PRInformationFromUser:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_information_from_user_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_information_from_user_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_information_from_user_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -68,7 +66,7 @@ class PRInformationFromUser:
def _prepare_pr_answer(self) -> str: def _prepare_pr_answer(self) -> str:
model_output = self.prediction.strip() model_output = self.prediction.strip()
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"answer_str:\n{model_output}") logging.info(f"answer_str:\n{model_output}")
answer_str = f"{model_output}\n\n Please respond to the questions above in the following format:\n\n" +\ answer_str = f"{model_output}\n\n Please respond to the questions above in the following format:\n\n" +\
"\n>/answer\n>1) ...\n>2) ...\n>...\n" "\n>/answer\n>1) ...\n>2) ...\n>...\n"

View File

@ -6,7 +6,7 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
@ -30,8 +30,8 @@ class PRQuestions:
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, self.vars,
settings.pr_questions_prompt.system, get_settings().pr_questions_prompt.system,
settings.pr_questions_prompt.user) get_settings().pr_questions_prompt.user)
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
@ -42,14 +42,14 @@ class PRQuestions:
question_str = "" question_str = ""
return question_str return question_str
async def answer(self): async def run(self):
logging.info('Answering a PR question...') logging.info('Answering a PR question...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing answer...", is_temporary=True) self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing answer...') logging.info('Preparing answer...')
pr_comment = self._prepare_pr_answer() pr_comment = self._prepare_pr_answer()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing answer...') logging.info('Pushing answer...')
self.git_provider.publish_comment(pr_comment) self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
@ -65,9 +65,9 @@ class PRQuestions:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_questions_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_questions_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -77,6 +77,6 @@ class PRQuestions:
def _prepare_pr_answer(self) -> str: def _prepare_pr_answer(self) -> str:
answer_str = f"Question: {self.question_str}\n\n" answer_str = f"Question: {self.question_str}\n\n"
answer_str += f"Answer:\n{self.prediction.strip()}\n\n" answer_str += f"Answer:\n{self.prediction.strip()}\n\n"
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"answer_str:\n{answer_str}") logging.info(f"answer_str:\n{answer_str}")
return answer_str return answer_str

View File

@ -2,7 +2,7 @@ import copy
import json import json
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, List from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -10,9 +10,9 @@ from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language, IncrementalPR from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language
from pr_agent.servers.help import actions_help_text, bot_help_text from pr_agent.servers.help import actions_help_text, bot_help_text
@ -20,17 +20,16 @@ class PRReviewer:
""" """
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, cli_mode: bool = False, is_answer: bool = False, args: list = None): def __init__(self, pr_url: str, is_answer: bool = False, args: list = None):
""" """
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
Args: Args:
pr_url (str): The URL of the pull request to be reviewed. pr_url (str): The URL of the pull request to be reviewed.
cli_mode (bool, optional): Indicates whether the review is being done in command-line interface mode. Defaults to False.
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
""" """
self.parse_args(args) self.parse_args(args) # -i command
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
@ -40,11 +39,10 @@ class PRReviewer:
self.is_answer = is_answer self.is_answer = is_answer
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now") raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode
answer_str, question_str = self._get_user_answers() answer_str, question_str = self._get_user_answers()
self.vars = { self.vars = {
@ -53,20 +51,21 @@ class PRReviewer:
"description": self.git_provider.get_pr_description(), "description": self.git_provider.get_pr_description(),
"language": self.main_language, "language": self.main_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"require_score": settings.pr_reviewer.require_score_review, "require_score": get_settings().pr_reviewer.require_score_review,
"require_tests": settings.pr_reviewer.require_tests_review, "require_tests": get_settings().pr_reviewer.require_tests_review,
"require_security": settings.pr_reviewer.require_security_review, "require_security": get_settings().pr_reviewer.require_security_review,
"require_focused": settings.pr_reviewer.require_focused_review, "require_focused": get_settings().pr_reviewer.require_focused_review,
'num_code_suggestions': settings.pr_reviewer.num_code_suggestions, 'num_code_suggestions': get_settings().pr_reviewer.num_code_suggestions,
'question_str': question_str, 'question_str': question_str,
'answer_str': answer_str, 'answer_str': answer_str,
"extra_instructions": get_settings().pr_reviewer.extra_instructions,
} }
self.token_handler = TokenHandler( self.token_handler = TokenHandler(
self.git_provider.pr, self.git_provider.pr,
self.vars, self.vars,
settings.pr_review_prompt.system, get_settings().pr_review_prompt.system,
settings.pr_review_prompt.user get_settings().pr_review_prompt.user
) )
def parse_args(self, args: List[str]) -> None: def parse_args(self, args: List[str]) -> None:
@ -86,13 +85,13 @@ class PRReviewer:
is_incremental = True is_incremental = True
self.incremental = IncrementalPR(is_incremental) self.incremental = IncrementalPR(is_incremental)
async def review(self) -> None: async def run(self) -> None:
""" """
Review the pull request and generate feedback. Review the pull request and generate feedback.
""" """
logging.info('Reviewing PR...') logging.info('Reviewing PR...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True) self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
@ -100,12 +99,12 @@ class PRReviewer:
logging.info('Preparing PR review...') logging.info('Preparing PR review...')
pr_comment = self._prepare_pr_review() pr_comment = self._prepare_pr_review()
if settings.config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing PR review...') logging.info('Pushing PR review...')
self.git_provider.publish_comment(pr_comment) self.git_provider.publish_comment(pr_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
if settings.pr_reviewer.inline_code_comments: if get_settings().pr_reviewer.inline_code_comments:
logging.info('Pushing inline code comments...') logging.info('Pushing inline code comments...')
self._publish_inline_code_comments() self._publish_inline_code_comments()
@ -138,10 +137,10 @@ class PRReviewer:
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
@ -156,7 +155,8 @@ class PRReviewer:
def _prepare_pr_review(self) -> str: def _prepare_pr_review(self) -> str:
""" """
Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback. Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes
the feedback.
""" """
review = self.prediction.strip() review = self.prediction.strip()
@ -172,7 +172,8 @@ class PRReviewer:
data['PR Analysis']['Security concerns'] = val data['PR Analysis']['Security concerns'] = val
# Filter out code suggestions that can be submitted as inline comments # Filter out code suggestions that can be submitted as inline comments
if settings.config.git_provider != 'bitbucket' and settings.pr_reviewer.inline_code_comments and 'Code suggestions' in data['PR Feedback']: if get_settings().config.git_provider != 'bitbucket' and get_settings().pr_reviewer.inline_code_comments \
and 'Code suggestions' in data['PR Feedback']:
data['PR Feedback']['Code suggestions'] = [ data['PR Feedback']['Code suggestions'] = [
d for d in data['PR Feedback']['Code suggestions'] d for d in data['PR Feedback']['Code suggestions']
if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content')) if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content'))
@ -182,7 +183,8 @@ class PRReviewer:
# Add incremental review section # Add incremental review section
if self.incremental.is_incremental: if self.incremental.is_incremental:
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}" last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \
f"{self.git_provider.incremental.first_new_commit_sha}"
data = OrderedDict(data) data = OrderedDict(data)
data.update({'Incremental PR Review': { data.update({'Incremental PR Review': {
"⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}}) "⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}})
@ -192,7 +194,7 @@ class PRReviewer:
user = self.git_provider.get_user_id() user = self.git_provider.get_user_id()
# Add help text if not in CLI mode # Add help text if not in CLI mode
if not self.cli_mode: if not get_settings().get("CONFIG.CLI_MODE", False):
markdown_text += "\n### How to use\n" markdown_text += "\n### How to use\n"
if user and '[bot]' not in user: if user and '[bot]' not in user:
markdown_text += bot_help_text(user) markdown_text += bot_help_text(user)
@ -200,7 +202,7 @@ class PRReviewer:
markdown_text += actions_help_text markdown_text += actions_help_text
# Log markdown response if verbosity level is high # Log markdown response if verbosity level is high
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}") logging.info(f"Markdown response:\n{markdown_text}")
return markdown_text return markdown_text
@ -209,7 +211,7 @@ class PRReviewer:
""" """
Publishes inline comments on a pull request with code suggestions generated by the AI model. Publishes inline comments on a pull request with code suggestions generated by the AI model.
""" """
if settings.pr_reviewer.num_code_suggestions == 0: if get_settings().pr_reviewer.num_code_suggestions == 0:
return return
review = self.prediction.strip() review = self.prediction.strip()
@ -250,7 +252,7 @@ class PRReviewer:
if self.is_answer: if self.is_answer:
discussion_messages = self.git_provider.get_issue_comments() discussion_messages = self.git_provider.get_issue_comments()
for message in reversed(discussion_messages): for message in discussion_messages.reversed:
if "Questions to better understand the PR:" in message.body: if "Questions to better understand the PR:" in message.body:
question_str = message.body question_str = message.body
elif '/answer' in message.body: elif '/answer' in message.body:

View File

@ -9,8 +9,8 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider, GithubProvider from pr_agent.git_providers import GithubProvider, get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
CHANGELOG_LINES = 50 CHANGELOG_LINES = 50
@ -23,7 +23,7 @@ class PRUpdateChangelog:
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.commit_changelog = self._parse_args(args, settings) self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str self._get_changlog_file() # self.changelog_file_str
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
@ -37,22 +37,23 @@ class PRUpdateChangelog:
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"changelog_file_str": self.changelog_file_str, "changelog_file_str": self.changelog_file_str,
"today": date.today(), "today": date.today(),
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.git_provider.pr,
self.vars, self.vars,
settings.pr_update_changelog_prompt.system, get_settings().pr_update_changelog_prompt.system,
settings.pr_update_changelog_prompt.user) get_settings().pr_update_changelog_prompt.user)
async def update_changelog(self): async def run(self):
assert type(self.git_provider) == GithubProvider, "Currently only Github is supported" assert type(self.git_provider) == GithubProvider, "Currently only Github is supported"
logging.info('Updating the changelog...') logging.info('Updating the changelog...')
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True) self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR changelog updates...') logging.info('Preparing PR changelog updates...')
new_file_content, answer = self._prepare_changelog_update() new_file_content, answer = self._prepare_changelog_update()
if settings.config.publish_output: if get_settings().config.publish_output:
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
logging.info('Publishing changelog updates...') logging.info('Publishing changelog updates...')
if self.commit_changelog: if self.commit_changelog:
@ -72,9 +73,9 @@ class PRUpdateChangelog:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_update_changelog_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_update_changelog_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
@ -83,7 +84,7 @@ class PRUpdateChangelog:
return response return response
def _prepare_changelog_update(self) -> Tuple[str, str]: def _prepare_changelog_update(self) -> Tuple[str, str]:
answer = self.prediction.strip().strip("```").strip() answer = self.prediction.strip().strip("```").strip() # noqa B005
if hasattr(self, "changelog_file"): if hasattr(self, "changelog_file"):
existing_content = self.changelog_file.decoded_content.decode() existing_content = self.changelog_file.decoded_content.decode()
else: else:
@ -95,9 +96,9 @@ class PRUpdateChangelog:
if not self.commit_changelog: if not self.commit_changelog:
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \ answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
"\n>'/update_changelog -commit'\n" "\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
if settings.config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"answer:\n{answer}") logging.info(f"answer:\n{answer}")
return new_file_content, answer return new_file_content, answer
@ -117,7 +118,7 @@ class PRUpdateChangelog:
last_commit_id = list(self.git_provider.pr.get_commits())[-1] last_commit_id = list(self.git_provider.pr.get_commits())[-1]
try: try:
self.git_provider.pr.create_review(commit=last_commit_id, comments=[d]) self.git_provider.pr.create_review(commit=last_commit_id, comments=[d])
except: except Exception:
# we can't create a review for some reason, let's just publish a comment # we can't create a review for some reason, let's just publish a comment
self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}") self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}")
@ -137,19 +138,6 @@ Example:
""" """
return example_changelog return example_changelog
def _parse_args(self, args, setting):
commit_changelog = False
if args and len(args) >= 1:
try:
if args[0] == "-commit":
commit_changelog = True
except:
pass
else:
commit_changelog = setting.pr_update_changelog.push_changelog_changes
return commit_changelog
def _get_changlog_file(self): def _get_changlog_file(self):
try: try:
self.changelog_file = self.git_provider.repo_obj.get_contents("CHANGELOG.md", self.changelog_file = self.git_provider.repo_obj.get_contents("CHANGELOG.md",
@ -157,7 +145,7 @@ Example:
changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines() changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines()
changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES] changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES]
self.changelog_file_str = "\n".join(changelog_file_lines) self.changelog_file_str = "\n".join(changelog_file_lines)
except: except Exception:
self.changelog_file_str = "" self.changelog_file_str = ""
if self.commit_changelog: if self.commit_changelog:
logging.info("No CHANGELOG.md file found in the repository. Creating one...") logging.info("No CHANGELOG.md file found in the repository. Creating one...")

View File

@ -41,6 +41,7 @@ dependencies = [
"aiohttp~=3.8.4", "aiohttp~=3.8.4",
"atlassian-python-api==3.39.0", "atlassian-python-api==3.39.0",
"GitPython~=3.1.32", "GitPython~=3.1.32",
"starlette-context==0.3.6"
] ]
[project.urls] [project.urls]

View File

@ -2,7 +2,7 @@
import logging import logging
from pr_agent.algo.git_patch_processing import handle_patch_deletions from pr_agent.algo.git_patch_processing import handle_patch_deletions
from pr_agent.config_loader import settings from pr_agent.config_loader import get_settings
""" """
Code Analysis Code Analysis
@ -49,7 +49,7 @@ class TestHandlePatchDeletions:
original_file_content_str = 'foo\nbar\n' original_file_content_str = 'foo\nbar\n'
new_file_content_str = '' new_file_content_str = ''
file_name = 'file.py' file_name = 'file.py'
settings.config.verbosity_level = 1 get_settings().config.verbosity_level = 1
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file_name) handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file_name)