222 lines
7.9 KiB
Go
222 lines
7.9 KiB
Go
// SPDX-License-Identifier: MIT
|
|
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"github.com/labstack/echo/v4"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------------------------------------------
|
|
// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
|
|
// WARNING: Timeout middleware causes more problems than it solves.
|
|
// WARNING: This middleware should be first middleware as it messes with request Writer and could cause data race if
|
|
// it is in other position
|
|
//
|
|
// Depending on out requirements you could be better of setting timeout to context and
|
|
// check its deadline from handler.
|
|
//
|
|
// For example: create middleware to set timeout to context
|
|
// func RequestTimeout(timeout time.Duration) echo.MiddlewareFunc {
|
|
// return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
// return func(c echo.Context) error {
|
|
// timeoutCtx, cancel := context.WithTimeout(c.Request().Context(), timeout)
|
|
// c.SetRequest(c.Request().WithContext(timeoutCtx))
|
|
// defer cancel()
|
|
// return next(c)
|
|
// }
|
|
// }
|
|
//}
|
|
//
|
|
// Create handler that checks for context deadline and runs actual task in separate coroutine
|
|
// Note: separate coroutine may not be even if you do not want to process continue executing and
|
|
// just want to stop long-running handler to stop and you are using "context aware" methods (ala db queries with ctx)
|
|
// e.GET("/", func(c echo.Context) error {
|
|
//
|
|
// doneCh := make(chan error)
|
|
// go func(ctx context.Context) {
|
|
// doneCh <- myPossiblyLongRunningBackgroundTaskWithCtx(ctx)
|
|
// }(c.Request().Context())
|
|
//
|
|
// select { // wait for task to finish or context to timeout/cancelled
|
|
// case err := <-doneCh:
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// return c.String(http.StatusOK, "OK")
|
|
// case <-c.Request().Context().Done():
|
|
// if c.Request().Context().Err() == context.DeadlineExceeded {
|
|
// return c.String(http.StatusServiceUnavailable, "timeout")
|
|
// }
|
|
// return c.Request().Context().Err()
|
|
// }
|
|
//
|
|
// })
|
|
//
|
|
|
|
// TimeoutConfig defines the config for Timeout middleware.
|
|
type TimeoutConfig struct {
|
|
// Skipper defines a function to skip middleware.
|
|
Skipper Skipper
|
|
|
|
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
|
|
// It can be used to define a custom timeout error message
|
|
ErrorMessage string
|
|
|
|
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
|
|
// request timeouted and we already had sent the error code (503) and message response to the client.
|
|
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
|
|
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
|
|
OnTimeoutRouteErrorHandler func(err error, c echo.Context)
|
|
|
|
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
|
|
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
|
|
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
|
|
// difference over 500microseconds (0.5millisecond) response seems to be reliable
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// DefaultTimeoutConfig is the default Timeout middleware config.
|
|
var DefaultTimeoutConfig = TimeoutConfig{
|
|
Skipper: DefaultSkipper,
|
|
Timeout: 0,
|
|
ErrorMessage: "",
|
|
}
|
|
|
|
// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler
|
|
// call runs for longer than its time limit. NB: timeout does not stop handler execution.
|
|
func Timeout() echo.MiddlewareFunc {
|
|
return TimeoutWithConfig(DefaultTimeoutConfig)
|
|
}
|
|
|
|
// TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration.
|
|
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
|
|
mw, err := config.ToMiddleware()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return mw
|
|
}
|
|
|
|
// ToMiddleware converts Config to middleware or returns an error for invalid configuration
|
|
func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|
if config.Skipper == nil {
|
|
config.Skipper = DefaultTimeoutConfig.Skipper
|
|
}
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if config.Skipper(c) || config.Timeout == 0 {
|
|
return next(c)
|
|
}
|
|
|
|
errChan := make(chan error, 1)
|
|
handlerWrapper := echoHandlerFuncWrapper{
|
|
writer: &ignorableWriter{ResponseWriter: c.Response().Writer},
|
|
ctx: c,
|
|
handler: next,
|
|
errChan: errChan,
|
|
errHandler: config.OnTimeoutRouteErrorHandler,
|
|
}
|
|
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
|
|
handler.ServeHTTP(handlerWrapper.writer, c.Request())
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
return err
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
}, nil
|
|
}
|
|
|
|
type echoHandlerFuncWrapper struct {
|
|
writer *ignorableWriter
|
|
ctx echo.Context
|
|
handler echo.HandlerFunc
|
|
errHandler func(err error, c echo.Context)
|
|
errChan chan error
|
|
}
|
|
|
|
func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|
// replace echo.Context Request with the one provided by TimeoutHandler to let later middlewares/handler on the chain
|
|
// handle properly it's cancellation
|
|
t.ctx.SetRequest(r)
|
|
|
|
// replace writer with TimeoutHandler custom one. This will guarantee that
|
|
// `writes by h to its ResponseWriter will return ErrHandlerTimeout.`
|
|
originalWriter := t.ctx.Response().Writer
|
|
t.ctx.Response().Writer = rw
|
|
|
|
// in case of panic we restore original writer and call panic again
|
|
// so it could be handled with global middleware Recover()
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
t.ctx.Response().Writer = originalWriter
|
|
panic(err)
|
|
}
|
|
}()
|
|
|
|
err := t.handler(t.ctx)
|
|
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
|
|
if err != nil && t.errHandler != nil {
|
|
t.errHandler(err, t.ctx)
|
|
}
|
|
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
|
|
}
|
|
if err != nil {
|
|
// This is needed as `http.TimeoutHandler` will write status code by itself on error and after that our tries to write
|
|
// status code will not work anymore as Echo.Response thinks it has been already "committed" and further writes
|
|
// create errors in log about `superfluous response.WriteHeader call from`
|
|
t.writer.Ignore(true)
|
|
t.ctx.Response().Writer = originalWriter // make sure we restore writer before we signal original coroutine about the error
|
|
// we pass error from handler to middlewares up in handler chain to act on it if needed.
|
|
t.errChan <- err
|
|
return
|
|
}
|
|
// we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
|
|
// and should not anymore send additional headers/data
|
|
// so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
|
|
t.ctx.Response().Writer = originalWriter
|
|
}
|
|
|
|
// ignorableWriter is ResponseWriter implementations that allows us to mark writer to ignore further write calls. This
|
|
// is handy in cases when you do not have direct control of code being executed (3rd party middleware) but want to make
|
|
// sure that external code will not be able to write response to the client.
|
|
// Writer is coroutine safe for writes.
|
|
type ignorableWriter struct {
|
|
http.ResponseWriter
|
|
|
|
lock sync.Mutex
|
|
ignoreWrites bool
|
|
}
|
|
|
|
func (w *ignorableWriter) Ignore(ignore bool) {
|
|
w.lock.Lock()
|
|
w.ignoreWrites = ignore
|
|
w.lock.Unlock()
|
|
}
|
|
|
|
func (w *ignorableWriter) WriteHeader(code int) {
|
|
w.lock.Lock()
|
|
defer w.lock.Unlock()
|
|
if w.ignoreWrites {
|
|
return
|
|
}
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (w *ignorableWriter) Write(b []byte) (int, error) {
|
|
w.lock.Lock()
|
|
defer w.lock.Unlock()
|
|
if w.ignoreWrites {
|
|
return len(b), nil
|
|
}
|
|
return w.ResponseWriter.Write(b)
|
|
}
|