mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-06 22:00:40 +08:00
Code adjustment to support calling is library
This commit is contained in:
@ -1,19 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import difflib
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Callable, List, Tuple
|
from typing import Callable, List, Tuple
|
||||||
|
|
||||||
from github import RateLimitExceededException
|
from github import RateLimitExceededException
|
||||||
|
|
||||||
from pr_agent.algo import MAX_TOKENS
|
from pr_agent.algo import MAX_TOKENS
|
||||||
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
|
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
|
||||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||||
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
|
from pr_agent.git_providers.git_provider import GitProvider
|
||||||
|
|
||||||
DELETED_FILES_ = "Deleted files:\n"
|
DELETED_FILES_ = "Deleted files:\n"
|
||||||
|
|
||||||
@ -247,99 +245,6 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
|
|||||||
return all_deployments
|
return all_deployments
|
||||||
|
|
||||||
|
|
||||||
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
|
||||||
relevant_file: str,
|
|
||||||
relevant_line_in_file: str) -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Find the line number and absolute position of a relevant line in a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files.
|
|
||||||
relevant_file (str): The name of the file where the relevant line is located.
|
|
||||||
relevant_line_in_file (str): The content of the relevant line.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file.
|
|
||||||
"""
|
|
||||||
position = -1
|
|
||||||
absolute_position = -1
|
|
||||||
re_hunk_header = re.compile(
|
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
|
||||||
|
|
||||||
for file in diff_files:
|
|
||||||
if file.filename.strip() == relevant_file:
|
|
||||||
patch = file.patch
|
|
||||||
patch_lines = patch.splitlines()
|
|
||||||
|
|
||||||
# try to find the line in the patch using difflib, with some margin of error
|
|
||||||
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
|
|
||||||
patch_lines, n=3, cutoff=0.93)
|
|
||||||
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
|
||||||
relevant_line_in_file = matches_difflib[0]
|
|
||||||
|
|
||||||
delta = 0
|
|
||||||
start1, size1, start2, size2 = 0, 0, 0, 0
|
|
||||||
for i, line in enumerate(patch_lines):
|
|
||||||
if line.startswith('@@'):
|
|
||||||
delta = 0
|
|
||||||
match = re_hunk_header.match(line)
|
|
||||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
|
||||||
elif not line.startswith('-'):
|
|
||||||
delta += 1
|
|
||||||
|
|
||||||
if relevant_line_in_file in line and line[0] != '-':
|
|
||||||
position = i
|
|
||||||
absolute_position = start2 + delta - 1
|
|
||||||
break
|
|
||||||
|
|
||||||
if position == -1 and relevant_line_in_file[0] == '+':
|
|
||||||
no_plus_line = relevant_line_in_file[1:].lstrip()
|
|
||||||
for i, line in enumerate(patch_lines):
|
|
||||||
if line.startswith('@@'):
|
|
||||||
delta = 0
|
|
||||||
match = re_hunk_header.match(line)
|
|
||||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
|
||||||
elif not line.startswith('-'):
|
|
||||||
delta += 1
|
|
||||||
|
|
||||||
if no_plus_line in line and line[0] != '-':
|
|
||||||
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
|
|
||||||
# it's a context line
|
|
||||||
position = i
|
|
||||||
absolute_position = start2 + delta - 1
|
|
||||||
break
|
|
||||||
return position, absolute_position
|
|
||||||
|
|
||||||
|
|
||||||
def clip_tokens(text: str, max_tokens: int) -> str:
|
|
||||||
"""
|
|
||||||
Clip the number of tokens in a string to a maximum number of tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): The string to clip.
|
|
||||||
max_tokens (int): The maximum number of tokens allowed in the string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The clipped string.
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
|
|
||||||
try:
|
|
||||||
encoder = get_token_encoder()
|
|
||||||
num_input_tokens = len(encoder.encode(text))
|
|
||||||
if num_input_tokens <= max_tokens:
|
|
||||||
return text
|
|
||||||
num_chars = len(text)
|
|
||||||
chars_per_token = num_chars / num_input_tokens
|
|
||||||
num_output_chars = int(chars_per_token * max_tokens)
|
|
||||||
clipped_text = text[:num_output_chars]
|
|
||||||
return clipped_text
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to clip tokens: {e}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def get_pr_multi_diffs(git_provider: GitProvider,
|
def get_pr_multi_diffs(git_provider: GitProvider,
|
||||||
token_handler: TokenHandler,
|
token_handler: TokenHandler,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -21,7 +21,7 @@ class TokenHandler:
|
|||||||
method.
|
method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pr, vars: dict, system, user):
|
def __init__(self, vars: dict, system, user):
|
||||||
"""
|
"""
|
||||||
Initializes the TokenHandler object.
|
Initializes the TokenHandler object.
|
||||||
|
|
||||||
@ -32,9 +32,9 @@ class TokenHandler:
|
|||||||
- user: The user string.
|
- user: The user string.
|
||||||
"""
|
"""
|
||||||
self.encoder = get_token_encoder()
|
self.encoder = get_token_encoder()
|
||||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
self.prompt_tokens = self._get_system_user_tokens(self.encoder, vars, system, user)
|
||||||
|
|
||||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
def _get_system_user_tokens(self, encoder, vars: dict, system, user):
|
||||||
"""
|
"""
|
||||||
Calculates the number of tokens in the system and user strings.
|
Calculates the number of tokens in the system and user strings.
|
||||||
|
|
||||||
|
@ -5,14 +5,24 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List
|
from enum import Enum
|
||||||
|
from typing import Any, List, Tuple, Optional
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from starlette_context import context
|
from starlette_context import context
|
||||||
|
|
||||||
|
from pr_agent.algo.token_handler import get_token_encoder
|
||||||
from pr_agent.config_loader import get_settings, global_settings
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
|
|
||||||
|
|
||||||
|
class EDIT_TYPE(Enum):
|
||||||
|
ADDED = 1
|
||||||
|
DELETED = 2
|
||||||
|
MODIFIED = 3
|
||||||
|
RENAMED = 4
|
||||||
|
|
||||||
def get_setting(key: str) -> Any:
|
def get_setting(key: str) -> Any:
|
||||||
try:
|
try:
|
||||||
key = key.upper()
|
key = key.upper()
|
||||||
@ -294,3 +304,108 @@ def try_fix_yaml(review_text: str) -> dict:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def clip_tokens(text: str, max_tokens: int) -> str:
|
||||||
|
"""
|
||||||
|
Clip the number of tokens in a string to a maximum number of tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The string to clip.
|
||||||
|
max_tokens (int): The maximum number of tokens allowed in the string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The clipped string.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoder = get_token_encoder()
|
||||||
|
num_input_tokens = len(encoder.encode(text))
|
||||||
|
if num_input_tokens <= max_tokens:
|
||||||
|
return text
|
||||||
|
num_chars = len(text)
|
||||||
|
chars_per_token = num_chars / num_input_tokens
|
||||||
|
num_output_chars = int(chars_per_token * max_tokens)
|
||||||
|
clipped_text = text[:num_output_chars]
|
||||||
|
return clipped_text
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to clip tokens: {e}")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilePatchInfo:
|
||||||
|
base_file: str
|
||||||
|
head_file: str
|
||||||
|
patch: str
|
||||||
|
filename: str
|
||||||
|
tokens: int = -1
|
||||||
|
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
|
||||||
|
old_filename: str = None
|
||||||
|
language: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Find the line number and absolute position of a relevant line in a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files.
|
||||||
|
relevant_file (str): The name of the file where the relevant line is located.
|
||||||
|
relevant_line_in_file (str): The content of the relevant line.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file.
|
||||||
|
"""
|
||||||
|
position = -1
|
||||||
|
absolute_position = -1
|
||||||
|
re_hunk_header = re.compile(
|
||||||
|
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||||
|
|
||||||
|
for file in diff_files:
|
||||||
|
if file.filename.strip() == relevant_file:
|
||||||
|
patch = file.patch
|
||||||
|
patch_lines = patch.splitlines()
|
||||||
|
|
||||||
|
# try to find the line in the patch using difflib, with some margin of error
|
||||||
|
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
|
||||||
|
patch_lines, n=3, cutoff=0.93)
|
||||||
|
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
||||||
|
relevant_line_in_file = matches_difflib[0]
|
||||||
|
|
||||||
|
delta = 0
|
||||||
|
start1, size1, start2, size2 = 0, 0, 0, 0
|
||||||
|
for i, line in enumerate(patch_lines):
|
||||||
|
if line.startswith('@@'):
|
||||||
|
delta = 0
|
||||||
|
match = re_hunk_header.match(line)
|
||||||
|
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||||
|
elif not line.startswith('-'):
|
||||||
|
delta += 1
|
||||||
|
|
||||||
|
if relevant_line_in_file in line and line[0] != '-':
|
||||||
|
position = i
|
||||||
|
absolute_position = start2 + delta - 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if position == -1 and relevant_line_in_file[0] == '+':
|
||||||
|
no_plus_line = relevant_line_in_file[1:].lstrip()
|
||||||
|
for i, line in enumerate(patch_lines):
|
||||||
|
if line.startswith('@@'):
|
||||||
|
delta = 0
|
||||||
|
match = re_hunk_header.match(line)
|
||||||
|
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||||
|
elif not line.startswith('-'):
|
||||||
|
delta += 1
|
||||||
|
|
||||||
|
if no_plus_line in line and line[0] != '-':
|
||||||
|
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
|
||||||
|
# it's a context line
|
||||||
|
position = i
|
||||||
|
absolute_position = start2 + delta - 1
|
||||||
|
break
|
||||||
|
return position, absolute_position
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@ from atlassian.bitbucket import Cloud
|
|||||||
from starlette_context import context
|
from starlette_context import context
|
||||||
|
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from .git_provider import FilePatchInfo, GitProvider
|
from .git_provider import GitProvider
|
||||||
|
from ..algo.utils import FilePatchInfo
|
||||||
|
|
||||||
|
|
||||||
class BitbucketProvider(GitProvider):
|
class BitbucketProvider(GitProvider):
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import boto3
|
try: # Allow this module to be imported without requiring boto3
|
||||||
import botocore
|
import boto3
|
||||||
|
import botocore
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
boto3 = None
|
||||||
|
botocore = None
|
||||||
|
|
||||||
class CodeCommitDifferencesResponse:
|
class CodeCommitDifferencesResponse:
|
||||||
"""
|
"""
|
||||||
|
@ -4,13 +4,12 @@ from collections import Counter
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from ..algo.language_handler import is_valid_file, language_extension_map
|
|
||||||
from ..algo.pr_processing import clip_tokens
|
|
||||||
from ..algo.utils import load_large_diff
|
|
||||||
from ..config_loader import get_settings
|
|
||||||
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider, IncrementalPR
|
|
||||||
from pr_agent.git_providers.codecommit_client import CodeCommitClient
|
from pr_agent.git_providers.codecommit_client import CodeCommitClient
|
||||||
|
|
||||||
|
from ..algo.language_handler import is_valid_file, language_extension_map
|
||||||
|
from ..algo.utils import EDIT_TYPE, FilePatchInfo, load_large_diff
|
||||||
|
from .git_provider import GitProvider
|
||||||
|
|
||||||
|
|
||||||
class PullRequestCCMimic:
|
class PullRequestCCMimic:
|
||||||
"""
|
"""
|
||||||
|
@ -1,27 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from pr_agent.algo.utils import FilePatchInfo
|
||||||
class EDIT_TYPE(Enum):
|
|
||||||
ADDED = 1
|
|
||||||
DELETED = 2
|
|
||||||
MODIFIED = 3
|
|
||||||
RENAMED = 4
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FilePatchInfo:
|
|
||||||
base_file: str
|
|
||||||
head_file: str
|
|
||||||
patch: str
|
|
||||||
filename: str
|
|
||||||
tokens: int = -1
|
|
||||||
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
|
|
||||||
old_filename: str = None
|
|
||||||
|
|
||||||
|
|
||||||
class GitProvider(ABC):
|
class GitProvider(ABC):
|
||||||
@ -87,7 +69,7 @@ class GitProvider(ABC):
|
|||||||
|
|
||||||
def get_pr_description(self) -> str:
|
def get_pr_description(self) -> str:
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.algo.pr_processing import clip_tokens
|
from pr_agent.algo.utils import clip_tokens
|
||||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
||||||
description = self.get_pr_description_full()
|
description = self.get_pr_description_full()
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
|
@ -9,10 +9,9 @@ from github import AppAuthentication, Auth, Github, GithubException, Reaction
|
|||||||
from retry import retry
|
from retry import retry
|
||||||
from starlette_context import context
|
from starlette_context import context
|
||||||
|
|
||||||
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
|
from .git_provider import GitProvider, IncrementalPR
|
||||||
from ..algo.language_handler import is_valid_file
|
from ..algo.language_handler import is_valid_file
|
||||||
from ..algo.utils import load_large_diff
|
from ..algo.utils import load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file, FilePatchInfo
|
||||||
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file, clip_tokens
|
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from ..servers.utils import RateLimitExceeded
|
from ..servers.utils import RateLimitExceeded
|
||||||
|
|
||||||
|
@ -7,10 +7,9 @@ import gitlab
|
|||||||
from gitlab import GitlabGetError
|
from gitlab import GitlabGetError
|
||||||
|
|
||||||
from ..algo.language_handler import is_valid_file
|
from ..algo.language_handler import is_valid_file
|
||||||
from ..algo.pr_processing import clip_tokens
|
from ..algo.utils import load_large_diff, clip_tokens, EDIT_TYPE, FilePatchInfo
|
||||||
from ..algo.utils import load_large_diff
|
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
from .git_provider import GitProvider
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ from typing import List
|
|||||||
from git import Repo
|
from git import Repo
|
||||||
|
|
||||||
from pr_agent.config_loader import _find_repository_root, get_settings
|
from pr_agent.config_loader import _find_repository_root, get_settings
|
||||||
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
from pr_agent.git_providers.git_provider import GitProvider
|
||||||
|
from pr_agent.algo.utils import EDIT_TYPE, FilePatchInfo
|
||||||
|
|
||||||
|
|
||||||
class PullRequestMimic:
|
class PullRequestMimic:
|
||||||
|
@ -45,9 +45,7 @@ class PRCodeSuggestions:
|
|||||||
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
|
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(self.vars, get_settings().pr_code_suggestions_prompt.system,
|
||||||
self.vars,
|
|
||||||
get_settings().pr_code_suggestions_prompt.system,
|
|
||||||
get_settings().pr_code_suggestions_prompt.user)
|
get_settings().pr_code_suggestions_prompt.user)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
@ -46,12 +46,8 @@ class PRDescription:
|
|||||||
self.user_description = self.git_provider.get_user_description()
|
self.user_description = self.git_provider.get_user_description()
|
||||||
|
|
||||||
# Initialize the token handler
|
# Initialize the token handler
|
||||||
self.token_handler = TokenHandler(
|
self.token_handler = TokenHandler(self.vars, get_settings().pr_description_prompt.system,
|
||||||
self.git_provider.pr,
|
get_settings().pr_description_prompt.user)
|
||||||
self.vars,
|
|
||||||
get_settings().pr_description_prompt.system,
|
|
||||||
get_settings().pr_description_prompt.user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize patches_diff and prediction attributes
|
# Initialize patches_diff and prediction attributes
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
|
@ -26,9 +26,7 @@ class PRInformationFromUser:
|
|||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(self.vars, get_settings().pr_information_from_user_prompt.system,
|
||||||
self.vars,
|
|
||||||
get_settings().pr_information_from_user_prompt.system,
|
|
||||||
get_settings().pr_information_from_user_prompt.user)
|
get_settings().pr_information_from_user_prompt.user)
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
@ -29,9 +29,7 @@ class PRQuestions:
|
|||||||
"questions": self.question_str,
|
"questions": self.question_str,
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(self.vars, get_settings().pr_questions_prompt.system,
|
||||||
self.vars,
|
|
||||||
get_settings().pr_questions_prompt.system,
|
|
||||||
get_settings().pr_questions_prompt.user)
|
get_settings().pr_questions_prompt.user)
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
@ -9,8 +9,7 @@ from jinja2 import Environment, StrictUndefined
|
|||||||
from yaml import SafeLoader
|
from yaml import SafeLoader
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, \
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
find_line_number_of_relevant_line_in_file, clip_tokens
|
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import convert_to_markdown, try_fix_json, try_fix_yaml, load_yaml
|
from pr_agent.algo.utils import convert_to_markdown, try_fix_json, try_fix_yaml, load_yaml
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
@ -66,12 +65,8 @@ class PRReviewer:
|
|||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.token_handler = TokenHandler(
|
self.token_handler = TokenHandler(self.vars, get_settings().pr_review_prompt.system,
|
||||||
self.git_provider.pr,
|
get_settings().pr_review_prompt.user)
|
||||||
self.vars,
|
|
||||||
get_settings().pr_review_prompt.system,
|
|
||||||
get_settings().pr_review_prompt.user
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_args(self, args: List[str]) -> None:
|
def parse_args(self, args: List[str]) -> None:
|
||||||
"""
|
"""
|
||||||
@ -217,8 +212,8 @@ class PRReviewer:
|
|||||||
markdown_text = convert_to_markdown(data)
|
markdown_text = convert_to_markdown(data)
|
||||||
user = self.git_provider.get_user_id()
|
user = self.git_provider.get_user_id()
|
||||||
|
|
||||||
# Add help text if not in CLI mode
|
# Add help text if not in CLI§ mode
|
||||||
if not get_settings().get("CONFIG.CLI_MODE", False):
|
if not get_settings().get("CONFIG.CLI§_MODE", False):
|
||||||
markdown_text += "\n### How to use\n"
|
markdown_text += "\n### How to use\n"
|
||||||
if user and '[bot]' not in user:
|
if user and '[bot]' not in user:
|
||||||
markdown_text += bot_help_text(user)
|
markdown_text += bot_help_text(user)
|
||||||
|
@ -40,9 +40,7 @@ class PRUpdateChangelog:
|
|||||||
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
|
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(self.vars, get_settings().pr_update_changelog_prompt.system,
|
||||||
self.vars,
|
|
||||||
get_settings().pr_update_changelog_prompt.system,
|
|
||||||
get_settings().pr_update_changelog_prompt.user)
|
get_settings().pr_update_changelog_prompt.user)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pr_agent.git_providers.codecommit_provider import CodeCommitFile
|
from pr_agent.git_providers.codecommit_provider import CodeCommitFile
|
||||||
from pr_agent.git_providers.codecommit_provider import CodeCommitProvider
|
from pr_agent.git_providers.codecommit_provider import CodeCommitProvider
|
||||||
from pr_agent.git_providers.git_provider import EDIT_TYPE
|
from pr_agent.algo.utils import EDIT_TYPE
|
||||||
|
|
||||||
|
|
||||||
class TestCodeCommitFile:
|
class TestCodeCommitFile:
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
|
|
||||||
# Generated by CodiumAI
|
# Generated by CodiumAI
|
||||||
from pr_agent.git_providers.git_provider import FilePatchInfo
|
from pr_agent.algo.utils import FilePatchInfo, find_line_number_of_relevant_line_in_file
|
||||||
from pr_agent.algo.pr_processing import find_line_number_of_relevant_line_in_file
|
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user