Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions src/lando/utils/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def _get_token(repo_owner: str, repo_name: str) -> str | None:
session = AppInstallationAuth(app_auth, repo_owner, repositories=[repo_name])
return asyncio.run(session.get_token())

def get(self, path: str, *args, **kwargs) -> dict:
def get(self, path: str, *args, **kwargs) -> requests.Response:
"""Send a GET request to the GitHub API with given args and kwargs."""
url = f"{self.GITHUB_BASE_URL}/{path}"
return self.session.get(url, *args, **kwargs)

def post(self, path: str, *args, **kwargs) -> dict:
def post(self, path: str, *args, **kwargs) -> requests.Response:
"""Send a POST request to the GitHub API with given args and kwargs."""
url = f"{self.GITHUB_BASE_URL}/{path}"
return self.session.post(url, *args, **kwargs)
Expand All @@ -69,17 +69,30 @@ def post(self, path: str, *args, **kwargs) -> dict:
class GitHubAPIClient:
"""A convenience client that provides various methods to interact with the GitHub API."""

client = None
_api: GitHubAPI
Comment thread
zzzeid marked this conversation as resolved.

def __init__(self, repo: Repo):
self.client = GitHubAPI(repo)
self._api = GitHubAPI(repo)
self.repo = repo
self.repo_base_url = (
f"repos/{self.repo._github_repo_org}/{self.repo.git_repo_name}"
)

def _get(self, path: str, *args, **kwargs) -> dict:
Comment thread
zzzeid marked this conversation as resolved.
result = self.client.get(path, *args, **kwargs)
def _repo_get(self, subpath: str, *args, **kwargs) -> dict | list:
"""Get API endpoint scoped to the repo_base_url.

Parameters:

subpath: str
Relative path without leading `/`.

Return:
dist | list: decoded JSON from the response
"""
return self._get(f"{self.repo_base_url}/{subpath}", *args, **kwargs)

def _get(self, path: str, *args, **kwargs) -> dict | list | str | None:
result = self._api.get(path, *args, **kwargs)
content_type = result.headers["content-type"]
if content_type == "application/json; charset=utf-8":
return result.json()
Expand All @@ -89,16 +102,16 @@ def _get(self, path: str, *args, **kwargs) -> dict:
return result.text

def _post(self, path: str, *args, **kwargs):
result = self.client.post(path, *args, **kwargs)
return result.json()
result = self._api.post(path, *args, **kwargs)
return result
Comment thread
shtrom marked this conversation as resolved.
Outdated

def list_pull_requests(self) -> list:
"""List all pull requests in the repo."""
return self._get(f"{self.repo_base_url}/pulls")
return self._repo_get("pulls")

def get_pull_request(self, pull_number: int) -> dict:
"""Get a specific pull request from the repo."""
return self._get(f"{self.repo_base_url}/pulls/{pull_number}")
return self._repo_get(f"pulls/{pull_number}")

def get_diff(self, pull_number: int) -> str:
"""Fetch a diff, given a pull request number."""
Expand Down