feat: Extend AzureDevopsProvider with GitProvider and add placeholder methods

This commit is contained in:
Ori Kotek
2024-01-07 14:42:05 +02:00
parent faba5a224a
commit 95a8a28071

View File

@ -10,20 +10,44 @@ AZURE_DEVOPS_AVAILABLE = True
try: try:
from msrest.authentication import BasicAuthentication from msrest.authentication import BasicAuthentication
from azure.devops.connection import Connection from azure.devops.connection import Connection
from azure.devops.v7_1.git.models import Comment, CommentThread, GitVersionDescriptor, GitPullRequest from azure.devops.v7_1.git.models import (
Comment,
CommentThread,
GitVersionDescriptor,
GitPullRequest,
)
except ImportError: except ImportError:
AZURE_DEVOPS_AVAILABLE = False AZURE_DEVOPS_AVAILABLE = False
from ..config_loader import get_settings
from ..algo.utils import load_large_diff, clip_tokens
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo from ..algo.utils import clip_tokens, load_large_diff
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
class AzureDevopsProvider: class AzureDevopsProvider(GitProvider):
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): def publish_code_suggestions(self, code_suggestions: list) -> bool:
pass
def get_pr_description_full(self) -> str:
pass
def remove_comment(self, comment):
pass
def publish_labels(self, labels):
pass
def get_pr_labels(self):
pass
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
):
if not AZURE_DEVOPS_AVAILABLE: if not AZURE_DEVOPS_AVAILABLE:
raise ImportError("Azure DevOps provider is not available. Please install the required dependencies.") raise ImportError(
"Azure DevOps provider is not available. Please install the required dependencies."
)
self.azure_devops_client = self._get_azure_devops_client() self.azure_devops_client = self._get_azure_devops_client()
@ -38,8 +62,13 @@ class AzureDevopsProvider:
self.set_pr(pr_url) self.set_pr(pr_url)
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels', if capability in [
'remove_initial_comment', 'gfm_markdown']: "get_issue_comments",
"create_inline_comment",
"publish_inline_comments",
"get_labels",
"remove_initial_comment",
]:
return False return False
return True return True
@ -49,10 +78,14 @@ class AzureDevopsProvider:
def get_repo_settings(self): def get_repo_settings(self):
try: try:
contents = self.azure_devops_client.get_item_content(repository_id=self.repo_slug, contents = self.azure_devops_client.get_item_content(
project=self.workspace_slug, download=False, repository_id=self.repo_slug,
include_content_metadata=False, include_content=True, project=self.workspace_slug,
path=".pr_agent.toml") download=False,
include_content_metadata=False,
include_content=True,
path=".pr_agent.toml",
)
return contents return contents
except Exception as e: except Exception as e:
get_logger().exception("get repo settings error") get_logger().exception("get repo settings error")
@ -60,15 +93,19 @@ class AzureDevopsProvider:
def get_files(self): def get_files(self):
files = [] files = []
for i in self.azure_devops_client.get_pull_request_commits(project=self.workspace_slug, for i in self.azure_devops_client.get_pull_request_commits(
repository_id=self.repo_slug, project=self.workspace_slug,
pull_request_id=self.pr_num): repository_id=self.repo_slug,
pull_request_id=self.pr_num,
changes_obj = self.azure_devops_client.get_changes(project=self.workspace_slug, ):
repository_id=self.repo_slug, commit_id=i.commit_id) changes_obj = self.azure_devops_client.get_changes(
project=self.workspace_slug,
repository_id=self.repo_slug,
commit_id=i.commit_id,
)
for c in changes_obj.changes: for c in changes_obj.changes:
files.append(c['item']['path']) files.append(c["item"]["path"])
return list(set(files)) return list(set(files))
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
@ -76,22 +113,27 @@ class AzureDevopsProvider:
base_sha = self.pr.last_merge_target_commit base_sha = self.pr.last_merge_target_commit
head_sha = self.pr.last_merge_source_commit head_sha = self.pr.last_merge_source_commit
commits = self.azure_devops_client.get_pull_request_commits(project=self.workspace_slug, commits = self.azure_devops_client.get_pull_request_commits(
repository_id=self.repo_slug, project=self.workspace_slug,
pull_request_id=self.pr_num) repository_id=self.repo_slug,
pull_request_id=self.pr_num,
)
diff_files = [] diff_files = []
diffs = [] diffs = []
diff_types = {} diff_types = {}
for c in commits: for c in commits:
changes_obj = self.azure_devops_client.get_changes(project=self.workspace_slug, changes_obj = self.azure_devops_client.get_changes(
repository_id=self.repo_slug, commit_id=c.commit_id) project=self.workspace_slug,
repository_id=self.repo_slug,
commit_id=c.commit_id,
)
for i in changes_obj.changes: for i in changes_obj.changes:
if(i['item']['gitObjectType'] == 'tree'): if i["item"]["gitObjectType"] == "tree":
continue continue
diffs.append(i['item']['path']) diffs.append(i["item"]["path"])
diff_types[i['item']['path']] = i['changeType'] diff_types[i["item"]["path"]] = i["changeType"]
diffs = list(set(diffs)) diffs = list(set(diffs))
@ -99,47 +141,72 @@ class AzureDevopsProvider:
if not is_valid_file(file): if not is_valid_file(file):
continue continue
version = GitVersionDescriptor(version=head_sha.commit_id, version_type='commit') version = GitVersionDescriptor(
version=head_sha.commit_id, version_type="commit"
)
try: try:
new_file_content_str = self.azure_devops_client.get_item(repository_id=self.repo_slug, new_file_content_str = self.azure_devops_client.get_item(
path=file, repository_id=self.repo_slug,
project=self.workspace_slug, path=file,
version_descriptor=version, project=self.workspace_slug,
download=False, version_descriptor=version,
include_content=True) download=False,
include_content=True,
)
new_file_content_str = new_file_content_str.content new_file_content_str = new_file_content_str.content
except Exception as error: except Exception as error:
get_logger().error("Failed to retrieve new file content of %s at version %s. Error: %s", file, version, str(error)) get_logger().error(
"Failed to retrieve new file content of %s at version %s. Error: %s",
file,
version,
str(error),
)
new_file_content_str = "" new_file_content_str = ""
edit_type = EDIT_TYPE.MODIFIED edit_type = EDIT_TYPE.MODIFIED
if diff_types[file] == 'add': if diff_types[file] == "add":
edit_type = EDIT_TYPE.ADDED edit_type = EDIT_TYPE.ADDED
elif diff_types[file] == 'delete': elif diff_types[file] == "delete":
edit_type = EDIT_TYPE.DELETED edit_type = EDIT_TYPE.DELETED
elif diff_types[file] == 'rename': elif diff_types[file] == "rename":
edit_type = EDIT_TYPE.RENAMED edit_type = EDIT_TYPE.RENAMED
version = GitVersionDescriptor(version=base_sha.commit_id, version_type='commit') version = GitVersionDescriptor(
version=base_sha.commit_id, version_type="commit"
)
try: try:
original_file_content_str = self.azure_devops_client.get_item(repository_id=self.repo_slug, original_file_content_str = self.azure_devops_client.get_item(
path=file, repository_id=self.repo_slug,
project=self.workspace_slug, path=file,
version_descriptor=version, project=self.workspace_slug,
download=False, version_descriptor=version,
include_content=True) download=False,
include_content=True,
)
original_file_content_str = original_file_content_str.content original_file_content_str = original_file_content_str.content
except Exception as error: except Exception as error:
get_logger().error("Failed to retrieve original file content of %s at version %s. Error: %s", file, version, str(error)) get_logger().error(
"Failed to retrieve original file content of %s at version %s. Error: %s",
file,
version,
str(error),
)
original_file_content_str = "" original_file_content_str = ""
patch = load_large_diff(file, new_file_content_str, original_file_content_str) patch = load_large_diff(
file, new_file_content_str, original_file_content_str
)
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, diff_files.append(
patch=patch, FilePatchInfo(
filename=file, original_file_content_str,
edit_type=edit_type)) new_file_content_str,
patch=patch,
filename=file,
edit_type=edit_type,
)
)
self.diff_files = diff_files self.diff_files = diff_files
return diff_files return diff_files
@ -150,64 +217,88 @@ class AzureDevopsProvider:
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
comment = Comment(content=pr_comment) comment = Comment(content=pr_comment)
thread = CommentThread(comments=[comment]) thread = CommentThread(comments=[comment])
thread_response = self.azure_devops_client.create_thread(comment_thread=thread, project=self.workspace_slug, thread_response = self.azure_devops_client.create_thread(
repository_id=self.repo_slug, comment_thread=thread,
pull_request_id=self.pr_num) project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
)
if is_temporary: if is_temporary:
self.temp_comments.append({'thread_id': thread_response.id, 'comment_id': comment.id}) self.temp_comments.append(
{"thread_id": thread_response.id, "comment_id": comment.id}
)
def publish_description(self, pr_title: str, pr_body: str): def publish_description(self, pr_title: str, pr_body: str):
try: try:
updated_pr = GitPullRequest() updated_pr = GitPullRequest()
updated_pr.title = pr_title updated_pr.title = pr_title
updated_pr.description = pr_body updated_pr.description = pr_body
self.azure_devops_client.update_pull_request(project=self.workspace_slug, self.azure_devops_client.update_pull_request(
repository_id=self.repo_slug, project=self.workspace_slug,
pull_request_id=self.pr_num, repository_id=self.repo_slug,
git_pull_request_to_update=updated_pr) pull_request_id=self.pr_num,
git_pull_request_to_update=updated_pr,
)
except Exception as e: except Exception as e:
get_logger().exception(f"Could not update pull request {self.pr_num} description: {e}") get_logger().exception(
f"Could not update pull request {self.pr_num} description: {e}"
)
def remove_initial_comment(self): def remove_initial_comment(self):
return "" # not implemented yet return "" # not implemented yet
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): def publish_inline_comment(
raise NotImplementedError("Azure DevOps provider does not support publishing inline comment yet") self, body: str, relevant_file: str, relevant_line_in_file: str
):
raise NotImplementedError(
"Azure DevOps provider does not support publishing inline comment yet"
)
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError("Azure DevOps provider does not support publishing inline comments yet") raise NotImplementedError(
"Azure DevOps provider does not support publishing inline comments yet"
)
def get_title(self): def get_title(self):
return self.pr.title return self.pr.title
def get_languages(self): def get_languages(self):
languages = [] languages = []
files = self.azure_devops_client.get_items(project=self.workspace_slug, repository_id=self.repo_slug, files = self.azure_devops_client.get_items(
recursion_level="Full", include_content_metadata=True, project=self.workspace_slug,
include_links=False, download=False) repository_id=self.repo_slug,
recursion_level="Full",
include_content_metadata=True,
include_links=False,
download=False,
)
for f in files: for f in files:
if f.git_object_type == 'blob': if f.git_object_type == "blob":
file_name, file_extension = os.path.splitext(f.path) file_name, file_extension = os.path.splitext(f.path)
languages.append(file_extension[1:]) languages.append(file_extension[1:])
extension_counts = {} extension_counts = {}
for ext in languages: for ext in languages:
if ext != '': if ext != "":
extension_counts[ext] = extension_counts.get(ext, 0) + 1 extension_counts[ext] = extension_counts.get(ext, 0) + 1
total_extensions = sum(extension_counts.values()) total_extensions = sum(extension_counts.values())
extension_percentages = {ext: (count / total_extensions) * 100 for ext, count in extension_counts.items()} extension_percentages = {
ext: (count / total_extensions) * 100
for ext, count in extension_counts.items()
}
return extension_percentages return extension_percentages
def get_pr_branch(self): def get_pr_branch(self):
pr_info = self.azure_devops_client.get_pull_request_by_id(project=self.workspace_slug, pr_info = self.azure_devops_client.get_pull_request_by_id(
pull_request_id=self.pr_num) project=self.workspace_slug, pull_request_id=self.pr_num
source_branch = pr_info.source_ref_name.split('/')[-1] )
source_branch = pr_info.source_ref_name.split("/")[-1]
return source_branch return source_branch
def get_pr_description(self): def get_pr_description(self, full=False):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens: if max_tokens:
return clip_tokens(self.pr.description, max_tokens) return clip_tokens(self.pr.description, max_tokens)
@ -217,7 +308,9 @@ class AzureDevopsProvider:
return 0 return 0
def get_issue_comments(self): def get_issue_comments(self):
raise NotImplementedError("Azure DevOps provider does not support issue comments yet") raise NotImplementedError(
"Azure DevOps provider does not support issue comments yet"
)
def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]: def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]:
return True return True
@ -226,16 +319,20 @@ class AzureDevopsProvider:
return True return True
def get_issue_comments(self): def get_issue_comments(self):
raise NotImplementedError("Azure DevOps provider does not support issue comments yet") raise NotImplementedError(
"Azure DevOps provider does not support issue comments yet"
)
@staticmethod @staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]: def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url) parsed_url = urlparse(pr_url)
path_parts = parsed_url.path.strip('/').split('/') path_parts = parsed_url.path.strip("/").split("/")
if len(path_parts) < 6 or path_parts[4] != 'pullrequest': if len(path_parts) < 6 or path_parts[4] != "pullrequest":
raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL") raise ValueError(
"The provided URL does not appear to be a Azure DevOps PR URL"
)
workspace_slug = path_parts[1] workspace_slug = path_parts[1]
repo_slug = path_parts[3] repo_slug = path_parts[3]
@ -251,10 +348,9 @@ class AzureDevopsProvider:
pat = get_settings().azure_devops.pat pat = get_settings().azure_devops.pat
org = get_settings().azure_devops.org org = get_settings().azure_devops.org
except AttributeError as e: except AttributeError as e:
raise ValueError( raise ValueError("Azure DevOps PAT token is required ") from e
"Azure DevOps PAT token is required ") from e
credentials = BasicAuthentication('', pat) credentials = BasicAuthentication("", pat)
azure_devops_connection = Connection(base_url=org, creds=credentials) azure_devops_connection = Connection(base_url=org, creds=credentials)
azure_devops_client = azure_devops_connection.clients.get_git_client() azure_devops_client = azure_devops_connection.clients.get_git_client()
@ -262,13 +358,23 @@ class AzureDevopsProvider:
def _get_repo(self): def _get_repo(self):
if self.repo is None: if self.repo is None:
self.repo = self.azure_devops_client.get_repository(project=self.workspace_slug, self.repo = self.azure_devops_client.get_repository(
repository_id=self.repo_slug) project=self.workspace_slug, repository_id=self.repo_slug
)
return self.repo return self.repo
def _get_pr(self): def _get_pr(self):
self.pr = self.azure_devops_client.get_pull_request_by_id(pull_request_id=self.pr_num, project=self.workspace_slug) self.pr = self.azure_devops_client.get_pull_request_by_id(
pull_request_id=self.pr_num, project=self.workspace_slug
)
return self.pr return self.pr
def get_commit_messages(self): def get_commit_messages(self):
return "" # not implemented yet return "" # not implemented yet
def get_pr_id(self):
try:
pr_id = f"{self.workspace_slug}/{self.repo_slug}/{self.pr_num}"
return pr_id
except Exception:
return ""