init Azure DevOps git provider

This commit is contained in:
szecsip
2023-08-23 16:01:10 +00:00
parent 52ba2793cd
commit 01d1cf98f4

View File

@ -1,23 +1,25 @@
import json
import logging import logging
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import os import os
import requests
from msrest.authentication import BasicAuthentication from msrest.authentication import BasicAuthentication
from azure.devops.connection import Connection from azure.devops.connection import Connection
from azure.devops.v7_0.git.models import Comment, CommentThread, GitVersionDescriptor, GitPullRequest
from ..algo.pr_processing import clip_tokens from ..algo.pr_processing import clip_tokens
from ..config_loader import get_settings from ..config_loader import get_settings
from .git_provider import FilePatchInfo from ..algo.utils import load_large_diff
from ..algo.language_handler import is_valid_file
from .git_provider import EDIT_TYPE, FilePatchInfo
class AzureDevopsProvider: class AzureDevopsProvider:
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
self.azure_devops_client = self._get_azure_devops_client() self.azure_devops_client = self._get_azure_devops_client()
logging.info(self.azure_devops_client)
self.workspace_slug = None self.workspace_slug = None
self.repo_slug = None self.repo_slug = None
@ -40,9 +42,10 @@ class AzureDevopsProvider:
def get_repo_settings(self): def get_repo_settings(self):
try: try:
contents = self.azure_devops_client.get_item_content(repository_id=self.repo_slug, project=self.workspace_slug, download=False, include_content_metadata=False, include_content=True, path=".pr_agent.toml") contents = self.azure_devops_client.get_item_content(repository_id=self.repo_slug,
logging.info("get repo settings") project=self.workspace_slug, download=False,
logging.info(contents) include_content_metadata=False, include_content=True,
path=".pr_agent.toml")
return contents return contents
except Exception as e: except Exception as e:
logging.info("get repo settings error") logging.info("get repo settings error")
@ -51,42 +54,121 @@ class AzureDevopsProvider:
def get_files(self): def get_files(self):
files = [] files = []
for i in self.azure_devops_client.get_pull_request_commits(project=self.workspace_slug, repository_id=self.repo_slug, pull_request_id=self.pr_num): for i in self.azure_devops_client.get_pull_request_commits(project=self.workspace_slug,
#logging.info(i) repository_id=self.repo_slug,
changes_obj = self.azure_devops_client.get_changes(project=self.workspace_slug, repository_id=self.repo_slug, commit_id=i.commit_id) pull_request_id=self.pr_num):
#logging.info(changes_obj)
#logging.info("***********") changes_obj = self.azure_devops_client.get_changes(project=self.workspace_slug,
repository_id=self.repo_slug, commit_id=i.commit_id)
for c in changes_obj.changes: for c in changes_obj.changes:
files.append(c['item']['path']) files.append(c['item']['path'])
#logging.info("###########") return list(set(files))
return files
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.pr.diffstat() try:
diff_split = ['diff --git%s' % x for x in self.pr.diff().split('diff --git') if x.strip()] base_sha = self.pr.last_merge_target_commit
head_sha = self.pr.last_merge_source_commit
diff_files = []
for index, diff in enumerate(diffs): commits = self.azure_devops_client.get_pull_request_commits(project=self.workspace_slug,
original_file_content_str = self._get_pr_file_content(diff.old.get_data('links')) repository_id=self.repo_slug,
new_file_content_str = self._get_pr_file_content(diff.new.get_data('links')) pull_request_id=self.pr_num)
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str,
diff_split[index], diff.new.path)) diff_files = []
return diff_files diffs = []
diff_types = {}
for c in commits:
changes_obj = self.azure_devops_client.get_changes(project=self.workspace_slug,
repository_id=self.repo_slug, commit_id=c.commit_id)
for i in changes_obj.changes:
logging.info(i)
diffs.append(i['item']['path'])
diff_types[i['item']['path']] = i['changeType']
diffs = list(set(diffs))
for file in diffs:
if not is_valid_file(file):
continue
version = GitVersionDescriptor(version=head_sha.commit_id, version_type='commit')
new_file_content_str = self.azure_devops_client.get_item(repository_id=self.repo_slug,
path=file,
project=self.workspace_slug,
version_descriptor=version,
download=False,
include_content=True)
new_file_content_str = new_file_content_str.content
edit_type = EDIT_TYPE.MODIFIED
if diff_types[file] == 'add':
edit_type = EDIT_TYPE.ADDED
elif diff_types[file] == 'delete':
edit_type = EDIT_TYPE.DELETED
elif diff_types[file] == 'rename':
edit_type = EDIT_TYPE.RENAMED
version = GitVersionDescriptor(version=base_sha.commit_id, version_type='commit')
original_file_content_str = self.azure_devops_client.get_item(repository_id=self.repo_slug,
path=file,
project=self.workspace_slug,
version_descriptor=version,
download=False,
include_content=True)
original_file_content_str = original_file_content_str.content
patch = load_large_diff(file, new_file_content_str, original_file_content_str)
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str,
patch=patch,
filename=file,
edit_type=edit_type))
self.diff_files = diff_files
return diff_files
except Exception as e:
print(f"Error: {str(e)}")
return []
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
comment = self.pr.comment(pr_comment) comment = Comment(content=pr_comment)
thread = CommentThread(comments=[comment])
thread_response = self.azure_devops_client.create_thread(comment_thread=thread, project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num)
if is_temporary: if is_temporary:
self.temp_comments.append(comment['id']) self.temp_comments.append({'thread_id': thread_response.id, 'comment_id': comment.id})
def publish_description(self, pr_title: str, pr_body: str):
try:
updated_pr = GitPullRequest()
updated_pr.title = pr_title
updated_pr.description = pr_body
self.azure_devops_client.update_pull_request(project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
git_pull_request_to_update=updated_pr)
except Exception as e:
logging.exception(f"Could not update pull request {self.pr_num} description: {e}")
def remove_initial_comment(self): def remove_initial_comment(self):
try: try:
for comment in self.temp_comments: for comment in self.temp_comments:
self.pr.delete(f'comments/{comment}') new_comment_thread = CommentThread(comments=[Comment(content='bumm')])
# self.azure_devops_client.delete_comment(project=self.workspace_slug, repository_id=self.repo_slug, thread_id=comment['thread_id'], comment_id=comment['comment_id'], pull_request_id=self.pr_num)
res = self.azure_devops_client.update_thread(project=self.workspace_slug, repository_id=self.repo_slug,
thread_id=comment['thread_id'],
pull_request_id=self.pr_num,
comment_thread=new_comment_thread)
logging.info(res)
except Exception as e: except Exception as e:
logging.exception(f"Failed to remove temp comments, error: {e}") logging.exception(f"Failed to remove temp comments, error: {e}")
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass raise NotImplementedError("Azure DevOps provider does not support publishing inline comment yet")
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
raise NotImplementedError("Azure DevOps provider does not support creating inline comments yet") raise NotImplementedError("Azure DevOps provider does not support creating inline comments yet")
@ -99,7 +181,9 @@ class AzureDevopsProvider:
def get_languages(self): def get_languages(self):
languages = [] languages = []
files = self.azure_devops_client.get_items(project=self.workspace_slug, repository_id=self.repo_slug, recursion_level="Full", include_content_metadata=True, include_links=False, download=False) files = self.azure_devops_client.get_items(project=self.workspace_slug, repository_id=self.repo_slug,
recursion_level="Full", include_content_metadata=True,
include_links=False, download=False)
for f in files: for f in files:
if f.git_object_type == 'blob': if f.git_object_type == 'blob':
file_name, file_extension = os.path.splitext(f.path) file_name, file_extension = os.path.splitext(f.path)
@ -113,12 +197,14 @@ class AzureDevopsProvider:
total_extensions = sum(extension_counts.values()) total_extensions = sum(extension_counts.values())
extension_percentages = {ext: (count / total_extensions) * 100 for ext, count in extension_counts.items()} extension_percentages = {ext: (count / total_extensions) * 100 for ext, count in extension_counts.items()}
logging.info(extension_percentages)
return extension_percentages return extension_percentages
def get_pr_branch(self): def get_pr_branch(self):
return self.pr.source_branch pr_info = self.azure_devops_client.get_pull_request_by_id(project=self.workspace_slug,
pull_request_id=self.pr_num)
source_branch = pr_info.source_ref_name.split('/')[-1]
return source_branch
def get_pr_description(self): def get_pr_description(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
@ -141,13 +227,12 @@ class AzureDevopsProvider:
@staticmethod @staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]: def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url) parsed_url = urlparse(pr_url)
if 'azure.com' not in parsed_url.netloc: if 'azure.com' not in parsed_url.netloc:
raise ValueError("The provided URL is not a valid Azure DevOps URL") raise ValueError("The provided URL is not a valid Azure DevOps URL")
path_parts = parsed_url.path.strip('/').split('/') path_parts = parsed_url.path.strip('/').split('/')
logging.info(path_parts)
if len(path_parts) < 6 or path_parts[4] != 'pullrequest': if len(path_parts) < 6 or path_parts[4] != 'pullrequest':
raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL") raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL")
@ -176,13 +261,13 @@ class AzureDevopsProvider:
def _get_repo(self): def _get_repo(self):
if self.repo is None: if self.repo is None:
self.repo = self.azure_devops_client.get_repository(project=self.workspace_slug, repository_id=self.repo_slug) self.repo = self.azure_devops_client.get_repository(project=self.workspace_slug,
#logging.info(self.repo) repository_id=self.repo_slug)
return self.repo return self.repo
def _get_pr(self): def _get_pr(self):
logging.info(self.azure_devops_client.get_pull_request_by_id(pull_request_id=self.pr_num, project=self.workspace_slug)) self.pr = self.azure_devops_client.get_pull_request_by_id(pull_request_id=self.pr_num, project=self.workspace_slug)
return self.azure_devops_client.get_pull_request_by_id(pull_request_id=self.pr_num, project=self.workspace_slug) return self.pr
def _get_pr_file_content(self, remote_link: str): def _get_pr_file_content(self, remote_link: str):
return "" return ""