Compare commits

..

26 Commits

Author SHA1 Message Date
d23daf880f Change gitlab API to use oauth_token instead of PAT (PAT shuold work as well) 2023-07-25 13:58:48 +03:00
adb3f17258 Merge pull request #131 from Codium-ai/ok/gitlab_webook
GitLab Webhook Integration and Provider Enhancements
2023-07-24 16:01:17 +03:00
2c03a67312 Add labels 2023-07-24 16:00:51 +03:00
55eb741965 Merge pull request #125 from Codium-ai/tr/code_enhancment
Code Enhancement in PR Agent
2023-07-24 15:37:53 +03:00
c9c95d60d4 Implement gitlab webhook 2023-07-24 15:05:24 +03:00
cca809e91c run_action 2023-07-24 12:45:24 +03:00
57ff46ecc1 stable 2023-07-24 12:41:00 +03:00
3819d52eb0 Merge remote-tracking branch 'origin/tr/code_enhancment' into tr/code_enhancment 2023-07-24 12:15:17 +03:00
3072325d2c PRDescription 2023-07-24 12:14:53 +03:00
abca2fdcb7 Merge remote-tracking branch 'origin/main' into tr/code_enhancment 2023-07-24 12:04:54 +03:00
4d84f76948 _get_prediction 2023-07-24 11:31:35 +03:00
dd8f6eb923 Merge pull request #126 from Codium-ai/ok/preserve_labels
Add functionality to preserve existing labels in PRs
2023-07-24 10:22:51 +03:00
b9c25e487a On /describe, preserve the current labels 2023-07-24 10:17:26 +03:00
1bf27c38a7 _prepare_pr_answer 2023-07-24 09:15:45 +03:00
1f987380ed Merge pull request #124 from Xyand/bugfix/mising_model
Bugfix - missing function argument
2023-07-24 07:36:21 +03:00
cd8bbbf889 bugfix 2023-07-24 00:58:21 +03:00
8e5498ee97 Merge pull request #122 from Codium-ai/update-readme-gifs-2
Update README.md
2023-07-23 17:40:26 +03:00
0412d7aca0 Update README.md 2023-07-23 17:38:08 +03:00
1eac3245d9 Merge pull request #121 from Codium-ai/update-gifs
Update GIF URLs in README
2023-07-23 17:33:47 +03:00
cd51bef7f7 Merge pull request #119 from zmeir/zmeir-code_suggestions_single_api_call
Optimize Code Suggestions API Calls
2023-07-23 17:30:37 +03:00
e8aa33fa0b Update README.md 2023-07-23 17:27:26 +03:00
54b021b02c Merge pull request #120 from Codium-ai/ok/remove_gitlab_polling
Temporarily remove gitlab polling server until a rewrite is ready
2023-07-23 17:07:59 +03:00
32151e3d9a Temporarily remove gitlab polling server until a rewrite is ready 2023-07-23 17:04:41 +03:00
32358678e6 Reduce the number of GitHub API calls when pushing code suggestions 2023-07-23 16:59:08 +03:00
42e32664a1 Merge pull request #118 from Codium-ai/ok/fallback_models
Handling exceptions in fallback models
2023-07-23 16:43:30 +03:00
321f7bce46 Merge pull request #117 from Codium-ai/ok/fallback_models
Implementing Fallback Models for Tokenization
2023-07-23 16:20:10 +03:00
10 changed files with 217 additions and 131 deletions

View File

@ -30,31 +30,31 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull
<h4>/describe:</h4> <h4>/describe:</h4>
<div align="center"> <div align="center">
<p float="center"> <p float="center">
<img src="https://www.codium.ai/wp-content/uploads/2023/07/describe.gif" width="800"> <img src="https://www.codium.ai/images/describe-2.gif" width="800">
</p> </p>
</div> </div>
<h4>/review:</h4> <h4>/review:</h4>
<div align="center"> <div align="center">
<p float="center"> <p float="center">
<img src="https://www.codium.ai/wp-content/uploads/2023/07/review.gif" width="800"> <img src="https://www.codium.ai/images/review-2.gif" width="800">
</p> </p>
</div> </div>
<h4>/reflect and review:</h4> <h4>/reflect_and_review:</h4>
<div align="center"> <div align="center">
<p float="center"> <p float="center">
<img src="https://www.codium.ai/wp-content/uploads/2023/07/reflect_and_review.gif" width="800"> <img src="https://www.codium.ai/images/reflect_and_review.gif" width="800">
</p> </p>
</div> </div>
<h4>/ask:</h4> <h4>/ask:</h4>
<div align="center"> <div align="center">
<p float="center"> <p float="center">
<img src="https://www.codium.ai/wp-content/uploads/2023/07/ask.gif" width="800"> <img src="https://www.codium.ai/images/ask-2.gif" width="800">
</p> </p>
</div> </div>
<h4>/improve:</h4> <h4>/improve:</h4>
<div align="center"> <div align="center">
<p float="center"> <p float="center">
<img src="https://www.codium.ai/wp-content/uploads/2023/07/improve-1.gif" width="800"> <img src="https://www.codium.ai/images/improve-2.gif" width="800">
</p> </p>
</div> </div>
<div align="left"> <div align="left">
@ -83,7 +83,8 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull
| | Reflect and Review | :white_check_mark: | | | | | Reflect and Review | :white_check_mark: | | |
| | | | | | | | | | | |
| USAGE | CLI | :white_check_mark: | :white_check_mark: | :white_check_mark: | | USAGE | CLI | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| | Tagging bot | :white_check_mark: | :white_check_mark: | | | | App / webhook | :white_check_mark: | :white_check_mark: | |
| | Tagging bot | :white_check_mark: | | |
| | Actions | :white_check_mark: | | | | | Actions | :white_check_mark: | | |
| | | | | | | | | | | |
| CORE | PR compression | :white_check_mark: | :white_check_mark: | :white_check_mark: | | CORE | PR compression | :white_check_mark: | :white_check_mark: | :white_check_mark: |
@ -106,7 +107,7 @@ In the [configuration](./CONFIGURATION.md) file you can select your git provider
Try GPT-4 powered PR-Agent on your public GitHub repository for free. Just mention `@CodiumAI-Agent` and add the desired command in any PR comment! The agent will generate a response based on your command. Try GPT-4 powered PR-Agent on your public GitHub repository for free. Just mention `@CodiumAI-Agent` and add the desired command in any PR comment! The agent will generate a response based on your command.
![Review generation process](https://www.codium.ai/wp-content/uploads/2023/07/demo.gif) ![Review generation process](https://www.codium.ai/images/demo-2.gif)
To set up your own PR-Agent, see the [Installation](#installation) section To set up your own PR-Agent, see the [Installation](#installation) section

View File

@ -55,7 +55,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
# if we are over the limit, start pruning # if we are over the limit, start pruning
patches_compressed, modified_file_names, deleted_file_names = \ patches_compressed, modified_file_names, deleted_file_names = \
pr_generate_compressed_diff(pr_languages, token_handler, add_line_numbers_to_hunks) pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
final_diff = "\n".join(patches_compressed) final_diff = "\n".join(patches_compressed)
if modified_file_names: if modified_file_names:

View File

@ -27,7 +27,7 @@ class BitbucketProvider:
self.set_pr(pr_url) self.set_pr(pr_url)
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments']: if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels']:
return False return False
return True return True

View File

@ -60,6 +60,10 @@ class GitProvider(ABC):
def publish_labels(self, labels): def publish_labels(self, labels):
pass pass
@abstractmethod
def get_labels(self):
pass
@abstractmethod @abstractmethod
def remove_initial_comment(self): def remove_initial_comment(self):
pass pass

View File

@ -152,10 +152,8 @@ class GithubProvider(GitProvider):
def publish_code_suggestions(self, code_suggestions: list): def publish_code_suggestions(self, code_suggestions: list):
""" """
Publishes code suggestions as comments on the PR. Publishes code suggestions as comments on the PR.
In practice current APU enables to send only one code suggestion per comment. Might change in the future.
""" """
post_parameters_list = [] post_parameters_list = []
import github.PullRequestComment
for suggestion in code_suggestions: for suggestion in code_suggestions:
body = suggestion['body'] body = suggestion['body']
relevant_file = suggestion['relevant_file'] relevant_file = suggestion['relevant_file']
@ -178,7 +176,6 @@ class GithubProvider(GitProvider):
if relevant_lines_end > relevant_lines_start: if relevant_lines_end > relevant_lines_start:
post_parameters = { post_parameters = {
"body": body, "body": body,
"commit_id": self.last_commit_id._identity,
"path": relevant_file, "path": relevant_file,
"line": relevant_lines_end, "line": relevant_lines_end,
"start_line": relevant_lines_start, "start_line": relevant_lines_start,
@ -187,19 +184,14 @@ class GithubProvider(GitProvider):
else: # API is different for single line comments else: # API is different for single line comments
post_parameters = { post_parameters = {
"body": body, "body": body,
"commit_id": self.last_commit_id._identity,
"path": relevant_file, "path": relevant_file,
"line": relevant_lines_start, "line": relevant_lines_start,
"side": "RIGHT", "side": "RIGHT",
} }
post_parameters_list.append(post_parameters)
try: try:
headers, data = self.pr._requester.requestJsonAndCheck( self.pr.create_review(commit=self.last_commit_id, comments=post_parameters_list)
"POST", f"{self.pr.url}/comments", input=post_parameters
)
github.PullRequestComment.PullRequestComment(
self.pr._requester, headers, data, completed=True
)
return True return True
except Exception as e: except Exception as e:
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
@ -330,5 +322,12 @@ class GithubProvider(GitProvider):
headers, data = self.pr._requester.requestJsonAndCheck( headers, data = self.pr._requester.requestJsonAndCheck(
"PUT", f"{self.pr.issue_url}/labels", input=post_parameters "PUT", f"{self.pr.issue_url}/labels", input=post_parameters
) )
except: except Exception as e:
logging.exception("Failed to publish labels") logging.exception(f"Failed to publish labels, error: {e}")
def get_labels(self):
try:
return [label.name for label in self.pr.labels]
except Exception as e:
logging.exception(f"Failed to get labels, error: {e}")
return []

View File

@ -8,11 +8,12 @@ from gitlab import GitlabGetError
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
class GitLabProvider(GitProvider): class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = settings.get("GITLAB.URL", None) gitlab_url = settings.get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
@ -21,8 +22,8 @@ class GitLabProvider(GitProvider):
if not gitlab_access_token: if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file") raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab( self.gl = gitlab.Gitlab(
gitlab_url, url=gitlab_url,
gitlab_access_token oauth_token=gitlab_access_token
) )
self.id_project = None self.id_project = None
self.id_mr = None self.id_mr = None
@ -112,7 +113,7 @@ class GitLabProvider(GitProvider):
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
raise NotImplementedError("Gitlab provider does not support creating inline comments yet") raise NotImplementedError("Gitlab provider does not support creating inline comments yet")
def create_inline_comment(self, comments: list[dict]): def create_inline_comments(self, comments: list[dict]):
raise NotImplementedError("Gitlab provider does not support publishing inline comments yet") raise NotImplementedError("Gitlab provider does not support publishing inline comments yet")
def send_inline_comment(self, body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, def send_inline_comment(self, body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
@ -258,8 +259,15 @@ class GitLabProvider(GitProvider):
def get_user_id(self): def get_user_id(self):
return None return None
def publish_labels(self, labels): def publish_labels(self, pr_types):
pass try:
self.mr.labels = list(set(pr_types))
self.mr.save()
except Exception as e:
logging.exception(f"Failed to publish labels, error: {e}")
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
pass pass
def get_labels(self):
return self.mr.labels

View File

@ -8,46 +8,57 @@ from pr_agent.tools.pr_reviewer import PRReviewer
async def run_action(): async def run_action():
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME', None) # Get environment variables
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME')
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH')
OPENAI_KEY = os.environ.get('OPENAI_KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
# Check if required environment variables are set
if not GITHUB_EVENT_NAME: if not GITHUB_EVENT_NAME:
print("GITHUB_EVENT_NAME not set") print("GITHUB_EVENT_NAME not set")
return return
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH', None)
if not GITHUB_EVENT_PATH: if not GITHUB_EVENT_PATH:
print("GITHUB_EVENT_PATH not set") print("GITHUB_EVENT_PATH not set")
return return
try:
event_payload = json.load(open(GITHUB_EVENT_PATH, 'r'))
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
OPENAI_KEY = os.environ.get('OPENAI_KEY', None)
if not OPENAI_KEY: if not OPENAI_KEY:
print("OPENAI_KEY not set") print("OPENAI_KEY not set")
return return
OPENAI_ORG = os.environ.get('OPENAI_ORG', None)
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN', None)
if not GITHUB_TOKEN: if not GITHUB_TOKEN:
print("GITHUB_TOKEN not set") print("GITHUB_TOKEN not set")
return return
# Set the environment variables in the settings
settings.set("OPENAI.KEY", OPENAI_KEY) settings.set("OPENAI.KEY", OPENAI_KEY)
if OPENAI_ORG: if OPENAI_ORG:
settings.set("OPENAI.ORG", OPENAI_ORG) settings.set("OPENAI.ORG", OPENAI_ORG)
settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN) settings.set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
settings.set("GITHUB.DEPLOYMENT_TYPE", "user") settings.set("GITHUB.DEPLOYMENT_TYPE", "user")
# Load the event payload
try:
with open(GITHUB_EVENT_PATH, 'r') as f:
event_payload = json.load(f)
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
# Handle pull request event
if GITHUB_EVENT_NAME == "pull_request": if GITHUB_EVENT_NAME == "pull_request":
action = event_payload.get("action", None) action = event_payload.get("action")
if action in ["opened", "reopened"]: if action in ["opened", "reopened"]:
pr_url = event_payload.get("pull_request", {}).get("url", None) pr_url = event_payload.get("pull_request", {}).get("url")
if pr_url: if pr_url:
await PRReviewer(pr_url).review() await PRReviewer(pr_url).review()
# Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment": elif GITHUB_EVENT_NAME == "issue_comment":
action = event_payload.get("action", None) action = event_payload.get("action")
if action in ["created", "edited"]: if action in ["created", "edited"]:
comment_body = event_payload.get("comment", {}).get("body", None) comment_body = event_payload.get("comment", {}).get("body")
if comment_body: if comment_body:
pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url", None) pr_url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
if pr_url: if pr_url:
body = comment_body.strip().lower() body = comment_body.strip().lower()
await PRAgent().handle_request(pr_url, body) await PRAgent().handle_request(pr_url, body)

View File

@ -1,64 +0,0 @@
import asyncio
import time
import gitlab
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
gl = gitlab.Gitlab(
settings.get("GITLAB.URL"),
private_token=settings.get("GITLAB.PERSONAL_ACCESS_TOKEN")
)
# Set the list of projects to monitor
projects_to_monitor = settings.get("GITLAB.PROJECTS_TO_MONITOR")
magic_word = settings.get("GITLAB.MAGIC_WORD")
# Hold the previous seen comments
previous_comments = set()
def check_comments():
print('Polling')
new_comments = {}
for project in projects_to_monitor:
project = gl.projects.get(project)
merge_requests = project.mergerequests.list(state='opened')
for mr in merge_requests:
notes = mr.notes.list(get_all=True)
for note in notes:
if note.id not in previous_comments and note.body.startswith(magic_word):
new_comments[note.id] = dict(
body=note.body[len(magic_word):],
project=project.name,
mr=mr
)
previous_comments.add(note.id)
print(f"New comment in project {project.name}, merge request {mr.title}: {note.body}")
return new_comments
def handle_new_comments(new_comments):
print('Handling new comments')
agent = PRAgent()
for _, comment in new_comments.items():
print(f"Handling comment: {comment['body']}")
asyncio.run(agent.handle_request(comment['mr'].web_url, comment['body']))
def run():
assert settings.get('CONFIG.GIT_PROVIDER') == 'gitlab', 'This script is only for GitLab'
# Initial run to populate previous_comments
check_comments()
# Run the check every minute
while True:
time.sleep(settings.get("GITLAB.POLLING_INTERVAL_SECONDS"))
new_comments = check_comments()
if new_comments:
handle_new_comments(new_comments)
if __name__ == '__main__':
run()

View File

@ -0,0 +1,47 @@
import logging
import uvicorn
from fastapi import APIRouter, FastAPI, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from starlette.background import BackgroundTasks
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import settings
app = FastAPI()
router = APIRouter()
@router.post("/webhook")
async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
data = await request.json()
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']:
logging.info(f"A merge request has been opened: {data['object_attributes'].get('title')}")
url = data['object_attributes'].get('url')
background_tasks.add_task(PRAgent().handle_request, url, "/review")
elif data.get('object_kind') == 'note' and data['event_type'] == 'note':
if 'merge_request' in data:
mr = data['merge_request']
url = mr.get('url')
body = data.get('object_attributes', {}).get('note')
background_tasks.add_task(PRAgent().handle_request, url, body)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
def start():
gitlab_url = settings.get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GITLAB.URL is not set")
gitlab_token = settings.get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_token:
raise ValueError("GITLAB.PERSONAL_ACCESS_TOKEN is not set")
settings.config.git_provider = "gitlab"
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=3000)
if __name__ == '__main__':
start()

View File

@ -1,6 +1,7 @@
import copy import copy
import json import json
import logging import logging
from typing import Tuple, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -14,11 +15,22 @@ from pr_agent.git_providers.git_provider import get_main_pr_language
class PRDescription: class PRDescription:
def __init__(self, pr_url: str): def __init__(self, pr_url: str):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model.
Args:
pr_url (str): The URL of the pull request.
"""
# Initialize the git provider and main PR language
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
# Initialize the AI handler
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
# Initialize the variables dictionary
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(), "branch": self.git_provider.get_pr_branch(),
@ -26,67 +38,135 @@ class PRDescription:
"language": self.main_pr_language, "language": self.main_pr_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
} }
self.token_handler = TokenHandler(self.git_provider.pr,
# Initialize the token handler
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars, self.vars,
settings.pr_description_prompt.system, settings.pr_description_prompt.system,
settings.pr_description_prompt.user) settings.pr_description_prompt.user,
)
# Initialize patches_diff and prediction attributes
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
async def describe(self): async def describe(self):
"""
Generates a PR description using an AI model and publishes it to the PR.
"""
logging.info('Generating a PR description...') logging.info('Generating a PR description...')
if settings.config.publish_output: if settings.config.publish_output:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True) self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing answer...') logging.info('Preparing answer...')
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer() pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
if settings.config.publish_output: if settings.config.publish_output:
logging.info('Pushing answer...') logging.info('Pushing answer...')
if settings.pr_description.publish_description_as_comment: if settings.pr_description.publish_description_as_comment:
self.git_provider.publish_comment(markdown_text) self.git_provider.publish_comment(markdown_text)
else: else:
self.git_provider.publish_description(pr_title, pr_body) self.git_provider.publish_description(pr_title, pr_body)
self.git_provider.publish_labels(pr_types) if self.git_provider.is_supported("get_labels"):
current_labels = self.git_provider.get_labels()
if current_labels is None:
current_labels = []
self.git_provider.publish_labels(pr_types + current_labels)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
return "" return ""
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str) -> None:
"""
Prepare the AI prediction for the PR description based on the provided model.
Args:
model (str): The name of the model to be used for generating the prediction.
Returns:
None
Raises:
Any exceptions raised by the 'get_pr_diff' and '_get_prediction' functions.
"""
logging.info('Getting PR diff...') logging.info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
logging.info('Getting AI prediction...') logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction(model) self.prediction = await self._get_prediction(model)
async def _get_prediction(self, model: str): async def _get_prediction(self, model: str) -> str:
"""
Generate an AI prediction for the PR description based on the provided model.
Args:
model (str): The name of the model to be used for generating the prediction.
Returns:
str: The generated AI prediction.
"""
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables) system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables) user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables)
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt) response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=0.2,
system=system_prompt,
user=user_prompt
)
return response return response
def _prepare_pr_answer(self): def _prepare_pr_answer(self) -> Tuple[str, str, List[str], str]:
"""
Prepare the PR description based on the AI prediction data.
Returns:
- title: a string containing the PR title.
- pr_body: a string containing the PR body in a markdown format.
- pr_types: a list of strings containing the PR types.
- markdown_text: a string containing the AI prediction data in a markdown format.
"""
# Load the AI prediction data into a dictionary
data = json.loads(self.prediction) data = json.loads(self.prediction)
markdown_text = ""
# Initialization
markdown_text = pr_body = ""
pr_types = []
# Iterate over the dictionary items and append the key and value to 'markdown_text' in a markdown format
for key, value in data.items(): for key, value in data.items():
markdown_text += f"## {key}\n\n" markdown_text += f"## {key}\n\n"
markdown_text += f"{value}\n\n" markdown_text += f"{value}\n\n"
pr_body = ""
pr_types = [] # If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
if 'PR Type' in data: if 'PR Type' in data:
pr_types = data['PR Type'].split(',') pr_types = data['PR Type'].split(',')
title = data['PR Title']
del data['PR Title'] # Assign the value of the 'PR Title' key to 'title' variable and remove it from the dictionary
title = data.pop('PR Title')
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
# except for the items containing the word 'walkthrough'
for key, value in data.items(): for key, value in data.items():
pr_body += f"{key}:\n" pr_body += f"{key}:\n"
if 'walkthrough' in key.lower(): if 'walkthrough' in key.lower():
pr_body += f"{value}\n" pr_body += f"{value}\n"
else: else:
pr_body += f"**{value}**\n\n___\n" pr_body += f"**{value}**\n\n___\n"
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}") logging.info(f"title:\n{title}\n{pr_body}")
return title, pr_body, pr_types, markdown_text return title, pr_body, pr_types, markdown_text