Merge pull request #667 from Codium-ai/tr_ado

azure webhook
This commit is contained in:
Tal
2024-02-17 22:01:57 -08:00
committed by GitHub
7 changed files with 235 additions and 26 deletions

View File

@ -369,9 +369,9 @@ def try_fix_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
pass
# third fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}')
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
try:
data = yaml.safe_load(response_text_copy,)
data = yaml.safe_load(response_text_copy)
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
return data
except:
@ -382,7 +382,7 @@ def try_fix_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
for i in range(1, len(response_text_lines)):
response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
try:
data = yaml.safe_load(response_text_lines_tmp,)
data = yaml.safe_load(response_text_lines_tmp)
get_logger().info(f"Successfully parsed AI prediction after removing {i} lines")
return data
except:

View File

@ -10,6 +10,7 @@ from .git_provider import GitProvider
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
AZURE_DEVOPS_AVAILABLE = True
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000-1
try:
# noinspection PyUnresolvedReferences
@ -38,7 +39,7 @@ class AzureDevopsProvider(GitProvider):
)
self.azure_devops_client = self._get_azure_devops_client()
self.diff_files = None
self.workspace_slug = None
self.repo_slug = None
self.repo = None
@ -124,6 +125,19 @@ class AzureDevopsProvider(GitProvider):
def get_pr_description_full(self) -> str:
return self.pr.description
def edit_comment(self, comment, body: str):
try:
self.azure_devops_client.update_comment(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
thread_id=comment["thread_id"],
comment_id=comment["comment_id"],
comment=Comment(content=body),
project=self.workspace_slug,
)
except Exception as e:
get_logger().exception(f"Failed to edit comment, error: {e}")
def remove_comment(self, comment):
try:
self.azure_devops_client.delete_comment(
@ -181,7 +195,7 @@ class AzureDevopsProvider(GitProvider):
include_content=True,
path=".pr_agent.toml",
)
return contents
return list(contents)[0]
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to get repo settings, error: {e}")
@ -206,6 +220,10 @@ class AzureDevopsProvider(GitProvider):
def get_diff_files(self) -> list[FilePatchInfo]:
try:
if self.diff_files:
return self.diff_files
base_sha = self.pr.last_merge_target_commit
head_sha = self.pr.last_merge_source_commit
@ -303,7 +321,7 @@ class AzureDevopsProvider(GitProvider):
edit_type=edit_type,
)
)
self.diff_files = diff_files
return diff_files
except Exception as e:
print(f"Error: {str(e)}")
@ -318,12 +336,29 @@ class AzureDevopsProvider(GitProvider):
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
)
response = {"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id}
if is_temporary:
self.temp_comments.append(
{"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id}
)
self.temp_comments.append(response)
return response
def publish_description(self, pr_title: str, pr_body: str):
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
usage_guide_text='<details> <summary><strong>✨ Usage guide:</strong></summary><hr>'
ind = pr_body.find(usage_guide_text)
if ind != -1:
pr_body = pr_body[:ind]
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
changes_walkthrough_text = '## **Changes walkthrough**'
ind = pr_body.find(changes_walkthrough_text)
if ind != -1:
pr_body = pr_body[:ind]
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
trunction_message = " ... (description truncated due to length limit)"
pr_body = pr_body[:MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + trunction_message
get_logger().warning("PR description was truncated due to length limit")
try:
updated_pr = GitPullRequest()
updated_pr.title = pr_title

View File

@ -0,0 +1,135 @@
# This file contains the code for the Azure DevOps Server webhook server.
# The server listens for incoming webhooks from Azure DevOps Server and forwards them to the PR Agent.
# ADO webhook documentation: https://learn.microsoft.com/en-us/azure/devops/service-hooks/services/webhooks?view=azure-devops
import json
import os
import re
import secrets
import uvicorn
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.encoders import jsonable_encoder
from starlette import status
from starlette.background import BackgroundTasks
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent, command2class
from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.log import get_logger
from fastapi import Request, Depends
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pr_agent.log import get_logger
security = HTTPBasic()
router = APIRouter()
available_commands_rgx = re.compile(r"^\/(" + "|".join(command2class.keys()) + r")\s*")
azure_devops_server = get_settings().get("azure_devops_server")
WEBHOOK_USERNAME = azure_devops_server.get("webhook_username")
WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password")
def handle_request(
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
):
log_context["action"] = body
log_context["api_url"] = url
with get_logger().contextualize(**log_context):
background_tasks.add_task(PRAgent().handle_request, url, body)
# currently only basic auth is supported with azure webhooks
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
if not (is_user_ok and is_pass_ok):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Incorrect username or password.',
headers={'WWW-Authenticate': 'Basic'},
)
async def _perform_commands(commands_conf: str, agent: PRAgent, body: dict, api_url: str, log_context: dict):
apply_repo_settings(api_url)
commands = get_settings().get(f"azure_devops_server.{commands_conf}")
for command in commands:
split_command = command.split(" ")
command = split_command[0]
args = split_command[1:]
other_args = update_settings_from_args(args)
new_command = ' '.join([command] + other_args)
if body:
get_logger().info(body)
get_logger().info(f"Performing command: {new_command}")
with get_logger().contextualize(**log_context):
await agent.handle_request(api_url, new_command)
@router.post("/", dependencies=[Depends(authorize)])
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
log_context = {"server_type": "azure_devops_server"}
data = await request.json()
get_logger().info(json.dumps(data))
actions = []
if data["eventType"] == "git.pullrequest.created":
# API V1 (latest)
pr_url = data["resource"]["_links"]["web"]["href"].replace("_apis/git/repositories", "_git")
log_context["event"] = data["eventType"]
log_context["api_url"] = pr_url
await _perform_commands("pr_commands", PRAgent(), {}, pr_url, log_context)
return
elif data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" and "content" in data["resource"]["comment"]:
if available_commands_rgx.match(data["resource"]["comment"]["content"]):
if(data["resourceVersion"] == "2.0"):
repo = data["resource"]["pullRequest"]["repository"]["webUrl"]
pr_url = f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}'
actions = [data["resource"]["comment"]["content"]]
else:
# API V1 not supported as it does not contain the PR URL
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=json.dumps({"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"})),
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=json.dumps({"message": "Unsupported command"}),
)
else:
return JSONResponse(
status_code=status.HTTP_204_NO_CONTENT,
content=json.dumps({"message": "Unsupported event"}),
)
log_context["event"] = data["eventType"]
log_context["api_url"] = pr_url
for action in actions:
try:
handle_request(background_tasks, pr_url, action, log_context)
except Exception as e:
get_logger().error("Azure DevOps Trigger failed. Error:" + str(e))
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=json.dumps({"message": "Internal server error"}),
)
return JSONResponse(
status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder({"message": "webhook triggerd successfully"})
)
@router.get("/")
async def root():
return {"status": "ok"}
def start():
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
if __name__ == "__main__":
start()

View File

@ -76,3 +76,14 @@ base_url = ""
[litellm]
LITELLM_TOKEN = "" # see https://docs.litellm.ai/docs/debugging/hosted_debugging for details and instructions on how to get a token
[azure_devops]
# For Azure devops personal access token
org = ""
pat = ""
[azure_devops_server]
# For Azure devops Server basic auth - configured in the webhook creation
# Optional, uncomment if you want to use Azure devops webhooks. Value assinged when you create the webhook
# webhook_username = "<basic auth user>"
# webhook_password = "<basic auth password>"

View File

@ -109,7 +109,7 @@ class PRDescription:
# final markdown description
full_markdown_description = f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
get_logger().debug(f"full_markdown_description:\n{full_markdown_description}")
# get_logger().debug(f"full_markdown_description:\n{full_markdown_description}")
if get_settings().config.publish_output:
get_logger().info(f"Pushing answer {self.pr_id}")