diff --git a/apps.go b/apps.go index 01b976e..6d7552f 100644 --- a/apps.go +++ b/apps.go @@ -55,6 +55,7 @@ func RegisterApp(ctx context.Context, appConfig *AppConfig) (*Application, error if err != nil { return nil, err } + req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := appConfig.Do(req) if err != nil { diff --git a/apps_test.go b/apps_test.go index 177440e..7848e0e 100644 --- a/apps_test.go +++ b/apps_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" ) func TestRegisterApp(t *testing.T) { @@ -41,3 +42,25 @@ func TestRegisterApp(t *testing.T) { t.Fatalf("want %q but %q", "bar", app.ClientSecret) } } + +func TestRegisterAppWithCancel(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + fmt.Fprintln(w, `{"client_id": "foo", "client_secret": "bar"}`) + return + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + go cancel() + _, err := RegisterApp(ctx, &AppConfig{ + Server: ts.URL, + Scopes: "read write follow", + }) + if err == nil { + t.Fatalf("should be fail: %v", err) + } + if want := "Post " + ts.URL + "/api/v1/apps: context canceled"; want != err.Error() { + t.Fatalf("want %q but %q", want, err.Error()) + } +} diff --git a/mastodon.go b/mastodon.go index b7c8824..df97f22 100644 --- a/mastodon.go +++ b/mastodon.go @@ -129,7 +129,7 @@ func (c *Client) doAPI(ctx context.Context, method string, uri string, params in } else { req, err = http.NewRequest(method, u.String(), nil) } - req.WithContext(ctx) + req = req.WithContext(ctx) req.Header.Set("Authorization", "Bearer "+c.config.AccessToken) if params != nil { req.Header.Set("Content-Type", ct) @@ -191,6 +191,7 @@ func (c *Client) Authenticate(ctx context.Context, username, password string) er if err != nil { return err } + req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := c.Do(req) if err != nil { diff --git a/mastodon_test.go b/mastodon_test.go index 1dfb53b..69c93c4 100644 --- a/mastodon_test.go +++ b/mastodon_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "reflect" "testing" + "time" ) func TestAuthenticate(t *testing.T) { @@ -42,6 +43,29 @@ func TestAuthenticate(t *testing.T) { } } +func TestAuthenticateWithCancel(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + return + })) + defer ts.Close() + + client := NewClient(&Config{ + Server: ts.URL, + ClientID: "foo", + ClientSecret: "bar", + }) + ctx, cancel := context.WithCancel(context.Background()) + go cancel() + err := client.Authenticate(ctx, "invalid", "user") + if err == nil { + t.Fatalf("should be fail: %v", err) + } + if want := "Post " + ts.URL + "/oauth/token: context canceled"; want != err.Error() { + t.Fatalf("want %q but %q", want, err.Error()) + } +} + func TestPostStatus(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer zoo" { @@ -79,6 +103,31 @@ func TestPostStatus(t *testing.T) { } } +func TestPostStatusWithCancel(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + return + })) + defer ts.Close() + + client := NewClient(&Config{ + Server: ts.URL, + ClientID: "foo", + ClientSecret: "bar", + }) + ctx, cancel := context.WithCancel(context.Background()) + go cancel() + _, err := client.PostStatus(ctx, &Toot{ + Status: "foobar", + }) + if err == nil { + t.Fatalf("should be fail: %v", err) + } + if want := "Post " + ts.URL + "/api/v1/statuses: context canceled"; want != err.Error() { + t.Fatalf("want %q but %q", want, err.Error()) + } +} + func TestGetTimelineHome(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, `[{"Content": "foo"}, {"Content": "bar"}]`) @@ -119,6 +168,30 @@ func TestGetTimelineHome(t *testing.T) { } } +func TestGetTimelineHomeWithCancel(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + return + })) + defer ts.Close() + + client := NewClient(&Config{ + Server: ts.URL, + ClientID: "foo", + ClientSecret: "bar", + AccessToken: "zoo", + }) + ctx, cancel := context.WithCancel(context.Background()) + go cancel() + _, err := client.GetTimelineHome(ctx) + if err == nil { + t.Fatalf("should be fail: %v", err) + } + if want := "Get " + ts.URL + "/api/v1/timelines/home: context canceled"; want != err.Error() { + t.Fatalf("want %q but %q", want, err.Error()) + } +} + func TestForTheCoverages(t *testing.T) { (*UpdateEvent)(nil).event() (*NotificationEvent)(nil).event() diff --git a/streaming.go b/streaming.go index 3b4007a..2a5295d 100644 --- a/streaming.go +++ b/streaming.go @@ -89,6 +89,7 @@ func (c *Client) streaming(ctx context.Context, p string, params url.Values) (ch } req, err := http.NewRequest(http.MethodGet, u.String(), in) if err == nil { + req = req.WithContext(ctx) req.Header.Set("Authorization", "Bearer "+c.config.AccessToken) resp, err = c.Do(req) if resp != nil && resp.StatusCode != http.StatusOK { @@ -114,7 +115,6 @@ func (c *Client) streaming(ctx context.Context, p string, params url.Values) (ch } }() return q, nil - } // StreamingPublic return channel to read events on public.