Format files by pre-commit run -a

Signed-off-by: Yu Ishikawa <yu-iskw@users.noreply.github.com>
This commit is contained in:
Yu Ishikawa
2024-10-30 09:56:03 +09:00
parent a3d572fb69
commit 81dea65856
122 changed files with 428 additions and 396 deletions

View File

@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler to be used by the PR Agents.
This class defines the interface for an AI handler to be used by the PR Agents.
"""
@abstractmethod
@ -23,6 +23,6 @@ class BaseAiHandler(ABC):
model (str): the name of the model to use for the chat completion
system (str): the system message string to use for the chat completion
user (str): the user message string to use for the chat completion
temperature (float): the temperature to use for the chat completion
temperature (float): the temperature to use for the chat completion
"""
pass

View File

@ -1,17 +1,18 @@
try:
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI, ChatOpenAI
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
pass
import functools
from openai import APIError, RateLimitError, Timeout
from retry import retry
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
from openai import APIError, RateLimitError, Timeout
from retry import retry
import functools
OPENAI_RETRIES = 5
@ -73,4 +74,3 @@ class LangChainOpenAIHandler(BaseAiHandler):
raise ValueError(f"OpenAI {e.name} is required") from e
else:
raise e

View File

@ -1,7 +1,8 @@
import os
import requests
import litellm
import openai
import requests
from litellm import acompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt

View File

@ -1,8 +1,8 @@
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger

View File

@ -3,8 +3,8 @@ from __future__ import annotations
import re
import traceback
from pr_agent.config_loader import get_settings
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
@ -388,4 +388,4 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
if not line.startswith('-'): # currently we don't support /ask line for deleted lines
selected_lines_num += 1
return patch_with_lines_str.rstrip(), selected_lines.rstrip()
return patch_with_lines_str.rstrip(), selected_lines.rstrip()

View File

@ -4,8 +4,6 @@ from typing import Dict
from pr_agent.config_loader import get_settings
def filter_bad_extensions(files):
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
bad_extensions = get_settings().bad_extensions.default

View File

@ -5,14 +5,15 @@ from typing import Callable, List, Tuple
from github import RateLimitExceededException
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.git_patch_processing import (
convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions)
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens, clip_tokens, ModelType
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from pr_agent.log import get_logger
DELETED_FILES_ = "Deleted files:\n"

View File

@ -1,8 +1,9 @@
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings
from threading import Lock
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
@ -85,4 +86,4 @@ class TokenHandler:
Returns:
The number of tokens in the patch string.
"""
return len(self.encoder.encode(patch, disallowed_special=()))
return len(self.encoder.encode(patch, disallowed_special=()))

View File

@ -14,7 +14,6 @@ from datetime import datetime
from enum import Enum
from typing import Any, List, Tuple
import html2text
import requests
import yaml
@ -23,10 +22,11 @@ from starlette_context import context
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.token_handler import TokenEncoder
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.algo.types import FilePatchInfo
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger
class Range(BaseModel):
line_start: int # should be 0-indexed
line_end: int