feat: Refactor AzureDevopsProvider class in azuredevops_provider.py

- Reorder class methods and constructor for better readability
- Add error logging for failed operations
- Implement get_pr_description_full method
- Update get_pr_description method to always return full description
- Modify _parse_pr_url method to return workspace_slug, repo_slug, and pr_number
- Make _get_azure_devops_client a static method
- Add error handling in get_pr_id method
This commit is contained in:
Sagi Medina
2024-01-08 09:15:34 +02:00
parent c8bca487e5
commit b776e5069c

View File

@ -3,27 +3,51 @@ from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from ..log import get_logger from ..log import get_logger
from ..algo.language_handler import is_valid_file
from ..algo.utils import clip_tokens, load_large_diff
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
AZURE_DEVOPS_AVAILABLE = True AZURE_DEVOPS_AVAILABLE = True
try: try:
# noinspection PyUnresolvedReferences
from msrest.authentication import BasicAuthentication from msrest.authentication import BasicAuthentication
# noinspection PyUnresolvedReferences
from azure.devops.connection import Connection from azure.devops.connection import Connection
# noinspection PyUnresolvedReferences
from azure.devops.v7_1.git.models import ( from azure.devops.v7_1.git.models import (
Comment, Comment,
CommentThread, CommentThread,
GitVersionDescriptor, GitVersionDescriptor,
GitPullRequest, GitPullRequest,
) )
except ImportError as e: except ImportError:
AZURE_DEVOPS_AVAILABLE = False AZURE_DEVOPS_AVAILABLE = False
from ..algo.language_handler import is_valid_file
from ..algo.utils import clip_tokens, load_large_diff
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
class AzureDevopsProvider(GitProvider): class AzureDevopsProvider(GitProvider):
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
):
if not AZURE_DEVOPS_AVAILABLE:
raise ImportError(
"Azure DevOps provider is not available. Please install the required dependencies."
)
self.azure_devops_client = self._get_azure_devops_client()
self.workspace_slug = None
self.repo_slug = None
self.repo = None
self.pr_num = None
self.pr = None
self.temp_comments = []
self.incremental = incremental
if pr_url:
self.set_pr(pr_url)
def publish_code_suggestions(self, code_suggestions: list) -> bool: def publish_code_suggestions(self, code_suggestions: list) -> bool:
""" """
Publishes code suggestions as comments on the PR. Publishes code suggestions as comments on the PR.
@ -97,7 +121,7 @@ class AzureDevopsProvider(GitProvider):
return False return False
def get_pr_description_full(self) -> str: def get_pr_description_full(self) -> str:
pass return self.pr.description
def remove_comment(self, comment): def remove_comment(self, comment):
try: try:
@ -135,26 +159,6 @@ class AzureDevopsProvider(GitProvider):
get_logger().exception(f"Failed to get labels, error: {e}") get_logger().exception(f"Failed to get labels, error: {e}")
return [] return []
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
):
if not AZURE_DEVOPS_AVAILABLE:
raise ImportError(
"Azure DevOps provider is not available. Please install the required dependencies."
)
self.azure_devops_client = self._get_azure_devops_client()
self.workspace_slug = None
self.repo_slug = None
self.repo = None
self.pr_num = None
self.pr = None
self.temp_comments = []
self.incremental = incremental
if pr_url:
self.set_pr(pr_url)
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
if capability in [ if capability in [
"get_issue_comments", "get_issue_comments",
@ -180,7 +184,8 @@ class AzureDevopsProvider(GitProvider):
) )
return contents return contents
except Exception as e: except Exception as e:
get_logger().exception("get repo settings error") if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to get repo settings, error: {e}")
return "" return ""
def get_files(self): def get_files(self):
@ -300,7 +305,6 @@ class AzureDevopsProvider(GitProvider):
) )
) )
self.diff_files = diff_files
return diff_files return diff_files
except Exception as e: except Exception as e:
print(f"Error: {str(e)}") print(f"Error: {str(e)}")
@ -394,7 +398,7 @@ class AzureDevopsProvider(GitProvider):
source_branch = pr_info.source_ref_name.split("/")[-1] source_branch = pr_info.source_ref_name.split("/")[-1]
return source_branch return source_branch
def get_pr_description(self, full=False): def get_pr_description(self, *, full: bool = True) -> str:
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens: if max_tokens:
return clip_tokens(self.pr.description, max_tokens) return clip_tokens(self.pr.description, max_tokens)
@ -414,13 +418,8 @@ class AzureDevopsProvider(GitProvider):
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
return True return True
def get_issue_comments(self):
raise NotImplementedError(
"Azure DevOps provider does not support issue comments yet"
)
@staticmethod @staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]: def _parse_pr_url(pr_url: str) -> Tuple[str, str, int]:
parsed_url = urlparse(pr_url) parsed_url = urlparse(pr_url)
path_parts = parsed_url.path.strip("/").split("/") path_parts = parsed_url.path.strip("/").split("/")
@ -439,7 +438,8 @@ class AzureDevopsProvider(GitProvider):
return workspace_slug, repo_slug, pr_number return workspace_slug, repo_slug, pr_number
def _get_azure_devops_client(self): @staticmethod
def _get_azure_devops_client():
try: try:
pat = get_settings().azure_devops.pat pat = get_settings().azure_devops.pat
org = get_settings().azure_devops.org org = get_settings().azure_devops.org
@ -472,5 +472,7 @@ class AzureDevopsProvider(GitProvider):
try: try:
pr_id = f"{self.workspace_slug}/{self.repo_slug}/{self.pr_num}" pr_id = f"{self.workspace_slug}/{self.repo_slug}/{self.pr_num}"
return pr_id return pr_id
except Exception: except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to get pr id, error: {e}")
return "" return ""