updated dependencies
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Paul 2024-04-28 13:08:41 +02:00
parent d7738f79e1
commit fa74f7ad4c
422 changed files with 4629 additions and 4353 deletions

View File

@ -1,5 +1,77 @@
# Changelog # Changelog
## v4.12.0 - 2024-04-15
**Security**
* Update golang.org/x/net dep because of [GO-2024-2687](https://pkg.go.dev/vuln/GO-2024-2687) by @aldas in https://github.com/labstack/echo/pull/2625
**Enhancements**
* binder: make binding to Map work better with string destinations by @aldas in https://github.com/labstack/echo/pull/2554
* README.md: add Encore as sponsor by @marcuskohlberg in https://github.com/labstack/echo/pull/2579
* Reorder paragraphs in README.md by @aldas in https://github.com/labstack/echo/pull/2581
* CI: upgrade actions/checkout to v4 by @aldas in https://github.com/labstack/echo/pull/2584
* Remove default charset from 'application/json' Content-Type header by @doortts in https://github.com/labstack/echo/pull/2568
* CI: Use Go 1.22 by @aldas in https://github.com/labstack/echo/pull/2588
* binder: allow binding to a nil map by @georgmu in https://github.com/labstack/echo/pull/2574
* Add Skipper Unit Test In BasicBasicAuthConfig and Add More Detail Explanation regarding BasicAuthValidator by @RyoKusnadi in https://github.com/labstack/echo/pull/2461
* fix some typos by @teslaedison in https://github.com/labstack/echo/pull/2603
* fix: some typos by @pomadev in https://github.com/labstack/echo/pull/2596
* Allow ResponseWriters to unwrap writers when flushing/hijacking by @aldas in https://github.com/labstack/echo/pull/2595
* Add SPDX licence comments to files. by @aldas in https://github.com/labstack/echo/pull/2604
* Upgrade deps by @aldas in https://github.com/labstack/echo/pull/2605
* Change type definition blocks to single declarations. This helps copy… by @aldas in https://github.com/labstack/echo/pull/2606
* Fix Real IP logic by @cl-bvl in https://github.com/labstack/echo/pull/2550
* Default binder can use `UnmarshalParams(params []string) error` inter… by @aldas in https://github.com/labstack/echo/pull/2607
* Default binder can bind pointer to slice as struct field. For example `*[]string` by @aldas in https://github.com/labstack/echo/pull/2608
* Remove maxparam dependence from Context by @aldas in https://github.com/labstack/echo/pull/2611
* When route is registered with empty path it is normalized to `/`. by @aldas in https://github.com/labstack/echo/pull/2616
* proxy middleware should use httputil.ReverseProxy for SSE requests by @aldas in https://github.com/labstack/echo/pull/2624
## v4.11.4 - 2023-12-20
**Security**
* Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability [issue](https://pkg.go.dev/vuln/GO-2023-2402) [#2562](https://github.com/labstack/echo/pull/2562)
**Enhancements**
* Update deps and mark Go version to 1.18 as this is what golang.org/x/* use [#2563](https://github.com/labstack/echo/pull/2563)
* Request logger: add example for Slog https://pkg.go.dev/log/slog [#2543](https://github.com/labstack/echo/pull/2543)
## v4.11.3 - 2023-11-07
**Security**
* 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541)
**Enhancements**
* Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540)
* Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537)
* Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536)
## v4.11.2 - 2023-10-11
**Security**
* Bump golang.org/x/net to prevent CVE-2023-39325 / CVE-2023-44487 HTTP/2 Rapid Reset Attack [#2527](https://github.com/labstack/echo/pull/2527)
* fix(sec): randomString bias introduced by #2490 [#2492](https://github.com/labstack/echo/pull/2492)
* CSRF/RequestID mw: switch math/random usage to crypto/random [#2490](https://github.com/labstack/echo/pull/2490)
**Enhancements**
* Delete unused context in body_limit.go [#2483](https://github.com/labstack/echo/pull/2483)
* Use Go 1.21 in CI [#2505](https://github.com/labstack/echo/pull/2505)
* Fix some typos [#2511](https://github.com/labstack/echo/pull/2511)
* Allow CORS middleware to send Access-Control-Max-Age: 0 [#2518](https://github.com/labstack/echo/pull/2518)
* Bump dependancies [#2522](https://github.com/labstack/echo/pull/2522)
## v4.11.1 - 2023-07-16 ## v4.11.1 - 2023-07-16
**Fixes** **Fixes**

View File

@ -31,6 +31,6 @@ benchmark: ## Run benchmarks
help: ## Display this help screen help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
goversion ?= "1.17" goversion ?= "1.19"
test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17 test_version: ## Run tests inside Docker with given version (defaults to 1.19 oldest supported). Example: make test_version goversion=1.19
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"

View File

@ -3,26 +3,24 @@
[![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge)
[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4)
[![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo)
[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) [![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions)
[![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo)
[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions)
[![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack)
[![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo/master/LICENSE) [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo/master/LICENSE)
## Supported Go versions ## Echo
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with High performance, extensible, minimalist Go web framework.
older versions.
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). * [Official website](https://echo.labstack.com)
Therefore a Go version capable of understanding /vN suffixed imports is required: * [Quick start](https://echo.labstack.com/docs/quick-start)
* [Middlewares](https://echo.labstack.com/docs/category/middleware)
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions)
way of using Echo going forward.
For older versions, please use the latest v3 tag.
## Feature Overview ### Feature Overview
- Optimized HTTP router which smartly prioritize routes - Optimized HTTP router which smartly prioritize routes
- Build robust and scalable RESTful APIs - Build robust and scalable RESTful APIs
@ -38,6 +36,18 @@ For older versions, please use the latest v3 tag.
- Automatic TLS via Lets Encrypt - Automatic TLS via Lets Encrypt
- HTTP/2 support - HTTP/2 support
## Sponsors
<div>
<a href="https://encore.dev" style="display: inline-flex; align-items: center; gap: 10px">
<img src="https://user-images.githubusercontent.com/78424526/214602214-52e0483a-b5fc-4d4c-b03e-0b7b23e012df.svg" height="28px" alt="encore icon"></img>
<b>Encore the platform for building Go-based cloud backends</b>
</a>
</div>
<br/>
Click [here](https://github.com/sponsors/labstack) for more information on sponsorship.
## Benchmarks ## Benchmarks
Date: 2020/11/11<br> Date: 2020/11/11<br>
@ -57,6 +67,7 @@ The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
// go get github.com/labstack/echo/{version} // go get github.com/labstack/echo/{version}
go get github.com/labstack/echo/v4 go get github.com/labstack/echo/v4
``` ```
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions.
### Example ### Example
@ -117,10 +128,6 @@ of middlewares in this list.
Please send a PR to add your own library here. Please send a PR to add your own library here.
## Help
- [Forum](https://github.com/labstack/echo/discussions)
## Contribute ## Contribute
**Use issues for everything** **Use issues for everything**

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (
@ -11,23 +14,28 @@ import (
"strings" "strings"
) )
type ( // Binder is the interface that wraps the Bind method.
// Binder is the interface that wraps the Bind method. type Binder interface {
Binder interface { Bind(i interface{}, c Context) error
Bind(i interface{}, c Context) error }
}
// DefaultBinder is the default implementation of the Binder interface. // DefaultBinder is the default implementation of the Binder interface.
DefaultBinder struct{} type DefaultBinder struct{}
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. // BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
// Types that don't implement this, but do implement encoding.TextUnmarshaler // Types that don't implement this, but do implement encoding.TextUnmarshaler
// will use that interface instead. // will use that interface instead.
BindUnmarshaler interface { type BindUnmarshaler interface {
// UnmarshalParam decodes and assigns a value from an form or query param. // UnmarshalParam decodes and assigns a value from an form or query param.
UnmarshalParam(param string) error UnmarshalParam(param string) error
} }
)
// bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to
// type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case
// for `a` following slice `["1", "2"] will be passed to unmarshaller.
type bindMultipleUnmarshaler interface {
UnmarshalParams(params []string) error
}
// BindPathParams binds path params to bindable object // BindPathParams binds path params to bindable object
func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error {
@ -131,10 +139,29 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
typ := reflect.TypeOf(destination).Elem() typ := reflect.TypeOf(destination).Elem()
val := reflect.ValueOf(destination).Elem() val := reflect.ValueOf(destination).Elem()
// Map // Support binding to limited Map destinations:
if typ.Kind() == reflect.Map { // - map[string][]string,
// - map[string]string <-- (binds first value from data slice)
// - map[string]interface{}
// You are better off binding to struct but there are user who want this map feature. Source of data for these cases are:
// params,query,header,form as these sources produce string values, most of the time slice of strings, actually.
if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String {
k := typ.Elem().Kind()
isElemInterface := k == reflect.Interface
isElemString := k == reflect.String
isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String
if !(isElemSliceOfStrings || isElemString || isElemInterface) {
return nil
}
if val.IsNil() {
val.Set(reflect.MakeMap(typ))
}
for k, v := range data { for k, v := range data {
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) if isElemString {
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
} else {
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v))
}
} }
return nil return nil
} }
@ -161,14 +188,14 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
} }
structFieldKind := structField.Kind() structFieldKind := structField.Kind()
inputFieldName := typeField.Tag.Get(tag) inputFieldName := typeField.Tag.Get(tag)
if typeField.Anonymous && structField.Kind() == reflect.Struct && inputFieldName != "" { if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" {
// if anonymous struct with query/param/form tags, report an error // if anonymous struct with query/param/form tags, report an error
return errors.New("query/param/form tags are not allowed with anonymous struct field") return errors.New("query/param/form tags are not allowed with anonymous struct field")
} }
if inputFieldName == "" { if inputFieldName == "" {
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags).
// structs that implement BindUnmarshaler are binded only when they have explicit tag // structs that implement BindUnmarshaler are bound only when they have explicit tag
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
return err return err
@ -197,27 +224,46 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
continue continue
} }
// Call this first, in case we're dealing with an alias to an array type // NOTE: algorithm here is not particularly sophisticated. It probably does not work with absurd types like `**[]*int`
if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok { // but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` .
// try unmarshalling first, in case we're dealing with an alias to an array type
if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok {
if err != nil { if err != nil {
return err return err
} }
continue continue
} }
numElems := len(inputValue) if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField); ok {
if structFieldKind == reflect.Slice && numElems > 0 { if err != nil {
return err
}
continue
}
// we could be dealing with pointer to slice `*[]string` so dereference it. There are wierd OpenAPI generators
// that could create struct fields like that.
if structFieldKind == reflect.Pointer {
structFieldKind = structField.Elem().Kind()
structField = structField.Elem()
}
if structFieldKind == reflect.Slice {
sliceOf := structField.Type().Elem().Kind() sliceOf := structField.Type().Elem().Kind()
numElems := len(inputValue)
slice := reflect.MakeSlice(structField.Type(), numElems, numElems) slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for j := 0; j < numElems; j++ { for j := 0; j < numElems; j++ {
if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil { if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
return err return err
} }
} }
val.Field(i).Set(slice) structField.Set(slice)
} else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { continue
return err }
if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil {
return err
} }
} }
return nil return nil
@ -225,7 +271,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
// But also call it here, in case we're dealing with an array of BindUnmarshalers // But also call it here, in case we're dealing with an array of BindUnmarshalers
if ok, err := unmarshalField(valueKind, val, structField); ok { if ok, err := unmarshalInputToField(valueKind, val, structField); ok {
return err return err
} }
@ -266,35 +312,41 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V
return nil return nil
} }
func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) { func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) {
switch valueKind { if valueKind == reflect.Ptr {
case reflect.Ptr: if field.IsNil() {
return unmarshalFieldPtr(val, field) field.Set(reflect.New(field.Type().Elem()))
default: }
return unmarshalFieldNonPtr(val, field) field = field.Elem()
} }
fieldIValue := field.Addr().Interface()
unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler)
if !ok {
return false, nil
}
return true, unmarshaler.UnmarshalParams(values)
} }
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
fieldIValue := field.Addr().Interface() if valueKind == reflect.Ptr {
if unmarshaler, ok := fieldIValue.(BindUnmarshaler); ok { if field.IsNil() {
return true, unmarshaler.UnmarshalParam(value) field.Set(reflect.New(field.Type().Elem()))
}
field = field.Elem()
} }
if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok {
return true, unmarshaler.UnmarshalText([]byte(value)) fieldIValue := field.Addr().Interface()
switch unmarshaler := fieldIValue.(type) {
case BindUnmarshaler:
return true, unmarshaler.UnmarshalParam(val)
case encoding.TextUnmarshaler:
return true, unmarshaler.UnmarshalText([]byte(val))
} }
return false, nil return false, nil
} }
func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
if field.IsNil() {
// Initialize the pointer to a nil value
field.Set(reflect.New(field.Type().Elem()))
}
return unmarshalFieldNonPtr(value, field.Elem())
}
func setIntField(value string, bitSize int, field reflect.Value) error { func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" { if value == "" {
value = "0" value = "0"

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (
@ -1323,7 +1326,7 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi
case time.Second: case time.Second:
*dest = time.Unix(n, 0) *dest = time.Unix(n, 0)
case time.Millisecond: case time.Millisecond:
*dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows *dest = time.UnixMilli(n)
case time.Nanosecond: case time.Nanosecond:
*dest = time.Unix(0, n) *dest = time.Unix(0, n)
} }

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (
@ -13,204 +16,216 @@ import (
"sync" "sync"
) )
type ( // Context represents the context of the current HTTP request. It holds request and
// Context represents the context of the current HTTP request. It holds request and // response objects, path, path parameters, data and registered handler.
// response objects, path, path parameters, data and registered handler. type Context interface {
Context interface { // Request returns `*http.Request`.
// Request returns `*http.Request`. Request() *http.Request
Request() *http.Request
// SetRequest sets `*http.Request`. // SetRequest sets `*http.Request`.
SetRequest(r *http.Request) SetRequest(r *http.Request)
// SetResponse sets `*Response`. // SetResponse sets `*Response`.
SetResponse(r *Response) SetResponse(r *Response)
// Response returns `*Response`. // Response returns `*Response`.
Response() *Response Response() *Response
// IsTLS returns true if HTTP connection is TLS otherwise false. // IsTLS returns true if HTTP connection is TLS otherwise false.
IsTLS() bool IsTLS() bool
// IsWebSocket returns true if HTTP connection is WebSocket otherwise false. // IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
IsWebSocket() bool IsWebSocket() bool
// Scheme returns the HTTP protocol scheme, `http` or `https`. // Scheme returns the HTTP protocol scheme, `http` or `https`.
Scheme() string Scheme() string
// RealIP returns the client's network address based on `X-Forwarded-For` // RealIP returns the client's network address based on `X-Forwarded-For`
// or `X-Real-IP` request header. // or `X-Real-IP` request header.
// The behavior can be configured using `Echo#IPExtractor`. // The behavior can be configured using `Echo#IPExtractor`.
RealIP() string RealIP() string
// Path returns the registered path for the handler. // Path returns the registered path for the handler.
Path() string Path() string
// SetPath sets the registered path for the handler. // SetPath sets the registered path for the handler.
SetPath(p string) SetPath(p string)
// Param returns path parameter by name. // Param returns path parameter by name.
Param(name string) string Param(name string) string
// ParamNames returns path parameter names. // ParamNames returns path parameter names.
ParamNames() []string ParamNames() []string
// SetParamNames sets path parameter names. // SetParamNames sets path parameter names.
SetParamNames(names ...string) SetParamNames(names ...string)
// ParamValues returns path parameter values. // ParamValues returns path parameter values.
ParamValues() []string ParamValues() []string
// SetParamValues sets path parameter values. // SetParamValues sets path parameter values.
SetParamValues(values ...string) SetParamValues(values ...string)
// QueryParam returns the query param for the provided name. // QueryParam returns the query param for the provided name.
QueryParam(name string) string QueryParam(name string) string
// QueryParams returns the query parameters as `url.Values`. // QueryParams returns the query parameters as `url.Values`.
QueryParams() url.Values QueryParams() url.Values
// QueryString returns the URL query string. // QueryString returns the URL query string.
QueryString() string QueryString() string
// FormValue returns the form field value for the provided name. // FormValue returns the form field value for the provided name.
FormValue(name string) string FormValue(name string) string
// FormParams returns the form parameters as `url.Values`. // FormParams returns the form parameters as `url.Values`.
FormParams() (url.Values, error) FormParams() (url.Values, error)
// FormFile returns the multipart form file for the provided name. // FormFile returns the multipart form file for the provided name.
FormFile(name string) (*multipart.FileHeader, error) FormFile(name string) (*multipart.FileHeader, error)
// MultipartForm returns the multipart form. // MultipartForm returns the multipart form.
MultipartForm() (*multipart.Form, error) MultipartForm() (*multipart.Form, error)
// Cookie returns the named cookie provided in the request. // Cookie returns the named cookie provided in the request.
Cookie(name string) (*http.Cookie, error) Cookie(name string) (*http.Cookie, error)
// SetCookie adds a `Set-Cookie` header in HTTP response. // SetCookie adds a `Set-Cookie` header in HTTP response.
SetCookie(cookie *http.Cookie) SetCookie(cookie *http.Cookie)
// Cookies returns the HTTP cookies sent with the request. // Cookies returns the HTTP cookies sent with the request.
Cookies() []*http.Cookie Cookies() []*http.Cookie
// Get retrieves data from the context. // Get retrieves data from the context.
Get(key string) interface{} Get(key string) interface{}
// Set saves data in the context. // Set saves data in the context.
Set(key string, val interface{}) Set(key string, val interface{})
// Bind binds path params, query params and the request body into provided type `i`. The default binder // Bind binds path params, query params and the request body into provided type `i`. The default binder
// binds body based on Content-Type header. // binds body based on Content-Type header.
Bind(i interface{}) error Bind(i interface{}) error
// Validate validates provided `i`. It is usually called after `Context#Bind()`. // Validate validates provided `i`. It is usually called after `Context#Bind()`.
// Validator must be registered using `Echo#Validator`. // Validator must be registered using `Echo#Validator`.
Validate(i interface{}) error Validate(i interface{}) error
// Render renders a template with data and sends a text/html response with status // Render renders a template with data and sends a text/html response with status
// code. Renderer must be registered using `Echo.Renderer`. // code. Renderer must be registered using `Echo.Renderer`.
Render(code int, name string, data interface{}) error Render(code int, name string, data interface{}) error
// HTML sends an HTTP response with status code. // HTML sends an HTTP response with status code.
HTML(code int, html string) error HTML(code int, html string) error
// HTMLBlob sends an HTTP blob response with status code. // HTMLBlob sends an HTTP blob response with status code.
HTMLBlob(code int, b []byte) error HTMLBlob(code int, b []byte) error
// String sends a string response with status code. // String sends a string response with status code.
String(code int, s string) error String(code int, s string) error
// JSON sends a JSON response with status code. // JSON sends a JSON response with status code.
JSON(code int, i interface{}) error JSON(code int, i interface{}) error
// JSONPretty sends a pretty-print JSON with status code. // JSONPretty sends a pretty-print JSON with status code.
JSONPretty(code int, i interface{}, indent string) error JSONPretty(code int, i interface{}, indent string) error
// JSONBlob sends a JSON blob response with status code. // JSONBlob sends a JSON blob response with status code.
JSONBlob(code int, b []byte) error JSONBlob(code int, b []byte) error
// JSONP sends a JSONP response with status code. It uses `callback` to construct // JSONP sends a JSONP response with status code. It uses `callback` to construct
// the JSONP payload. // the JSONP payload.
JSONP(code int, callback string, i interface{}) error JSONP(code int, callback string, i interface{}) error
// JSONPBlob sends a JSONP blob response with status code. It uses `callback` // JSONPBlob sends a JSONP blob response with status code. It uses `callback`
// to construct the JSONP payload. // to construct the JSONP payload.
JSONPBlob(code int, callback string, b []byte) error JSONPBlob(code int, callback string, b []byte) error
// XML sends an XML response with status code. // XML sends an XML response with status code.
XML(code int, i interface{}) error XML(code int, i interface{}) error
// XMLPretty sends a pretty-print XML with status code. // XMLPretty sends a pretty-print XML with status code.
XMLPretty(code int, i interface{}, indent string) error XMLPretty(code int, i interface{}, indent string) error
// XMLBlob sends an XML blob response with status code. // XMLBlob sends an XML blob response with status code.
XMLBlob(code int, b []byte) error XMLBlob(code int, b []byte) error
// Blob sends a blob response with status code and content type. // Blob sends a blob response with status code and content type.
Blob(code int, contentType string, b []byte) error Blob(code int, contentType string, b []byte) error
// Stream sends a streaming response with status code and content type. // Stream sends a streaming response with status code and content type.
Stream(code int, contentType string, r io.Reader) error Stream(code int, contentType string, r io.Reader) error
// File sends a response with the content of the file. // File sends a response with the content of the file.
File(file string) error File(file string) error
// Attachment sends a response as attachment, prompting client to save the // Attachment sends a response as attachment, prompting client to save the
// file. // file.
Attachment(file string, name string) error Attachment(file string, name string) error
// Inline sends a response as inline, opening the file in the browser. // Inline sends a response as inline, opening the file in the browser.
Inline(file string, name string) error Inline(file string, name string) error
// NoContent sends a response with no body and a status code. // NoContent sends a response with no body and a status code.
NoContent(code int) error NoContent(code int) error
// Redirect redirects the request to a provided URL with status code. // Redirect redirects the request to a provided URL with status code.
Redirect(code int, url string) error Redirect(code int, url string) error
// Error invokes the registered global HTTP error handler. Generally used by middleware. // Error invokes the registered global HTTP error handler. Generally used by middleware.
// A side-effect of calling global error handler is that now Response has been committed (sent to the client) and // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and
// middlewares up in chain can not change Response status code or Response body anymore. // middlewares up in chain can not change Response status code or Response body anymore.
// //
// Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that.
Error(err error) Error(err error)
// Handler returns the matched handler by router. // Handler returns the matched handler by router.
Handler() HandlerFunc Handler() HandlerFunc
// SetHandler sets the matched handler by router. // SetHandler sets the matched handler by router.
SetHandler(h HandlerFunc) SetHandler(h HandlerFunc)
// Logger returns the `Logger` instance. // Logger returns the `Logger` instance.
Logger() Logger Logger() Logger
// SetLogger Set the logger // SetLogger Set the logger
SetLogger(l Logger) SetLogger(l Logger)
// Echo returns the `Echo` instance. // Echo returns the `Echo` instance.
Echo() *Echo Echo() *Echo
// Reset resets the context after request completes. It must be called along // Reset resets the context after request completes. It must be called along
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
// See `Echo#ServeHTTP()` // See `Echo#ServeHTTP()`
Reset(r *http.Request, w http.ResponseWriter) Reset(r *http.Request, w http.ResponseWriter)
} }
context struct { type context struct {
request *http.Request request *http.Request
response *Response response *Response
path string query url.Values
pnames []string echo *Echo
pvalues []string logger Logger
query url.Values
handler HandlerFunc store Map
store Map lock sync.RWMutex
echo *Echo
logger Logger // following fields are set by Router
lock sync.RWMutex
} // path is route path that Router matched. It is empty string where there is no route match.
) // Route registered with RouteNotFound is considered as a match and path therefore is not empty.
path string
// pnames length is tied to param count for the matched route
pnames []string
// Usually echo.Echo is sizing pvalues but there could be user created middlewares that decide to
// overwrite parameter by calling SetParamNames + SetParamValues.
// When echo.Echo allocated that slice it length/capacity is tied to echo.Echo.maxParam value.
//
// It is important that pvalues size is always equal or bigger to pnames length.
pvalues []string
handler HandlerFunc
}
const ( const (
// ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
@ -329,13 +344,9 @@ func (c *context) SetParamNames(names ...string) {
c.pnames = names c.pnames = names
l := len(names) l := len(names)
if *c.echo.maxParam < l {
*c.echo.maxParam = l
}
if len(c.pvalues) < l { if len(c.pvalues) < l {
// Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them,
// probably those values will be overriden in a Context#SetParamValues // probably those values will be overridden in a Context#SetParamValues
newPvalues := make([]string, l) newPvalues := make([]string, l)
copy(newPvalues, c.pvalues) copy(newPvalues, c.pvalues)
c.pvalues = newPvalues c.pvalues = newPvalues
@ -347,11 +358,11 @@ func (c *context) ParamValues() []string {
} }
func (c *context) SetParamValues(values ...string) { func (c *context) SetParamValues(values ...string) {
// NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam (or bigger) at all times
// It will brake the Router#Find code // It will brake the Router#Find code
limit := len(values) limit := len(values)
if limit > *c.echo.maxParam { if limit > len(c.pvalues) {
limit = *c.echo.maxParam c.pvalues = make([]string, limit)
} }
for i := 0; i < limit; i++ { for i := 0; i < limit; i++ {
c.pvalues[i] = values[i] c.pvalues[i] = values[i]
@ -489,7 +500,7 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error
} }
func (c *context) json(code int, i interface{}, indent string) error { func (c *context) json(code int, i interface{}, indent string) error {
c.writeContentType(MIMEApplicationJSONCharsetUTF8) c.writeContentType(MIMEApplicationJSON)
c.response.Status = code c.response.Status = code
return c.echo.JSONSerializer.Serialize(c, i, indent) return c.echo.JSONSerializer.Serialize(c, i, indent)
} }
@ -507,7 +518,7 @@ func (c *context) JSONPretty(code int, i interface{}, indent string) (err error)
} }
func (c *context) JSONBlob(code int, b []byte) (err error) { func (c *context) JSONBlob(code int, b []byte) (err error) {
return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b) return c.Blob(code, MIMEApplicationJSON, b)
} }
func (c *context) JSONP(code int, callback string, i interface{}) (err error) { func (c *context) JSONP(code int, callback string, i interface{}) (err error) {
@ -584,8 +595,10 @@ func (c *context) Inline(file, name string) error {
return c.contentDisposition(file, name, "inline") return c.contentDisposition(file, name, "inline")
} }
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func (c *context) contentDisposition(file, name, dispositionType string) error { func (c *context) contentDisposition(file, name, dispositionType string) error {
c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name)) c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name)))
return c.File(file) return c.File(file)
} }
@ -640,8 +653,8 @@ func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
c.path = "" c.path = ""
c.pnames = nil c.pnames = nil
c.logger = nil c.logger = nil
// NOTE: Don't reset because it has to have length c.echo.maxParam at all times // NOTE: Don't reset because it has to have length c.echo.maxParam (or bigger) at all times
for i := 0; i < *c.echo.maxParam; i++ { for i := 0; i < len(c.pvalues); i++ {
c.pvalues[i] = "" c.pvalues[i] = ""
} }
} }

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
/* /*
Package echo implements high performance, minimalist Go web framework. Package echo implements high performance, minimalist Go web framework.
@ -60,97 +63,95 @@ import (
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
) )
type ( // Echo is the top-level framework instance.
// Echo is the top-level framework instance. //
// // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these
// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these // fields from handlers/middlewares and changing field values at the same time leads to data-races.
// fields from handlers/middlewares and changing field values at the same time leads to data-races. // Adding new routes after the server has been started is also not safe!
// Adding new routes after the server has been started is also not safe! type Echo struct {
Echo struct { filesystem
filesystem common
common // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get
// startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get // listener address info (on which interface/port was listener bound) without having data races.
// listener address info (on which interface/port was listener binded) without having data races. startupMutex sync.RWMutex
startupMutex sync.RWMutex colorer *color.Color
colorer *color.Color
// premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns
// an error the router is not executed and the request will end up in the global error handler. // an error the router is not executed and the request will end up in the global error handler.
premiddleware []MiddlewareFunc premiddleware []MiddlewareFunc
middleware []MiddlewareFunc middleware []MiddlewareFunc
maxParam *int maxParam *int
router *Router router *Router
routers map[string]*Router routers map[string]*Router
pool sync.Pool pool sync.Pool
StdLogger *stdLog.Logger StdLogger *stdLog.Logger
Server *http.Server Server *http.Server
TLSServer *http.Server TLSServer *http.Server
Listener net.Listener Listener net.Listener
TLSListener net.Listener TLSListener net.Listener
AutoTLSManager autocert.Manager AutoTLSManager autocert.Manager
DisableHTTP2 bool DisableHTTP2 bool
Debug bool Debug bool
HideBanner bool HideBanner bool
HidePort bool HidePort bool
HTTPErrorHandler HTTPErrorHandler HTTPErrorHandler HTTPErrorHandler
Binder Binder Binder Binder
JSONSerializer JSONSerializer JSONSerializer JSONSerializer
Validator Validator Validator Validator
Renderer Renderer Renderer Renderer
Logger Logger Logger Logger
IPExtractor IPExtractor IPExtractor IPExtractor
ListenerNetwork string ListenerNetwork string
// OnAddRouteHandler is called when Echo adds new route to specific host router. // OnAddRouteHandler is called when Echo adds new route to specific host router.
OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc)
} }
// Route contains a handler and information for matching against requests. // Route contains a handler and information for matching against requests.
Route struct { type Route struct {
Method string `json:"method"` Method string `json:"method"`
Path string `json:"path"` Path string `json:"path"`
Name string `json:"name"` Name string `json:"name"`
} }
// HTTPError represents an error that occurred while handling a request. // HTTPError represents an error that occurred while handling a request.
HTTPError struct { type HTTPError struct {
Code int `json:"-"` Code int `json:"-"`
Message interface{} `json:"message"` Message interface{} `json:"message"`
Internal error `json:"-"` // Stores the error returned by an external dependency Internal error `json:"-"` // Stores the error returned by an external dependency
} }
// MiddlewareFunc defines a function to process middleware. // MiddlewareFunc defines a function to process middleware.
MiddlewareFunc func(next HandlerFunc) HandlerFunc type MiddlewareFunc func(next HandlerFunc) HandlerFunc
// HandlerFunc defines a function to serve HTTP requests. // HandlerFunc defines a function to serve HTTP requests.
HandlerFunc func(c Context) error type HandlerFunc func(c Context) error
// HTTPErrorHandler is a centralized HTTP error handler. // HTTPErrorHandler is a centralized HTTP error handler.
HTTPErrorHandler func(err error, c Context) type HTTPErrorHandler func(err error, c Context)
// Validator is the interface that wraps the Validate function. // Validator is the interface that wraps the Validate function.
Validator interface { type Validator interface {
Validate(i interface{}) error Validate(i interface{}) error
} }
// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
JSONSerializer interface { type JSONSerializer interface {
Serialize(c Context, i interface{}, indent string) error Serialize(c Context, i interface{}, indent string) error
Deserialize(c Context, i interface{}) error Deserialize(c Context, i interface{}) error
} }
// Renderer is the interface that wraps the Render function. // Renderer is the interface that wraps the Render function.
Renderer interface { type Renderer interface {
Render(io.Writer, string, interface{}, Context) error Render(io.Writer, string, interface{}, Context) error
} }
// Map defines a generic map of type `map[string]interface{}`. // Map defines a generic map of type `map[string]interface{}`.
Map map[string]interface{} type Map map[string]interface{}
// Common struct for Echo & Group. // Common struct for Echo & Group.
common struct{} type common struct{}
)
// HTTP methods // HTTP methods
// NOTE: Deprecated, please use the stdlib constants directly instead. // NOTE: Deprecated, please use the stdlib constants directly instead.
@ -169,7 +170,12 @@ const (
// MIME types // MIME types
const ( const (
MIMEApplicationJSON = "application/json" // MIMEApplicationJSON JavaScript Object Notation (JSON) https://www.rfc-editor.org/rfc/rfc8259
MIMEApplicationJSON = "application/json"
// Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default.
// No "charset" parameter is defined for this registration.
// Adding one really has no effect on compliant recipients.
// See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1
MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8
MIMEApplicationJavaScript = "application/javascript" MIMEApplicationJavaScript = "application/javascript"
MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8
@ -259,7 +265,7 @@ const (
const ( const (
// Version of Echo // Version of Echo
Version = "4.11.1" Version = "4.12.0"
website = "https://echo.labstack.com" website = "https://echo.labstack.com"
// http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
banner = ` banner = `
@ -274,21 +280,19 @@ ____________________________________O/_______
` `
) )
var ( var methods = [...]string{
methods = [...]string{ http.MethodConnect,
http.MethodConnect, http.MethodDelete,
http.MethodDelete, http.MethodGet,
http.MethodGet, http.MethodHead,
http.MethodHead, http.MethodOptions,
http.MethodOptions, http.MethodPatch,
http.MethodPatch, http.MethodPost,
http.MethodPost, PROPFIND,
PROPFIND, http.MethodPut,
http.MethodPut, http.MethodTrace,
http.MethodTrace, REPORT,
REPORT, }
}
)
// Errors // Errors
var ( var (
@ -341,22 +345,23 @@ var (
ErrInvalidListenerNetwork = errors.New("invalid listener network") ErrInvalidListenerNetwork = errors.New("invalid listener network")
) )
// Error handlers // NotFoundHandler is the handler that router uses in case there was no matching route found. Returns an error that results
var ( // HTTP 404 status code.
NotFoundHandler = func(c Context) error { var NotFoundHandler = func(c Context) error {
return ErrNotFound return ErrNotFound
} }
MethodNotAllowedHandler = func(c Context) error { // MethodNotAllowedHandler is the handler thar router uses in case there was no matching route found but there was
// See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) // another matching routes for that requested URL. Returns an error that results HTTP 405 Method Not Allowed status code.
// response and MAY do so in any other response. For disabled resources an empty Allow header may be returned var MethodNotAllowedHandler = func(c Context) error {
routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed)
if ok && routerAllowMethods != "" { // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned
c.Response().Header().Set(HeaderAllow, routerAllowMethods) routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string)
} if ok && routerAllowMethods != "" {
return ErrMethodNotAllowed c.Response().Header().Set(HeaderAllow, routerAllowMethods)
} }
) return ErrMethodNotAllowed
}
// New creates an instance of Echo. // New creates an instance of Echo.
func New() (e *Echo) { func New() (e *Echo) {
@ -414,7 +419,7 @@ func (e *Echo) Routers() map[string]*Router {
// //
// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error).
// When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from
// handler. Then the error that global error handler received will be ignored because we have already "commited" the // handler. Then the error that global error handler received will be ignored because we have already "committed" the
// response and status code header has been sent to the client. // response and status code header has been sent to the client.
func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (

View File

@ -1,21 +1,22 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (
"net/http" "net/http"
) )
type ( // Group is a set of sub-routes for a specified route. It can be used for inner
// Group is a set of sub-routes for a specified route. It can be used for inner // routes that share a common middleware or functionality that should be separate
// routes that share a common middleware or functionality that should be separate // from the parent echo instance while still inheriting from it.
// from the parent echo instance while still inheriting from it. type Group struct {
Group struct { common
common host string
host string prefix string
prefix string middleware []MiddlewareFunc
middleware []MiddlewareFunc echo *Echo
echo *Echo }
}
)
// Use implements `Echo#Use()` for sub-routes within the Group. // Use implements `Echo#Use()` for sub-routes within the Group.
func (g *Group) Use(middleware ...MiddlewareFunc) { func (g *Group) Use(middleware ...MiddlewareFunc) {

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (
@ -64,7 +67,7 @@ XFF: "x" "x, a" "x, a, b"
``` ```
In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is
configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre". configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructure".
In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`.
In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`.
@ -225,15 +228,21 @@ func extractIP(req *http.Request) string {
func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
checker := newIPChecker(options) checker := newIPChecker(options)
return func(req *http.Request) string { return func(req *http.Request) string {
directIP := extractIP(req)
realIP := req.Header.Get(HeaderXRealIP) realIP := req.Header.Get(HeaderXRealIP)
if realIP != "" { if realIP == "" {
return directIP
}
if checker.trust(net.ParseIP(directIP)) {
realIP = strings.TrimPrefix(realIP, "[") realIP = strings.TrimPrefix(realIP, "[")
realIP = strings.TrimSuffix(realIP, "]") realIP = strings.TrimSuffix(realIP, "]")
if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { if rIP := net.ParseIP(realIP); rIP != nil {
return realIP return realIP
} }
} }
return extractIP(req)
return directIP
} }
} }

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (

View File

@ -1,41 +1,41 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (
"io"
"github.com/labstack/gommon/log" "github.com/labstack/gommon/log"
"io"
) )
type ( // Logger defines the logging interface.
// Logger defines the logging interface. type Logger interface {
Logger interface { Output() io.Writer
Output() io.Writer SetOutput(w io.Writer)
SetOutput(w io.Writer) Prefix() string
Prefix() string SetPrefix(p string)
SetPrefix(p string) Level() log.Lvl
Level() log.Lvl SetLevel(v log.Lvl)
SetLevel(v log.Lvl) SetHeader(h string)
SetHeader(h string) Print(i ...interface{})
Print(i ...interface{}) Printf(format string, args ...interface{})
Printf(format string, args ...interface{}) Printj(j log.JSON)
Printj(j log.JSON) Debug(i ...interface{})
Debug(i ...interface{}) Debugf(format string, args ...interface{})
Debugf(format string, args ...interface{}) Debugj(j log.JSON)
Debugj(j log.JSON) Info(i ...interface{})
Info(i ...interface{}) Infof(format string, args ...interface{})
Infof(format string, args ...interface{}) Infoj(j log.JSON)
Infoj(j log.JSON) Warn(i ...interface{})
Warn(i ...interface{}) Warnf(format string, args ...interface{})
Warnf(format string, args ...interface{}) Warnj(j log.JSON)
Warnj(j log.JSON) Error(i ...interface{})
Error(i ...interface{}) Errorf(format string, args ...interface{})
Errorf(format string, args ...interface{}) Errorj(j log.JSON)
Errorj(j log.JSON) Fatal(i ...interface{})
Fatal(i ...interface{}) Fatalj(j log.JSON)
Fatalj(j log.JSON) Fatalf(format string, args ...interface{})
Fatalf(format string, args ...interface{}) Panic(i ...interface{})
Panic(i ...interface{}) Panicj(j log.JSON)
Panicj(j log.JSON) Panicf(format string, args ...interface{})
Panicf(format string, args ...interface{}) }
}
)

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -9,37 +12,35 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // BasicAuthConfig defines the config for BasicAuth middleware.
// BasicAuthConfig defines the config for BasicAuth middleware. type BasicAuthConfig struct {
BasicAuthConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Validator is a function to validate BasicAuth credentials. // Validator is a function to validate BasicAuth credentials.
// Required. // Required.
Validator BasicAuthValidator Validator BasicAuthValidator
// Realm is a string to define realm attribute of BasicAuth. // Realm is a string to define realm attribute of BasicAuth.
// Default value "Restricted". // Default value "Restricted".
Realm string Realm string
} }
// BasicAuthValidator defines a function to validate BasicAuth credentials. // BasicAuthValidator defines a function to validate BasicAuth credentials.
BasicAuthValidator func(string, string, echo.Context) (bool, error) // The function should return a boolean indicating whether the credentials are valid,
) // and an error if any error occurs during the validation process.
type BasicAuthValidator func(string, string, echo.Context) (bool, error)
const ( const (
basic = "basic" basic = "basic"
defaultRealm = "Restricted" defaultRealm = "Restricted"
) )
var ( // DefaultBasicAuthConfig is the default BasicAuth middleware config.
// DefaultBasicAuthConfig is the default BasicAuth middleware config. var DefaultBasicAuthConfig = BasicAuthConfig{
DefaultBasicAuthConfig = BasicAuthConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Realm: defaultRealm,
Realm: defaultRealm, }
}
)
// BasicAuth returns an BasicAuth middleware. // BasicAuth returns an BasicAuth middleware.
// //

View File

@ -1,8 +1,12 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -10,32 +14,28 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // BodyDumpConfig defines the config for BodyDump middleware.
// BodyDumpConfig defines the config for BodyDump middleware. type BodyDumpConfig struct {
BodyDumpConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Handler receives request and response payload. // Handler receives request and response payload.
// Required. // Required.
Handler BodyDumpHandler Handler BodyDumpHandler
} }
// BodyDumpHandler receives the request and response payload. // BodyDumpHandler receives the request and response payload.
BodyDumpHandler func(echo.Context, []byte, []byte) type BodyDumpHandler func(echo.Context, []byte, []byte)
bodyDumpResponseWriter struct { type bodyDumpResponseWriter struct {
io.Writer io.Writer
http.ResponseWriter http.ResponseWriter
} }
)
var ( // DefaultBodyDumpConfig is the default BodyDump middleware config.
// DefaultBodyDumpConfig is the default BodyDump middleware config. var DefaultBodyDumpConfig = BodyDumpConfig{
DefaultBodyDumpConfig = BodyDumpConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, }
}
)
// BodyDump returns a BodyDump middleware. // BodyDump returns a BodyDump middleware.
// //
@ -98,9 +98,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
} }
func (w *bodyDumpResponseWriter) Flush() { func (w *bodyDumpResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush() err := responseControllerFlush(w.ResponseWriter)
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
} }
func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack() return responseControllerHijack(w.ResponseWriter)
}
func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
} }

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -9,32 +12,27 @@ import (
"github.com/labstack/gommon/bytes" "github.com/labstack/gommon/bytes"
) )
type ( // BodyLimitConfig defines the config for BodyLimit middleware.
// BodyLimitConfig defines the config for BodyLimit middleware. type BodyLimitConfig struct {
BodyLimitConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Maximum allowed size for a request body, it can be specified // Maximum allowed size for a request body, it can be specified
// as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
Limit string `yaml:"limit"` Limit string `yaml:"limit"`
limit int64 limit int64
} }
limitedReader struct { type limitedReader struct {
BodyLimitConfig BodyLimitConfig
reader io.ReadCloser reader io.ReadCloser
read int64 read int64
context echo.Context }
}
)
var ( // DefaultBodyLimitConfig is the default BodyLimit middleware config.
// DefaultBodyLimitConfig is the default BodyLimit middleware config. var DefaultBodyLimitConfig = BodyLimitConfig{
DefaultBodyLimitConfig = BodyLimitConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, }
}
)
// BodyLimit returns a BodyLimit middleware. // BodyLimit returns a BodyLimit middleware.
// //
@ -80,7 +78,7 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
// Based on content read // Based on content read
r := pool.Get().(*limitedReader) r := pool.Get().(*limitedReader)
r.Reset(req.Body, c) r.Reset(req.Body)
defer pool.Put(r) defer pool.Put(r)
req.Body = r req.Body = r
@ -102,9 +100,8 @@ func (r *limitedReader) Close() error {
return r.reader.Close() return r.reader.Close()
} }
func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { func (r *limitedReader) Reset(reader io.ReadCloser) {
r.reader = reader r.reader = reader
r.context = context
r.read = 0 r.read = 0
} }

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -13,54 +16,50 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // GzipConfig defines the config for Gzip middleware.
// GzipConfig defines the config for Gzip middleware. type GzipConfig struct {
GzipConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Gzip compression level. // Gzip compression level.
// Optional. Default value -1. // Optional. Default value -1.
Level int `yaml:"level"` Level int `yaml:"level"`
// Length threshold before gzip compression is applied. // Length threshold before gzip compression is applied.
// Optional. Default value 0. // Optional. Default value 0.
// //
// Most of the time you will not need to change the default. Compressing // Most of the time you will not need to change the default. Compressing
// a short response might increase the transmitted data because of the // a short response might increase the transmitted data because of the
// gzip format overhead. Compressing the response will also consume CPU // gzip format overhead. Compressing the response will also consume CPU
// and time on the server and the client (for decompressing). Depending on // and time on the server and the client (for decompressing). Depending on
// your use case such a threshold might be useful. // your use case such a threshold might be useful.
// //
// See also: // See also:
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
MinLength int MinLength int
} }
gzipResponseWriter struct { type gzipResponseWriter struct {
io.Writer io.Writer
http.ResponseWriter http.ResponseWriter
wroteHeader bool wroteHeader bool
wroteBody bool wroteBody bool
minLength int minLength int
minLengthExceeded bool minLengthExceeded bool
buffer *bytes.Buffer buffer *bytes.Buffer
code int code int
} }
)
const ( const (
gzipScheme = "gzip" gzipScheme = "gzip"
) )
var ( // DefaultGzipConfig is the default Gzip middleware config.
// DefaultGzipConfig is the default Gzip middleware config. var DefaultGzipConfig = GzipConfig{
DefaultGzipConfig = GzipConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Level: -1,
Level: -1, MinLength: 0,
MinLength: 0, }
}
)
// Gzip returns a middleware which compresses HTTP response using gzip compression // Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme. // scheme.
@ -191,13 +190,15 @@ func (w *gzipResponseWriter) Flush() {
} }
w.Writer.(*gzip.Writer).Flush() w.Writer.(*gzip.Writer).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok { _ = responseControllerFlush(w.ResponseWriter)
flusher.Flush() }
}
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
} }
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack() return responseControllerHijack(w.ResponseWriter)
} }
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -13,7 +16,7 @@ type ContextTimeoutConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// ErrorHandler is a function when error aries in middeware execution. // ErrorHandler is a function when error aries in middleware execution.
ErrorHandler func(err error, c echo.Context) error ErrorHandler func(err error, c echo.Context) error
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout // Timeout configures a timeout for the middleware, defaults to 0 for no timeout

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -9,112 +12,109 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // CORSConfig defines the config for CORS middleware.
// CORSConfig defines the config for CORS middleware. type CORSConfig struct {
CORSConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// AllowOrigins determines the value of the Access-Control-Allow-Origin // AllowOrigins determines the value of the Access-Control-Allow-Origin
// response header. This header defines a list of origins that may access the // response header. This header defines a list of origins that may access the
// resource. The wildcard characters '*' and '?' are supported and are // resource. The wildcard characters '*' and '?' are supported and are
// converted to regex fragments '.*' and '.' accordingly. // converted to regex fragments '.*' and '.' accordingly.
// //
// Security: use extreme caution when handling the origin, and carefully // Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names. // validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
// //
// Optional. Default value []string{"*"}. // Optional. Default value []string{"*"}.
// //
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
AllowOrigins []string `yaml:"allow_origins"` AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the // AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If // origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is // an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored. // set, AllowOrigins is ignored.
// //
// Security: use extreme caution when handling the origin, and carefully // Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names. // validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
// //
// Optional. // Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` AllowOriginFunc func(origin string) (bool, error) `yaml:"-"`
// AllowMethods determines the value of the Access-Control-Allow-Methods // AllowMethods determines the value of the Access-Control-Allow-Methods
// response header. This header specified the list of methods allowed when // response header. This header specified the list of methods allowed when
// accessing the resource. This is used in response to a preflight request. // accessing the resource. This is used in response to a preflight request.
// //
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
// If `allowMethods` is left empty, this middleware will fill for preflight // If `allowMethods` is left empty, this middleware will fill for preflight
// request `Access-Control-Allow-Methods` header value // request `Access-Control-Allow-Methods` header value
// from `Allow` header that echo.Router set into context. // from `Allow` header that echo.Router set into context.
// //
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
AllowMethods []string `yaml:"allow_methods"` AllowMethods []string `yaml:"allow_methods"`
// AllowHeaders determines the value of the Access-Control-Allow-Headers // AllowHeaders determines the value of the Access-Control-Allow-Headers
// response header. This header is used in response to a preflight request to // response header. This header is used in response to a preflight request to
// indicate which HTTP headers can be used when making the actual request. // indicate which HTTP headers can be used when making the actual request.
// //
// Optional. Default value []string{}. // Optional. Default value []string{}.
// //
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
AllowHeaders []string `yaml:"allow_headers"` AllowHeaders []string `yaml:"allow_headers"`
// AllowCredentials determines the value of the // AllowCredentials determines the value of the
// Access-Control-Allow-Credentials response header. This header indicates // Access-Control-Allow-Credentials response header. This header indicates
// whether or not the response to the request can be exposed when the // whether or not the response to the request can be exposed when the
// credentials mode (Request.credentials) is true. When used as part of a // credentials mode (Request.credentials) is true. When used as part of a
// response to a preflight request, this indicates whether or not the actual // response to a preflight request, this indicates whether or not the actual
// request can be made using credentials. See also // request can be made using credentials. See also
// [MDN: Access-Control-Allow-Credentials]. // [MDN: Access-Control-Allow-Credentials].
// //
// Optional. Default value false, in which case the header is not set. // Optional. Default value false, in which case the header is not set.
// //
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See "Exploiting CORS misconfigurations for Bitcoins and bounties", // See "Exploiting CORS misconfigurations for Bitcoins and bounties",
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
// //
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
AllowCredentials bool `yaml:"allow_credentials"` AllowCredentials bool `yaml:"allow_credentials"`
// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
// //
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
// //
// Optional. Default value is false. // Optional. Default value is false.
UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"` UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"`
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which // ExposeHeaders determines the value of Access-Control-Expose-Headers, which
// defines a list of headers that clients are allowed to access. // defines a list of headers that clients are allowed to access.
// //
// Optional. Default value []string{}, in which case the header is not set. // Optional. Default value []string{}, in which case the header is not set.
// //
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
ExposeHeaders []string `yaml:"expose_headers"` ExposeHeaders []string `yaml:"expose_headers"`
// MaxAge determines the value of the Access-Control-Max-Age response header. // MaxAge determines the value of the Access-Control-Max-Age response header.
// This header indicates how long (in seconds) the results of a preflight // This header indicates how long (in seconds) the results of a preflight
// request can be cached. // request can be cached.
// // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
// Optional. Default value 0. The header is set only if MaxAge > 0. //
// // Optional. Default value 0 - meaning header is not sent.
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age //
MaxAge int `yaml:"max_age"` // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
} MaxAge int `yaml:"max_age"`
) }
var ( // DefaultCORSConfig is the default CORS middleware config.
// DefaultCORSConfig is the default CORS middleware config. var DefaultCORSConfig = CORSConfig{
DefaultCORSConfig = CORSConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, AllowOrigins: []string{"*"},
AllowOrigins: []string{"*"}, AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, }
}
)
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See also [MDN: Cross-Origin Resource Sharing (CORS)]. // See also [MDN: Cross-Origin Resource Sharing (CORS)].
@ -159,7 +159,11 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
allowMethods := strings.Join(config.AllowMethods, ",") allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",") allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",")
maxAge := strconv.Itoa(config.MaxAge)
maxAge := "0"
if config.MaxAge > 0 {
maxAge = strconv.Itoa(config.MaxAge)
}
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
@ -282,7 +286,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
} }
} }
if config.MaxAge > 0 { if config.MaxAge != 0 {
res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge) res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
} }
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -6,85 +9,80 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
) )
type ( // CSRFConfig defines the config for CSRF middleware.
// CSRFConfig defines the config for CSRF middleware. type CSRFConfig struct {
CSRFConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// TokenLength is the length of the generated token. // TokenLength is the length of the generated token.
TokenLength uint8 `yaml:"token_length"` TokenLength uint8 `yaml:"token_length"`
// Optional. Default value 32. // Optional. Default value 32.
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request. // to extract token from the request.
// Optional. Default value "header:X-CSRF-Token". // Optional. Default value "header:X-CSRF-Token".
// Possible values: // Possible values:
// - "header:<name>" or "header:<name>:<cut-prefix>" // - "header:<name>" or "header:<name>:<cut-prefix>"
// - "query:<name>" // - "query:<name>"
// - "form:<name>" // - "form:<name>"
// Multiple sources example: // Multiple sources example:
// - "header:X-CSRF-Token,query:csrf" // - "header:X-CSRF-Token,query:csrf"
TokenLookup string `yaml:"token_lookup"` TokenLookup string `yaml:"token_lookup"`
// Context key to store generated CSRF token into context. // Context key to store generated CSRF token into context.
// Optional. Default value "csrf". // Optional. Default value "csrf".
ContextKey string `yaml:"context_key"` ContextKey string `yaml:"context_key"`
// Name of the CSRF cookie. This cookie will store CSRF token. // Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf". // Optional. Default value "csrf".
CookieName string `yaml:"cookie_name"` CookieName string `yaml:"cookie_name"`
// Domain of the CSRF cookie. // Domain of the CSRF cookie.
// Optional. Default value none. // Optional. Default value none.
CookieDomain string `yaml:"cookie_domain"` CookieDomain string `yaml:"cookie_domain"`
// Path of the CSRF cookie. // Path of the CSRF cookie.
// Optional. Default value none. // Optional. Default value none.
CookiePath string `yaml:"cookie_path"` CookiePath string `yaml:"cookie_path"`
// Max age (in seconds) of the CSRF cookie. // Max age (in seconds) of the CSRF cookie.
// Optional. Default value 86400 (24hr). // Optional. Default value 86400 (24hr).
CookieMaxAge int `yaml:"cookie_max_age"` CookieMaxAge int `yaml:"cookie_max_age"`
// Indicates if CSRF cookie is secure. // Indicates if CSRF cookie is secure.
// Optional. Default value false. // Optional. Default value false.
CookieSecure bool `yaml:"cookie_secure"` CookieSecure bool `yaml:"cookie_secure"`
// Indicates if CSRF cookie is HTTP only. // Indicates if CSRF cookie is HTTP only.
// Optional. Default value false. // Optional. Default value false.
CookieHTTPOnly bool `yaml:"cookie_http_only"` CookieHTTPOnly bool `yaml:"cookie_http_only"`
// Indicates SameSite mode of the CSRF cookie. // Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode. // Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"` CookieSameSite http.SameSite `yaml:"cookie_same_site"`
// ErrorHandler defines a function which is executed for returning custom errors. // ErrorHandler defines a function which is executed for returning custom errors.
ErrorHandler CSRFErrorHandler ErrorHandler CSRFErrorHandler
} }
// CSRFErrorHandler is a function which is executed for creating custom errors. // CSRFErrorHandler is a function which is executed for creating custom errors.
CSRFErrorHandler func(err error, c echo.Context) error type CSRFErrorHandler func(err error, c echo.Context) error
)
// ErrCSRFInvalid is returned when CSRF check fails // ErrCSRFInvalid is returned when CSRF check fails
var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
var ( // DefaultCSRFConfig is the default CSRF middleware config.
// DefaultCSRFConfig is the default CSRF middleware config. var DefaultCSRFConfig = CSRFConfig{
DefaultCSRFConfig = CSRFConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, TokenLength: 32,
TokenLength: 32, TokenLookup: "header:" + echo.HeaderXCSRFToken,
TokenLookup: "header:" + echo.HeaderXCSRFToken, ContextKey: "csrf",
ContextKey: "csrf", CookieName: "_csrf",
CookieName: "_csrf", CookieMaxAge: 86400,
CookieMaxAge: 86400, CookieSameSite: http.SameSiteDefaultMode,
CookieSameSite: http.SameSiteDefaultMode, }
}
)
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
@ -103,6 +101,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.TokenLength == 0 { if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength config.TokenLength = DefaultCSRFConfig.TokenLength
} }
if config.TokenLookup == "" { if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup config.TokenLookup = DefaultCSRFConfig.TokenLookup
} }
@ -132,7 +131,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
token := "" token := ""
if k, err := c.Cookie(config.CookieName); err != nil { if k, err := c.Cookie(config.CookieName); err != nil {
token = random.String(config.TokenLength) // Generate token token = randomString(config.TokenLength)
} else { } else {
token = k.Value // Reuse token token = k.Value // Reuse token
} }

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -9,16 +12,14 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // DecompressConfig defines the config for Decompress middleware.
// DecompressConfig defines the config for Decompress middleware. type DecompressConfig struct {
DecompressConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor GzipDecompressPool Decompressor
} }
)
// GZIPEncoding content-encoding header if set to "gzip", decompress body contents. // GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip" const GZIPEncoding string = "gzip"
@ -28,13 +29,11 @@ type Decompressor interface {
gzipDecompressPool() sync.Pool gzipDecompressPool() sync.Pool
} }
var ( // DefaultDecompressConfig defines the config for decompress middleware
//DefaultDecompressConfig defines the config for decompress middleware var DefaultDecompressConfig = DecompressConfig{
DefaultDecompressConfig = DecompressConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, GzipDecompressPool: &DefaultGzipDecompressPool{},
GzipDecompressPool: &DefaultGzipDecompressPool{}, }
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface // DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct { type DefaultGzipDecompressPool struct {

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
//go:build go1.15 //go:build go1.15
// +build go1.15 // +build go1.15
@ -12,139 +15,135 @@ import (
"reflect" "reflect"
) )
type ( // JWTConfig defines the config for JWT middleware.
// JWTConfig defines the config for JWT middleware. type JWTConfig struct {
JWTConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// BeforeFunc defines a function which is executed just before the middleware. // BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc BeforeFunc BeforeFunc BeforeFunc
// SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next
// middleware or handler. // middleware or handler.
SuccessHandler JWTSuccessHandler SuccessHandler JWTSuccessHandler
// ErrorHandler defines a function which is executed for an invalid token. // ErrorHandler defines a function which is executed for an invalid token.
// It may be used to define a custom JWT error. // It may be used to define a custom JWT error.
ErrorHandler JWTErrorHandler ErrorHandler JWTErrorHandler
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
ErrorHandlerWithContext JWTErrorHandlerWithContext ErrorHandlerWithContext JWTErrorHandlerWithContext
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to
// ignore the error (by returning `nil`). // ignore the error (by returning `nil`).
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
// In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context
// and continue. Some logic down the remaining execution chain needs to check that (public) token value then. // and continue. Some logic down the remaining execution chain needs to check that (public) token value then.
ContinueOnIgnoredError bool ContinueOnIgnoredError bool
// Signing key to validate token. // Signing key to validate token.
// This is one of the three options to provide a token validation key. // This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKeys is provided. // Required if neither user-defined KeyFunc nor SigningKeys is provided.
SigningKey interface{} SigningKey interface{}
// Map of signing keys to validate token with kid field usage. // Map of signing keys to validate token with kid field usage.
// This is one of the three options to provide a token validation key. // This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKey is provided. // Required if neither user-defined KeyFunc nor SigningKey is provided.
SigningKeys map[string]interface{} SigningKeys map[string]interface{}
// Signing method used to check the token's signing algorithm. // Signing method used to check the token's signing algorithm.
// Optional. Default value HS256. // Optional. Default value HS256.
SigningMethod string SigningMethod string
// Context key to store user information from the token into context. // Context key to store user information from the token into context.
// Optional. Default value "user". // Optional. Default value "user".
ContextKey string ContextKey string
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
// Not used if custom ParseTokenFunc is set. // Not used if custom ParseTokenFunc is set.
// Optional. Default value jwt.MapClaims // Optional. Default value jwt.MapClaims
Claims jwt.Claims Claims jwt.Claims
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request. // to extract token from the request.
// Optional. Default value "header:Authorization". // Optional. Default value "header:Authorization".
// Possible values: // Possible values:
// - "header:<name>" or "header:<name>:<cut-prefix>" // - "header:<name>" or "header:<name>:<cut-prefix>"
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header // `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we // value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
// want to cut is `<auth-scheme> ` note the space at the end. // want to cut is `<auth-scheme> ` note the space at the end.
// In case of JWT tokens `Authorization: Bearer <token>` prefix we cut is `Bearer `. // In case of JWT tokens `Authorization: Bearer <token>` prefix we cut is `Bearer `.
// If prefix is left empty the whole value is returned. // If prefix is left empty the whole value is returned.
// - "query:<name>" // - "query:<name>"
// - "param:<name>" // - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
// - "form:<name>" // - "form:<name>"
// Multiple sources example: // Multiple sources example:
// - "header:Authorization,cookie:myowncookie" // - "header:Authorization,cookie:myowncookie"
TokenLookup string TokenLookup string
// TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context.
// This is one of the two options to provide a token extractor. // This is one of the two options to provide a token extractor.
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
// You can also provide both if you want. // You can also provide both if you want.
TokenLookupFuncs []ValuesExtractor TokenLookupFuncs []ValuesExtractor
// AuthScheme to be used in the Authorization header. // AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer". // Optional. Default value "Bearer".
AuthScheme string AuthScheme string
// KeyFunc defines a user-defined function that supplies the public key for a token validation. // KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key. // The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party. // A user-defined KeyFunc can be useful if tokens are issued by an external party.
// Used by default ParseTokenFunc implementation. // Used by default ParseTokenFunc implementation.
// //
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key. // This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided. // Required if neither SigningKeys nor SigningKey is provided.
// Not used if custom ParseTokenFunc is set. // Not used if custom ParseTokenFunc is set.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key. // Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc KeyFunc jwt.Keyfunc
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid. // parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
} }
// JWTSuccessHandler defines a function which is executed for a valid token. // JWTSuccessHandler defines a function which is executed for a valid token.
JWTSuccessHandler func(c echo.Context) type JWTSuccessHandler func(c echo.Context)
// JWTErrorHandler defines a function which is executed for an invalid token. // JWTErrorHandler defines a function which is executed for an invalid token.
JWTErrorHandler func(err error) error type JWTErrorHandler func(err error) error
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(err error, c echo.Context) error type JWTErrorHandlerWithContext func(err error, c echo.Context) error
)
// Algorithms // Algorithms
const ( const (
AlgorithmHS256 = "HS256" AlgorithmHS256 = "HS256"
) )
// Errors // ErrJWTMissing is error that is returned when no JWToken was extracted from the request.
var ( var ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
)
var ( // ErrJWTInvalid is error that is returned when middleware could not parse JWT correctly.
// DefaultJWTConfig is the default JWT auth middleware config. var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper, // DefaultJWTConfig is the default JWT auth middleware config.
SigningMethod: AlgorithmHS256, var DefaultJWTConfig = JWTConfig{
ContextKey: "user", Skipper: DefaultSkipper,
TokenLookup: "header:" + echo.HeaderAuthorization, SigningMethod: AlgorithmHS256,
TokenLookupFuncs: nil, ContextKey: "user",
AuthScheme: "Bearer", TokenLookup: "header:" + echo.HeaderAuthorization,
Claims: jwt.MapClaims{}, TokenLookupFuncs: nil,
KeyFunc: nil, AuthScheme: "Bearer",
} Claims: jwt.MapClaims{},
) KeyFunc: nil,
}
// JWT returns a JSON Web Token (JWT) auth middleware. // JWT returns a JSON Web Token (JWT) auth middleware.
// //

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -6,69 +9,65 @@ import (
"net/http" "net/http"
) )
type ( // KeyAuthConfig defines the config for KeyAuth middleware.
// KeyAuthConfig defines the config for KeyAuth middleware. type KeyAuthConfig struct {
KeyAuthConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used // KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract key from the request. // to extract key from the request.
// Optional. Default value "header:Authorization". // Optional. Default value "header:Authorization".
// Possible values: // Possible values:
// - "header:<name>" or "header:<name>:<cut-prefix>" // - "header:<name>" or "header:<name>:<cut-prefix>"
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header // `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we // value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
// want to cut is `<auth-scheme> ` note the space at the end. // want to cut is `<auth-scheme> ` note the space at the end.
// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `. // In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `.
// - "query:<name>" // - "query:<name>"
// - "form:<name>" // - "form:<name>"
// - "cookie:<name>" // - "cookie:<name>"
// Multiple sources example: // Multiple sources example:
// - "header:Authorization,header:X-Api-Key" // - "header:Authorization,header:X-Api-Key"
KeyLookup string KeyLookup string
// AuthScheme to be used in the Authorization header. // AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer". // Optional. Default value "Bearer".
AuthScheme string AuthScheme string
// Validator is a function to validate key. // Validator is a function to validate key.
// Required. // Required.
Validator KeyAuthValidator Validator KeyAuthValidator
// ErrorHandler defines a function which is executed for an invalid key. // ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error. // It may be used to define a custom error.
ErrorHandler KeyAuthErrorHandler ErrorHandler KeyAuthErrorHandler
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
// ignore the error (by returning `nil`). // ignore the error (by returning `nil`).
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
// In that case you can use ErrorHandler to set a default public key auth value in the request context // In that case you can use ErrorHandler to set a default public key auth value in the request context
// and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
ContinueOnIgnoredError bool ContinueOnIgnoredError bool
} }
// KeyAuthValidator defines a function to validate KeyAuth credentials. // KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(auth string, c echo.Context) (bool, error) type KeyAuthValidator func(auth string, c echo.Context) (bool, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key. // KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(err error, c echo.Context) error type KeyAuthErrorHandler func(err error, c echo.Context) error
)
var (
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: DefaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
}
)
// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups // ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
type ErrKeyAuthMissing struct { type ErrKeyAuthMissing struct {
Err error Err error
} }
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
var DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: DefaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
}
// Error returns errors text // Error returns errors text
func (e *ErrKeyAuthMissing) Error() string { func (e *ErrKeyAuthMissing) Error() string {
return e.Err.Error() return e.Err.Error()

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -14,77 +17,73 @@ import (
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
) )
type ( // LoggerConfig defines the config for Logger middleware.
// LoggerConfig defines the config for Logger middleware. type LoggerConfig struct {
LoggerConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Tags to construct the logger format. // Tags to construct the logger format.
// //
// - time_unix // - time_unix
// - time_unix_milli // - time_unix_milli
// - time_unix_micro // - time_unix_micro
// - time_unix_nano // - time_unix_nano
// - time_rfc3339 // - time_rfc3339
// - time_rfc3339_nano // - time_rfc3339_nano
// - time_custom // - time_custom
// - id (Request ID) // - id (Request ID)
// - remote_ip // - remote_ip
// - uri // - uri
// - host // - host
// - method // - method
// - path // - path
// - route // - route
// - protocol // - protocol
// - referer // - referer
// - user_agent // - user_agent
// - status // - status
// - error // - error
// - latency (In nanoseconds) // - latency (In nanoseconds)
// - latency_human (Human readable) // - latency_human (Human readable)
// - bytes_in (Bytes received) // - bytes_in (Bytes received)
// - bytes_out (Bytes sent) // - bytes_out (Bytes sent)
// - header:<NAME> // - header:<NAME>
// - query:<NAME> // - query:<NAME>
// - form:<NAME> // - form:<NAME>
// - custom (see CustomTagFunc field) // - custom (see CustomTagFunc field)
// //
// Example "${remote_ip} ${status}" // Example "${remote_ip} ${status}"
// //
// Optional. Default value DefaultLoggerConfig.Format. // Optional. Default value DefaultLoggerConfig.Format.
Format string `yaml:"format"` Format string `yaml:"format"`
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat. // Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string `yaml:"custom_time_format"` CustomTimeFormat string `yaml:"custom_time_format"`
// CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf. // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf.
// Make sure that outputted text creates valid JSON string with other logged tags. // Make sure that outputted text creates valid JSON string with other logged tags.
// Optional. // Optional.
CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error)
// Output is a writer where logs in JSON format are written. // Output is a writer where logs in JSON format are written.
// Optional. Default value os.Stdout. // Optional. Default value os.Stdout.
Output io.Writer Output io.Writer
template *fasttemplate.Template template *fasttemplate.Template
colorer *color.Color colorer *color.Color
pool *sync.Pool pool *sync.Pool
} }
)
var ( // DefaultLoggerConfig is the default Logger middleware config.
// DefaultLoggerConfig is the default Logger middleware config. var DefaultLoggerConfig = LoggerConfig{
DefaultLoggerConfig = LoggerConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
`"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
`"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
`,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", CustomTimeFormat: "2006-01-02 15:04:05.00000",
CustomTimeFormat: "2006-01-02 15:04:05.00000", colorer: color.New(),
colorer: color.New(), }
}
)
// Logger returns a middleware that logs HTTP requests. // Logger returns a middleware that logs HTTP requests.
func Logger() echo.MiddlewareFunc { func Logger() echo.MiddlewareFunc {

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -6,28 +9,24 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // MethodOverrideConfig defines the config for MethodOverride middleware.
// MethodOverrideConfig defines the config for MethodOverride middleware. type MethodOverrideConfig struct {
MethodOverrideConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Getter is a function that gets overridden method from the request. // Getter is a function that gets overridden method from the request.
// Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
Getter MethodOverrideGetter Getter MethodOverrideGetter
} }
// MethodOverrideGetter is a function that gets overridden method from the request // MethodOverrideGetter is a function that gets overridden method from the request
MethodOverrideGetter func(echo.Context) string type MethodOverrideGetter func(echo.Context) string
)
var ( // DefaultMethodOverrideConfig is the default MethodOverride middleware config.
// DefaultMethodOverrideConfig is the default MethodOverride middleware config. var DefaultMethodOverrideConfig = MethodOverrideConfig{
DefaultMethodOverrideConfig = MethodOverrideConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), }
}
)
// MethodOverride returns a MethodOverride middleware. // MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and // MethodOverride middleware checks for the overridden method from the request and

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -9,14 +12,12 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // Skipper defines a function to skip middleware. Returning true skips processing
// Skipper defines a function to skip middleware. Returning true skips processing // the middleware.
// the middleware. type Skipper func(c echo.Context) bool
Skipper func(c echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware. // BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc func(c echo.Context) type BeforeFunc func(c echo.Context)
)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1) groups := pattern.FindAllStringSubmatch(input, -1)
@ -53,7 +54,7 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error
return nil return nil
} }
// Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. // Depending on how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
// We only want to use path part for rewriting and therefore trim prefix if it exists // We only want to use path part for rewriting and therefore trim prefix if it exists
rawURI := req.RequestURI rawURI := req.RequestURI
if rawURI != "" && rawURI[0] != '/' { if rawURI != "" && rawURI[0] != '/' {

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -19,117 +22,113 @@ import (
// TODO: Handle TLS proxy // TODO: Handle TLS proxy
type ( // ProxyConfig defines the config for Proxy middleware.
// ProxyConfig defines the config for Proxy middleware. type ProxyConfig struct {
ProxyConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Balancer defines a load balancing technique. // Balancer defines a load balancing technique.
// Required. // Required.
Balancer ProxyBalancer Balancer ProxyBalancer
// RetryCount defines the number of times a failed proxied request should be retried // RetryCount defines the number of times a failed proxied request should be retried
// using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
RetryCount int RetryCount int
// RetryFilter defines a function used to determine if a failed request to a // RetryFilter defines a function used to determine if a failed request to a
// ProxyTarget should be retried. The RetryFilter will only be called when the number // ProxyTarget should be retried. The RetryFilter will only be called when the number
// of previous retries is less than RetryCount. If the function returns true, the // of previous retries is less than RetryCount. If the function returns true, the
// request will be retried. The provided error indicates the reason for the request // request will be retried. The provided error indicates the reason for the request
// failure. When the ProxyTarget is unavailable, the error will be an instance of // failure. When the ProxyTarget is unavailable, the error will be an instance of
// echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not // will indicate an internal error in the Proxy middleware. When a RetryFilter is not
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom // specified, all requests that fail with http.StatusBadGateway will be retried. A custom
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
// only called when the request to the target fails, or an internal error in the Proxy // only called when the request to the target fails, or an internal error in the Proxy
// middleware has occurred. Successful requests that return a non-200 response code cannot // middleware has occurred. Successful requests that return a non-200 response code cannot
// be retried. // be retried.
RetryFilter func(c echo.Context, e error) bool RetryFilter func(c echo.Context, e error) bool
// ErrorHandler defines a function which can be used to return custom errors from // ErrorHandler defines a function which can be used to return custom errors from
// the Proxy middleware. ErrorHandler is only invoked when there has been // the Proxy middleware. ErrorHandler is only invoked when there has been
// either an internal error in the Proxy middleware or the ProxyTarget is // either an internal error in the Proxy middleware or the ProxyTarget is
// unavailable. Due to the way requests are proxied, ErrorHandler is not invoked // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
// when a ProxyTarget returns a non-200 response. In these cases, the response // when a ProxyTarget returns a non-200 response. In these cases, the response
// is already written so errors cannot be modified. ErrorHandler is only // is already written so errors cannot be modified. ErrorHandler is only
// invoked after all retry attempts have been exhausted. // invoked after all retry attempts have been exhausted.
ErrorHandler func(c echo.Context, err error) error ErrorHandler func(c echo.Context, err error) error
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be // Rewrite defines URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on. // retrieved by index e.g. $1, $2 and so on.
// Examples: // Examples:
// "/old": "/new", // "/old": "/new",
// "/api/*": "/$1", // "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1", // "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2", // "/users/*/orders/*": "/user/$1/order/$2",
Rewrite map[string]string Rewrite map[string]string
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures // RegexRewrite defines rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example: // Example:
// "^/old/[0.9]+/": "/new", // "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1", // "^/api/.+?/(.*)": "/v2/$1",
RegexRewrite map[*regexp.Regexp]string RegexRewrite map[*regexp.Regexp]string
// Context key to store selected ProxyTarget into context. // Context key to store selected ProxyTarget into context.
// Optional. Default value "target". // Optional. Default value "target".
ContextKey string ContextKey string
// To customize the transport to remote. // To customize the transport to remote.
// Examples: If custom TLS certificates are required. // Examples: If custom TLS certificates are required.
Transport http.RoundTripper Transport http.RoundTripper
// ModifyResponse defines function to modify response from ProxyTarget. // ModifyResponse defines function to modify response from ProxyTarget.
ModifyResponse func(*http.Response) error ModifyResponse func(*http.Response) error
} }
// ProxyTarget defines the upstream target. // ProxyTarget defines the upstream target.
ProxyTarget struct { type ProxyTarget struct {
Name string Name string
URL *url.URL URL *url.URL
Meta echo.Map Meta echo.Map
} }
// ProxyBalancer defines an interface to implement a load balancing technique. // ProxyBalancer defines an interface to implement a load balancing technique.
ProxyBalancer interface { type ProxyBalancer interface {
AddTarget(*ProxyTarget) bool AddTarget(*ProxyTarget) bool
RemoveTarget(string) bool RemoveTarget(string) bool
Next(echo.Context) *ProxyTarget Next(echo.Context) *ProxyTarget
} }
// TargetProvider defines an interface that gives the opportunity for balancer // TargetProvider defines an interface that gives the opportunity for balancer
// to return custom errors when selecting target. // to return custom errors when selecting target.
TargetProvider interface { type TargetProvider interface {
NextTarget(echo.Context) (*ProxyTarget, error) NextTarget(echo.Context) (*ProxyTarget, error)
} }
commonBalancer struct { type commonBalancer struct {
targets []*ProxyTarget targets []*ProxyTarget
mutex sync.Mutex mutex sync.Mutex
} }
// RandomBalancer implements a random load balancing technique. // RandomBalancer implements a random load balancing technique.
randomBalancer struct { type randomBalancer struct {
commonBalancer commonBalancer
random *rand.Rand random *rand.Rand
} }
// RoundRobinBalancer implements a round-robin load balancing technique. // RoundRobinBalancer implements a round-robin load balancing technique.
roundRobinBalancer struct { type roundRobinBalancer struct {
commonBalancer commonBalancer
// tracking the index on `targets` slice for the next `*ProxyTarget` to be used // tracking the index on `targets` slice for the next `*ProxyTarget` to be used
i int i int
} }
)
var ( // DefaultProxyConfig is the default Proxy middleware config.
// DefaultProxyConfig is the default Proxy middleware config. var DefaultProxyConfig = ProxyConfig{
DefaultProxyConfig = ProxyConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, ContextKey: "target",
ContextKey: "target", }
}
)
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -359,12 +358,15 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
c.Set("_error", nil) c.Set("_error", nil)
} }
// This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request
// that Balancer may have replaced with c.SetRequest.
req = c.Request()
// Proxy // Proxy
switch { switch {
case c.IsWebSocket(): case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req) proxyRaw(tgt, c).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream": default: // even SSE requests
default:
proxyHTTP(tgt, c, config).ServeHTTP(res, req) proxyHTTP(tgt, c, config).ServeHTTP(res, req)
} }

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -9,39 +12,34 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
type ( // RateLimiterStore is the interface to be implemented by custom stores.
// RateLimiterStore is the interface to be implemented by custom stores. type RateLimiterStore interface {
RateLimiterStore interface { // Stores for the rate limiter have to implement the Allow method
// Stores for the rate limiter have to implement the Allow method Allow(identifier string) (bool, error)
Allow(identifier string) (bool, error) }
}
)
type ( // RateLimiterConfig defines the configuration for the rate limiter
// RateLimiterConfig defines the configuration for the rate limiter type RateLimiterConfig struct {
RateLimiterConfig struct { Skipper Skipper
Skipper Skipper BeforeFunc BeforeFunc
BeforeFunc BeforeFunc // IdentifierExtractor uses echo.Context to extract the identifier for a visitor
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor IdentifierExtractor Extractor
IdentifierExtractor Extractor // Store defines a store for the rate limiter
// Store defines a store for the rate limiter Store RateLimiterStore
Store RateLimiterStore // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error ErrorHandler func(context echo.Context, err error) error
ErrorHandler func(context echo.Context, err error) error // DenyHandler provides a handler to be called when RateLimiter denies access
// DenyHandler provides a handler to be called when RateLimiter denies access DenyHandler func(context echo.Context, identifier string, err error) error
DenyHandler func(context echo.Context, identifier string, err error) error }
}
// Extractor is used to extract data from echo.Context
Extractor func(context echo.Context) (string, error)
)
// errors // Extractor is used to extract data from echo.Context
var ( type Extractor func(context echo.Context) (string, error)
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
// ErrExtractorError denotes an error raised when extractor function is unsuccessful var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
) // ErrExtractorError denotes an error raised when extractor function is unsuccessful
var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
// DefaultRateLimiterConfig defines default values for RateLimiterConfig // DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{ var DefaultRateLimiterConfig = RateLimiterConfig{
@ -150,25 +148,24 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
} }
} }
type ( // RateLimiterMemoryStore is the built-in store implementation for RateLimiter
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter type RateLimiterMemoryStore struct {
RateLimiterMemoryStore struct { visitors map[string]*Visitor
visitors map[string]*Visitor mutex sync.Mutex
mutex sync.Mutex rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
burst int burst int
expiresIn time.Duration expiresIn time.Duration
lastCleanup time.Time lastCleanup time.Time
timeNow func() time.Time timeNow func() time.Time
} }
// Visitor signifies a unique user's limiter details
Visitor struct { // Visitor signifies a unique user's limiter details
*rate.Limiter type Visitor struct {
lastSeen time.Time *rate.Limiter
} lastSeen time.Time
) }
/* /*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -9,56 +12,52 @@ import (
"github.com/labstack/gommon/log" "github.com/labstack/gommon/log"
) )
type ( // LogErrorFunc defines a function for custom logging in the middleware.
type LogErrorFunc func(c echo.Context, err error, stack []byte) error
// RecoverConfig defines the config for Recover middleware.
type RecoverConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Size of the stack to be printed.
// Optional. Default value 4KB.
StackSize int `yaml:"stack_size"`
// DisableStackAll disables formatting stack traces of all other goroutines
// into buffer after the trace for the current goroutine.
// Optional. Default value false.
DisableStackAll bool `yaml:"disable_stack_all"`
// DisablePrintStack disables printing stack trace.
// Optional. Default value as false.
DisablePrintStack bool `yaml:"disable_print_stack"`
// LogLevel is log level to printing stack trace.
// Optional. Default value 0 (Print).
LogLevel log.Lvl
// LogErrorFunc defines a function for custom logging in the middleware. // LogErrorFunc defines a function for custom logging in the middleware.
LogErrorFunc func(c echo.Context, err error, stack []byte) error // If it's set you don't need to provide LogLevel for config.
// If this function returns nil, the centralized HTTPErrorHandler will not be called.
LogErrorFunc LogErrorFunc
// RecoverConfig defines the config for Recover middleware. // DisableErrorHandler disables the call to centralized HTTPErrorHandler.
RecoverConfig struct { // The recovered error is then passed back to upstream middleware, instead of swallowing the error.
// Skipper defines a function to skip middleware. // Optional. Default value false.
Skipper Skipper DisableErrorHandler bool `yaml:"disable_error_handler"`
}
// Size of the stack to be printed. // DefaultRecoverConfig is the default Recover middleware config.
// Optional. Default value 4KB. var DefaultRecoverConfig = RecoverConfig{
StackSize int `yaml:"stack_size"` Skipper: DefaultSkipper,
StackSize: 4 << 10, // 4 KB
// DisableStackAll disables formatting stack traces of all other goroutines DisableStackAll: false,
// into buffer after the trace for the current goroutine. DisablePrintStack: false,
// Optional. Default value false. LogLevel: 0,
DisableStackAll bool `yaml:"disable_stack_all"` LogErrorFunc: nil,
DisableErrorHandler: false,
// DisablePrintStack disables printing stack trace. }
// Optional. Default value as false.
DisablePrintStack bool `yaml:"disable_print_stack"`
// LogLevel is log level to printing stack trace.
// Optional. Default value 0 (Print).
LogLevel log.Lvl
// LogErrorFunc defines a function for custom logging in the middleware.
// If it's set you don't need to provide LogLevel for config.
// If this function returns nil, the centralized HTTPErrorHandler will not be called.
LogErrorFunc LogErrorFunc
// DisableErrorHandler disables the call to centralized HTTPErrorHandler.
// The recovered error is then passed back to upstream middleware, instead of swallowing the error.
// Optional. Default value false.
DisableErrorHandler bool `yaml:"disable_error_handler"`
}
)
var (
// DefaultRecoverConfig is the default Recover middleware config.
DefaultRecoverConfig = RecoverConfig{
Skipper: DefaultSkipper,
StackSize: 4 << 10, // 4 KB
DisableStackAll: false,
DisablePrintStack: false,
LogLevel: 0,
LogErrorFunc: nil,
DisableErrorHandler: false,
}
)
// Recover returns a middleware which recovers from panics anywhere in the chain // Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler. // and handles the control to the centralized HTTPErrorHandler.

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (

View File

@ -1,36 +1,34 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
) )
type ( // RequestIDConfig defines the config for RequestID middleware.
// RequestIDConfig defines the config for RequestID middleware. type RequestIDConfig struct {
RequestIDConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Generator defines a function to generate an ID. // Generator defines a function to generate an ID.
// Optional. Default value random.String(32). // Optional. Defaults to generator for random string of length 32.
Generator func() string Generator func() string
// RequestIDHandler defines a function which is executed for a request id. // RequestIDHandler defines a function which is executed for a request id.
RequestIDHandler func(echo.Context, string) RequestIDHandler func(echo.Context, string)
// TargetHeader defines what header to look for to populate the id // TargetHeader defines what header to look for to populate the id
TargetHeader string TargetHeader string
} }
)
var ( // DefaultRequestIDConfig is the default RequestID middleware config.
// DefaultRequestIDConfig is the default RequestID middleware config. var DefaultRequestIDConfig = RequestIDConfig{
DefaultRequestIDConfig = RequestIDConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Generator: generator,
Generator: generator, TargetHeader: echo.HeaderXRequestID,
TargetHeader: echo.HeaderXRequestID, }
}
)
// RequestID returns a X-Request-ID middleware. // RequestID returns a X-Request-ID middleware.
func RequestID() echo.MiddlewareFunc { func RequestID() echo.MiddlewareFunc {
@ -73,5 +71,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
} }
func generator() string { func generator() string {
return random.String(32) return randomString(32)
} }

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -8,6 +11,30 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
// Example for `slog` https://pkg.go.dev/log/slog
// logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogStatus: true,
// LogURI: true,
// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
// slog.String("uri", v.URI),
// slog.Int("status", v.Status),
// )
// } else {
// logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
// slog.String("uri", v.URI),
// slog.Int("status", v.Status),
// slog.String("err", v.Error.Error()),
// )
// }
// return nil
// },
// }))
//
// Example for `fmt.Printf` // Example for `fmt.Printf`
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogStatus: true, // LogStatus: true,

View File

@ -0,0 +1,44 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
//go:build !go1.20
package middleware
import (
"bufio"
"fmt"
"net"
"net/http"
)
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerFlush(rw http.ResponseWriter) error {
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case http.Flusher:
t.Flush()
return nil
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%w", http.ErrNotSupported)
}
}
}
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t.Hijack()
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
}
}
}

View File

@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
//go:build go1.20
package middleware
import (
"bufio"
"net"
"net/http"
)
func responseControllerFlush(rw http.ResponseWriter) error {
return http.NewResponseController(rw).Flush()
}
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw).Hijack()
}

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -6,37 +9,33 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // RewriteConfig defines the config for Rewrite middleware.
// RewriteConfig defines the config for Rewrite middleware. type RewriteConfig struct {
RewriteConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Rules defines the URL path rewrite rules. The values captured in asterisk can be // Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on. // retrieved by index e.g. $1, $2 and so on.
// Example: // Example:
// "/old": "/new", // "/old": "/new",
// "/api/*": "/$1", // "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1", // "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2", // "/users/*/orders/*": "/user/$1/order/$2",
// Required. // Required.
Rules map[string]string `yaml:"rules"` Rules map[string]string `yaml:"rules"`
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example: // Example:
// "^/old/[0.9]+/": "/new", // "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1", // "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` RegexRules map[*regexp.Regexp]string `yaml:"-"`
} }
)
var ( // DefaultRewriteConfig is the default Rewrite middleware config.
// DefaultRewriteConfig is the default Rewrite middleware config. var DefaultRewriteConfig = RewriteConfig{
DefaultRewriteConfig = RewriteConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, }
}
)
// Rewrite returns a Rewrite middleware. // Rewrite returns a Rewrite middleware.
// //

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -6,84 +9,80 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // SecureConfig defines the config for Secure middleware.
// SecureConfig defines the config for Secure middleware. type SecureConfig struct {
SecureConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// XSSProtection provides protection against cross-site scripting attack (XSS) // XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header. // by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block". // Optional. Default value "1; mode=block".
XSSProtection string `yaml:"xss_protection"` XSSProtection string `yaml:"xss_protection"`
// ContentTypeNosniff provides protection against overriding Content-Type // ContentTypeNosniff provides protection against overriding Content-Type
// header by setting the `X-Content-Type-Options` header. // header by setting the `X-Content-Type-Options` header.
// Optional. Default value "nosniff". // Optional. Default value "nosniff".
ContentTypeNosniff string `yaml:"content_type_nosniff"` ContentTypeNosniff string `yaml:"content_type_nosniff"`
// XFrameOptions can be used to indicate whether or not a browser should // XFrameOptions can be used to indicate whether or not a browser should
// be allowed to render a page in a <frame>, <iframe> or <object> . // be allowed to render a page in a <frame>, <iframe> or <object> .
// Sites can use this to avoid clickjacking attacks, by ensuring that their // Sites can use this to avoid clickjacking attacks, by ensuring that their
// content is not embedded into other sites.provides protection against // content is not embedded into other sites.provides protection against
// clickjacking. // clickjacking.
// Optional. Default value "SAMEORIGIN". // Optional. Default value "SAMEORIGIN".
// Possible values: // Possible values:
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself. // - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so. // - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin. // - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
XFrameOptions string `yaml:"x_frame_options"` XFrameOptions string `yaml:"x_frame_options"`
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how // HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
// long (in seconds) browsers should remember that this site is only to // long (in seconds) browsers should remember that this site is only to
// be accessed using HTTPS. This reduces your exposure to some SSL-stripping // be accessed using HTTPS. This reduces your exposure to some SSL-stripping
// man-in-the-middle (MITM) attacks. // man-in-the-middle (MITM) attacks.
// Optional. Default value 0. // Optional. Default value 0.
HSTSMaxAge int `yaml:"hsts_max_age"` HSTSMaxAge int `yaml:"hsts_max_age"`
// HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security` // HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security`
// header, excluding all subdomains from security policy. It has no effect // header, excluding all subdomains from security policy. It has no effect
// unless HSTSMaxAge is set to a non-zero value. // unless HSTSMaxAge is set to a non-zero value.
// Optional. Default value false. // Optional. Default value false.
HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"` HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"`
// ContentSecurityPolicy sets the `Content-Security-Policy` header providing // ContentSecurityPolicy sets the `Content-Security-Policy` header providing
// security against cross-site scripting (XSS), clickjacking and other code // security against cross-site scripting (XSS), clickjacking and other code
// injection attacks resulting from execution of malicious content in the // injection attacks resulting from execution of malicious content in the
// trusted web page context. // trusted web page context.
// Optional. Default value "". // Optional. Default value "".
ContentSecurityPolicy string `yaml:"content_security_policy"` ContentSecurityPolicy string `yaml:"content_security_policy"`
// CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead // CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead
// of the `Content-Security-Policy` header. This allows iterative updates of the // of the `Content-Security-Policy` header. This allows iterative updates of the
// content security policy by only reporting the violations that would // content security policy by only reporting the violations that would
// have occurred instead of blocking the resource. // have occurred instead of blocking the resource.
// Optional. Default value false. // Optional. Default value false.
CSPReportOnly bool `yaml:"csp_report_only"` CSPReportOnly bool `yaml:"csp_report_only"`
// HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security` // HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security`
// header, which enables the domain to be included in the HSTS preload list // header, which enables the domain to be included in the HSTS preload list
// maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/ // maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/
// Optional. Default value false. // Optional. Default value false.
HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"` HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"`
// ReferrerPolicy sets the `Referrer-Policy` header providing security against // ReferrerPolicy sets the `Referrer-Policy` header providing security against
// leaking potentially sensitive request paths to third parties. // leaking potentially sensitive request paths to third parties.
// Optional. Default value "". // Optional. Default value "".
ReferrerPolicy string `yaml:"referrer_policy"` ReferrerPolicy string `yaml:"referrer_policy"`
} }
)
var ( // DefaultSecureConfig is the default Secure middleware config.
// DefaultSecureConfig is the default Secure middleware config. var DefaultSecureConfig = SecureConfig{
DefaultSecureConfig = SecureConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, XSSProtection: "1; mode=block",
XSSProtection: "1; mode=block", ContentTypeNosniff: "nosniff",
ContentTypeNosniff: "nosniff", XFrameOptions: "SAMEORIGIN",
XFrameOptions: "SAMEORIGIN", HSTSPreloadEnabled: false,
HSTSPreloadEnabled: false, }
}
)
// Secure returns a Secure middleware. // Secure returns a Secure middleware.
// Secure middleware provides protection against cross-site scripting (XSS) attack, // Secure middleware provides protection against cross-site scripting (XSS) attack,

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -6,24 +9,20 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // TrailingSlashConfig defines the config for TrailingSlash middleware.
// TrailingSlashConfig defines the config for TrailingSlash middleware. type TrailingSlashConfig struct {
TrailingSlashConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Status code to be used when redirecting the request. // Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code. // Optional, but when provided the request is redirected using this code.
RedirectCode int `yaml:"redirect_code"` RedirectCode int `yaml:"redirect_code"`
} }
)
var ( // DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config. var DefaultTrailingSlashConfig = TrailingSlashConfig{
DefaultTrailingSlashConfig = TrailingSlashConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, }
}
)
// AddTrailingSlash returns a root level (before router) middleware which adds a // AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`. // trailing slash to the request `URL#Path`.

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -14,40 +17,38 @@ import (
"github.com/labstack/gommon/bytes" "github.com/labstack/gommon/bytes"
) )
type ( // StaticConfig defines the config for Static middleware.
// StaticConfig defines the config for Static middleware. type StaticConfig struct {
StaticConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Root directory from where the static content is served. // Root directory from where the static content is served.
// Required. // Required.
Root string `yaml:"root"` Root string `yaml:"root"`
// Index file for serving a directory. // Index file for serving a directory.
// Optional. Default value "index.html". // Optional. Default value "index.html".
Index string `yaml:"index"` Index string `yaml:"index"`
// Enable HTML5 mode by forwarding all not-found requests to root so that // Enable HTML5 mode by forwarding all not-found requests to root so that
// SPA (single-page application) can handle the routing. // SPA (single-page application) can handle the routing.
// Optional. Default value false. // Optional. Default value false.
HTML5 bool `yaml:"html5"` HTML5 bool `yaml:"html5"`
// Enable directory browsing. // Enable directory browsing.
// Optional. Default value false. // Optional. Default value false.
Browse bool `yaml:"browse"` Browse bool `yaml:"browse"`
// Enable ignoring of the base of the URL path. // Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group, // Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled // the filesystem path is not doubled
// Optional. Default value false. // Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"` IgnoreBase bool `yaml:"ignoreBase"`
// Filesystem provides access to the static content. // Filesystem provides access to the static content.
// Optional. Defaults to http.Dir(config.Root) // Optional. Defaults to http.Dir(config.Root)
Filesystem http.FileSystem `yaml:"-"` Filesystem http.FileSystem `yaml:"-"`
} }
)
const html = ` const html = `
<!DOCTYPE html> <!DOCTYPE html>
@ -121,13 +122,11 @@ const html = `
</html> </html>
` `
var ( // DefaultStaticConfig is the default Static middleware config.
// DefaultStaticConfig is the default Static middleware config. var DefaultStaticConfig = StaticConfig{
DefaultStaticConfig = StaticConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Index: "index.html",
Index: "index.html", }
}
)
// Static returns a Static middleware to serves static content from the provided // Static returns a Static middleware to serves static content from the provided
// root directory. // root directory.

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
//go:build !windows //go:build !windows
package middleware package middleware

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
@ -77,14 +80,12 @@ type TimeoutConfig struct {
Timeout time.Duration Timeout time.Duration
} }
var ( // DefaultTimeoutConfig is the default Timeout middleware config.
// DefaultTimeoutConfig is the default Timeout middleware config. var DefaultTimeoutConfig = TimeoutConfig{
DefaultTimeoutConfig = TimeoutConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Timeout: 0,
Timeout: 0, ErrorMessage: "",
ErrorMessage: "", }
}
)
// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler // Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler
// call runs for longer than its time limit. NB: timeout does not stop handler execution. // call runs for longer than its time limit. NB: timeout does not stop handler execution.

View File

@ -1,7 +1,14 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware package middleware
import ( import (
"bufio"
"crypto/rand"
"io"
"strings" "strings"
"sync"
) )
func matchScheme(domain, pattern string) bool { func matchScheme(domain, pattern string) bool {
@ -52,3 +59,45 @@ func matchSubdomain(domain, pattern string) bool {
} }
return false return false
} }
// https://tip.golang.org/doc/go1.19#:~:text=Read%20no%20longer%20buffers%20random%20data%20obtained%20from%20the%20operating%20system%20between%20calls
var randomReaderPool = sync.Pool{New: func() interface{} {
return bufio.NewReader(rand.Reader)
}}
const randomStringCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
const randomStringCharsetLen = 52 // len(randomStringCharset)
const randomStringMaxByte = 255 - (256 % randomStringCharsetLen)
func randomString(length uint8) string {
reader := randomReaderPool.Get().(*bufio.Reader)
defer randomReaderPool.Put(reader)
b := make([]byte, length)
r := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times
var i uint8 = 0
// security note:
// we can't just simply do b[i]=randomStringCharset[rb%len(randomStringCharset)],
// len(len(randomStringCharset)) is 52, and rb is [0, 255], 256 = 52 * 4 + 48.
// make the first 48 characters more possibly to be generated then others.
// So we have to skip bytes when rb > randomStringMaxByte
for {
_, err := io.ReadFull(reader, r)
if err != nil {
panic("unexpected error happened when reading from bufio.NewReader(crypto/rand.Reader)")
}
for _, rb := range r {
if rb > randomStringMaxByte {
// Skip this number to avoid bias.
continue
}
b[i] = randomStringCharset[rb%randomStringCharsetLen]
i++
if i == length {
return string(b)
}
}
}
}

View File

@ -1,25 +1,27 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (
"bufio" "bufio"
"errors"
"net" "net"
"net/http" "net/http"
) )
type ( // Response wraps an http.ResponseWriter and implements its interface to be used
// Response wraps an http.ResponseWriter and implements its interface to be used // by an HTTP handler to construct an HTTP response.
// by an HTTP handler to construct an HTTP response. // See: https://golang.org/pkg/net/http/#ResponseWriter
// See: https://golang.org/pkg/net/http/#ResponseWriter type Response struct {
Response struct { echo *Echo
echo *Echo beforeFuncs []func()
beforeFuncs []func() afterFuncs []func()
afterFuncs []func() Writer http.ResponseWriter
Writer http.ResponseWriter Status int
Status int Size int64
Size int64 Committed bool
Committed bool }
}
)
// NewResponse creates a new instance of Response. // NewResponse creates a new instance of Response.
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) { func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
@ -84,14 +86,17 @@ func (r *Response) Write(b []byte) (n int, err error) {
// buffered data to the client. // buffered data to the client.
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher) // See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
func (r *Response) Flush() { func (r *Response) Flush() {
r.Writer.(http.Flusher).Flush() err := responseControllerFlush(r.Writer)
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
} }
// Hijack implements the http.Hijacker interface to allow an HTTP handler to // Hijack implements the http.Hijacker interface to allow an HTTP handler to
// take over the connection. // take over the connection.
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker) // See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.Writer.(http.Hijacker).Hijack() return responseControllerHijack(r.Writer)
} }
// Unwrap returns the original http.ResponseWriter. // Unwrap returns the original http.ResponseWriter.

View File

@ -0,0 +1,44 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
//go:build !go1.20
package echo
import (
"bufio"
"fmt"
"net"
"net/http"
)
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerFlush(rw http.ResponseWriter) error {
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case http.Flusher:
t.Flush()
return nil
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%w", http.ErrNotSupported)
}
}
}
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t.Hijack()
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
}
}
}

View File

@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
//go:build go1.20
package echo
import (
"bufio"
"net"
"net/http"
)
func responseControllerFlush(rw http.ResponseWriter) error {
return http.NewResponseController(rw).Flush()
}
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw).Hijack()
}

View File

@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo package echo
import ( import (
@ -6,56 +9,58 @@ import (
"net/http" "net/http"
) )
type ( // Router is the registry of all registered routes for an `Echo` instance for
// Router is the registry of all registered routes for an `Echo` instance for // request matching and URL path parameter parsing.
// request matching and URL path parameter parsing. type Router struct {
Router struct { tree *node
tree *node routes map[string]*Route
routes map[string]*Route echo *Echo
echo *Echo }
}
node struct {
kind kind
label byte
prefix string
parent *node
staticChildren children
originalPath string
methods *routeMethods
paramChild *node
anyChild *node
paramsCount int
// isLeaf indicates that node does not have child routes
isLeaf bool
// isHandler indicates that node has at least one handler registered to it
isHandler bool
// notFoundHandler is handler registered with RouteNotFound method and is executed for 404 cases type node struct {
notFoundHandler *routeMethod kind kind
} label byte
kind uint8 prefix string
children []*node parent *node
routeMethod struct { staticChildren children
ppath string originalPath string
pnames []string methods *routeMethods
handler HandlerFunc paramChild *node
} anyChild *node
routeMethods struct { paramsCount int
connect *routeMethod // isLeaf indicates that node does not have child routes
delete *routeMethod isLeaf bool
get *routeMethod // isHandler indicates that node has at least one handler registered to it
head *routeMethod isHandler bool
options *routeMethod
patch *routeMethod // notFoundHandler is handler registered with RouteNotFound method and is executed for 404 cases
post *routeMethod notFoundHandler *routeMethod
propfind *routeMethod }
put *routeMethod
trace *routeMethod type kind uint8
report *routeMethod type children []*node
anyOther map[string]*routeMethod
allowHeader string type routeMethod struct {
} ppath string
) pnames []string
handler HandlerFunc
}
type routeMethods struct {
connect *routeMethod
delete *routeMethod
get *routeMethod
head *routeMethod
options *routeMethod
patch *routeMethod
post *routeMethod
propfind *routeMethod
put *routeMethod
trace *routeMethod
report *routeMethod
anyOther map[string]*routeMethod
allowHeader string
}
const ( const (
staticKind kind = iota staticKind kind = iota
@ -180,8 +185,18 @@ func (r *Router) Reverse(name string, params ...interface{}) string {
return uri.String() return uri.String()
} }
func normalizePathSlash(path string) string {
if path == "" {
path = "/"
} else if path[0] != '/' {
path = "/" + path
}
return path
}
func (r *Router) add(method, path, name string, h HandlerFunc) *Route { func (r *Router) add(method, path, name string, h HandlerFunc) *Route {
r.Add(method, path, h) path = normalizePathSlash(path)
r.insert(method, path, h)
route := &Route{ route := &Route{
Method: method, Method: method,
@ -194,13 +209,11 @@ func (r *Router) add(method, path, name string, h HandlerFunc) *Route {
// Add registers a new route for method and path with matching handler. // Add registers a new route for method and path with matching handler.
func (r *Router) Add(method, path string, h HandlerFunc) { func (r *Router) Add(method, path string, h HandlerFunc) {
// Validate path r.insert(method, normalizePathSlash(path), h)
if path == "" { }
path = "/"
} func (r *Router) insert(method, path string, h HandlerFunc) {
if path[0] != '/' { path = normalizePathSlash(path)
path = "/" + path
}
pnames := []string{} // Param names pnames := []string{} // Param names
ppath := path // Pristine path ppath := path // Pristine path
@ -219,7 +232,7 @@ func (r *Router) Add(method, path string, h HandlerFunc) {
} }
j := i + 1 j := i + 1
r.insert(method, path[:i], staticKind, routeMethod{}) r.insertNode(method, path[:i], staticKind, routeMethod{})
for ; i < lcpIndex && path[i] != '/'; i++ { for ; i < lcpIndex && path[i] != '/'; i++ {
} }
@ -229,21 +242,21 @@ func (r *Router) Add(method, path string, h HandlerFunc) {
if i == lcpIndex { if i == lcpIndex {
// path node is last fragment of route path. ie. `/users/:id` // path node is last fragment of route path. ie. `/users/:id`
r.insert(method, path[:i], paramKind, routeMethod{ppath, pnames, h}) r.insertNode(method, path[:i], paramKind, routeMethod{ppath, pnames, h})
} else { } else {
r.insert(method, path[:i], paramKind, routeMethod{}) r.insertNode(method, path[:i], paramKind, routeMethod{})
} }
} else if path[i] == '*' { } else if path[i] == '*' {
r.insert(method, path[:i], staticKind, routeMethod{}) r.insertNode(method, path[:i], staticKind, routeMethod{})
pnames = append(pnames, "*") pnames = append(pnames, "*")
r.insert(method, path[:i+1], anyKind, routeMethod{ppath, pnames, h}) r.insertNode(method, path[:i+1], anyKind, routeMethod{ppath, pnames, h})
} }
} }
r.insert(method, path, staticKind, routeMethod{ppath, pnames, h}) r.insertNode(method, path, staticKind, routeMethod{ppath, pnames, h})
} }
func (r *Router) insert(method, path string, t kind, rm routeMethod) { func (r *Router) insertNode(method, path string, t kind, rm routeMethod) {
// Adjust max param // Adjust max param
paramLen := len(rm.pnames) paramLen := len(rm.pnames)
if *r.echo.maxParam < paramLen { if *r.echo.maxParam < paramLen {

View File

@ -1,48 +0,0 @@
package random
import (
"math/rand"
"strings"
"time"
)
type (
Random struct {
}
)
// Charsets
const (
Uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
Lowercase = "abcdefghijklmnopqrstuvwxyz"
Alphabetic = Uppercase + Lowercase
Numeric = "0123456789"
Alphanumeric = Alphabetic + Numeric
Symbols = "`" + `~!@#$%^&*()-_+={}[]|\;:"<>,./?`
Hex = Numeric + "abcdef"
)
var (
global = New()
)
func New() *Random {
rand.Seed(time.Now().UnixNano())
return new(Random)
}
func (r *Random) String(length uint8, charsets ...string) string {
charset := strings.Join(charsets, "")
if charset == "" {
charset = Alphanumeric
}
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Int63()%int64(len(charset))]
}
return string(b)
}
func String(length uint8, charsets ...string) string {
return global.String(length, charsets...)
}

View File

@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier same "printed page" as the copyright notice for easier
identification within third-party archives. identification within third-party archives.
Copyright 2012-2023 Li Kexian Copyright 2012-2024 Li Kexian
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -200,6 +200,6 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
APPENDIX: Copyright 2012-2023 Li Kexian APPENDIX: Copyright 2012-2024 Li Kexian
https://www.likexian.com/ https://www.likexian.com/

View File

@ -85,7 +85,7 @@ b := assert.If(a == 1, true, false)
## License ## License
Copyright 2012-2023 [Li Kexian](https://www.likexian.com/) Copyright 2012-2024 [Li Kexian](https://www.likexian.com/)
Licensed under the Apache License 2.0 Licensed under the Apache License 2.0

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2023 Li Kexian * Copyright 2012-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2023 Li Kexian * Copyright 2012-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -43,7 +43,7 @@ if err != nil {
## License ## License
Copyright 2012-2023 [Li Kexian](https://www.likexian.com/) Copyright 2012-2024 [Li Kexian](https://www.likexian.com/)
Licensed under the Apache License 2.0 Licensed under the Apache License 2.0

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2023 Li Kexian * Copyright 2012-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -34,7 +34,7 @@ fmt.Println("new array:", array)
## License ## License
Copyright 2012-2023 [Li Kexian](https://www.likexian.com/) Copyright 2012-2024 [Li Kexian](https://www.likexian.com/)
Licensed under the Apache License 2.0 Licensed under the Apache License 2.0

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2023 Li Kexian * Copyright 2012-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier same "printed page" as the copyright notice for easier
identification within third-party archives. identification within third-party archives.
Copyright 2014-2023 Li Kexian Copyright 2014-2024 Li Kexian
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -200,6 +200,6 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
APPENDIX: Copyright 2014-2023 Li Kexian APPENDIX: Copyright 2014-2024 Li Kexian
https://www.likexian.com/ https://www.likexian.com/

View File

@ -71,7 +71,7 @@ Please refer to [whois](https://github.com/likexian/whois)
## License ## License
Copyright 2014-2023 [Li Kexian](https://www.likexian.com/) Copyright 2014-2024 [Li Kexian](https://www.likexian.com/)
Licensed under the Apache License 2.0 Licensed under the Apache License 2.0

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2014-2023 Li Kexian * Copyright 2014-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2014-2023 Li Kexian * Copyright 2014-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -153,6 +153,8 @@ func Parse(text string) (whoisInfo WhoisInfo, err error) { //nolint:cyclop
if !strings.Contains(name, " ") { if !strings.Contains(name, " ") {
if name == "registrar" { if name == "registrar" {
name += " name" name += " name"
} else if domain.Extension == "dk" {
name = "registrant " + name
} else { } else {
name += " organization" name += " organization"
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2014-2023 Li Kexian * Copyright 2014-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -91,6 +91,8 @@ func Prepare(text, ext string) (string, bool) { //nolint:cyclop
return prepareBY(text), true return prepareBY(text), true
case "ua": case "ua":
return prepareUA(text), true return prepareUA(text), true
case "at":
return prepareAT(text), true
default: default:
return text, false return text, false
} }
@ -685,7 +687,7 @@ func prepareRU(text string) string {
return result return result
} }
var prepareJPreplacerRx = regexp.MustCompile(`\n\[(.+?)\][\ ]*(.+?)?`) var prepareJPreplacerRx = regexp.MustCompile(`\n(?:\w+\.\s)?\[(.+?)\][\ ]*(.+?)?`)
// prepareJP do prepare the .jp domain // prepareJP do prepare the .jp domain
func prepareJP(text string) string { func prepareJP(text string) string {
@ -712,6 +714,7 @@ func prepareJP(text string) string {
if strings.ToLower(token) == "registrant" { if strings.ToLower(token) == "registrant" {
v = fmt.Sprintf("registrant name: %s", vs[1]) v = fmt.Sprintf("registrant name: %s", vs[1])
} }
v = prepareSecondLevelJP(v, token, vs[1])
} else { } else {
if token == addressToken { if token == addressToken {
result += ", " + v result += ", " + v
@ -724,6 +727,29 @@ func prepareJP(text string) string {
return result return result
} }
// prepareJP prepares specific mappings for second level .jp domains
// examples include:
// - co.jp
// - ac.jp
// - go.jp
// - or.jp
// - ad.jp
// - ne.jp
// - gr.jp
// - ed.jp
func prepareSecondLevelJP(original string, token string, value string) string {
if strings.ToLower(token) == "administrative contact" {
return fmt.Sprintf("Administrative Contact ID: %s", strings.TrimSpace(value))
}
if strings.ToLower(token) == "technical contact" {
return fmt.Sprintf("Technical Contact ID: %s", strings.TrimSpace(value))
}
if strings.ToLower(token) == "organization" || strings.ToLower(token) == "network service name" {
return fmt.Sprintf("Registrant Organization: %s", strings.TrimSpace(value))
}
return original
}
// prepareUK do prepare the .uk domain // prepareUK do prepare the .uk domain
func prepareUK(text string) string { func prepareUK(text string) string {
tokens := map[string]string{ tokens := map[string]string{
@ -1335,3 +1361,78 @@ func prepareUA(text string) string {
return result return result
} }
// prepareAT prepares the .at domain
func prepareAT(text string) string {
result := ""
registrantID := ""
techID := ""
tokens := map[string]string{
"street address": "address",
"postal code": "address",
"city": "address",
"country": "address",
"e-mail": "email",
"nic-hdl": "id",
"personname": "name",
}
formatLine := func(line, token string) string {
before, after, _ := strings.Cut(line, ":")
key := strings.TrimSpace(before)
if t, ok := tokens[key]; ok {
key = t
}
val := strings.TrimSpace(after)
return fmt.Sprintf("%s %s: %s", token, key, val)
}
for _, v := range strings.Split(text, "\n\n") {
v = strings.TrimSpace(v)
if strings.HasPrefix(v, "%") {
continue
}
if strings.Contains(v, ":") {
b := strings.Split(v, "\n")
if strings.HasPrefix(b[0], "domain") {
for _, l := range b {
w := ""
if before, after, ok := strings.Cut(l, ":"); ok {
key := strings.TrimSpace(before)
val := strings.TrimSpace(after)
switch key {
case "domain":
w = fmt.Sprintf("%s: %s", "domain name", val)
case "registrant":
registrantID = val
case "tech-c":
techID = val
case "changed":
w = fmt.Sprintf("%s: %s", "updated_date", val)
case "nserver":
w = fmt.Sprintf("%s: %s", "name_servers", val)
default:
w = fmt.Sprintf("domain %s: %s", key, val)
}
if w != "" {
result += w + "\n"
}
}
}
} else if strings.HasPrefix(b[0], "personname") {
token := ""
if strings.Contains(v, registrantID) {
token = "registrant"
} else if strings.Contains(v, techID) {
token = "technical contact"
}
for _, l := range strings.Split(v, "\n") {
result += formatLine(l, token) + "\n"
}
}
}
}
return result
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2014-2023 Li Kexian * Copyright 2014-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -152,7 +152,9 @@ var (
"registrant contact state province": "registrant_state_province", "registrant contact state province": "registrant_state_province",
"registrant zipcode": "registrant_postal_code", "registrant zipcode": "registrant_postal_code",
"registrant zip code": "registrant_postal_code", "registrant zip code": "registrant_postal_code",
"registrant postalcode": "registrant_postal_code",
"registrant postal code": "registrant_postal_code", "registrant postal code": "registrant_postal_code",
"registrant contact postalcode": "registrant_postal_code",
"registrant contact postal code": "registrant_postal_code", "registrant contact postal code": "registrant_postal_code",
"registrant country": "registrant_country", "registrant country": "registrant_country",
"registrant country economy": "registrant_country", "registrant country economy": "registrant_country",

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2014-2023 Li Kexian * Copyright 2014-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2014-2023 Li Kexian * Copyright 2014-2024 Li Kexian
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -118,16 +118,20 @@ func keys(m map[string]string) []string {
// parseDateString attempts to parse a given date using a collection of common // parseDateString attempts to parse a given date using a collection of common
// format strings. Date formats containing time components are tried first // format strings. Date formats containing time components are tried first
// before attempts are made using date-only formats. // before attempts are made using date-only formats.
func parseDateString(dateString string) (time.Time, error) { func parseDateString(datetime string) (time.Time, error) {
formats := [...]string{ datetime = strings.Trim(datetime, ".")
datetime = strings.ReplaceAll(datetime, ". ", "-")
formats := [...]string{
// Date & time formats // Date & time formats
"2006-01-02 15:04:05", "2006-01-02 15:04:05",
"2006.01.02 15:04:05",
"02/01/2006 15:04:05", "02/01/2006 15:04:05",
"02.01.2006 15:04:05", "02.01.2006 15:04:05",
"02.1.2006 15:04:05", "02.1.2006 15:04:05",
"2.1.2006 15:04:05", "2.1.2006 15:04:05",
"02-Jan-2006 15:04:05", "02-Jan-2006 15:04:05",
"20060102 15:04:05",
time.ANSIC, time.ANSIC,
time.Stamp, time.Stamp,
time.StampMilli, time.StampMilli,
@ -151,24 +155,25 @@ func parseDateString(dateString string) (time.Time, error) {
// Date only formats // Date only formats
"2006-01-02", "2006-01-02",
"2006. 01. 02.",
"02-Jan-2006", "02-Jan-2006",
"02.01.2006", "02.01.2006",
"02-01-2006", "02-01-2006",
"January _2 2006", "January _2 2006",
"Mon Jan _2 2006",
"02/01/2006", "02/01/2006",
"01/02/2006", "01/02/2006",
"2006/01/02",
"2006-Jan-02", "2006-Jan-02",
"2006-Jan-02.", "before Jan-2006",
} }
for _, format := range formats { for _, format := range formats {
result, err := time.Parse(format, dateString) result, err := time.Parse(format, datetime)
if err != nil { if err != nil {
continue continue
} }
return result, nil return result, nil
} }
return time.Now(), fmt.Errorf("could not parse %s as a date", dateString) return time.Now(), fmt.Errorf("could not parse %s as a date", datetime)
} }

View File

@ -55,7 +55,7 @@ type Client struct {
// Version returns package version // Version returns package version
func Version() string { func Version() string {
return "1.15.0" return "1.15.2"
} }
// Author returns package author // Author returns package author

View File

@ -1,6 +1,7 @@
//go:build (darwin || freebsd || openbsd || netbsd || dragonfly || hurd) && !appengine //go:build (darwin || freebsd || openbsd || netbsd || dragonfly || hurd) && !appengine && !tinygo
// +build darwin freebsd openbsd netbsd dragonfly hurd // +build darwin freebsd openbsd netbsd dragonfly hurd
// +build !appengine // +build !appengine
// +build !tinygo
package isatty package isatty

View File

@ -1,5 +1,6 @@
//go:build appengine || js || nacl || wasm //go:build (appengine || js || nacl || tinygo || wasm) && !windows
// +build appengine js nacl wasm // +build appengine js nacl tinygo wasm
// +build !windows
package isatty package isatty

View File

@ -1,6 +1,7 @@
//go:build (linux || aix || zos) && !appengine //go:build (linux || aix || zos) && !appengine && !tinygo
// +build linux aix zos // +build linux aix zos
// +build !appengine // +build !appengine
// +build !tinygo
package isatty package isatty

View File

@ -3,7 +3,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.12 //go:build go1.12
// +build go1.12
package acme package acme

View File

@ -20,41 +20,44 @@ import (
// TODO: Benchmark to determine if the pools are necessary. The GC may have // TODO: Benchmark to determine if the pools are necessary. The GC may have
// improved enough that we can instead allocate chunks like this: // improved enough that we can instead allocate chunks like this:
// make([]byte, max(16<<10, expectedBytesRemaining)) // make([]byte, max(16<<10, expectedBytesRemaining))
var ( var dataChunkPools = [...]sync.Pool{
dataChunkSizeClasses = []int{ {New: func() interface{} { return new([1 << 10]byte) }},
1 << 10, {New: func() interface{} { return new([2 << 10]byte) }},
2 << 10, {New: func() interface{} { return new([4 << 10]byte) }},
4 << 10, {New: func() interface{} { return new([8 << 10]byte) }},
8 << 10, {New: func() interface{} { return new([16 << 10]byte) }},
16 << 10, }
}
dataChunkPools = [...]sync.Pool{
{New: func() interface{} { return make([]byte, 1<<10) }},
{New: func() interface{} { return make([]byte, 2<<10) }},
{New: func() interface{} { return make([]byte, 4<<10) }},
{New: func() interface{} { return make([]byte, 8<<10) }},
{New: func() interface{} { return make([]byte, 16<<10) }},
}
)
func getDataBufferChunk(size int64) []byte { func getDataBufferChunk(size int64) []byte {
i := 0 switch {
for ; i < len(dataChunkSizeClasses)-1; i++ { case size <= 1<<10:
if size <= int64(dataChunkSizeClasses[i]) { return dataChunkPools[0].Get().(*[1 << 10]byte)[:]
break case size <= 2<<10:
} return dataChunkPools[1].Get().(*[2 << 10]byte)[:]
case size <= 4<<10:
return dataChunkPools[2].Get().(*[4 << 10]byte)[:]
case size <= 8<<10:
return dataChunkPools[3].Get().(*[8 << 10]byte)[:]
default:
return dataChunkPools[4].Get().(*[16 << 10]byte)[:]
} }
return dataChunkPools[i].Get().([]byte)
} }
func putDataBufferChunk(p []byte) { func putDataBufferChunk(p []byte) {
for i, n := range dataChunkSizeClasses { switch len(p) {
if len(p) == n { case 1 << 10:
dataChunkPools[i].Put(p) dataChunkPools[0].Put((*[1 << 10]byte)(p))
return case 2 << 10:
} dataChunkPools[1].Put((*[2 << 10]byte)(p))
case 4 << 10:
dataChunkPools[2].Put((*[4 << 10]byte)(p))
case 8 << 10:
dataChunkPools[3].Put((*[8 << 10]byte)(p))
case 16 << 10:
dataChunkPools[4].Put((*[16 << 10]byte)(p))
default:
panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
} }
panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
} }
// dataBuffer is an io.ReadWriter backed by a list of data chunks. // dataBuffer is an io.ReadWriter backed by a list of data chunks.

View File

@ -1510,13 +1510,12 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
} }
func (fr *Framer) maxHeaderStringLen() int { func (fr *Framer) maxHeaderStringLen() int {
v := fr.maxHeaderListSize() v := int(fr.maxHeaderListSize())
if uint32(int(v)) == v { if v < 0 {
return int(v) // If maxHeaderListSize overflows an int, use no limit (0).
return 0
} }
// They had a crazy big number for MaxHeaderBytes anyway, return v
// so give them unlimited header lengths:
return 0
} }
// readMetaFrame returns 0 or more CONTINUATION frames from fr and // readMetaFrame returns 0 or more CONTINUATION frames from fr and
@ -1565,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if size > remainSize { if size > remainSize {
hdec.SetEmitEnabled(false) hdec.SetEmitEnabled(false)
mh.Truncated = true mh.Truncated = true
remainSize = 0
return return
} }
remainSize -= size remainSize -= size
@ -1577,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
var hc headersOrContinuation = hf var hc headersOrContinuation = hf
for { for {
frag := hc.HeaderBlockFragment() frag := hc.HeaderBlockFragment()
// Avoid parsing large amounts of headers that we will then discard.
// If the sender exceeds the max header list size by too much,
// skip parsing the fragment and close the connection.
//
// "Too much" is either any CONTINUATION frame after we've already
// exceeded the max header list size (in which case remainSize is 0),
// or a frame whose encoded size is more than twice the remaining
// header list bytes we're willing to accept.
if int64(len(frag)) > int64(2*remainSize) {
if VerboseLogs {
log.Printf("http2: header list too large")
}
// It would be nice to send a RST_STREAM before sending the GOAWAY,
// but the structure of the server's frame writer makes this difficult.
return nil, ConnectionError(ErrCodeProtocol)
}
// Also close the connection after any CONTINUATION frame following an
// invalid header, since we stop tracking the size of the headers after
// an invalid one.
if invalid != nil {
if VerboseLogs {
log.Printf("http2: invalid header: %v", invalid)
}
// It would be nice to send a RST_STREAM before sending the GOAWAY,
// but the structure of the server's frame writer makes this difficult.
return nil, ConnectionError(ErrCodeProtocol)
}
if _, err := hdec.Write(frag); err != nil { if _, err := hdec.Write(frag); err != nil {
return nil, ConnectionError(ErrCodeCompression) return nil, ConnectionError(ErrCodeCompression)
} }

View File

@ -1,30 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.11
// +build go1.11
package http2
import (
"net/http/httptrace"
"net/textproto"
)
func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
return trace != nil && trace.WroteHeaderField != nil
}
func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(k, []string{v})
}
}
func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
if trace != nil {
return trace.Got1xxResponse
}
return nil
}

View File

@ -1,27 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.15
// +build go1.15
package http2
import (
"context"
"crypto/tls"
)
// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
// connection.
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
cn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
return tlsCn, nil
}

View File

@ -1,17 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.18
// +build go1.18
package http2
import (
"crypto/tls"
"net"
)
func tlsUnderlyingConn(tc *tls.Conn) net.Conn {
return tc.NetConn()
}

View File

@ -1,21 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package http2
import (
"net/http/httptrace"
"net/textproto"
)
func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { return false }
func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {}
func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
return nil
}

View File

@ -1,31 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.15
// +build !go1.15
package http2
import (
"context"
"crypto/tls"
)
// dialTLSWithContext opens a TLS connection.
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
cn, err := tls.Dial(network, addr, cfg)
if err != nil {
return nil, err
}
if err := cn.Handshake(); err != nil {
return nil, err
}
if cfg.InsecureSkipVerify {
return cn, nil
}
if err := cn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
return cn, nil
}

View File

@ -1,17 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.18
// +build !go1.18
package http2
import (
"crypto/tls"
"net"
)
func tlsUnderlyingConn(tc *tls.Conn) net.Conn {
return nil
}

View File

@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) {
} }
} }
var errClosedPipeWrite = errors.New("write on closed buffer") var (
errClosedPipeWrite = errors.New("write on closed buffer")
errUninitializedPipeWrite = errors.New("write on uninitialized buffer")
)
// Write copies bytes from p into the buffer and wakes a reader. // Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold. // It is an error to write more data than the buffer can hold.
@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) {
if p.err != nil || p.breakErr != nil { if p.err != nil || p.breakErr != nil {
return 0, errClosedPipeWrite return 0, errClosedPipeWrite
} }
// pipe.setBuffer is never invoked, leaving the buffer uninitialized.
// We shouldn't try to write to an uninitialized pipe,
// but returning an error is better than panicking.
if p.b == nil {
return 0, errUninitializedPipeWrite
}
return p.b.Write(d) return p.b.Write(d)
} }

View File

@ -124,6 +124,7 @@ type Server struct {
// IdleTimeout specifies how long until idle clients should be // IdleTimeout specifies how long until idle clients should be
// closed with a GOAWAY frame. PING frames are not considered // closed with a GOAWAY frame. PING frames are not considered
// activity for the purposes of IdleTimeout. // activity for the purposes of IdleTimeout.
// If zero or negative, there is no timeout.
IdleTimeout time.Duration IdleTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow // MaxUploadBufferPerConnection is the size of the initial flow
@ -434,7 +435,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// passes the connection off to us with the deadline already set. // passes the connection off to us with the deadline already set.
// Write deadlines are set per stream in serverConn.newStream. // Write deadlines are set per stream in serverConn.newStream.
// Disarm the net.Conn write deadline here. // Disarm the net.Conn write deadline here.
if sc.hs.WriteTimeout != 0 { if sc.hs.WriteTimeout > 0 {
sc.conn.SetWriteDeadline(time.Time{}) sc.conn.SetWriteDeadline(time.Time{})
} }
@ -581,9 +582,11 @@ type serverConn struct {
advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
curClientStreams uint32 // number of open streams initiated by the client curClientStreams uint32 // number of open streams initiated by the client
curPushedStreams uint32 // number of open streams initiated by server push curPushedStreams uint32 // number of open streams initiated by server push
curHandlers uint32 // number of running handler goroutines
maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests
maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes
streams map[uint32]*stream streams map[uint32]*stream
unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32 initialStreamSendWindowSize int32
maxFrameSize int32 maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default) peerMaxHeaderListSize uint32 // zero means unknown (default)
@ -922,7 +925,7 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateActive) sc.setConnState(http.StateActive)
sc.setConnState(http.StateIdle) sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout != 0 { if sc.srv.IdleTimeout > 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop() defer sc.idleTimer.Stop()
} }
@ -981,6 +984,8 @@ func (sc *serverConn) serve() {
return return
case gracefulShutdownMsg: case gracefulShutdownMsg:
sc.startGracefulShutdownInternal() sc.startGracefulShutdownInternal()
case handlerDoneMsg:
sc.handlerDone()
default: default:
panic("unknown timer") panic("unknown timer")
} }
@ -1020,6 +1025,7 @@ var (
idleTimerMsg = new(serverMessage) idleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage) shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage) gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage)
) )
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
@ -1632,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id) delete(sc.streams, st.id)
if len(sc.streams) == 0 { if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle) sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout != 0 { if sc.srv.IdleTimeout > 0 {
sc.idleTimer.Reset(sc.srv.IdleTimeout) sc.idleTimer.Reset(sc.srv.IdleTimeout)
} }
if h1ServerKeepAlivesDisabled(sc.hs) { if h1ServerKeepAlivesDisabled(sc.hs) {
@ -1892,9 +1898,11 @@ func (st *stream) copyTrailersToHandlerRequest() {
// onReadTimeout is run on its own goroutine (from time.AfterFunc) // onReadTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's ReadTimeout has fired. // when the stream's ReadTimeout has fired.
func (st *stream) onReadTimeout() { func (st *stream) onReadTimeout() {
// Wrap the ErrDeadlineExceeded to avoid callers depending on us if st.body != nil {
// returning the bare error. // Wrap the ErrDeadlineExceeded to avoid callers depending on us
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded)) // returning the bare error.
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
}
} }
// onWriteTimeout is run on its own goroutine (from time.AfterFunc) // onWriteTimeout is run on its own goroutine (from time.AfterFunc)
@ -2010,15 +2018,12 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// similar to how the http1 server works. Here it's // similar to how the http1 server works. Here it's
// technically more like the http1 Server's ReadHeaderTimeout // technically more like the http1 Server's ReadHeaderTimeout
// (in Go 1.8), though. That's a more sane option anyway. // (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout != 0 { if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
if st.body != nil { st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
} }
go sc.runHandler(rw, req, handler) return sc.scheduleHandler(id, rw, req, handler)
return nil
} }
func (sc *serverConn) upgradeRequest(req *http.Request) { func (sc *serverConn) upgradeRequest(req *http.Request) {
@ -2034,10 +2039,14 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
// Disable any read deadline set by the net/http package // Disable any read deadline set by the net/http package
// prior to the upgrade. // prior to the upgrade.
if sc.hs.ReadTimeout != 0 { if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
} }
// This is the first request on the connection,
// so start the handler directly rather than going
// through scheduleHandler.
sc.curHandlers++
go sc.runHandler(rw, req, sc.handler.ServeHTTP) go sc.runHandler(rw, req, sc.handler.ServeHTTP)
} }
@ -2108,7 +2117,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.flow.conn = &sc.flow // link to conn-level counter st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize) st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize()) st.inflow.init(sc.srv.initialStreamRecvWindowSize())
if sc.hs.WriteTimeout != 0 { if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
} }
@ -2278,8 +2287,62 @@ func (sc *serverConn) newResponseWriter(st *stream, req *http.Request) *response
return &responseWriter{rws: rws} return &responseWriter{rws: rws}
} }
type unstartedHandler struct {
streamID uint32
rw *responseWriter
req *http.Request
handler func(http.ResponseWriter, *http.Request)
}
// scheduleHandler starts a handler goroutine,
// or schedules one to start as soon as an existing handler finishes.
func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) error {
sc.serveG.check()
maxHandlers := sc.advMaxStreams
if sc.curHandlers < maxHandlers {
sc.curHandlers++
go sc.runHandler(rw, req, handler)
return nil
}
if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) {
return sc.countError("too_many_early_resets", ConnectionError(ErrCodeEnhanceYourCalm))
}
sc.unstartedHandlers = append(sc.unstartedHandlers, unstartedHandler{
streamID: streamID,
rw: rw,
req: req,
handler: handler,
})
return nil
}
func (sc *serverConn) handlerDone() {
sc.serveG.check()
sc.curHandlers--
i := 0
maxHandlers := sc.advMaxStreams
for ; i < len(sc.unstartedHandlers); i++ {
u := sc.unstartedHandlers[i]
if sc.streams[u.streamID] == nil {
// This stream was reset before its goroutine had a chance to start.
continue
}
if sc.curHandlers >= maxHandlers {
break
}
sc.curHandlers++
go sc.runHandler(u.rw, u.req, u.handler)
sc.unstartedHandlers[i] = unstartedHandler{} // don't retain references
}
sc.unstartedHandlers = sc.unstartedHandlers[i:]
if len(sc.unstartedHandlers) == 0 {
sc.unstartedHandlers = nil
}
}
// Run on its own goroutine. // Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
defer sc.sendServeMsg(handlerDoneMsg)
didPanic := true didPanic := true
defer func() { defer func() {
rw.rws.stream.cancelCtx() rw.rws.stream.cancelCtx()
@ -2487,7 +2550,6 @@ type responseWriterState struct {
wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
sentHeader bool // have we sent the header frame? sentHeader bool // have we sent the header frame?
handlerDone bool // handler has finished handlerDone bool // handler has finished
dirty bool // a Write failed; don't reuse this responseWriterState
sentContentLen int64 // non-zero if handler set a Content-Length header sentContentLen int64 // non-zero if handler set a Content-Length header
wroteBytes int64 wroteBytes int64
@ -2607,7 +2669,6 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
date: date, date: date,
}) })
if err != nil { if err != nil {
rws.dirty = true
return 0, err return 0, err
} }
if endStream { if endStream {
@ -2628,7 +2689,6 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
if len(p) > 0 || endStream { if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream. // only send a 0 byte DATA frame if we're ending the stream.
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
rws.dirty = true
return 0, err return 0, err
} }
} }
@ -2640,9 +2700,6 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
trailers: rws.trailers, trailers: rws.trailers,
endStream: true, endStream: true,
}) })
if err != nil {
rws.dirty = true
}
return len(p), err return len(p), err
} }
return len(p), nil return len(p), nil
@ -2858,14 +2915,12 @@ func (rws *responseWriterState) writeHeader(code int) {
h.Del("Transfer-Encoding") h.Del("Transfer-Encoding")
} }
if rws.conn.writeHeaders(rws.stream, &writeResHeaders{ rws.conn.writeHeaders(rws.stream, &writeResHeaders{
streamID: rws.stream.id, streamID: rws.stream.id,
httpResCode: code, httpResCode: code,
h: h, h: h,
endStream: rws.handlerDone && !rws.hasTrailers(), endStream: rws.handlerDone && !rws.hasTrailers(),
}) != nil { })
rws.dirty = true
}
return return
} }
@ -2930,19 +2985,10 @@ func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int,
func (w *responseWriter) handlerDone() { func (w *responseWriter) handlerDone() {
rws := w.rws rws := w.rws
dirty := rws.dirty
rws.handlerDone = true rws.handlerDone = true
w.Flush() w.Flush()
w.rws = nil w.rws = nil
if !dirty { responseWriterStatePool.Put(rws)
// Only recycle the pool if all prior Write calls to
// the serverConn goroutine completed successfully. If
// they returned earlier due to resets from the peer
// there might still be write goroutines outstanding
// from the serverConn referencing the rws memory. See
// issue 20704.
responseWriterStatePool.Put(rws)
}
} }
// Push errors. // Push errors.
@ -3125,6 +3171,7 @@ func (sc *serverConn) startPush(msg *startPushRequest) {
panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
} }
sc.curHandlers++
go sc.runHandler(rw, req, sc.handler.ServeHTTP) go sc.runHandler(rw, req, sc.handler.ServeHTTP)
return promisedID, nil return promisedID, nil
} }

331
vendor/golang.org/x/net/http2/testsync.go generated vendored Normal file
View File

@ -0,0 +1,331 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"context"
"sync"
"time"
)
// testSyncHooks coordinates goroutines in tests.
//
// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
// - the goroutine running RoundTrip;
// - the clientStream.doRequest goroutine, which writes the request; and
// - the clientStream.readLoop goroutine, which reads the response.
//
// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
// are blocked waiting for some condition such as reading the Request.Body or waiting for
// flow control to become available.
//
// The testSyncHooks also manage timers and synthetic time in tests.
// This permits us to, for example, start a request and cause it to time out waiting for
// response headers without resorting to time.Sleep calls.
type testSyncHooks struct {
// active/inactive act as a mutex and condition variable.
//
// - neither chan contains a value: testSyncHooks is locked.
// - active contains a value: unlocked, and at least one goroutine is not blocked
// - inactive contains a value: unlocked, and all goroutines are blocked
active chan struct{}
inactive chan struct{}
// goroutine counts
total int // total goroutines
condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
blocked []*testBlockedGoroutine // otherwise blocked
// fake time
now time.Time
timers []*fakeTimer
// Transport testing: Report various events.
newclientconn func(*ClientConn)
newstream func(*clientStream)
}
// testBlockedGoroutine is a blocked goroutine.
type testBlockedGoroutine struct {
f func() bool // blocked until f returns true
ch chan struct{} // closed when unblocked
}
func newTestSyncHooks() *testSyncHooks {
h := &testSyncHooks{
active: make(chan struct{}, 1),
inactive: make(chan struct{}, 1),
condwait: map[*sync.Cond]int{},
}
h.inactive <- struct{}{}
h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
return h
}
// lock acquires the testSyncHooks mutex.
func (h *testSyncHooks) lock() {
select {
case <-h.active:
case <-h.inactive:
}
}
// waitInactive waits for all goroutines to become inactive.
func (h *testSyncHooks) waitInactive() {
for {
<-h.inactive
if !h.unlock() {
break
}
}
}
// unlock releases the testSyncHooks mutex.
// It reports whether any goroutines are active.
func (h *testSyncHooks) unlock() (active bool) {
// Look for a blocked goroutine which can be unblocked.
blocked := h.blocked[:0]
unblocked := false
for _, b := range h.blocked {
if !unblocked && b.f() {
unblocked = true
close(b.ch)
} else {
blocked = append(blocked, b)
}
}
h.blocked = blocked
// Count goroutines blocked on condition variables.
condwait := 0
for _, count := range h.condwait {
condwait += count
}
if h.total > condwait+len(blocked) {
h.active <- struct{}{}
return true
} else {
h.inactive <- struct{}{}
return false
}
}
// goRun starts a new goroutine.
func (h *testSyncHooks) goRun(f func()) {
h.lock()
h.total++
h.unlock()
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
f()
}()
}
// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
// It waits until f returns true before proceeding.
//
// Example usage:
//
// h.blockUntil(func() bool {
// // Is the context done yet?
// select {
// case <-ctx.Done():
// default:
// return false
// }
// return true
// })
// // Wait for the context to become done.
// <-ctx.Done()
//
// The function f passed to blockUntil must be non-blocking and idempotent.
func (h *testSyncHooks) blockUntil(f func() bool) {
if f() {
return
}
ch := make(chan struct{})
h.lock()
h.blocked = append(h.blocked, &testBlockedGoroutine{
f: f,
ch: ch,
})
h.unlock()
<-ch
}
// broadcast is sync.Cond.Broadcast.
func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
h.lock()
delete(h.condwait, cond)
h.unlock()
cond.Broadcast()
}
// broadcast is sync.Cond.Wait.
func (h *testSyncHooks) condWait(cond *sync.Cond) {
h.lock()
h.condwait[cond]++
h.unlock()
}
// newTimer creates a new fake timer.
func (h *testSyncHooks) newTimer(d time.Duration) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
c: make(chan time.Time),
}
h.timers = append(h.timers, t)
return t
}
// afterFunc creates a new fake AfterFunc timer.
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
f: f,
}
h.timers = append(h.timers, t)
return t
}
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
t := h.afterFunc(d, cancel)
return ctx, func() {
t.Stop()
cancel()
}
}
func (h *testSyncHooks) timeUntilEvent() time.Duration {
h.lock()
defer h.unlock()
var next time.Time
for _, t := range h.timers {
if next.IsZero() || t.when.Before(next) {
next = t.when
}
}
if d := next.Sub(h.now); d > 0 {
return d
}
return 0
}
// advance advances time and causes synthetic timers to fire.
func (h *testSyncHooks) advance(d time.Duration) {
h.lock()
defer h.unlock()
h.now = h.now.Add(d)
timers := h.timers[:0]
for _, t := range h.timers {
t := t // remove after go.mod depends on go1.22
t.mu.Lock()
switch {
case t.when.After(h.now):
timers = append(timers, t)
case t.when.IsZero():
// stopped timer
default:
t.when = time.Time{}
if t.c != nil {
close(t.c)
}
if t.f != nil {
h.total++
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
t.f()
}()
}
}
t.mu.Unlock()
}
h.timers = timers
}
// A timer wraps a time.Timer, or a synthetic equivalent in tests.
// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
type timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// timeTimer implements timer using real time.
type timeTimer struct {
t *time.Timer
c chan time.Time
}
// newTimeTimer creates a new timer using real time.
func newTimeTimer(d time.Duration) timer {
ch := make(chan time.Time)
t := time.AfterFunc(d, func() {
close(ch)
})
return &timeTimer{t, ch}
}
// newTimeAfterFunc creates an AfterFunc timer using real time.
func newTimeAfterFunc(d time.Duration, f func()) timer {
return &timeTimer{
t: time.AfterFunc(d, f),
}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
// fakeTimer implements timer using fake time.
type fakeTimer struct {
hooks *testSyncHooks
mu sync.Mutex
when time.Time // when the timer will fire
c chan time.Time // closed when the timer fires; mutually exclusive with f
f func() // called when the timer fires; mutually exclusive with c
}
func (t *fakeTimer) C() <-chan time.Time { return t.c }
func (t *fakeTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
stopped := t.when.IsZero()
t.when = time.Time{}
return stopped
}
func (t *fakeTimer) Reset(d time.Duration) bool {
if t.c != nil || t.f == nil {
panic("fakeTimer only supports Reset on AfterFunc timers")
}
t.mu.Lock()
defer t.mu.Unlock()
t.hooks.lock()
defer t.hooks.unlock()
active := !t.when.IsZero()
t.when = t.hooks.now.Add(d)
if !active {
t.hooks.timers = append(t.hooks.timers, t)
}
return active
}

View File

@ -147,6 +147,12 @@ type Transport struct {
// waiting for their turn. // waiting for their turn.
StrictMaxConcurrentStreams bool StrictMaxConcurrentStreams bool
// IdleConnTimeout is the maximum amount of time an idle
// (keep-alive) connection will remain idle before closing
// itself.
// Zero means no limit.
IdleConnTimeout time.Duration
// ReadIdleTimeout is the timeout after which a health check using ping // ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection. // frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if // Note that a ping response will is considered a received frame, so if
@ -178,6 +184,8 @@ type Transport struct {
connPoolOnce sync.Once connPoolOnce sync.Once
connPoolOrDef ClientConnPool // non-nil version of ConnPool connPoolOrDef ClientConnPool // non-nil version of ConnPool
syncHooks *testSyncHooks
} }
func (t *Transport) maxHeaderListSize() uint32 { func (t *Transport) maxHeaderListSize() uint32 {
@ -302,7 +310,7 @@ type ClientConn struct {
readerErr error // set before readerDone is closed readerErr error // set before readerDone is closed
idleTimeout time.Duration // or 0 for never idleTimeout time.Duration // or 0 for never
idleTimer *time.Timer idleTimer timer
mu sync.Mutex // guards following mu sync.Mutex // guards following
cond *sync.Cond // hold mu; broadcast on flow/closed changes cond *sync.Cond // hold mu; broadcast on flow/closed changes
@ -344,6 +352,60 @@ type ClientConn struct {
werr error // first write error that has occurred werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder henc *hpack.Encoder
syncHooks *testSyncHooks // can be nil
}
// Hook points used for testing.
// Outside of tests, cc.syncHooks is nil and these all have minimal implementations.
// Inside tests, see the testSyncHooks function docs.
// goRun starts a new goroutine.
func (cc *ClientConn) goRun(f func()) {
if cc.syncHooks != nil {
cc.syncHooks.goRun(f)
return
}
go f()
}
// condBroadcast is cc.cond.Broadcast.
func (cc *ClientConn) condBroadcast() {
if cc.syncHooks != nil {
cc.syncHooks.condBroadcast(cc.cond)
}
cc.cond.Broadcast()
}
// condWait is cc.cond.Wait.
func (cc *ClientConn) condWait() {
if cc.syncHooks != nil {
cc.syncHooks.condWait(cc.cond)
}
cc.cond.Wait()
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (cc *ClientConn) newTimer(d time.Duration) timer {
if cc.syncHooks != nil {
return cc.syncHooks.newTimer(d)
}
return newTimeTimer(d)
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
if cc.syncHooks != nil {
return cc.syncHooks.afterFunc(d, f)
}
return newTimeAfterFunc(d, f)
}
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
if cc.syncHooks != nil {
return cc.syncHooks.contextWithTimeout(ctx, d)
}
return context.WithTimeout(ctx, d)
} }
// clientStream is the state for a single HTTP/2 stream. One of these // clientStream is the state for a single HTTP/2 stream. One of these
@ -425,7 +487,7 @@ func (cs *clientStream) abortStreamLocked(err error) {
// TODO(dneil): Clean up tests where cs.cc.cond is nil. // TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil { if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control. // Wake up writeRequestBody if it is waiting on flow control.
cs.cc.cond.Broadcast() cs.cc.condBroadcast()
} }
} }
@ -435,7 +497,7 @@ func (cs *clientStream) abortRequestBodyWrite() {
defer cc.mu.Unlock() defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil { if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked() cs.closeReqBodyLocked()
cc.cond.Broadcast() cc.condBroadcast()
} }
} }
@ -445,10 +507,10 @@ func (cs *clientStream) closeReqBodyLocked() {
} }
cs.reqBodyClosed = make(chan struct{}) cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed reqBodyClosed := cs.reqBodyClosed
go func() { cs.cc.goRun(func() {
cs.reqBody.Close() cs.reqBody.Close()
close(reqBodyClosed) close(reqBodyClosed)
}() })
} }
type stickyErrWriter struct { type stickyErrWriter struct {
@ -537,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) {
return net.JoinHostPort(host, port) return net.JoinHostPort(host, port)
} }
var retryBackoffHook func(time.Duration) *time.Timer
func backoffNewTimer(d time.Duration) *time.Timer {
if retryBackoffHook != nil {
return retryBackoffHook(d)
}
return time.NewTimer(d)
}
// RoundTripOpt is like RoundTrip, but takes options. // RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
@ -573,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
backoff := float64(uint(1) << (uint(retry) - 1)) backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64()) backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff) d := time.Second * time.Duration(backoff)
timer := backoffNewTimer(d) var tm timer
if t.syncHooks != nil {
tm = t.syncHooks.newTimer(d)
t.syncHooks.blockUntil(func() bool {
select {
case <-tm.C():
case <-req.Context().Done():
default:
return false
}
return true
})
} else {
tm = newTimeTimer(d)
}
select { select {
case <-timer.C: case <-tm.C():
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue continue
case <-req.Context().Done(): case <-req.Context().Done():
timer.Stop() tm.Stop()
err = req.Context().Err() err = req.Context().Err()
} }
} }
@ -658,6 +725,9 @@ func canRetryError(err error) bool {
} }
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
if t.syncHooks != nil {
return t.newClientConn(nil, singleUse, t.syncHooks)
}
host, _, err := net.SplitHostPort(addr) host, _, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -666,7 +736,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b
if err != nil { if err != nil {
return nil, err return nil, err
} }
return t.newClientConn(tconn, singleUse) return t.newClientConn(tconn, singleUse, nil)
} }
func (t *Transport) newTLSConfig(host string) *tls.Config { func (t *Transport) newTLSConfig(host string) *tls.Config {
@ -732,10 +802,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 {
} }
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives()) return t.newClientConn(c, t.disableKeepAlives(), nil)
} }
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) {
cc := &ClientConn{ cc := &ClientConn{
t: t, t: t,
tconn: c, tconn: c,
@ -750,10 +820,15 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
wantSettingsAck: true, wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}), pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1), reqHeaderMu: make(chan struct{}, 1),
syncHooks: hooks,
}
if hooks != nil {
hooks.newclientconn(cc)
c = cc.tconn
} }
if d := t.idleConnTimeout(); d != 0 { if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d cc.idleTimeout = d
cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout)
} }
if VerboseLogs { if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@ -818,7 +893,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
return nil, cc.werr return nil, cc.werr
} }
go cc.readLoop() cc.goRun(cc.readLoop)
return cc, nil return cc, nil
} }
@ -826,7 +901,7 @@ func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout() pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will // We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received. // trigger the healthCheck again if there is no frame received.
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
defer cancel() defer cancel()
cc.vlogf("http2: Transport sending health check") cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx) err := cc.Ping(ctx)
@ -1018,7 +1093,7 @@ func (cc *ClientConn) forceCloseConn() {
if !ok { if !ok {
return return
} }
if nc := tlsUnderlyingConn(tc); nc != nil { if nc := tc.NetConn(); nc != nil {
nc.Close() nc.Close()
} }
} }
@ -1056,7 +1131,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
// Wait for all in-flight streams to complete or connection to close // Wait for all in-flight streams to complete or connection to close
done := make(chan struct{}) done := make(chan struct{})
cancelled := false // guarded by cc.mu cancelled := false // guarded by cc.mu
go func() { cc.goRun(func() {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
for { for {
@ -1068,9 +1143,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
if cancelled { if cancelled {
break break
} }
cc.cond.Wait() cc.condWait()
} }
}() })
shutdownEnterWaitStateHook() shutdownEnterWaitStateHook()
select { select {
case <-done: case <-done:
@ -1080,7 +1155,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
cc.mu.Lock() cc.mu.Lock()
// Free the goroutine above // Free the goroutine above
cancelled = true cancelled = true
cc.cond.Broadcast() cc.condBroadcast()
cc.mu.Unlock() cc.mu.Unlock()
return ctx.Err() return ctx.Err()
} }
@ -1118,7 +1193,7 @@ func (cc *ClientConn) closeForError(err error) {
for _, cs := range cc.streams { for _, cs := range cc.streams {
cs.abortStreamLocked(err) cs.abortStreamLocked(err)
} }
cc.cond.Broadcast() cc.condBroadcast()
cc.mu.Unlock() cc.mu.Unlock()
cc.closeConn() cc.closeConn()
} }
@ -1215,6 +1290,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() {
} }
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return cc.roundTrip(req, nil)
}
func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) {
ctx := req.Context() ctx := req.Context()
cs := &clientStream{ cs := &clientStream{
cc: cc, cc: cc,
@ -1229,9 +1308,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
respHeaderRecv: make(chan struct{}), respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}), donec: make(chan struct{}),
} }
go cs.doRequest(req) cc.goRun(func() {
cs.doRequest(req)
})
waitDone := func() error { waitDone := func() error {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.donec:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select { select {
case <-cs.donec: case <-cs.donec:
return nil return nil
@ -1292,7 +1385,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return err return err
} }
if streamf != nil {
streamf(cs)
}
for { for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select { select {
case <-cs.respHeaderRecv: case <-cs.respHeaderRecv:
return handleResponseHeaders() return handleResponseHeaders()
@ -1348,6 +1458,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
if cc.reqHeaderMu == nil { if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests panic("RoundTrip on uninitialized ClientConn") // for tests
} }
var newStreamHook func(*clientStream)
if cc.syncHooks != nil {
newStreamHook = cc.syncHooks.newstream
cc.syncHooks.blockUntil(func() bool {
select {
case cc.reqHeaderMu <- struct{}{}:
<-cc.reqHeaderMu
case <-cs.reqCancel:
case <-ctx.Done():
default:
return false
}
return true
})
}
select { select {
case cc.reqHeaderMu <- struct{}{}: case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel: case <-cs.reqCancel:
@ -1372,6 +1497,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
} }
cc.mu.Unlock() cc.mu.Unlock()
if newStreamHook != nil {
newStreamHook(cs)
}
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() && if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Accept-Encoding") == "" &&
@ -1452,15 +1581,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
var respHeaderTimer <-chan time.Time var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{} var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 { if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d) timer := cc.newTimer(d)
defer timer.Stop() defer timer.Stop()
respHeaderTimer = timer.C respHeaderTimer = timer.C()
respHeaderRecv = cs.respHeaderRecv respHeaderRecv = cs.respHeaderRecv
} }
// Wait until the peer half-closes its end of the stream, // Wait until the peer half-closes its end of the stream,
// or until the request is aborted (via context, error, or otherwise), // or until the request is aborted (via context, error, or otherwise),
// whichever comes first. // whichever comes first.
for { for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.peerClosed:
case <-respHeaderTimer:
case <-respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select { select {
case <-cs.peerClosed: case <-cs.peerClosed:
return nil return nil
@ -1609,7 +1753,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
return nil return nil
} }
cc.pendingRequests++ cc.pendingRequests++
cc.cond.Wait() cc.condWait()
cc.pendingRequests-- cc.pendingRequests--
select { select {
case <-cs.abort: case <-cs.abort:
@ -1871,10 +2015,26 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
cs.flow.take(take) cs.flow.take(take)
return take, nil return take, nil
} }
cc.cond.Wait() cc.condWait()
} }
} }
func validateHeaders(hdrs http.Header) string {
for k, vv := range hdrs {
if !httpguts.ValidHeaderFieldName(k) {
return fmt.Sprintf("name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error,
// because it may be sensitive.
return fmt.Sprintf("value for header %q", k)
}
}
}
return ""
}
var errNilRequestURL = errors.New("http2: Request.URI is nil") var errNilRequestURL = errors.New("http2: Request.URI is nil")
// requires cc.wmu be held. // requires cc.wmu be held.
@ -1912,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
} }
} }
// Check for any invalid headers and return an error before we // Check for any invalid headers+trailers and return an error before we
// potentially pollute our hpack state. (We want to be able to // potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests) // continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header { if err := validateHeaders(req.Header); err != "" {
if !httpguts.ValidHeaderFieldName(k) { return nil, fmt.Errorf("invalid HTTP header %s", err)
return nil, fmt.Errorf("invalid HTTP header name %q", k) }
} if err := validateHeaders(req.Trailer); err != "" {
for _, v := range vv { return nil, fmt.Errorf("invalid HTTP trailer %s", err)
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error, because it may be sensitive.
return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
}
}
} }
enumerateHeaders := func(f func(name, value string)) { enumerateHeaders := func(f func(name, value string)) {
@ -2143,7 +2298,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
} }
// Wake up writeRequestBody via clientStream.awaitFlowControl and // Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request. // wake up RoundTrip if there is a pending request.
cc.cond.Broadcast() cc.condBroadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
@ -2231,7 +2386,7 @@ func (rl *clientConnReadLoop) cleanup() {
cs.abortStreamLocked(err) cs.abortStreamLocked(err)
} }
} }
cc.cond.Broadcast() cc.condBroadcast()
cc.mu.Unlock() cc.mu.Unlock()
} }
@ -2266,10 +2421,9 @@ func (rl *clientConnReadLoop) run() error {
cc := rl.cc cc := rl.cc
gotSettings := false gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout readIdleTimeout := cc.t.ReadIdleTimeout
var t *time.Timer var t timer
if readIdleTimeout != 0 { if readIdleTimeout != 0 {
t = time.AfterFunc(readIdleTimeout, cc.healthCheck) t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
defer t.Stop()
} }
for { for {
f, err := cc.fr.ReadFrame() f, err := cc.fr.ReadFrame()
@ -2684,7 +2838,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
}) })
return nil return nil
} }
if !cs.firstByte { if !cs.pastHeaders {
cc.logf("protocol error: received DATA before a HEADERS frame") cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, StreamError{ rl.endStreamError(cs, StreamError{
StreamID: f.StreamID, StreamID: f.StreamID,
@ -2867,7 +3021,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
for _, cs := range cc.streams { for _, cs := range cc.streams {
cs.flow.add(delta) cs.flow.add(delta)
} }
cc.cond.Broadcast() cc.condBroadcast()
cc.initialWindowSize = s.Val cc.initialWindowSize = s.Val
case SettingHeaderTableSize: case SettingHeaderTableSize:
@ -2911,9 +3065,18 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
fl = &cs.flow fl = &cs.flow
} }
if !fl.add(int32(f.Increment)) { if !fl.add(int32(f.Increment)) {
// For stream, the sender sends RST_STREAM with an error code of FLOW_CONTROL_ERROR
if cs != nil {
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
Code: ErrCodeFlowControl,
})
return nil
}
return ConnectionError(ErrCodeFlowControl) return ConnectionError(ErrCodeFlowControl)
} }
cc.cond.Broadcast() cc.condBroadcast()
return nil return nil
} }
@ -2955,24 +3118,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
} }
cc.mu.Unlock() cc.mu.Unlock()
} }
errc := make(chan error, 1) var pingError error
go func() { errc := make(chan struct{})
cc.goRun(func() {
cc.wmu.Lock() cc.wmu.Lock()
defer cc.wmu.Unlock() defer cc.wmu.Unlock()
if err := cc.fr.WritePing(false, p); err != nil { if pingError = cc.fr.WritePing(false, p); pingError != nil {
errc <- err close(errc)
return return
} }
if err := cc.bw.Flush(); err != nil { if pingError = cc.bw.Flush(); pingError != nil {
errc <- err close(errc)
return return
} }
}() })
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-c:
case <-errc:
case <-ctx.Done():
case <-cc.readerDone:
default:
return false
}
return true
})
}
select { select {
case <-c: case <-c:
return nil return nil
case err := <-errc: case <-errc:
return err return pingError
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-cc.readerDone: case <-cc.readerDone:
@ -3141,9 +3318,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err
} }
func (t *Transport) idleConnTimeout() time.Duration { func (t *Transport) idleConnTimeout() time.Duration {
// to keep things backwards compatible, we use non-zero values of
// IdleConnTimeout, followed by using the IdleConnTimeout on the underlying
// http1 transport, followed by 0
if t.IdleConnTimeout != 0 {
return t.IdleConnTimeout
}
if t.t1 != nil { if t.t1 != nil {
return t.t1.IdleConnTimeout return t.t1.IdleConnTimeout
} }
return 0 return 0
} }
@ -3201,3 +3386,34 @@ func traceFirstResponseByte(trace *httptrace.ClientTrace) {
trace.GotFirstResponseByte() trace.GotFirstResponseByte()
} }
} }
func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
return trace != nil && trace.WroteHeaderField != nil
}
func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(k, []string{v})
}
}
func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
if trace != nil {
return trace.Got1xxResponse
}
return nil
}
// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
// connection.
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
cn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
return tlsCn, nil
}

View File

@ -5,7 +5,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.18 //go:build go1.18
// +build go1.18
package idna package idna

View File

@ -5,7 +5,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.10 //go:build go1.10
// +build go1.10
// Package idna implements IDNA2008 using the compatibility processing // Package idna implements IDNA2008 using the compatibility processing
// defined by UTS (Unicode Technical Standard) #46, which defines a standard to // defined by UTS (Unicode Technical Standard) #46, which defines a standard to

View File

@ -5,7 +5,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build !go1.10 //go:build !go1.10
// +build !go1.10
// Package idna implements IDNA2008 using the compatibility processing // Package idna implements IDNA2008 using the compatibility processing
// defined by UTS (Unicode Technical Standard) #46, which defines a standard to // defined by UTS (Unicode Technical Standard) #46, which defines a standard to

View File

@ -5,7 +5,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build !go1.18 //go:build !go1.18
// +build !go1.18
package idna package idna

View File

@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT. // Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.10 && !go1.13 //go:build go1.10 && !go1.13
// +build go1.10,!go1.13
package idna package idna

View File

@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT. // Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.13 && !go1.14 //go:build go1.13 && !go1.14
// +build go1.13,!go1.14
package idna package idna

View File

@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT. // Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.14 && !go1.16 //go:build go1.14 && !go1.16
// +build go1.14,!go1.16
package idna package idna

View File

@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT. // Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.16 && !go1.21 //go:build go1.16 && !go1.21
// +build go1.16,!go1.21
package idna package idna

View File

@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT. // Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.21 //go:build go1.21
// +build go1.21
package idna package idna

View File

@ -1,7 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT. // Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build !go1.10 //go:build !go1.10
// +build !go1.10
package idna package idna

View File

@ -5,7 +5,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build !go1.16 //go:build !go1.16
// +build !go1.16
package idna package idna

View File

@ -5,7 +5,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.16 //go:build go1.16
// +build go1.16
package idna package idna

View File

@ -1,30 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package unsafeheader contains header declarations for the Go runtime's
// slice and string implementations.
//
// This package allows x/sys to use types equivalent to
// reflect.SliceHeader and reflect.StringHeader without introducing
// a dependency on the (relatively heavy) "reflect" package.
package unsafeheader
import (
"unsafe"
)
// Slice is the runtime representation of a slice.
// It cannot be used safely or portably and its representation may change in a later release.
type Slice struct {
Data unsafe.Pointer
Len int
Cap int
}
// String is the runtime representation of a string.
// It cannot be used safely or portably and its representation may change in a later release.
type String struct {
Data unsafe.Pointer
Len int
}

View File

@ -2,9 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos) && go1.9 //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
// +build go1.9
package unix package unix

View File

@ -3,7 +3,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build gc //go:build gc
// +build gc
#include "textflag.h" #include "textflag.h"

View File

@ -3,8 +3,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (freebsd || netbsd || openbsd) && gc //go:build (freebsd || netbsd || openbsd) && gc
// +build freebsd netbsd openbsd
// +build gc
#include "textflag.h" #include "textflag.h"

View File

@ -3,8 +3,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && gc //go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && gc
// +build darwin dragonfly freebsd netbsd openbsd
// +build gc
#include "textflag.h" #include "textflag.h"

View File

@ -3,8 +3,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (freebsd || netbsd || openbsd) && gc //go:build (freebsd || netbsd || openbsd) && gc
// +build freebsd netbsd openbsd
// +build gc
#include "textflag.h" #include "textflag.h"

Some files were not shown because too many files have changed in this diff Show More