diff --git a/pr_agent/git_providers/codecommit_provider.py b/pr_agent/git_providers/codecommit_provider.py index 8bd1aedd..5fa8c873 100644 --- a/pr_agent/git_providers/codecommit_provider.py +++ b/pr_agent/git_providers/codecommit_provider.py @@ -233,7 +233,7 @@ class CodeCommitProvider(GitProvider): raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet") def get_title(self): - return self.pr.get("title", "") + return self.pr.title def get_pr_id(self): """ diff --git a/tests/unittest/test_codecommit_client.py b/tests/unittest/test_codecommit_client.py index 0aa1ffa6..a81e4b32 100644 --- a/tests/unittest/test_codecommit_client.py +++ b/tests/unittest/test_codecommit_client.py @@ -110,7 +110,7 @@ class TestCodeCommitProvider: # Mock the response from the AWS client for get_pull_request method api.boto_client.get_pull_request.return_value = { "pullRequest": { - "pullRequestId": "3", + "pullRequestId": "321", "title": "My PR", "description": "My PR description", "pullRequestTargets": [ diff --git a/tests/unittest/test_codecommit_provider.py b/tests/unittest/test_codecommit_provider.py index 9de7c45c..6f187de7 100644 --- a/tests/unittest/test_codecommit_provider.py +++ b/tests/unittest/test_codecommit_provider.py @@ -1,6 +1,8 @@ import pytest +from unittest.mock import patch from pr_agent.git_providers.codecommit_provider import CodeCommitFile from pr_agent.git_providers.codecommit_provider import CodeCommitProvider +from pr_agent.git_providers.codecommit_provider import PullRequestCCMimic from pr_agent.git_providers.git_provider import EDIT_TYPE @@ -25,6 +27,21 @@ class TestCodeCommitFile: class TestCodeCommitProvider: + def test_get_title(self): + # Test that the get_title() function returns the PR title + with patch.object(CodeCommitProvider, "__init__", lambda x, y: None): + provider = CodeCommitProvider(None) + provider.pr = PullRequestCCMimic("My Test PR Title", []) + assert provider.get_title() == "My Test PR Title" + + def test_get_pr_id(self): + # Test that the get_pr_id() function returns the correct ID + with patch.object(CodeCommitProvider, "__init__", lambda x, y: None): + provider = CodeCommitProvider(None) + provider.repo_name = "my_test_repo" + provider.pr_num = 321 + assert provider.get_pr_id() == "my_test_repo/321" + def test_parse_pr_url(self): # Test that the _parse_pr_url() function can extract the repo name and PR number from a CodeCommit URL url = "https://us-east-1.console.aws.amazon.com/codesuite/codecommit/repositories/my_test_repo/pull-requests/321" @@ -169,4 +186,4 @@ class TestCodeCommitProvider: def test_remove_markdown_html(self): input = "## PR Feedback\n
Code feedback:\nfile foo\n\n" expect = "## PR Feedback\nCode feedback:\nfile foo\n\n" - assert CodeCommitProvider._remove_markdown_html(input) == expect \ No newline at end of file + assert CodeCommitProvider._remove_markdown_html(input) == expect