Compare commits

...

9 Commits

23 changed files with 252 additions and 189 deletions

1
MANIFEST.in Normal file
View File

@ -0,0 +1 @@
recursive-include pr_agent/settings/ *.toml

View File

@ -1,19 +1,17 @@
from __future__ import annotations from __future__ import annotations
import difflib
import logging import logging
import re
import traceback import traceback
from typing import Any, Callable, List, Tuple from typing import Callable, List, Tuple
from github import RateLimitExceededException from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider from pr_agent.git_providers.git_provider import GitProvider
DELETED_FILES_ = "Deleted files:\n" DELETED_FILES_ = "Deleted files:\n"
@ -247,99 +245,6 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
return all_deployments return all_deployments
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str) -> Tuple[int, int]:
"""
Find the line number and absolute position of a relevant line in a file.
Args:
diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files.
relevant_file (str): The name of the file where the relevant line is located.
relevant_line_in_file (str): The content of the relevant line.
Returns:
Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file.
"""
position = -1
absolute_position = -1
re_hunk_header = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
for file in diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
# try to find the line in the patch using difflib, with some margin of error
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
patch_lines, n=3, cutoff=0.93)
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
relevant_line_in_file = matches_difflib[0]
delta = 0
start1, size1, start2, size2 = 0, 0, 0, 0
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if relevant_line_in_file in line and line[0] != '-':
position = i
absolute_position = start2 + delta - 1
break
if position == -1 and relevant_line_in_file[0] == '+':
no_plus_line = relevant_line_in_file[1:].lstrip()
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if no_plus_line in line and line[0] != '-':
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
position = i
absolute_position = start2 + delta - 1
break
return position, absolute_position
def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
Returns:
str: The clipped string.
"""
if not text:
return text
try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text
except Exception as e:
logging.warning(f"Failed to clip tokens: {e}")
return text
def get_pr_multi_diffs(git_provider: GitProvider, def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler, token_handler: TokenHandler,
model: str, model: str,

View File

@ -21,7 +21,7 @@ class TokenHandler:
method. method.
""" """
def __init__(self, pr, vars: dict, system, user): def __init__(self, vars: dict, system, user):
""" """
Initializes the TokenHandler object. Initializes the TokenHandler object.
@ -32,9 +32,9 @@ class TokenHandler:
- user: The user string. - user: The user string.
""" """
self.encoder = get_token_encoder() self.encoder = get_token_encoder()
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): def _get_system_user_tokens(self, encoder, vars: dict, system, user):
""" """
Calculates the number of tokens in the system and user strings. Calculates the number of tokens in the system and user strings.

View File

@ -5,14 +5,24 @@ import json
import logging import logging
import re import re
import textwrap import textwrap
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, List from enum import Enum
from typing import Any, List, Tuple, Optional
import yaml import yaml
from starlette_context import context from starlette_context import context
from pr_agent.algo.token_handler import get_token_encoder
from pr_agent.config_loader import get_settings, global_settings from pr_agent.config_loader import get_settings, global_settings
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4
def get_setting(key: str) -> Any: def get_setting(key: str) -> Any:
try: try:
key = key.upper() key = key.upper()
@ -294,3 +304,108 @@ def try_fix_yaml(review_text: str) -> dict:
except: except:
pass pass
return data return data
def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
Returns:
str: The clipped string.
"""
if not text:
return text
try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text
except Exception as e:
logging.warning(f"Failed to clip tokens: {e}")
return text
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
old_filename: str = None
language: Optional[str] = None
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str) -> Tuple[int, int]:
"""
Find the line number and absolute position of a relevant line in a file.
Args:
diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files.
relevant_file (str): The name of the file where the relevant line is located.
relevant_line_in_file (str): The content of the relevant line.
Returns:
Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file.
"""
position = -1
absolute_position = -1
re_hunk_header = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
for file in diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
# try to find the line in the patch using difflib, with some margin of error
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
patch_lines, n=3, cutoff=0.93)
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
relevant_line_in_file = matches_difflib[0]
delta = 0
start1, size1, start2, size2 = 0, 0, 0, 0
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if relevant_line_in_file in line and line[0] != '-':
position = i
absolute_position = start2 + delta - 1
break
if position == -1 and relevant_line_in_file[0] == '+':
no_plus_line = relevant_line_in_file[1:].lstrip()
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if no_plus_line in line and line[0] != '-':
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
position = i
absolute_position = start2 + delta - 1
break
return position, absolute_position

View File

@ -13,11 +13,9 @@ try:
except ImportError: except ImportError:
AZURE_DEVOPS_AVAILABLE = False AZURE_DEVOPS_AVAILABLE = False
from ..algo.pr_processing import clip_tokens
from ..config_loader import get_settings from ..config_loader import get_settings
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff, FilePatchInfo, EDIT_TYPE, clip_tokens
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo
class AzureDevopsProvider: class AzureDevopsProvider:

View File

@ -8,7 +8,8 @@ from atlassian.bitbucket import Cloud
from starlette_context import context from starlette_context import context
from ..config_loader import get_settings from ..config_loader import get_settings
from .git_provider import FilePatchInfo, GitProvider from .git_provider import GitProvider
from ..algo.utils import FilePatchInfo
class BitbucketProvider(GitProvider): class BitbucketProvider(GitProvider):

View File

@ -1,6 +1,9 @@
import boto3 try: # Allow this module to be imported without requiring boto3
import botocore import boto3
import botocore
except ModuleNotFoundError:
boto3 = None
botocore = None
class CodeCommitDifferencesResponse: class CodeCommitDifferencesResponse:
""" """

View File

@ -4,13 +4,12 @@ from collections import Counter
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from ..algo.language_handler import is_valid_file, language_extension_map
from ..algo.pr_processing import clip_tokens
from ..algo.utils import load_large_diff
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider, IncrementalPR
from pr_agent.git_providers.codecommit_client import CodeCommitClient from pr_agent.git_providers.codecommit_client import CodeCommitClient
from ..algo.language_handler import is_valid_file, language_extension_map
from ..algo.utils import EDIT_TYPE, FilePatchInfo, load_large_diff
from .git_provider import GitProvider
class PullRequestCCMimic: class PullRequestCCMimic:
""" """

View File

@ -1,28 +1,10 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED) # enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from enum import Enum
from typing import Optional from typing import Optional
from pr_agent.algo.utils import FilePatchInfo
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
old_filename: str = None
class GitProvider(ABC): class GitProvider(ABC):
@ -88,7 +70,7 @@ class GitProvider(ABC):
def get_pr_description(self) -> str: def get_pr_description(self) -> str:
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.algo.pr_processing import clip_tokens from pr_agent.algo.utils import clip_tokens
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
description = self.get_pr_description_full() description = self.get_pr_description_full()
if max_tokens: if max_tokens:

View File

@ -9,10 +9,9 @@ from github import AppAuthentication, Auth, Github, GithubException, Reaction
from retry import retry from retry import retry
from starlette_context import context from starlette_context import context
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from .git_provider import GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file, FilePatchInfo
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens
from ..config_loader import get_settings from ..config_loader import get_settings
from ..servers.utils import RateLimitExceeded from ..servers.utils import RateLimitExceeded

View File

@ -7,10 +7,9 @@ import gitlab
from gitlab import GitlabGetError from gitlab import GitlabGetError
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.pr_processing import clip_tokens from ..algo.utils import load_large_diff, clip_tokens, EDIT_TYPE, FilePatchInfo
from ..algo.utils import load_large_diff
from ..config_loader import get_settings from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from .git_provider import GitProvider
logger = logging.getLogger() logger = logging.getLogger()

View File

@ -0,0 +1,79 @@
import itertools
from collections import Counter
from typing import List, Optional
from pr_agent.algo.utils import FilePatchInfo
from pr_agent.git_providers.git_provider import GitProvider
class InMemoryProvider(GitProvider):
def __init__(self, head_branch: str, target_branch: str, files: List[FilePatchInfo]):
self.head_branch = head_branch
self.target_branch = target_branch
self.files = files
def is_supported(self, capability: str) -> bool:
pass
def get_files(self) -> list[FilePatchInfo]:
return self.files
def get_diff_files(self) -> list[FilePatchInfo]:
return self.get_files()
def publish_description(self, pr_title: str, pr_body: str):
pass
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass
def publish_inline_comments(self, comments: list[dict]):
pass
def publish_code_suggestions(self, code_suggestions: list) -> bool:
pass
def publish_labels(self, labels):
pass
def get_labels(self):
pass
def remove_initial_comment(self):
pass
def get_languages(self):
language_count = Counter(file.language for file in self.files)
return dict(language_count)
def get_pr_branch(self):
pass
def get_user_id(self):
pass
def get_pr_description_full(self) -> str:
pass
def get_issue_comments(self):
pass
def get_repo_settings(self):
pass
def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]:
pass
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
pass
def get_commit_messages(self):
pass

View File

@ -6,7 +6,8 @@ from typing import List
from git import Repo from git import Repo
from pr_agent.config_loader import _find_repository_root, get_settings from pr_agent.config_loader import _find_repository_root, get_settings
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider from pr_agent.git_providers.git_provider import GitProvider
from pr_agent.algo.utils import EDIT_TYPE, FilePatchInfo
class PullRequestMimic: class PullRequestMimic:

View File

@ -42,9 +42,7 @@ class PRCodeSuggestions:
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions, "extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.vars, get_settings().pr_code_suggestions_prompt.system,
self.vars,
get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user) get_settings().pr_code_suggestions_prompt.user)
async def run(self): async def run(self):

View File

@ -46,12 +46,8 @@ class PRDescription:
self.user_description = self.git_provider.get_user_description() self.user_description = self.git_provider.get_user_description()
# Initialize the token handler # Initialize the token handler
self.token_handler = TokenHandler( self.token_handler = TokenHandler(self.vars, get_settings().pr_description_prompt.system,
self.git_provider.pr, get_settings().pr_description_prompt.user)
self.vars,
get_settings().pr_description_prompt.system,
get_settings().pr_description_prompt.user,
)
# Initialize patches_diff and prediction attributes # Initialize patches_diff and prediction attributes
self.patches_diff = None self.patches_diff = None

View File

@ -26,9 +26,7 @@ class PRInformationFromUser:
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.vars, get_settings().pr_information_from_user_prompt.system,
self.vars,
get_settings().pr_information_from_user_prompt.system,
get_settings().pr_information_from_user_prompt.user) get_settings().pr_information_from_user_prompt.user)
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None

View File

@ -29,9 +29,7 @@ class PRQuestions:
"questions": self.question_str, "questions": self.question_str,
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.vars, get_settings().pr_questions_prompt.system,
self.vars,
get_settings().pr_questions_prompt.system,
get_settings().pr_questions_prompt.user) get_settings().pr_questions_prompt.user)
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None

View File

@ -9,8 +9,7 @@ from jinja2 import Environment, StrictUndefined
from yaml import SafeLoader from yaml import SafeLoader
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \ from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
find_line_number_of_relevant_line_in_file, clip_tokens
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json, try_fix_yaml, load_yaml from pr_agent.algo.utils import convert_to_markdown, try_fix_json, try_fix_yaml, load_yaml
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -66,12 +65,8 @@ class PRReviewer:
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
self.token_handler = TokenHandler( self.token_handler = TokenHandler(self.vars, get_settings().pr_review_prompt.system,
self.git_provider.pr, get_settings().pr_review_prompt.user)
self.vars,
get_settings().pr_review_prompt.system,
get_settings().pr_review_prompt.user
)
def parse_args(self, args: List[str]) -> None: def parse_args(self, args: List[str]) -> None:
""" """
@ -217,8 +212,8 @@ class PRReviewer:
markdown_text = convert_to_markdown(data) markdown_text = convert_to_markdown(data)
user = self.git_provider.get_user_id() user = self.git_provider.get_user_id()
# Add help text if not in CLI mode # Add help text if not in CLI§ mode
if not get_settings().get("CONFIG.CLI_MODE", False): if not get_settings().get("CONFIG.CLI§_MODE", False):
markdown_text += "\n### How to use\n" markdown_text += "\n### How to use\n"
if user and '[bot]' not in user: if user and '[bot]' not in user:
markdown_text += bot_help_text(user) markdown_text += bot_help_text(user)

View File

@ -40,9 +40,7 @@ class PRUpdateChangelog:
"extra_instructions": get_settings().pr_update_changelog.extra_instructions, "extra_instructions": get_settings().pr_update_changelog.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(self.vars, get_settings().pr_update_changelog_prompt.system,
self.vars,
get_settings().pr_update_changelog_prompt.system,
get_settings().pr_update_changelog_prompt.user) get_settings().pr_update_changelog_prompt.user)
async def run(self): async def run(self):

View File

@ -35,12 +35,12 @@ dependencies = {file = ["requirements.txt"]}
"Homepage" = "https://github.com/Codium-ai/pr-agent" "Homepage" = "https://github.com/Codium-ai/pr-agent"
[tool.setuptools] [tool.setuptools]
include-package-data = false include-package-data = true
license-files = ["LICENSE"] license-files = ["LICENSE"]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["."] where = ["."]
include = ["pr_agent"] include = ["pr_agent", "pr_agent.*"]
[project.scripts] [project.scripts]
pr-agent = "pr_agent.cli:run" pr-agent = "pr_agent.cli:run"

View File

@ -1,19 +1,19 @@
dynaconf==3.1.12 dynaconf~=3.1.12
fastapi==0.99.0 fastapi~=0.103.0
PyGithub==1.59.* PyGithub~=1.59.0
retry==0.9.2 retry~=0.9.2
openai==0.27.8 openai~=0.27.8
Jinja2==3.1.2 Jinja2~=3.1.2
tiktoken==0.4.0 tiktoken~=0.4.0
uvicorn==0.22.0 uvicorn~=0.22.0
python-gitlab==3.15.0 python-gitlab~=3.15.0
pytest~=7.4.0 pytest~=7.4.0
aiohttp~=3.8.4 aiohttp~=3.8.4
atlassian-python-api==3.39.0 atlassian-python-api~=3.39.0
GitPython~=3.1.32 GitPython~=3.1.32
PyYAML==6.0 PyYAML~=6.0
starlette-context==0.3.6 starlette-context~=0.3.6
litellm~=0.1.445 litellm~=0.1.445
boto3~=1.28.25 boto3~=1.28.25
google-cloud-storage==2.10.0 google-cloud-storage~=2.10.0
ujson==5.8.0 ujson~=5.8.0

View File

@ -1,7 +1,7 @@
import pytest import pytest
from pr_agent.git_providers.codecommit_provider import CodeCommitFile from pr_agent.git_providers.codecommit_provider import CodeCommitFile
from pr_agent.git_providers.codecommit_provider import CodeCommitProvider from pr_agent.git_providers.codecommit_provider import CodeCommitProvider
from pr_agent.git_providers.git_provider import EDIT_TYPE from pr_agent.algo.utils import EDIT_TYPE
class TestCodeCommitFile: class TestCodeCommitFile:

View File

@ -1,8 +1,6 @@
# Generated by CodiumAI # Generated by CodiumAI
from pr_agent.git_providers.git_provider import FilePatchInfo from pr_agent.algo.utils import FilePatchInfo, find_line_number_of_relevant_line_in_file
from pr_agent.algo.pr_processing import find_line_number_of_relevant_line_in_file
import pytest import pytest