mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 20:30:41 +08:00
Separate output token threshold to soft and hard instead of implicit hard = soft/2
This commit is contained in:
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import difflib
|
import difflib
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Tuple, Union
|
||||||
|
|
||||||
from pr_agent.algo.git_patch_processing import extend_patch, handle_patch_deletions
|
from pr_agent.algo.git_patch_processing import extend_patch, handle_patch_deletions
|
||||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||||
@ -14,7 +14,8 @@ DELETED_FILES_ = "Deleted files:\n"
|
|||||||
|
|
||||||
MORE_MODIFIED_FILES_ = "More modified files:\n"
|
MORE_MODIFIED_FILES_ = "More modified files:\n"
|
||||||
|
|
||||||
OUTPUT_BUFFER_TOKENS = 800
|
OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000
|
||||||
|
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
|
||||||
PATCH_EXTRA_LINES = 3
|
PATCH_EXTRA_LINES = 3
|
||||||
|
|
||||||
|
|
||||||
@ -32,11 +33,12 @@ def get_pr_diff(git_provider: Union[GithubProvider, Any], token_handler: TokenHa
|
|||||||
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler)
|
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler)
|
||||||
|
|
||||||
# if we are under the limit, return the full diff
|
# if we are under the limit, return the full diff
|
||||||
if total_tokens + OUTPUT_BUFFER_TOKENS < token_handler.limit:
|
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit:
|
||||||
return "\n".join(patches_extended)
|
return "\n".join(patches_extended)
|
||||||
|
|
||||||
# if we are over the limit, start pruning
|
# if we are over the limit, start pruning
|
||||||
patches_compressed, modified_file_names, deleted_file_names = pr_generate_compressed_diff(pr_languages, token_handler)
|
patches_compressed, modified_file_names, deleted_file_names = pr_generate_compressed_diff(pr_languages,
|
||||||
|
token_handler)
|
||||||
final_diff = "\n".join(patches_compressed)
|
final_diff = "\n".join(patches_compressed)
|
||||||
if modified_file_names:
|
if modified_file_names:
|
||||||
modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names)
|
modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names)
|
||||||
@ -115,12 +117,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) ->
|
|||||||
new_patch_tokens = token_handler.count_tokens(patch)
|
new_patch_tokens = token_handler.count_tokens(patch)
|
||||||
|
|
||||||
# Hard Stop, no more tokens
|
# Hard Stop, no more tokens
|
||||||
if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS // 2:
|
if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||||
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
|
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If the patch is too large, just show the file name
|
# If the patch is too large, just show the file name
|
||||||
if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS:
|
if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||||
# Current logic is to skip the patch if it's too large
|
# Current logic is to skip the patch if it's too large
|
||||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||||
# until we meet the requirements
|
# until we meet the requirements
|
||||||
|
@ -38,6 +38,7 @@ async def polling_loop():
|
|||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
await asyncio.sleep(5)
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
"Authorization": f"Bearer {token}"
|
"Authorization": f"Bearer {token}"
|
||||||
@ -86,10 +87,8 @@ async def polling_loop():
|
|||||||
elif response.status != 304:
|
elif response.status != 304:
|
||||||
print(f"Failed to fetch notifications. Status code: {response.status}")
|
print(f"Failed to fetch notifications. Status code: {response.status}")
|
||||||
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Exception during processing of a notification: {e}")
|
logging.error(f"Exception during processing of a notification: {e}")
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
asyncio.run(polling_loop())
|
asyncio.run(polling_loop())
|
||||||
|
Reference in New Issue
Block a user