Compare commits

...

9 Commits

23 changed files with 252 additions and 189 deletions

1
MANIFEST.in Normal file
View File

@ -0,0 +1 @@
recursive-include pr_agent/settings/ *.toml

View File

@ -1,19 +1,17 @@
from __future__ import annotations
import difflib
import logging
import re
import traceback
from typing import Any, Callable, List, Tuple
from typing import Callable, List, Tuple
from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
from pr_agent.git_providers.git_provider import GitProvider
DELETED_FILES_ = "Deleted files:\n"
@ -247,99 +245,6 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
return all_deployments
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str) -> Tuple[int, int]:
"""
Find the line number and absolute position of a relevant line in a file.
Args:
diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files.
relevant_file (str): The name of the file where the relevant line is located.
relevant_line_in_file (str): The content of the relevant line.
Returns:
Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file.
"""
position = -1
absolute_position = -1
re_hunk_header = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
for file in diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
# try to find the line in the patch using difflib, with some margin of error
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
patch_lines, n=3, cutoff=0.93)
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
relevant_line_in_file = matches_difflib[0]
delta = 0
start1, size1, start2, size2 = 0, 0, 0, 0
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if relevant_line_in_file in line and line[0] != '-':
position = i
absolute_position = start2 + delta - 1
break
if position == -1 and relevant_line_in_file[0] == '+':
no_plus_line = relevant_line_in_file[1:].lstrip()
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if no_plus_line in line and line[0] != '-':
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
position = i
absolute_position = start2 + delta - 1
break
return position, absolute_position
def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
Returns:
str: The clipped string.
"""
if not text:
return text
try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text
except Exception as e:
logging.warning(f"Failed to clip tokens: {e}")
return text
def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,

View File

@ -21,7 +21,7 @@ class TokenHandler:
method.
"""
def __init__(self, pr, vars: dict, system, user):
def __init__(self, vars: dict, system, user):
"""
Initializes the TokenHandler object.
@ -32,9 +32,9 @@ class TokenHandler:
- user: The user string.
"""
self.encoder = get_token_encoder()
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
self.prompt_tokens = self._get_system_user_tokens(self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
def _get_system_user_tokens(self, encoder, vars: dict, system, user):
"""
Calculates the number of tokens in the system and user strings.

View File

@ -5,14 +5,24 @@ import json
import logging
import re
import textwrap
from dataclasses import dataclass
from datetime import datetime
from typing import Any, List
from enum import Enum
from typing import Any, List, Tuple, Optional
import yaml
from starlette_context import context
from pr_agent.algo.token_handler import get_token_encoder
from pr_agent.config_loader import get_settings, global_settings
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4
def get_setting(key: str) -> Any:
try:
key = key.upper()
@ -294,3 +304,108 @@ def try_fix_yaml(review_text: str) -> dict:
except:
pass
return data
def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
Returns:
str: The clipped string.
"""
if not text:
return text
try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text
except Exception as e:
logging.warning(f"Failed to clip tokens: {e}")
return text
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
old_filename: str = None
language: Optional[str] = None
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str) -> Tuple[int, int]:
"""
Find the line number and absolute position of a relevant line in a file.
Args:
diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files.
relevant_file (str): The name of the file where the relevant line is located.
relevant_line_in_file (str): The content of the relevant line.
Returns:
Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file.
"""
position = -1
absolute_position = -1
re_hunk_header = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
for file in diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
# try to find the line in the patch using difflib, with some margin of error
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
patch_lines, n=3, cutoff=0.93)
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
relevant_line_in_file = matches_difflib[0]
delta = 0
start1, size1, start2, size2 = 0, 0, 0, 0
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if relevant_line_in_file in line and line[0] != '-':
position = i
absolute_position = start2 + delta - 1
break
if position == -1 and relevant_line_in_file[0] == '+':
no_plus_line = relevant_line_in_file[1:].lstrip()
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if no_plus_line in line and line[0] != '-':
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
position = i
absolute_position = start2 + delta - 1
break
return position, absolute_position

View File

@ -13,11 +13,9 @@ try:
except ImportError:
AZURE_DEVOPS_AVAILABLE = False
from ..algo.pr_processing import clip_tokens
from ..config_loader import get_settings
from ..algo.utils import load_large_diff
from ..algo.utils import load_large_diff, FilePatchInfo, EDIT_TYPE, clip_tokens
from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo
class AzureDevopsProvider:

View File

@ -8,7 +8,8 @@ from atlassian.bitbucket import Cloud
from starlette_context import context
from ..config_loader import get_settings
from .git_provider import FilePatchInfo, GitProvider
from .git_provider import GitProvider
from ..algo.utils import FilePatchInfo
class BitbucketProvider(GitProvider):

View File

@ -1,6 +1,9 @@
import boto3
import botocore
try: # Allow this module to be imported without requiring boto3
import boto3
import botocore
except ModuleNotFoundError:
boto3 = None
botocore = None
class CodeCommitDifferencesResponse:
"""

View File

@ -4,13 +4,12 @@ from collections import Counter
from typing import List, Optional, Tuple
from urllib.parse import urlparse
from ..algo.language_handler import is_valid_file, language_extension_map
from ..algo.pr_processing import clip_tokens
from ..algo.utils import load_large_diff
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider, IncrementalPR
from pr_agent.git_providers.codecommit_client import CodeCommitClient
from ..algo.language_handler import is_valid_file, language_extension_map
from ..algo.utils import EDIT_TYPE, FilePatchInfo, load_large_diff
from .git_provider import GitProvider
class PullRequestCCMimic:
"""

View File

@ -1,28 +1,10 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from enum import Enum
from typing import Optional
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
old_filename: str = None
from pr_agent.algo.utils import FilePatchInfo
class GitProvider(ABC):
@ -88,7 +70,7 @@ class GitProvider(ABC):
def get_pr_description(self) -> str:
from pr_agent.config_loader import get_settings
from pr_agent.algo.pr_processing import clip_tokens
from pr_agent.algo.utils import clip_tokens
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
description = self.get_pr_description_full()
if max_tokens:

View File

@ -9,10 +9,9 @@ from github import AppAuthentication, Auth, Github, GithubException, Reaction
from retry import retry
from starlette_context import context
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from .git_provider import GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens
from ..algo.utils import load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file, FilePatchInfo
from ..config_loader import get_settings
from ..servers.utils import RateLimitExceeded

View File

@ -7,10 +7,9 @@ import gitlab
from gitlab import GitlabGetError
from ..algo.language_handler import is_valid_file
from ..algo.pr_processing import clip_tokens
from ..algo.utils import load_large_diff
from ..algo.utils import load_large_diff, clip_tokens, EDIT_TYPE, FilePatchInfo
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
from .git_provider import GitProvider
logger = logging.getLogger()

View File

@ -0,0 +1,79 @@
import itertools
from collections import Counter
from typing import List, Optional
from pr_agent.algo.utils import FilePatchInfo
from pr_agent.git_providers.git_provider import GitProvider
class InMemoryProvider(GitProvider):
def __init__(self, head_branch: str, target_branch: str, files: List[FilePatchInfo]):
self.head_branch = head_branch
self.target_branch = target_branch
self.files = files
def is_supported(self, capability: str) -> bool:
pass
def get_files(self) -> list[FilePatchInfo]:
return self.files
def get_diff_files(self) -> list[FilePatchInfo]:
return self.get_files()
def publish_description(self, pr_title: str, pr_body: str):
pass
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass
def publish_inline_comments(self, comments: list[dict]):
pass
def publish_code_suggestions(self, code_suggestions: list) -> bool:
pass
def publish_labels(self, labels):
pass
def get_labels(self):
pass
def remove_initial_comment(self):
pass
def get_languages(self):
language_count = Counter(file.language for file in self.files)
return dict(language_count)
def get_pr_branch(self):
pass
def get_user_id(self):
pass
def get_pr_description_full(self) -> str:
pass
def get_issue_comments(self):
pass
def get_repo_settings(self):
pass
def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]:
pass
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
pass
def get_commit_messages(self):
pass

View File

@ -6,7 +6,8 @@ from typing import List
from git import Repo
from pr_agent.config_loader import _find_repository_root, get_settings
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
from pr_agent.git_providers.git_provider import GitProvider
from pr_agent.algo.utils import EDIT_TYPE, FilePatchInfo
class PullRequestMimic:

View File

@ -42,9 +42,7 @@ class PRCodeSuggestions:
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_code_suggestions_prompt.system,
self.token_handler = TokenHandler(self.vars, get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user)
async def run(self):

View File

@ -46,12 +46,8 @@ class PRDescription:
self.user_description = self.git_provider.get_user_description()
# Initialize the token handler
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_description_prompt.system,
get_settings().pr_description_prompt.user,
)
self.token_handler = TokenHandler(self.vars, get_settings().pr_description_prompt.system,
get_settings().pr_description_prompt.user)
# Initialize patches_diff and prediction attributes
self.patches_diff = None

View File

@ -26,9 +26,7 @@ class PRInformationFromUser:
"diff": "", # empty diff for initial calculation
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_information_from_user_prompt.system,
self.token_handler = TokenHandler(self.vars, get_settings().pr_information_from_user_prompt.system,
get_settings().pr_information_from_user_prompt.user)
self.patches_diff = None
self.prediction = None

View File

@ -29,9 +29,7 @@ class PRQuestions:
"questions": self.question_str,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_questions_prompt.system,
self.token_handler = TokenHandler(self.vars, get_settings().pr_questions_prompt.system,
get_settings().pr_questions_prompt.user)
self.patches_diff = None
self.prediction = None

View File

@ -9,8 +9,7 @@ from jinja2 import Environment, StrictUndefined
from yaml import SafeLoader
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \
find_line_number_of_relevant_line_in_file, clip_tokens
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json, try_fix_yaml, load_yaml
from pr_agent.config_loader import get_settings
@ -66,12 +65,8 @@ class PRReviewer:
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_review_prompt.system,
get_settings().pr_review_prompt.user
)
self.token_handler = TokenHandler(self.vars, get_settings().pr_review_prompt.system,
get_settings().pr_review_prompt.user)
def parse_args(self, args: List[str]) -> None:
"""
@ -217,8 +212,8 @@ class PRReviewer:
markdown_text = convert_to_markdown(data)
user = self.git_provider.get_user_id()
# Add help text if not in CLI mode
if not get_settings().get("CONFIG.CLI_MODE", False):
# Add help text if not in CLI§ mode
if not get_settings().get("CONFIG.CLI§_MODE", False):
markdown_text += "\n### How to use\n"
if user and '[bot]' not in user:
markdown_text += bot_help_text(user)

View File

@ -40,9 +40,7 @@ class PRUpdateChangelog:
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_update_changelog_prompt.system,
self.token_handler = TokenHandler(self.vars, get_settings().pr_update_changelog_prompt.system,
get_settings().pr_update_changelog_prompt.user)
async def run(self):

View File

@ -35,12 +35,12 @@ dependencies = {file = ["requirements.txt"]}
"Homepage" = "https://github.com/Codium-ai/pr-agent"
[tool.setuptools]
include-package-data = false
include-package-data = true
license-files = ["LICENSE"]
[tool.setuptools.packages.find]
where = ["."]
include = ["pr_agent"]
include = ["pr_agent", "pr_agent.*"]
[project.scripts]
pr-agent = "pr_agent.cli:run"

View File

@ -1,19 +1,19 @@
dynaconf==3.1.12
fastapi==0.99.0
PyGithub==1.59.*
retry==0.9.2
openai==0.27.8
Jinja2==3.1.2
tiktoken==0.4.0
uvicorn==0.22.0
python-gitlab==3.15.0
dynaconf~=3.1.12
fastapi~=0.103.0
PyGithub~=1.59.0
retry~=0.9.2
openai~=0.27.8
Jinja2~=3.1.2
tiktoken~=0.4.0
uvicorn~=0.22.0
python-gitlab~=3.15.0
pytest~=7.4.0
aiohttp~=3.8.4
atlassian-python-api==3.39.0
atlassian-python-api~=3.39.0
GitPython~=3.1.32
PyYAML==6.0
starlette-context==0.3.6
PyYAML~=6.0
starlette-context~=0.3.6
litellm~=0.1.445
boto3~=1.28.25
google-cloud-storage==2.10.0
ujson==5.8.0
google-cloud-storage~=2.10.0
ujson~=5.8.0

View File

@ -1,7 +1,7 @@
import pytest
from pr_agent.git_providers.codecommit_provider import CodeCommitFile
from pr_agent.git_providers.codecommit_provider import CodeCommitProvider
from pr_agent.git_providers.git_provider import EDIT_TYPE
from pr_agent.algo.utils import EDIT_TYPE
class TestCodeCommitFile:

View File

@ -1,8 +1,6 @@
# Generated by CodiumAI
from pr_agent.git_providers.git_provider import FilePatchInfo
from pr_agent.algo.pr_processing import find_line_number_of_relevant_line_in_file
from pr_agent.algo.utils import FilePatchInfo, find_line_number_of_relevant_line_in_file
import pytest