|
5 | 5 | "io"
|
6 | 6 | "net/http"
|
7 | 7 | "net/http/httptest"
|
| 8 | + "net/url" |
8 | 9 | "testing"
|
9 | 10 | "time"
|
10 | 11 | )
|
@@ -154,3 +155,51 @@ func TestExpiredWithExpiry(t *testing.T) {
|
154 | 155 | func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server {
|
155 | 156 | return httptest.NewServer(http.HandlerFunc(handler))
|
156 | 157 | }
|
| 158 | + |
| 159 | +// TestTransportWithNilHeader tests that the Transport.RoundTrip method |
| 160 | +// correctly handles requests with nil Headers. |
| 161 | +func TestTransportWithNilHeader(t *testing.T) { |
| 162 | + // Create a mock token source that returns a fixed token |
| 163 | + tokenSource := StaticTokenSource(&Token{ |
| 164 | + AccessToken: "test-access-token", |
| 165 | + TokenType: "Bearer", |
| 166 | + }) |
| 167 | + |
| 168 | + // Create a mock http server to verify the request |
| 169 | + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 170 | + // Check that the Authorization header was correctly set |
| 171 | + authHeader := r.Header.Get("Authorization") |
| 172 | + expectedHeader := "Bearer test-access-token" |
| 173 | + if authHeader != expectedHeader { |
| 174 | + t.Errorf("expected authorization header %q, got %q", expectedHeader, authHeader) |
| 175 | + } |
| 176 | + w.WriteHeader(http.StatusOK) |
| 177 | + })) |
| 178 | + defer server.Close() |
| 179 | + |
| 180 | + // Create Transport with our token source |
| 181 | + transport := &Transport{ |
| 182 | + Source: tokenSource, |
| 183 | + Base: http.DefaultTransport, |
| 184 | + } |
| 185 | + |
| 186 | + // Create a request with nil Header |
| 187 | + reqURL, _ := url.Parse(server.URL) |
| 188 | + req := &http.Request{ |
| 189 | + Method: "GET", |
| 190 | + URL: reqURL, |
| 191 | + // Header is intentionally nil |
| 192 | + } |
| 193 | + |
| 194 | + // Make the request using our Transport |
| 195 | + resp, err := transport.RoundTrip(req) |
| 196 | + if err != nil { |
| 197 | + t.Fatalf("roundTrip failed with nil Header: %v", err) |
| 198 | + } |
| 199 | + defer resp.Body.Close() |
| 200 | + |
| 201 | + // Verify response status |
| 202 | + if resp.StatusCode != http.StatusOK { |
| 203 | + t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) |
| 204 | + } |
| 205 | +} |
0 commit comments