Compare commits

..

10 Commits

6 changed files with 135 additions and 11 deletions

View File

@ -14,6 +14,13 @@ class AiHandler:
openai.api_key = settings.openai.key
if settings.get("OPENAI.ORG", None):
openai.organization = settings.openai.org
self.deployment_id = settings.get("OPENAI.DEPLOYMENT_ID", None)
if settings.get("OPENAI.API_TYPE", None):
openai.api_type = settings.openai.api_type
if settings.get("OPENAI.API_VERSION", None):
openai.engine = settings.openai.api_version
if settings.get("OPENAI.API_BASE", None):
openai.api_base = settings.openai.api_base
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@ -23,6 +30,7 @@ class AiHandler:
try:
response = await openai.ChatCompletion.acreate(
model=model,
deployment_id=self.deployment_id,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user}

View File

@ -1,5 +1,8 @@
from __future__ import annotations
import json
import logging
import re
import textwrap
@ -61,3 +64,25 @@ def parse_code_suggestion(code_suggestions: dict) -> str:
markdown_text += "\n"
return markdown_text
def try_fix_json(review, max_iter=10):
# Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion
data = {}
if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0:
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
valid_json = False
iter_count = 0
while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter:
try:
data = json.loads(review[:last_code_suggestion_ind] + "]}}")
valid_json = True
review = review[:last_code_suggestion_ind].strip() + "]}}"
except json.decoder.JSONDecodeError:
review = review[:last_code_suggestion_ind]
# Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
iter_count += 1
if not valid_json:
logging.error("Unable to decode JSON response from AI")
data = {}
return data

View File

@ -15,11 +15,11 @@ def run():
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
if args.question:
print(f"Question: {args.question} about PR {args.pr_url}")
reviewer = PRQuestions(args.pr_url, args.question, installation_id=None)
reviewer = PRQuestions(args.pr_url, args.question)
asyncio.run(reviewer.answer())
else:
print(f"Reviewing PR: {args.pr_url}")
reviewer = PRReviewer(args.pr_url, installation_id=None, cli_mode=True)
reviewer = PRReviewer(args.pr_url, cli_mode=True)
asyncio.run(reviewer.review())

View File

@ -9,6 +9,11 @@
[openai]
key = "<API_KEY>" # Acquire through https://platform.openai.com
org = "<ORGANIZATION>" # Optional, may be commented out.
# Uncomment the following for Azure OpenAI
#api_type = "azure"
#api_version = '2023-05-15' # Check Azure documentation for the current API version
#api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
#deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine
[github]
# ---- Set the following only for deployment type == "user"

View File

@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@ -69,11 +69,7 @@ class PRReviewer:
model = settings.config.model
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
try:
json.loads(response)
except json.decoder.JSONDecodeError:
logging.warning("Could not decode JSON")
response = {}
return response
def _prepare_pr_review(self) -> str:
@ -81,8 +77,7 @@ class PRReviewer:
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
logging.error("Unable to decode JSON response from AI")
data = {}
data = try_fix_json(review)
# reordering for nicer display
if 'PR Feedback' in data:
@ -108,4 +103,4 @@ class PRReviewer:
if settings.config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}")
return markdown_text
return markdown_text

View File

@ -0,0 +1,91 @@
# Generated by CodiumAI
from pr_agent.algo.utils import try_fix_json
import pytest
class TestTryFixJson:
# Tests that JSON with complete 'Code suggestions' section returns expected output
def test_incomplete_code_suggestions(self):
review = '{"PR Analysis": {"Main theme": "xxx", "Description and title": "Yes", "Type of PR": "Bug fix"}, "PR Feedback": {"General PR suggestions": "..., `xxx`...", "Code suggestions": [{"suggestion number": 1, "relevant file": "xxx.py", "suggestion content": "xxx [important]"}, {"suggestion number": 2, "relevant file": "yyy.py", "suggestion content": "yyy [incomp...'
expected_output = {
'PR Analysis': {
'Main theme': 'xxx',
'Description and title': 'Yes',
'Type of PR': 'Bug fix'
},
'PR Feedback': {
'General PR suggestions': '..., `xxx`...',
'Code suggestions': [
{
'suggestion number': 1,
'relevant file': 'xxx.py',
'suggestion content': 'xxx [important]'
}
]
}
}
assert try_fix_json(review) == expected_output
def test_incomplete_code_suggestions_new_line(self):
review = '{"PR Analysis": {"Main theme": "xxx", "Description and title": "Yes", "Type of PR": "Bug fix"}, "PR Feedback": {"General PR suggestions": "..., `xxx`...", "Code suggestions": [{"suggestion number": 1, "relevant file": "xxx.py", "suggestion content": "xxx [important]"} \n\t, {"suggestion number": 2, "relevant file": "yyy.py", "suggestion content": "yyy [incomp...'
expected_output = {
'PR Analysis': {
'Main theme': 'xxx',
'Description and title': 'Yes',
'Type of PR': 'Bug fix'
},
'PR Feedback': {
'General PR suggestions': '..., `xxx`...',
'Code suggestions': [
{
'suggestion number': 1,
'relevant file': 'xxx.py',
'suggestion content': 'xxx [important]'
}
]
}
}
assert try_fix_json(review) == expected_output
def test_incomplete_code_suggestions_many_close_brackets(self):
review = '{"PR Analysis": {"Main theme": "xxx", "Description and title": "Yes", "Type of PR": "Bug fix"}, "PR Feedback": {"General PR suggestions": "..., `xxx`...", "Code suggestions": [{"suggestion number": 1, "relevant file": "xxx.py", "suggestion content": "xxx [important]"} \n, {"suggestion number": 2, "relevant file": "yyy.py", "suggestion content": "yyy }, [}\n ,incomp.} ,..'
expected_output = {
'PR Analysis': {
'Main theme': 'xxx',
'Description and title': 'Yes',
'Type of PR': 'Bug fix'
},
'PR Feedback': {
'General PR suggestions': '..., `xxx`...',
'Code suggestions': [
{
'suggestion number': 1,
'relevant file': 'xxx.py',
'suggestion content': 'xxx [important]'
}
]
}
}
assert try_fix_json(review) == expected_output
def test_incomplete_code_suggestions_relevant_file(self):
review = '{"PR Analysis": {"Main theme": "xxx", "Description and title": "Yes", "Type of PR": "Bug fix"}, "PR Feedback": {"General PR suggestions": "..., `xxx`...", "Code suggestions": [{"suggestion number": 1, "relevant file": "xxx.py", "suggestion content": "xxx [important]"}, {"suggestion number": 2, "relevant file": "yyy.p'
expected_output = {
'PR Analysis': {
'Main theme': 'xxx',
'Description and title': 'Yes',
'Type of PR': 'Bug fix'
},
'PR Feedback': {
'General PR suggestions': '..., `xxx`...',
'Code suggestions': [
{
'suggestion number': 1,
'relevant file': 'xxx.py',
'suggestion content': 'xxx [important]'
}
]
}
}
assert try_fix_json(review) == expected_output