mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-06 05:40:38 +08:00
Code adjustment to support calling is library
This commit is contained in:
@ -5,14 +5,24 @@ import json
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
from enum import Enum
|
||||
from typing import Any, List, Tuple, Optional
|
||||
|
||||
import yaml
|
||||
from starlette_context import context
|
||||
|
||||
from pr_agent.algo.token_handler import get_token_encoder
|
||||
from pr_agent.config_loader import get_settings, global_settings
|
||||
|
||||
|
||||
class EDIT_TYPE(Enum):
|
||||
ADDED = 1
|
||||
DELETED = 2
|
||||
MODIFIED = 3
|
||||
RENAMED = 4
|
||||
|
||||
def get_setting(key: str) -> Any:
|
||||
try:
|
||||
key = key.upper()
|
||||
@ -294,3 +304,108 @@ def try_fix_yaml(review_text: str) -> dict:
|
||||
except:
|
||||
pass
|
||||
return data
|
||||
|
||||
def clip_tokens(text: str, max_tokens: int) -> str:
|
||||
"""
|
||||
Clip the number of tokens in a string to a maximum number of tokens.
|
||||
|
||||
Args:
|
||||
text (str): The string to clip.
|
||||
max_tokens (int): The maximum number of tokens allowed in the string.
|
||||
|
||||
Returns:
|
||||
str: The clipped string.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
try:
|
||||
encoder = get_token_encoder()
|
||||
num_input_tokens = len(encoder.encode(text))
|
||||
if num_input_tokens <= max_tokens:
|
||||
return text
|
||||
num_chars = len(text)
|
||||
chars_per_token = num_chars / num_input_tokens
|
||||
num_output_chars = int(chars_per_token * max_tokens)
|
||||
clipped_text = text[:num_output_chars]
|
||||
return clipped_text
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to clip tokens: {e}")
|
||||
return text
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePatchInfo:
|
||||
base_file: str
|
||||
head_file: str
|
||||
patch: str
|
||||
filename: str
|
||||
tokens: int = -1
|
||||
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
|
||||
old_filename: str = None
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str) -> Tuple[int, int]:
|
||||
"""
|
||||
Find the line number and absolute position of a relevant line in a file.
|
||||
|
||||
Args:
|
||||
diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files.
|
||||
relevant_file (str): The name of the file where the relevant line is located.
|
||||
relevant_line_in_file (str): The content of the relevant line.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file.
|
||||
"""
|
||||
position = -1
|
||||
absolute_position = -1
|
||||
re_hunk_header = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
|
||||
for file in diff_files:
|
||||
if file.filename.strip() == relevant_file:
|
||||
patch = file.patch
|
||||
patch_lines = patch.splitlines()
|
||||
|
||||
# try to find the line in the patch using difflib, with some margin of error
|
||||
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
|
||||
patch_lines, n=3, cutoff=0.93)
|
||||
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
||||
relevant_line_in_file = matches_difflib[0]
|
||||
|
||||
delta = 0
|
||||
start1, size1, start2, size2 = 0, 0, 0, 0
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
match = re_hunk_header.match(line)
|
||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||
elif not line.startswith('-'):
|
||||
delta += 1
|
||||
|
||||
if relevant_line_in_file in line and line[0] != '-':
|
||||
position = i
|
||||
absolute_position = start2 + delta - 1
|
||||
break
|
||||
|
||||
if position == -1 and relevant_line_in_file[0] == '+':
|
||||
no_plus_line = relevant_line_in_file[1:].lstrip()
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
match = re_hunk_header.match(line)
|
||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||
elif not line.startswith('-'):
|
||||
delta += 1
|
||||
|
||||
if no_plus_line in line and line[0] != '-':
|
||||
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
|
||||
# it's a context line
|
||||
position = i
|
||||
absolute_position = start2 + delta - 1
|
||||
break
|
||||
return position, absolute_position
|
||||
|
||||
|
Reference in New Issue
Block a user