Skip to content

Commit 0eb8e02

Browse files
fix(versioning): infer correct default branch (#232)
1 parent 912d749 commit 0eb8e02

File tree

1 file changed

+18
-4
lines changed
  • src/encord_active/lib/versioning

1 file changed

+18
-4
lines changed

src/encord_active/lib/versioning/git.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,22 @@ def __init__(self, path: Path) -> None:
2828
self.repo.index.add(self.repo.untracked_files)
2929
self.repo.index.commit("init")
3030

31+
self._default_branch = _get_default_branch(self.repo)
32+
33+
@property
34+
def _default_head(self):
35+
return self.repo.heads.__getattr__(self._default_branch)
36+
3137
@property
3238
def current_version(self):
3339
return _commit_to_version(self.repo.head.commit)
3440

3541
def is_latest(self, version: Optional[Version] = None) -> bool:
36-
return (version or self.current_version).id == self.repo.heads.main.commit.hexsha
42+
return (version or self.current_version).id == self._default_head.commit.hexsha
3743

3844
@property
3945
def versions(self):
40-
return [_commit_to_version(commit) for commit in self.repo.iter_commits(self.repo.heads.main)]
46+
return [_commit_to_version(commit) for commit in self.repo.iter_commits(self._default_head)]
4147

4248
@property
4349
def has_changes(self):
@@ -52,8 +58,8 @@ def create_version(self, name: str):
5258
return new_version
5359

5460
def jump_to(self, version: Union[Version, Literal["latest"]]):
55-
if version == "latest" or version.id == self.repo.heads.main.commit.hexsha and not self.is_latest():
56-
self.repo.head.reference = self.repo.heads.main # type: ignore
61+
if version == "latest" or version.id == self._default_head.commit.hexsha and not self.is_latest():
62+
self.repo.head.reference = self._default_head # type: ignore
5763
self.discard_changes()
5864
elif self.repo.head.commit.hexsha != version.id:
5965
self.repo.head.reference = self.repo.rev_parse(version.id) # type: ignore
@@ -73,3 +79,11 @@ def unstash(self):
7379

7480
def _commit_to_version(commit: Commit) -> Version:
7581
return Version(name=str(commit.message), id=commit.hexsha)
82+
83+
84+
def _get_default_branch(repo: Repo):
85+
with repo.config_reader() as reader:
86+
possible_heads = [head.name for head in repo.heads]
87+
global_default = str(reader.get_value("init", "defaultBranch")).strip('"')
88+
89+
return global_default if global_default in possible_heads else possible_heads.pop()

0 commit comments

Comments
 (0)