Code adjustment to support calling is library

This commit is contained in:
Ori Kotek
2023-08-30 10:29:51 +03:00
parent 56828f0170
commit d51e7ee5ad
18 changed files with 155 additions and 170 deletions

View File

@ -5,14 +5,24 @@ import json
import logging
import re
import textwrap
from dataclasses import dataclass
from datetime import datetime
from typing import Any, List
from enum import Enum
from typing import Any, List, Tuple, Optional
import yaml
from starlette_context import context
from pr_agent.algo.token_handler import get_token_encoder
from pr_agent.config_loader import get_settings, global_settings
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4
def get_setting(key: str) -> Any:
try:
key = key.upper()
@ -294,3 +304,108 @@ def try_fix_yaml(review_text: str) -> dict:
except:
pass
return data
def clip_tokens(text: str, max_tokens: int) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
Args:
text (str): The string to clip.
max_tokens (int): The maximum number of tokens allowed in the string.
Returns:
str: The clipped string.
"""
if not text:
return text
try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
return clipped_text
except Exception as e:
logging.warning(f"Failed to clip tokens: {e}")
return text
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
old_filename: str = None
language: Optional[str] = None
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str) -> Tuple[int, int]:
"""
Find the line number and absolute position of a relevant line in a file.
Args:
diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files.
relevant_file (str): The name of the file where the relevant line is located.
relevant_line_in_file (str): The content of the relevant line.
Returns:
Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file.
"""
position = -1
absolute_position = -1
re_hunk_header = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
for file in diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
# try to find the line in the patch using difflib, with some margin of error
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
patch_lines, n=3, cutoff=0.93)
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
relevant_line_in_file = matches_difflib[0]
delta = 0
start1, size1, start2, size2 = 0, 0, 0, 0
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if relevant_line_in_file in line and line[0] != '-':
position = i
absolute_position = start2 + delta - 1
break
if position == -1 and relevant_line_in_file[0] == '+':
no_plus_line = relevant_line_in_file[1:].lstrip()
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
match = re_hunk_header.match(line)
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif not line.startswith('-'):
delta += 1
if no_plus_line in line and line[0] != '-':
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
position = i
absolute_position = start2 + delta - 1
break
return position, absolute_position