diff --git a/drivers/all.go b/drivers/all.go index 8b253a08558..1fd4fdd0a73 100644 --- a/drivers/all.go +++ b/drivers/all.go @@ -25,6 +25,7 @@ import ( _ "github.com/alist-org/alist/v3/drivers/febbox" _ "github.com/alist-org/alist/v3/drivers/ftp" _ "github.com/alist-org/alist/v3/drivers/github" + _ "github.com/alist-org/alist/v3/drivers/github_release" _ "github.com/alist-org/alist/v3/drivers/google_drive" _ "github.com/alist-org/alist/v3/drivers/google_photo" _ "github.com/alist-org/alist/v3/drivers/halalcloud" diff --git a/drivers/github_release/backoff.go b/drivers/github_release/backoff.go new file mode 100644 index 00000000000..224e783be97 --- /dev/null +++ b/drivers/github_release/backoff.go @@ -0,0 +1,45 @@ +package github_release + +import ( + "math/rand" + "time" +) + +const ( + initialRetryInterval = 500 * time.Millisecond + maxInterval = 10 * time.Second + maxElapsedTime = 30 * time.Second + randomizationFactor = 0.5 + multiplier = 1.5 +) + +// Backoff 提供了确定在重试操作之前等待的时间算法 +type Backoff struct { + interval time.Duration + elapsedTime time.Duration +} + +// Pause 返回重试操作之前等待的时间量,如果可以再次尝试则返回 true,否则返回 false,表示操作应该被放弃。 +func (b *Backoff) Pause() (time.Duration, bool) { + if b.interval == 0 { + // first time + b.interval = initialRetryInterval + b.elapsedTime = 0 + } + + // interval from [1 - randomizationFactor, 1 + randomizationFactor) + randomizedInterval := time.Duration((rand.Float64()*(2*randomizationFactor) + (1 - randomizationFactor)) * float64(b.interval)) + b.elapsedTime += randomizedInterval + + if b.elapsedTime > maxElapsedTime { + return 0, false + } + + // 将间隔增加到间隔上限 + b.interval = time.Duration(float64(b.interval) * multiplier) + if b.interval > maxInterval { + b.interval = maxInterval + } + + return randomizedInterval, true +} diff --git a/drivers/github_release/backoff_test.go b/drivers/github_release/backoff_test.go new file mode 100644 index 00000000000..288e57ad181 --- /dev/null +++ b/drivers/github_release/backoff_test.go @@ -0,0 +1,38 @@ +package github_release + +import ( + "testing" + "time" +) + +func TestBackoffMultiple(t *testing.T) { + b := &Backoff{} + for i := 0; i < 10; i++ { + p, ok := b.Pause() + t.Logf("iteration %d pausing for %s", i, p) + if !ok { + t.Logf("hit the pause timeout after %d pauses", i) + return + } + } +} + +func TestBackoffTimeout(t *testing.T) { + var elapsed time.Duration + b := &Backoff{} + for i := 0; i < 40; i++ { + p, ok := b.Pause() + elapsed += p + t.Logf("iteration %d pausing for %s (total %s)", i, p, elapsed) + if !ok { + break + } + } + if _, ok := b.Pause(); ok { + t.Fatalf("did not hit the pause timeout") + } + + if elapsed > maxElapsedTime { + t.Fatalf("waited too long: %s > %s", elapsed, maxElapsedTime) + } +} diff --git a/drivers/github_release/driver.go b/drivers/github_release/driver.go new file mode 100644 index 00000000000..18dd57b54b4 --- /dev/null +++ b/drivers/github_release/driver.go @@ -0,0 +1,238 @@ +package github_release + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" +) + +// GithubRelease implements a driver for GitHub Release +type GithubRelease struct { + model.Storage + Addition + + api *APIContext + repo repository +} + +// Config returns the driver config +func (d *GithubRelease) Config() driver.Config { + return config +} + +func (d *GithubRelease) GetAddition() driver.Additional { + return &d.Addition +} + +// validate checks if the driver configuration is valid +func (d *GithubRelease) validate() error { + if d.Addition.Token == "" { + return errs.EmptyToken + } + + if d.Addition.MaxReleases < 1 { + return errors.New("max_releases must be greater than 0") + } + + if d.Addition.MaxReleases > 100 { + d.Addition.MaxReleases = 100 + } + + return nil +} + +// Init initializes the driver +func (d *GithubRelease) Init(ctx context.Context) error { + if err := d.validate(); err != nil { + return err + } + + d.api = NewAPIContext(d.Addition.Token, nil) + + repo, err := newRepository(d.Addition.Repo) + if err != nil { + return errors.Wrap(err, "failed to create repository") + } + d.repo = repo + + return nil +} + +// Drop deletes this driver +func (d *GithubRelease) Drop(ctx context.Context) error { + return nil +} + +// listReleases gets all releases +func (d *GithubRelease) listReleases(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + g, ctx := errgroup.WithContext(ctx) + + var releases []model.Obj + var latest model.Obj + + // Get latest release if enabled + if d.Addition.ShowLatest { + g.Go(func() error { + release, err := d.api.GetLatestRelease(ctx, d.repo) + if err != nil { + if err == ErrNoRelease { + // for no release, just return + return nil + } + return errors.Wrap(err, "failed to get latest release") + } + latest = release + return nil + }) + } + + // Get all releases + g.Go(func() error { + r, err := d.api.GetReleases(ctx, d.repo, d.Addition.MaxReleases) + if err != nil { + return errors.Wrap(err, "failed to get releases") + } + releases = r + return nil + }) + + // Wait for all goroutines to complete + if err := g.Wait(); err != nil { + return nil, err + } + + // Add latest release to the top if available + if latest != nil && releases != nil { + releases = append([]model.Obj{latest}, releases...) + } + + return releases, nil +} + +func (d *GithubRelease) listReleaseAssets(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + idStr := dir.GetID() + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + return nil, errors.Wrapf(err, "list release %s failed, id is not a number", idStr) + } + release, err := d.api.GetRelease(ctx, d.repo, id) + if err != nil { + return nil, err + } + return release.Children() +} + +// List returns the objects in the given directory +func (d *GithubRelease) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + // If dir is root, return all releases + if dir.GetPath() == "" { + return d.listReleases(ctx, dir, args) + } + + // Otherwise return release assets + idStr := dir.GetID() + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + return nil, errors.Wrapf(err, "list release %s failed, id is not a number", idStr) + } + + release, err := d.api.GetRelease(ctx, d.repo, id) + if err != nil { + return nil, errors.Wrap(err, "failed to get release") + } + + return release.Children() +} + +// proxyDownload checks if download should be proxied +func (d *GithubRelease) proxyDownload(file model.Obj, args model.LinkArgs) bool { + // Must proxy if configured + if d.Config().MustProxy() || d.GetStorage().WebProxy { + return true + } + + // Check if request path indicates proxy is needed + if req := args.HttpReq; req != nil && req.URL != nil { + proxyPath := fmt.Sprintf("/p%s", d.GetStorage().MountPath) + return strings.HasPrefix(req.URL.Path, proxyPath) + } + + return false +} + +// Link returns the download link for a file +func (d *GithubRelease) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + idStr := file.GetID() + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + return nil, errors.Wrapf(err, "get link of file %s failed, id is not a number", idStr) + } + + asset, err := d.api.GetReleaseAsset(ctx, d.repo, id) + if err != nil { + return nil, errors.Wrap(err, "failed to get release asset") + } + + if d.proxyDownload(file, args) { + header := http.Header{ + "User-Agent": {"Alist/" + conf.VERSION}, + "Accept": {"application/octet-stream"}, + } + d.api.SetAuthHeader(header) + + return &model.Link{ + URL: asset.URL, + Header: header, + }, nil + } + + return &model.Link{ + URL: asset.BrowserDownloadURL, + }, nil +} + +// MakeDir is not supported +func (d *GithubRelease) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + return nil, errs.NotSupport +} + +// Move is not supported +func (d *GithubRelease) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return nil, errs.NotSupport +} + +// Rename is not supported +func (d *GithubRelease) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + return nil, errs.NotSupport +} + +// Copy is not supported +func (d *GithubRelease) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return nil, errs.NotSupport +} + +// Remove is not supported +func (d *GithubRelease) Remove(ctx context.Context, obj model.Obj) error { + return errs.NotSupport +} + +// Put is not supported +func (d *GithubRelease) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + return nil, errs.NotSupport +} + +// Other implements custom operations +func (d *GithubRelease) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { + return nil, errs.NotSupport +} + +var _ driver.Driver = (*GithubRelease)(nil) diff --git a/drivers/github_release/github.go b/drivers/github_release/github.go new file mode 100644 index 00000000000..0edbbd8b103 --- /dev/null +++ b/drivers/github_release/github.go @@ -0,0 +1,326 @@ +package github_release + +import ( + "context" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" +) + +const ( + GITHUB_API_VERSION = "2022-11-28" + DEFAULT_TIMEOUT = 10 * time.Second +) + +var ErrRateLimitExceeded = errors.New("rate limit exceeded") + +// RateLimit 表示 GitHub API 的速率限制信息 +type RateLimit struct { + Limit uint + Remaining uint + Reset time.Time +} + +// GitHubError 表示 GitHub API 返回的错误信息 +type GitHubError struct { + Message string `json:"message"` + DocumentationURL string `json:"documentation_url"` + StatusCode int +} + +func (e *GitHubError) Error() string { + return fmt.Sprintf("github api error: %s (status: %d)", e.Message, e.StatusCode) +} + +// parseHTTPError 解析 GitHub API 的错误响应 +func parseHTTPError(statusCode int, body []byte) error { + var v GitHubError + err := utils.Json.Unmarshal(body, &v) + if err != nil { + return &GitHubError{ + Message: string(body), + StatusCode: statusCode, + } + } + v.StatusCode = statusCode + return &v +} + +// parseRateLimit 从响应头中解析速率限制信息 +func parseRateLimit(header http.Header) *RateLimit { + limit, _ := strconv.Atoi(header.Get("X-RateLimit-Limit")) + remaining, _ := strconv.Atoi(header.Get("X-RateLimit-Remaining")) + reset, _ := strconv.ParseInt(header.Get("X-RateLimit-Reset"), 10, 64) + + return &RateLimit{ + Limit: uint(limit), + Remaining: uint(remaining), + Reset: time.Unix(reset, 0), + } +} + +// APIContext 表示 GitHub API 的上下文信息 +type APIContext struct { + token string + version string + client *http.Client + defaultTimeout time.Duration + rateLimit *RateLimit +} + +// NewAPIContext 创建一个新的 GitHub API 上下文 +func NewAPIContext(token string, client *http.Client) *APIContext { + ret := APIContext{ + token: token, + version: GITHUB_API_VERSION, + client: client, + defaultTimeout: DEFAULT_TIMEOUT, + } + + if ret.client == nil { + ret.client = &http.Client{ + Timeout: ret.defaultTimeout, + } + } + + return &ret +} + +// sleepWithContext 在指定的时间内等待, 如果 context 被取消则提前返回. +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +// getWithRetry 获取 GitHub API 并重试. +func (a *APIContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) { + backoff := Backoff{} + + for { + if err := ctx.Err(); err != nil { + return nil, err + } + + response, err := a.get(ctx, url) + + // non-2xx code does not cause error + if err != nil { + // 如果错误是速率限制错误, 则直接返回 + if errors.Is(err, ErrRateLimitExceeded) { + return nil, err + } + + // retry when error is not nil + p, retryAgain := backoff.Pause() + if !retryAgain { + return nil, errors.Wrap(err, "request failed") + } + utils.Log.Debugf("query github api error: %s, retry after %s", err, p) + + if err := sleepWithContext(ctx, p); err != nil { + return nil, err + } + continue + } + + // defensive check + if response == nil { + utils.Log.Errorf("query github api error: %s, will not retry", err) + return nil, errors.New("request failed: response is nil") + } + + if response.StatusCode == http.StatusOK { + return response, nil + } + + // We won't return the response to the caller here, but it's still better to read the response.Body completely even if we don't use it. + // see https://pkg.go.dev/net/http#Client.Do + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + if response.StatusCode >= 500 && response.StatusCode <= 599 { + // retry when server error + p, retryAgain := backoff.Pause() + if !retryAgain { + return nil, parseHTTPError(response.StatusCode, body) + } + utils.Log.Debugf("query github api error: status code %d, retry after %s", response.StatusCode, p) + + if err := sleepWithContext(ctx, p); err != nil { + return nil, err + } + continue + } + + return nil, parseHTTPError(response.StatusCode, body) + } +} + +// SetAuthHeader 为请求头添加 GitHub API 所需的认证头. +// 这是一个副作用函数, 会直接修改传入的 header. +func (a *APIContext) SetAuthHeader(header http.Header) { + header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) +} + +// get 获取 GitHub API. +func (a *APIContext) get(ctx context.Context, url string) (*http.Response, error) { + request, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + request.Header.Set("Accept", "application/vnd.github+json") + a.SetAuthHeader(request.Header) + + response, err := a.client.Do(request) + if err != nil { + return nil, err + } + + // 更新速率限制信息 + a.rateLimit = parseRateLimit(response.Header) + + // 如果剩余请求数为 0, 等待到重置时间 + if a.rateLimit.Remaining == 0 { + waitTime := time.Until(a.rateLimit.Reset) + utils.Log.Warnf("rate limit exceeded, will wait for %s", waitTime) + return nil, ErrRateLimitExceeded + } + + return response, nil +} + +// GetReleases 获取仓库信息. +func (a *APIContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) { + if perPage < 1 { + perPage = 30 + } + url := fmt.Sprintf("https://api.github.com/repos/%s/releases?per_page=%d", repo.UrlEncode(), perPage) + response, err := a.getWithRetry(ctx, url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + releases := []Release{} + err = utils.Json.Unmarshal(body, &releases) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal releases") + } + + tree := make([]model.Obj, 0, len(releases)) + for _, release := range releases { + tree = append(tree, &release) + } + return tree, nil +} + +// GetLatestRelease 获取最新 release. +func (a *APIContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode()) + response, err := a.getWithRetry(ctx, url) + if err != nil { + var githubErr *GitHubError + if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease + } + return nil, errors.Wrap(err, "get latest release") + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "read response body") + } + + if response.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease + } + + if response.StatusCode != http.StatusOK { + err := parseHTTPError(response.StatusCode, body) + var githubErr *GitHubError + if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease + } + return nil, err + } + + var release Release + if err := utils.Json.Unmarshal(body, &release); err != nil { + return nil, errors.Wrap(err, "unmarshal release data") + } + + release.SetLatestFlag(true) + return &release, nil +} + +// GetRelease 获取指定 tag 的 release. +func (a *APIContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id) + response, err := a.getWithRetry(ctx, url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + release := Release{} + err = utils.Json.Unmarshal(body, &release) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal release") + } + + return &release, nil +} + +// GetReleaseAsset 获取指定 tag 的 release 的 assets. +func (a *APIContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID) + response, err := a.getWithRetry(ctx, url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + asset := Asset{} + err = utils.Json.Unmarshal(body, &asset) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal asset") + } + + return &asset, nil +} + +var ( + ErrNoRelease = errors.New("no release found") +) diff --git a/drivers/github_release/github_test.go b/drivers/github_release/github_test.go new file mode 100644 index 00000000000..cc4b03dbff6 --- /dev/null +++ b/drivers/github_release/github_test.go @@ -0,0 +1,155 @@ +package github_release + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseRateLimit(t *testing.T) { + header := http.Header{} + header.Set("X-RateLimit-Limit", "60") + header.Set("X-RateLimit-Remaining", "59") + header.Set("X-RateLimit-Reset", "1735689600") // 2025-01-01 00:00:00 UTC + + rateLimit := parseRateLimit(header) + + assert.Equal(t, uint(60), rateLimit.Limit) + assert.Equal(t, uint(59), rateLimit.Remaining) + assert.Equal(t, time.Unix(1735689600, 0), rateLimit.Reset) +} + +func TestGitHubError(t *testing.T) { + err := &GitHubError{ + Message: "API rate limit exceeded", + StatusCode: 403, + } + + assert.Equal(t, "github api error: API rate limit exceeded (status: 403)", err.Error()) +} + +func TestNewAPIContext(t *testing.T) { + token := "test-token" + client := &http.Client{} + ctx := NewAPIContext(token, client) + + assert.Equal(t, token, ctx.token) + assert.Equal(t, GITHUB_API_VERSION, ctx.version) + assert.Equal(t, client, ctx.client) + assert.Equal(t, DEFAULT_TIMEOUT, ctx.defaultTimeout) +} + +func TestAPIContext_SetAuthHeader(t *testing.T) { + token := "test-token" + ctx := NewAPIContext(token, nil) + header := http.Header{} + + ctx.SetAuthHeader(header) + assert.Equal(t, "Bearer "+token, header.Get("Authorization")) +} + +func TestAPIContext_GetWithRetry_RateLimit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "0") + w.Header().Set("X-RateLimit-Reset", "1735689600") + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"message": "API rate limit exceeded"}`)) + })) + defer server.Close() + + ctx := NewAPIContext("test-token", server.Client()) + _, err := ctx.getWithRetry(context.Background(), server.URL) + + assert.ErrorIs(t, err, ErrRateLimitExceeded) +} + +type testRoundTripper struct { + handler http.HandlerFunc +} + +func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // 创建一个响应记录器 + w := httptest.NewRecorder() + // 调用处理函数 + t.handler.ServeHTTP(w, req) + // 将响应记录器转换为响应 + return w.Result(), nil +} + +func TestAPIContext_GetLatestRelease(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求路径 + assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path) + + // 验证请求头 + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept")) + + // 设置速率限制头部 + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "59") + w.Header().Set("X-RateLimit-Reset", "1735689600") + + // 设置响应头和内容 + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "id": 1, + "tag_name": "v1.0.0", + "name": "Release 1.0.0", + "published_at": "2025-01-01T00:00:00Z", + "created_at": "2025-01-01T00:00:00Z", + "assets": [] + }`)) + }) + + // 创建一个自定义的 HTTP 客户端 + client := &http.Client{ + Transport: &testRoundTripper{handler: handler}, + } + + ctx := NewAPIContext("test-token", client) + repo := repository{owner: "test-owner", name: "test-repo"} + release, err := ctx.GetLatestRelease(context.Background(), repo) + + if assert.NoError(t, err) { + assert.NotNil(t, release) + assert.Equal(t, "latest(v1.0.0)", release.GetName()) + } +} + +func TestAPIContext_GetLatestRelease_NoRelease(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求路径 + assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path) + + // 验证请求头 + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept")) + + // 设置速率限制头部 + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "59") + w.Header().Set("X-RateLimit-Reset", "1735689600") + + // 返回 404 状态码 + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{"message": "Not Found"}`)) + }) + + // 创建一个自定义的 HTTP 客户端 + client := &http.Client{ + Transport: &testRoundTripper{handler: handler}, + } + + ctx := NewAPIContext("test-token", client) + repo := repository{owner: "test-owner", name: "test-repo"} + _, err := ctx.GetLatestRelease(context.Background(), repo) + + assert.ErrorIs(t, err, ErrNoRelease) +} diff --git a/drivers/github_release/meta.go b/drivers/github_release/meta.go new file mode 100644 index 00000000000..00e495a32b1 --- /dev/null +++ b/drivers/github_release/meta.go @@ -0,0 +1,35 @@ +package github_release + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + // define other + Repo string `json:"repo" required:"true" default:"AlistGo/alist" help:"Repository name(owner/repo)"` + Token string `json:"token" required:"true" default:"" help:"Github personal access token"` + MaxReleases int `json:"max_releases" required:"true" type:"number" default:"30" help:"Max releases to list"` + ShowLatest bool `json:"show_latest" type:"bool" default:"true" help:"Show latest release on top"` +} + +var config = driver.Config{ + Name: "Github Release", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: true, + NeedMs: false, + DefaultRoot: "0", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &GithubRelease{} + }) +} diff --git a/drivers/github_release/types.go b/drivers/github_release/types.go new file mode 100644 index 00000000000..4751c135104 --- /dev/null +++ b/drivers/github_release/types.go @@ -0,0 +1,262 @@ +package github_release + +import ( + "fmt" + "net/url" + "regexp" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" +) + +type repository struct { + owner string + name string +} + +func newRepository(name string) (repository, error) { + parts := strings.Split(name, "/") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return repository{}, errors.New("repo name must be in the format of owner/repo") + } + + return repository{ + owner: parts[0], + name: parts[1], + }, nil +} + +func (r *repository) String() string { + return fmt.Sprintf("%s/%s", r.owner, r.name) +} + +func (r *repository) UrlEncode() string { + ownerPart := url.QueryEscape(r.owner) + namePart := url.QueryEscape(r.name) + return fmt.Sprintf("%s/%s", ownerPart, namePart) +} + +type Release struct { + URL string `json:"url"` + HTMLURL string `json:"html_url"` + AssetsURL string `json:"assets_url"` + UploadURL string `json:"upload_url"` + TarballURL string `json:"tarball_url"` + ZipballURL string `json:"zipball_url"` + ID int64 `json:"id"` + NodeID string `json:"node_id"` + TagName string `json:"tag_name"` + TargetCommitish string `json:"target_commitish"` + Name string `json:"name"` + Body string `json:"body"` + Draft bool `json:"draft"` + Prerelease bool `json:"prerelease"` + CreatedAt time.Time `json:"created_at"` + PublishedAt time.Time `json:"published_at"` + Author User `json:"author"` + Assets []Asset `json:"assets"` + BodyHTML string `json:"body_html"` + BodyText string `json:"body_text"` + MentionsCount int `json:"mentions_count"` + DiscussionURL string `json:"discussion_url"` + + latest bool +} + +func (r *Release) UnmarshalJSON(data []byte) error { + type alias Release + aux := struct { + CreatedAt string `json:"created_at"` + PublishedAt string `json:"published_at"` + *alias + }{ + alias: (*alias)(r), + } + + if err := utils.Json.Unmarshal(data, &aux); err != nil { + return errors.Wrap(err, "failed to unmarshal release") + } + + createdAt, err := time.Parse(time.RFC3339, aux.CreatedAt) + if err != nil { + utils.Log.Error("failed to parse created_at in release", "error", err) + createdAt = time.Time{} + } else { + r.CreatedAt = createdAt + } + + publishedAt, err := time.Parse(time.RFC3339, aux.PublishedAt) + if err != nil { + utils.Log.Error("failed to parse published_at in release", "error", err) + publishedAt = time.Time{} + } else { + r.PublishedAt = publishedAt + } + + return nil +} + +func (r *Release) GetSize() int64 { + return 0 +} + +func (r *Release) SetLatestFlag(flag bool) { + r.latest = flag +} + +func (r *Release) GetName() string { + if r.latest { + return fmt.Sprintf("latest(%s)", r.TagName) + } + return r.TagName +} + +func (r *Release) ModTime() time.Time { + return r.PublishedAt +} + +func (r *Release) CreateTime() time.Time { + return r.CreatedAt +} + +func (r *Release) IsDir() bool { + return true +} + +func (r *Release) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (r *Release) GetID() string { + return fmt.Sprintf("%d", r.ID) +} + +func (r *Release) GetPath() string { + return r.TagName +} + +func (r *Release) Children() ([]model.Obj, error) { + return utils.SliceConvert(r.Assets, func(src Asset) (model.Obj, error) { + return &src, nil + }) +} + +type Asset struct { + URL string `json:"url"` + BrowserDownloadURL string `json:"browser_download_url"` + ID int64 `json:"id"` + NodeID string `json:"node_id"` + Name string `json:"name"` + Label string `json:"label"` + State string `json:"state"` + ContentType string `json:"content_type"` + Size int64 `json:"size"` + DownloadCount int64 `json:"download_count"` + CreatedAt *time.Time `json:"created_at"` + UpdatedAt *time.Time `json:"updated_at"` + Uploader *User `json:"uploader"` +} + +func (a *Asset) UnmarshalJSON(data []byte) error { + type alias Asset + aux := struct { + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + *alias + }{ + alias: (*alias)(a), + } + + if err := utils.Json.Unmarshal(data, &aux); err != nil { + return errors.Wrap(err, "failed to unmarshal asset") + } + + createdAt, err := time.Parse(time.RFC3339, aux.CreatedAt) + if err != nil { + return errors.Wrap(err, "failed to parse created_at in asset") + } + + a.CreatedAt = &createdAt + + updatedAt, err := time.Parse(time.RFC3339, aux.UpdatedAt) + if err != nil { + return errors.Wrap(err, "failed to parse updated_at in asset") + } + + a.UpdatedAt = &updatedAt + + return nil +} +func (a *Asset) GetSize() (_ int64) { + return a.Size +} + +func (a *Asset) GetName() (_ string) { + return a.Name +} + +func (a *Asset) ModTime() (_ time.Time) { + if a.UpdatedAt == nil { + return time.Time{} + } + return *a.UpdatedAt +} + +func (a *Asset) CreateTime() (_ time.Time) { + if a.CreatedAt == nil { + return time.Time{} + } + return *a.CreatedAt +} + +func (a *Asset) IsDir() bool { + return false +} + +// GetHash 获取文件的哈希值. github release api 不提供哈希值 +func (a *Asset) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (a *Asset) GetID() string { + return fmt.Sprintf("%d", a.ID) +} + +// GetPath 获取路径. 通过解析 Asset.BrowserDownloadURL 获取 +func (a *Asset) GetPath() string { + pattern := `https://github.com/[^/]+/[^/]+/releases/download/([^/]+)/([^/]+)` + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(a.BrowserDownloadURL) + if len(matches) != 3 { + return "" + } + tag := matches[1] + assetName := matches[2] + return fmt.Sprintf("%s/%s", tag, assetName) +} + +type User struct { + Name string `json:"name"` + Email string `json:"email"` + Login string `json:"login"` + ID int64 `json:"id"` + NodeID string `json:"node_id"` + AvatarURL string `json:"avatar_url"` + GravatarID string `json:"gravatar_id"` + URL string `json:"url"` + HTMLURL string `json:"html_url"` + FollowersURL string `json:"followers_url"` + FollowingURL string `json:"following_url"` + GistsURL string `json:"gists_url"` + StarredURL string `json:"starred_url"` + SubscriptionsURL string `json:"subscriptions_url"` + OrganizationsURL string `json:"organizations_url"` + ReposURL string `json:"repos_url"` + EventsURL string `json:"events_url"` + ReceivedEventsURL string `json:"received_events_url"` + Type string `json:"type"` + SiteAdmin bool `json:"site_admin"` +} diff --git a/drivers/github_release/types_test.go b/drivers/github_release/types_test.go new file mode 100644 index 00000000000..9bc43bfdb4c --- /dev/null +++ b/drivers/github_release/types_test.go @@ -0,0 +1,477 @@ +package github_release + +import ( + "testing" + "time" + + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewRepository(t *testing.T) { + tests := []struct { + name string + input string + want repository + wantErr bool + }{ + { + name: "正常的仓库名称", + input: "alist-org/alist", + want: repository{ + owner: "alist-org", + name: "alist", + }, + wantErr: false, + }, + { + name: "缺少斜杠的仓库名称", + input: "alist-org", + want: repository{}, + wantErr: true, + }, + { + name: "空的所有者", + input: "/alist", + want: repository{}, + wantErr: true, + }, + { + name: "空的仓库名", + input: "alist-org/", + want: repository{}, + wantErr: true, + }, + { + name: "完全空的输入", + input: "", + want: repository{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newRepository(tt.input) + if tt.wantErr { + assert.Error(t, err) + assert.Equal(t, repository{}, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestRepository_String(t *testing.T) { + repo := repository{ + owner: "alist-org", + name: "alist", + } + assert.Equal(t, "alist-org/alist", repo.String()) +} + +func TestRepository_UrlEncode(t *testing.T) { + tests := []struct { + name string + repo repository + want string + }{ + { + name: "普通仓库名称", + repo: repository{ + owner: "alist-org", + name: "alist", + }, + want: "alist-org/alist", + }, + { + name: "包含特殊字符的仓库名称", + repo: repository{ + owner: "user name", + name: "repo name", + }, + want: "user+name/repo+name", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.repo.UrlEncode() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestRelease_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + want *Release + invalidDatetime bool + }{ + { + name: "正常的发布数据", + json: `{ + "url": "https://api.github.com/repos/alist-org/alist/releases/1", + "html_url": "https://github.com/alist-org/alist/releases/tag/v1.0.0", + "tag_name": "v1.0.0", + "name": "Release v1.0.0", + "body": "Release notes", + "created_at": "2023-01-01T12:00:00Z", + "published_at": "2023-01-01T12:30:00Z", + "author": { + "login": "test-user", + "id": 1 + } + }`, + want: &Release{ + URL: "https://api.github.com/repos/alist-org/alist/releases/1", + HTMLURL: "https://github.com/alist-org/alist/releases/tag/v1.0.0", + TagName: "v1.0.0", + Name: "Release v1.0.0", + Body: "Release notes", + Author: User{ + Login: "test-user", + ID: 1, + }, + }, + invalidDatetime: false, + }, + { + name: "无效的时间格式", + json: `{ + "created_at": "invalid-time", + "published_at": "invalid-time" + }`, + want: nil, + invalidDatetime: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var release Release + err := release.UnmarshalJSON([]byte(tt.json)) + if tt.invalidDatetime { + assert.True(t, release.CreatedAt.IsZero()) + assert.True(t, release.PublishedAt.IsZero()) + } else { + assert.NoError(t, err) + // 验证时间字段 + assert.Equal(t, 2023, release.CreatedAt.Year()) + assert.Equal(t, 2023, release.PublishedAt.Year()) + // 验证其他字段 + assert.Equal(t, tt.want.URL, release.URL) + assert.Equal(t, tt.want.HTMLURL, release.HTMLURL) + assert.Equal(t, tt.want.TagName, release.TagName) + assert.Equal(t, tt.want.Name, release.Name) + assert.Equal(t, tt.want.Body, release.Body) + assert.Equal(t, tt.want.Author.Login, release.Author.Login) + assert.Equal(t, tt.want.Author.ID, release.Author.ID) + } + }) + } +} + +func TestAsset_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + want *Asset + wantErr bool + }{ + { + name: "正常的资源数据", + json: `{ + "url": "https://api.github.com/repos/alist-org/alist/releases/assets/1", + "browser_download_url": "https://github.com/alist-org/alist/releases/download/v1.0.0/asset.zip", + "id": 1, + "name": "asset.zip", + "label": "Binary", + "state": "uploaded", + "content_type": "application/zip", + "size": 1024, + "download_count": 100, + "created_at": "2023-01-01T12:00:00Z", + "updated_at": "2023-01-01T12:30:00Z", + "uploader": { + "login": "test-user", + "id": 1 + } + }`, + want: &Asset{ + URL: "https://api.github.com/repos/alist-org/alist/releases/assets/1", + BrowserDownloadURL: "https://github.com/alist-org/alist/releases/download/v1.0.0/asset.zip", + ID: 1, + Name: "asset.zip", + Label: "Binary", + State: "uploaded", + ContentType: "application/zip", + Size: 1024, + DownloadCount: 100, + Uploader: &User{ + Login: "test-user", + ID: 1, + }, + }, + wantErr: false, + }, + { + name: "无效的时间格式", + json: `{ + "created_at": "invalid-time", + "updated_at": "2023-01-01T12:30:00Z" + }`, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var asset Asset + err := asset.UnmarshalJSON([]byte(tt.json)) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + // 验证时间字段 + assert.Equal(t, 2023, asset.CreatedAt.Year()) + assert.Equal(t, 2023, asset.UpdatedAt.Year()) + // 验证其他字段 + assert.Equal(t, tt.want.URL, asset.URL) + assert.Equal(t, tt.want.BrowserDownloadURL, asset.BrowserDownloadURL) + assert.Equal(t, tt.want.ID, asset.ID) + assert.Equal(t, tt.want.Name, asset.Name) + assert.Equal(t, tt.want.Label, asset.Label) + assert.Equal(t, tt.want.State, asset.State) + assert.Equal(t, tt.want.ContentType, asset.ContentType) + assert.Equal(t, tt.want.Size, asset.Size) + assert.Equal(t, tt.want.DownloadCount, asset.DownloadCount) + assert.Equal(t, tt.want.Uploader.Login, asset.Uploader.Login) + assert.Equal(t, tt.want.Uploader.ID, asset.Uploader.ID) + } + }) + } +} + +func TestReleases_UnmarshalJSON(t *testing.T) { + jsonData := `[ + { + "url": "https://api.github.com/repos/AlistGo/alist/releases/170718825", + "assets_url": "https://api.github.com/repos/AlistGo/alist/releases/170718825/assets", + "upload_url": "https://uploads.github.com/repos/AlistGo/alist/releases/170718825/assets{?name,label}", + "html_url": "https://github.com/AlistGo/alist/releases/tag/beta", + "id": 170718825, + "author": { + "login": "xhofe", + "id": 36558727, + "node_id": "MDQ6VXNlcjM2NTU4NzI3", + "avatar_url": "https://avatars.githubusercontent.com/u/36558727?v=4", + "url": "https://api.github.com/users/xhofe", + "html_url": "https://github.com/xhofe", + "type": "User", + "site_admin": false + }, + "node_id": "RE_kwDOE09S284KLPZp", + "tag_name": "beta", + "target_commitish": "main", + "name": "AList Beta Version", + "draft": false, + "prerelease": true, + "created_at": "2025-01-18T15:52:02Z", + "published_at": "2024-08-17T14:10:08Z", + "assets": [ + { + "url": "https://api.github.com/repos/AlistGo/alist/releases/assets/221414212", + "id": 221414212, + "name": "alist-android-386.tar.gz", + "content_type": "application/gzip", + "state": "uploaded", + "size": 31186443, + "download_count": 6, + "created_at": "2025-01-18T15:58:55Z", + "updated_at": "2025-01-18T15:58:56Z", + "browser_download_url": "https://github.com/AlistGo/alist/releases/download/beta/alist-android-386.tar.gz", + "uploader": { + "login": "github-actions[bot]", + "id": 41898282, + "type": "Bot", + "site_admin": false + } + }, + { + "url": "https://api.github.com/repos/AlistGo/alist/releases/assets/221414214", + "id": 221414214, + "name": "alist-android-amd64.tar.gz", + "content_type": "application/gzip", + "state": "uploaded", + "size": 31586093, + "download_count": 10, + "created_at": "2025-01-18T15:58:55Z", + "updated_at": "2025-01-18T15:58:56Z", + "browser_download_url": "https://github.com/AlistGo/alist/releases/download/beta/alist-android-amd64.tar.gz", + "uploader": { + "login": "github-actions[bot]", + "id": 41898282, + "type": "Bot", + "site_admin": false + } + } + ], + "body": "Test text" + } + ]` + + var releases []Release + err := utils.Json.Unmarshal([]byte(jsonData), &releases) + assert.NoError(t, err) + assert.Len(t, releases, 1) + + release := releases[0] + // 验证 Release 基本信息 + assert.Equal(t, int64(170718825), release.ID) + assert.Equal(t, "beta", release.TagName) + assert.Equal(t, "AList Beta Version", release.Name) + assert.Equal(t, "Test text", release.Body) + assert.False(t, release.Draft) + assert.True(t, release.Prerelease) + + // 验证时间 + assert.Equal(t, 2025, release.CreatedAt.Year()) + assert.Equal(t, 2024, release.PublishedAt.Year()) + + // 验证作者信息 + assert.Equal(t, "xhofe", release.Author.Login) + assert.Equal(t, int64(36558727), release.Author.ID) + assert.Equal(t, "User", release.Author.Type) + + // 验证资源信息 + assert.Len(t, release.Assets, 2) + + // 验证第一个资源 + asset1 := release.Assets[0] + assert.Equal(t, int64(221414212), asset1.ID) + assert.Equal(t, "alist-android-386.tar.gz", asset1.Name) + assert.Equal(t, "application/gzip", asset1.ContentType) + assert.Equal(t, int64(31186443), asset1.Size) + assert.Equal(t, int64(6), asset1.DownloadCount) + assert.Equal(t, "uploaded", asset1.State) + assert.Equal(t, "https://github.com/AlistGo/alist/releases/download/beta/alist-android-386.tar.gz", asset1.BrowserDownloadURL) + + // 验证第一个资源的上传者 + assert.Equal(t, "github-actions[bot]", asset1.Uploader.Login) + assert.Equal(t, int64(41898282), asset1.Uploader.ID) + assert.Equal(t, "Bot", asset1.Uploader.Type) + + // 验证第二个资源 + asset2 := release.Assets[1] + assert.Equal(t, int64(221414214), asset2.ID) + assert.Equal(t, "alist-android-amd64.tar.gz", asset2.Name) + assert.Equal(t, int64(31586093), asset2.Size) + assert.Equal(t, int64(10), asset2.DownloadCount) +} + +func TestRelease_InterfaceMethods(t *testing.T) { + release := &Release{ + ID: 123, + TagName: "v1.0.0", + CreatedAt: time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), + PublishedAt: time.Date(2023, 1, 2, 0, 0, 0, 0, time.UTC), + Assets: []Asset{ + {Name: "asset1.zip"}, + {Name: "asset2.tar.gz"}, + }, + } + + // 测试基本方法 + t.Run("basic methods", func(t *testing.T) { + assert.Equal(t, int64(0), release.GetSize()) + assert.Equal(t, "v1.0.0", release.GetName()) + assert.Equal(t, time.Date(2023, 1, 2, 0, 0, 0, 0, time.UTC), release.ModTime()) + assert.Equal(t, time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), release.CreateTime()) + assert.True(t, release.IsDir()) + assert.Equal(t, utils.HashInfo{}, release.GetHash()) + assert.Equal(t, "123", release.GetID()) + assert.Equal(t, "v1.0.0", release.GetPath()) + }) + + // 测试 Children 方法 + t.Run("children", func(t *testing.T) { + children, err := release.Children() + assert.NoError(t, err) + assert.Len(t, children, 2) + assert.Equal(t, "asset1.zip", children[0].GetName()) + assert.Equal(t, "asset2.tar.gz", children[1].GetName()) + }) +} + +func TestAsset_InterfaceMethods(t *testing.T) { + now := time.Now() + asset := &Asset{ + ID: 456, + Name: "test.zip", + Size: 12345, + CreatedAt: &now, + UpdatedAt: &now, + BrowserDownloadURL: "https://github.com/owner/repo/releases/download/v1.0.0/test.zip", + } + + t.Run("basic methods", func(t *testing.T) { + assert.Equal(t, int64(12345), asset.GetSize()) + assert.Equal(t, "test.zip", asset.GetName()) + assert.Equal(t, now, asset.ModTime()) + assert.Equal(t, now, asset.CreateTime()) + assert.False(t, asset.IsDir()) + assert.Equal(t, utils.HashInfo{}, asset.GetHash()) + assert.Equal(t, "456", asset.GetID()) + }) + + // 测试空时间的情况 + t.Run("nil time fields", func(t *testing.T) { + emptyAsset := &Asset{} + assert.Equal(t, time.Time{}, emptyAsset.ModTime()) + assert.Equal(t, time.Time{}, emptyAsset.CreateTime()) + }) +} + +func TestAsset_GetPath(t *testing.T) { + tests := []struct { + name string + browserDownloadURL string + want string + }{ + { + name: "valid url", + browserDownloadURL: "https://github.com/owner/repo/releases/download/v1.0.0/test.zip", + want: "v1.0.0/test.zip", + }, + { + name: "invalid url format", + browserDownloadURL: "https://github.com/invalid/url", + want: "", + }, + { + name: "empty url", + browserDownloadURL: "", + want: "", + }, + { + name: "url with special characters", + browserDownloadURL: "https://github.com/owner/repo/releases/download/v1.0.0-beta/test-file_1.2.3.zip", + want: "v1.0.0-beta/test-file_1.2.3.zip", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + asset := &Asset{ + BrowserDownloadURL: tt.browserDownloadURL, + } + got := asset.GetPath() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/drivers/github_release/util.go b/drivers/github_release/util.go new file mode 100644 index 00000000000..eb1164d7976 --- /dev/null +++ b/drivers/github_release/util.go @@ -0,0 +1 @@ +package github_release