Compare commits

..

No commits in common. "master" and "0.2.7" have entirely different histories.

1044 changed files with 38612 additions and 133458 deletions

View File

@ -1,70 +1,89 @@
---
kind: pipeline
type: docker
name: build-linux
environment:
GOOS: linux
GOOPTIONS: -mod=vendor
SRCFILES: cmd/pki/*.go
PROJECTNAME: pki
name: cleanup-before
steps:
- name: build-linux-amd64
image: golang
- name: clean
image: alpine
commands:
- go build -o $PROJECTNAME $GOOPTIONS $SRCFILES
environment:
GOARCH: amd64
- rm -rf /build/*
volumes:
- name: build
path: /build
when:
event:
exclude:
- tag
- name: build-linux-arm64
image: golang
commands:
- go build -o $PROJECTNAME $GOOPTIONS $SRCFILES
environment:
GOARCH: arm64
when:
event:
exclude:
- tag
event: tag
volumes:
- name: build
host:
path: /tmp/pki/build
---
kind: pipeline
type: docker
name: gitea-release-linux
environment:
GOOS: linux
GOOPTIONS: -mod=vendor
SRCFILES: cmd/pki/*.go
PROJECTNAME: pki
name: default-linux-amd64
steps:
- name: build-linux-amd64
- name: build
image: golang
commands:
- go build -o $PROJECTNAME $GOOPTIONS $SRCFILES
- tar -czvf $PROJECTNAME-$DRONE_TAG-$GOOS-$GOARCH.tar.gz $PROJECTNAME
- echo $PROJECTNAME $DRONE_TAG > VERSION
- ./ci-build.sh build
environment:
GOOS: linux
GOARCH: amd64
when:
event:
- tag
- name: build-linux-arm64
volumes:
- name: build
path: /build
volumes:
- name: build
host:
path: /tmp/pki/build
depends_on:
- cleanup-before
---
kind: pipeline
type: docker
name: default-linux-arm64
steps:
- name: build
image: golang
commands:
- go build -o $PROJECTNAME $GOOPTIONS $SRCFILES
- tar -czvf $PROJECTNAME-$DRONE_TAG-$GOOS-$GOARCH.tar.gz $PROJECTNAME
- echo $PROJECTNAME $DRONE_TAG > VERSION
- ./ci-build.sh build
environment:
GOOS: linux
GOARCH: arm64
volumes:
- name: build
path: /build
volumes:
- name: build
host:
path: /tmp/pki/build
depends_on:
- cleanup-before
---
kind: pipeline
type: docker
name: gitea-release
steps:
- name: move
image: alpine
commands:
- mv build/* ./
volumes:
- name: build
path: /drone/src/build
when:
event:
- tag
event: tag
- name: release
image: plugins/gitea-release
settings:
@ -76,6 +95,50 @@ steps:
- sha256
- sha512
title: VERSION
volumes:
- name: build
path: /drone/src/build
when:
event:
- tag
event: tag
- name: ls
image: alpine
commands:
- find .
volumes:
- name: build
path: /drone/src/build
when:
event: tag
volumes:
- name: build
host:
path: /tmp/pki/build
depends_on:
- default-linux-amd64
- default-linux-arm64
---
kind: pipeline
type: docker
name: cleanup-after
steps:
- name: clean
image: alpine
commands:
- rm -rf /build/*
volumes:
- name: build
path: /build
when:
event: tag
volumes:
- name: build
host:
path: /tmp/pki/build
depends_on:
- gitea-release

18
Makefile Normal file
View File

@ -0,0 +1,18 @@
# pki Makefile
GOCMD=go
GOBUILDCMD=${GOCMD} build
GOOPTIONS=-mod=vendor -ldflags="-s -w"
RMCMD=rm
BINNAME=pki
SRCFILES=cmd/pki/*.go
all: build
build:
${GOBUILDCMD} ${GOOPTIONS} ${SRCFILES}
clean:
${RMCMD} -f ${BINNAME}

View File

@ -10,7 +10,7 @@ PKI is a centralized Letsencrypt database server and renewer for certificate man
### Build
```bash
go build cmd/pki/pki.go
make
```
### Sample config in pki.ini
@ -40,7 +40,7 @@ ovhck=
## License
```text
Copyright (c) 2020, 2021, 2022 PaulBSD
Copyright (c) 2020, 2021 PaulBSD
All rights reserved.
Redistribution and use in source and binary forms, with or without

62
ci-build.sh Executable file
View File

@ -0,0 +1,62 @@
#!/bin/bash
set -e
PROJECTNAME=pki
RELEASENAME=${PROJECTNAME}
VERSION="0"
GOOPTIONS="-mod=vendor"
SRCFILES=cmd/${PROJECTNAME}/*.go
build() {
echo "Begin of build"
if [[ ! -z $DRONE_TAG ]]
then
echo "Drone tag set, let's do a release"
VERSION=$DRONE_TAG
echo "${PROJECTNAME} ${VERSION}" > /build/VERSION
elif [[ ! -z $DRONE_TAG ]]
then
echo "Drone not set, let's only do a build"
VERSION=$DRONE_COMMIT
fi
if [[ ! -z $VERSION && ! -z $GOOS && ! -z $GOARCH ]]
then
echo "Let's set a release name"
RELEASENAME=${PROJECTNAME}-${VERSION}-${GOOS}-${GOARCH}
fi
echo "Building project"
go build -o ${PROJECTNAME} ${GOOPTIONS} ${SRCFILES}
if [[ ! -z $DRONE_TAG ]]
then
echo "Let's make archives"
mkdir -p /build
tar -czvf /build/${RELEASENAME}.tar.gz ${PROJECTNAME}
fi
echo "Removing binary file"
rm ${PROJECTNAME}
echo "End of build"
}
clean() {
rm -rf $RELEASEDIR
}
case $1 in
"build")
build
;;
"clean")
clean
;;
*)
echo "No options choosen"
exit 1
;;
esac

53
go.mod
View File

@ -1,42 +1,41 @@
module git.paulbsd.com/paulbsd/pki
go 1.23
go 1.17
require (
github.com/go-acme/lego/v4 v4.17.4
github.com/go-acme/lego/v4 v4.4.0
github.com/golang/snappy v0.0.4 // indirect
github.com/labstack/echo/v4 v4.12.0
github.com/lib/pq v1.10.9
github.com/miekg/dns v1.1.62 // indirect
github.com/google/go-cmp v0.5.5 // indirect
github.com/gopherjs/gopherjs v0.0.0-20210406100015-1e088ea4ee04 // indirect
github.com/labstack/echo/v4 v4.5.0
github.com/lib/pq v1.10.3
github.com/miekg/dns v1.1.43 // indirect
github.com/onsi/ginkgo v1.16.0 // indirect
github.com/onsi/gomega v1.11.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sys v0.24.0 // indirect
golang.org/x/text v0.17.0 // indirect
golang.org/x/time v0.6.0 // indirect
gopkg.in/ini.v1 v1.67.0
xorm.io/builder v0.3.13 // indirect
xorm.io/xorm v1.3.9
github.com/smartystreets/assertions v1.2.0 // indirect
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 // indirect
golang.org/x/net v0.0.0-20210903162142-ad29c8ab022f // indirect
golang.org/x/sys v0.0.0-20210903071746-97244b99971b // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
gopkg.in/ini.v1 v1.62.1
xorm.io/builder v0.3.9 // indirect
xorm.io/xorm v1.2.3
)
require (
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.4 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/cenkalti/backoff/v4 v4.1.1 // indirect
github.com/goccy/go-json v0.7.8 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/labstack/gommon v0.4.2 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/json-iterator/go v1.1.11 // indirect
github.com/labstack/gommon v0.3.0 // indirect
github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-isatty v0.0.13 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/ovh/go-ovh v1.6.0 // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/ovh/go-ovh v1.1.0 // indirect
github.com/syndtr/goleveldb v1.0.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/oauth2 v0.22.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/tools v0.24.0 // indirect
github.com/valyala/fasttemplate v1.2.1 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
)

1007
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -2,13 +2,10 @@ package cert
import "time"
func (e *Entry) Z() {
}
// Entry is the main struct for stored certificates
type Entry struct {
ID int `xorm:"pk autoincr"`
Domain string `xorm:"notnull"`
Domains string `xorm:"notnull"`
Certificate string `xorm:"text notnull"`
PrivateKey string `xorm:"text notnull"`
AuthURL string `xorm:"notnull"`

View File

@ -60,13 +60,10 @@ func (cfg *Config) GetConfig() error {
options["ovhas"] = pkisection.Key("ovhas").MustString("")
options["ovhck"] = pkisection.Key("ovhck").MustString("")
options["pdnsapiurl"] = pkisection.Key("pdnsapiurl").MustString("")
options["pdnsapikey"] = pkisection.Key("pdnsapikey").MustString("")
cfg.ACME.ProviderOptions = options
for key, value := range options {
if value == "" {
utils.Advice(fmt.Sprintf("Provider parameter %s not set", key))
for k, v := range options {
if v == "" {
utils.Advice(fmt.Sprintf("OVH provider parameter %s not set", k))
}
}
@ -75,8 +72,6 @@ func (cfg *Config) GetConfig() error {
cfg.ACME.AuthURL = lego.LEDirectoryProduction
case "staging":
cfg.ACME.AuthURL = lego.LEDirectoryStaging
default:
cfg.ACME.AuthURL = lego.LEDirectoryStaging
}
return nil

View File

@ -7,7 +7,6 @@ import (
"git.paulbsd.com/paulbsd/pki/src/cert"
"git.paulbsd.com/paulbsd/pki/src/config"
"git.paulbsd.com/paulbsd/pki/src/domain"
"git.paulbsd.com/paulbsd/pki/src/pki"
_ "github.com/lib/pq"
"xorm.io/xorm"
@ -18,7 +17,7 @@ import (
func Init(cfg *config.Config) (err error) {
var databaseEngine = "postgres"
tables := []interface{}{cert.Entry{},
pki.User{}, domain.Domain{}}
pki.User{}}
cfg.Db, err = xorm.NewEngine(databaseEngine,
fmt.Sprintf("%s://%s:%s@%s/%s",

View File

@ -1,12 +0,0 @@
package domain
import "time"
// Domain describes a domain
type Domain struct {
ID int `xorm:"pk autoincr"`
Domain string `xorm:"text notnull unique(domain_provider)"`
Provider string `xorm:"text notnull unique(domain_provider)"`
Created time.Time `xorm:"created notnull"`
Updated time.Time `xorm:"updated notnull"`
}

View File

@ -9,13 +9,12 @@ import (
"encoding/pem"
"fmt"
"log"
"strings"
"git.paulbsd.com/paulbsd/pki/src/cert"
"git.paulbsd.com/paulbsd/pki/src/config"
"git.paulbsd.com/paulbsd/pki/src/domain"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
)
@ -30,12 +29,15 @@ func (u *User) Init(cfg *config.Config) (err error) {
}
// GetEntry returns requested acme ressource in database relative to domain
func (u *User) GetEntry(cfg *config.Config, domain *string) (Entry cert.Entry, err error) {
has, err := cfg.Db.Where("domain = ?", domain).And(
func (u *User) GetEntry(cfg *config.Config, domains []string) (Entry cert.Entry, err error) {
has, err := cfg.Db.Where("domains = ?", strings.Join(domains, ",")).And(
"auth_url = ?", cfg.ACME.AuthURL).And(
fmt.Sprintf("validity_end::timestamp-'%d DAY'::INTERVAL >= now()", cfg.ACME.MaxDaysBefore)).Desc(
"id").Get(&Entry)
fmt.Println(has, err)
if !has {
err = fmt.Errorf("entry doesn't exists")
}
@ -66,38 +68,12 @@ func (u *User) HandleRegistration(cfg *config.Config, client *lego.Client) (err
}
// RequestNewCert returns a newly requested certificate to letsencrypt
func (u *User) RequestNewCert(cfg *config.Config, domainnames *[]string) (certs *certificate.Resource, err error) {
func (u *User) RequestNewCert(cfg *config.Config, domains []string) (certificates *certificate.Resource, err error) {
legoconfig := lego.NewConfig(u)
legoconfig.CADirURL = cfg.ACME.AuthURL
legoconfig.Certificate.KeyType = certcrypto.RSA2048
var dom domain.Domain
var has bool
for _, d := range *domainnames {
dom = domain.Domain{Domain: d}
if has, err = cfg.Db.Get(&dom); has {
break
}
if err != nil {
log.Println(err)
}
}
if !has {
err = fmt.Errorf("supplied domain not in allowed domains")
return
}
var provider challenge.Provider
switch dom.Provider {
case "ovh":
provider, err = initOVHProvider(cfg)
case "pdns":
provider, err = initPowerDNSProvider(cfg)
default:
return
}
ovhprovider, err := initProvider(cfg)
if err != nil {
log.Println(err)
}
@ -107,7 +83,7 @@ func (u *User) RequestNewCert(cfg *config.Config, domainnames *[]string) (certs
log.Println(err)
}
err = client.Challenge.SetDNS01Provider(provider)
err = client.Challenge.SetDNS01Provider(ovhprovider)
if err != nil {
log.Println(err)
}
@ -121,15 +97,14 @@ func (u *User) RequestNewCert(cfg *config.Config, domainnames *[]string) (certs
}
request := certificate.ObtainRequest{
Domains: *domainnames,
Domains: domains,
Bundle: true,
}
certs, err = client.Certificate.Obtain(request)
certificates, err = client.Certificate.Obtain(request)
if err != nil {
log.Println(err)
}
return
}

View File

@ -1,16 +1,12 @@
package pki
import (
"log"
"net/url"
"git.paulbsd.com/paulbsd/pki/src/config"
"github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/go-acme/lego/v4/providers/dns/pdns"
)
// initOVHProvider initialize DNS provider configuration
func initOVHProvider(cfg *config.Config) (ovhprovider *ovh.DNSProvider, err error) {
// initProvider initialize DNS provider configuration
func initProvider(cfg *config.Config) (ovhprovider *ovh.DNSProvider, err error) {
ovhconfig := ovh.NewDefaultConfig()
ovhconfig.APIEndpoint = cfg.ACME.ProviderOptions["ovhendpoint"]
@ -19,24 +15,6 @@ func initOVHProvider(cfg *config.Config) (ovhprovider *ovh.DNSProvider, err erro
ovhconfig.ConsumerKey = cfg.ACME.ProviderOptions["ovhck"]
ovhprovider, err = ovh.NewDNSProviderConfig(ovhconfig)
if err != nil {
log.Println(err)
}
return
}
// initPowerDNSProvider initialize DNS provider configuration
func initPowerDNSProvider(cfg *config.Config) (pdnsprovider *pdns.DNSProvider, err error) {
pdnsconfig := pdns.NewDefaultConfig()
pdnsconfig.Host, err = url.Parse(cfg.ACME.ProviderOptions["pdnsapiurl"])
pdnsconfig.APIKey = cfg.ACME.ProviderOptions["pdnsapikey"]
pdnsprovider, err = pdns.NewDNSProviderConfig(pdnsconfig)
if err != nil {
log.Println(err)
}
return
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"log"
"net/http"
"strings"
"git.paulbsd.com/paulbsd/pki/src/config"
"git.paulbsd.com/paulbsd/pki/src/pki"
@ -29,18 +30,13 @@ func RunServer(cfg *config.Config) (err error) {
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "Welcome to PKI software (https://git.paulbsd.com/paulbsd/pki)")
})
e.POST("/cert", func(c echo.Context) (err error) {
var request = new(EntryRequest)
var result = make(map[string]EntryResponse)
err = c.Bind(&request)
if err != nil {
log.Println(err)
return c.JSON(http.StatusInternalServerError, "error parsing request")
}
e.GET("/domain/:domains", func(c echo.Context) (err error) {
var result EntryResponse
var domains = strings.Split(c.Param("domains"), ",")
log.Printf("Providing %s to user %s at %s\n", request.Domains, c.Get("username"), c.RealIP())
log.Println(fmt.Sprintf("Providing %s to user %s at %s", domains, c.Get("username"), c.RealIP()))
result, err = GetCertificate(cfg, c.Get("user").(*pki.User), &request.Domains)
result, err = GetCertificate(cfg, c.Get("user").(*pki.User), domains)
if err != nil {
return c.String(http.StatusInternalServerError, fmt.Sprintf("%s", err))
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"log"
"regexp"
"strings"
"time"
"git.paulbsd.com/paulbsd/pki/src/cert"
@ -13,24 +14,18 @@ import (
"git.paulbsd.com/paulbsd/pki/src/pki"
)
const timeformatstring string = "2006-01-02 15:04:05"
var domainRegex, err = regexp.Compile(`^[a-z0-9\*]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,6}$`)
// GetCertificate get certificate from database if exists, of request it from ACME
func GetCertificate(cfg *config.Config, user *pki.User, domains *[]string) (result map[string]EntryResponse, err error) {
func GetCertificate(cfg *config.Config, user *pki.User, domains []string) (result EntryResponse, err error) {
err = CheckDomains(domains)
if err != nil {
return result, err
}
result = make(map[string]EntryResponse)
firstdomain := (*domains)[0]
entry, err := user.GetEntry(cfg, &firstdomain)
entry, err := user.GetEntry(cfg, domains)
if err != nil {
certs, err := user.RequestNewCert(cfg, domains)
if err != nil {
log.Printf("Error fetching new certificate %s\n", err)
log.Println(fmt.Sprintf("Error fetching new certificate %s", err))
return result, err
}
NotBefore, NotAfter, err := GetDates(certs.Certificate)
@ -38,36 +33,33 @@ func GetCertificate(cfg *config.Config, user *pki.User, domains *[]string) (resu
log.Println("Error where parsing dates")
return result, err
}
entry := cert.Entry{Domain: certs.Domain,
entry := cert.Entry{Domains: strings.Join(domains, ","),
Certificate: string(certs.Certificate),
PrivateKey: string(certs.PrivateKey),
ValidityBegin: NotBefore,
ValidityEnd: NotAfter,
AuthURL: cfg.ACME.AuthURL}
cfg.Db.Insert(&entry)
result[firstdomain] = convertEntryToResponse(entry)
result = convertEntryToResponse(entry)
return result, err
}
result[firstdomain] = convertEntryToResponse(entry)
result = convertEntryToResponse(entry)
return
}
// CheckDomains check if requested domains are valid
func CheckDomains(domains *[]string) (err error) {
for _, domain := range *domains {
err = CheckDomain(&domain)
func CheckDomains(domains []string) (err error) {
domainRegex, err := regexp.Compile(`^[a-z0-9\*]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,6}$`)
if err != nil {
return
}
}
return
}
// CheckDomain check if requested domain are valid
func CheckDomain(domain *string) (err error) {
res := domainRegex.Match([]byte(*domain))
for _, d := range domains {
res := domainRegex.Match([]byte(d))
if !res {
return fmt.Errorf("Domain %s has not a valid syntax %s, please verify", *domain, err)
return fmt.Errorf(fmt.Sprintf("Domain %s has not a valid syntax %s, please verify", d, err))
}
}
return
}
@ -88,7 +80,9 @@ func GetDates(cert []byte) (NotBefore time.Time, NotAfter time.Time, err error)
// convertEntryToResponse converts database ACME entry to JSON ACME entry
func convertEntryToResponse(in cert.Entry) (out EntryResponse) {
out.Domains = append(out.Domains, in.Domain)
timeformatstring := "2006-01-02 15:04:05"
out.Domains = in.Domains
out.Certificate = in.Certificate
out.PrivateKey = in.PrivateKey
out.ValidityBegin = in.ValidityBegin.Format(timeformatstring)
@ -97,14 +91,9 @@ func convertEntryToResponse(in cert.Entry) (out EntryResponse) {
return
}
// EntryRequest
type EntryRequest struct {
Domains []string `json:"domains"`
}
// EntryResponse is the struct defining JSON response from webservice
type EntryResponse struct {
Domains []string `json:"domains"`
Domains string `json:"domains"`
Certificate string `json:"certificate"`
PrivateKey string `json:"privatekey"`
ValidityBegin string `json:"validitybegin"`

10
vendor/github.com/cenkalti/backoff/v4/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,10 @@
language: go
go:
- 1.13
- 1.x
- tip
before_install:
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
script:
- $HOME/gopath/bin/goveralls -service=travis-ci

View File

@ -1,4 +1,4 @@
# Exponential Backoff [![GoDoc][godoc image]][godoc] [![Coverage Status][coveralls image]][coveralls]
# Exponential Backoff [![GoDoc][godoc image]][godoc] [![Build Status][travis image]][travis] [![Coverage Status][coveralls image]][coveralls]
This is a Go port of the exponential backoff algorithm from [Google's HTTP Client Library for Java][google-http-java-client].
@ -21,6 +21,8 @@ Use https://pkg.go.dev/github.com/cenkalti/backoff/v4 to view the documentation.
[godoc]: https://pkg.go.dev/github.com/cenkalti/backoff/v4
[godoc image]: https://godoc.org/github.com/cenkalti/backoff?status.png
[travis]: https://travis-ci.org/cenkalti/backoff
[travis image]: https://travis-ci.org/cenkalti/backoff.png?branch=master
[coveralls]: https://coveralls.io/github/cenkalti/backoff?branch=master
[coveralls image]: https://coveralls.io/repos/github/cenkalti/backoff/badge.svg?branch=master

View File

@ -71,9 +71,6 @@ type Clock interface {
Now() time.Time
}
// ExponentialBackOffOpts is a function type used to configure ExponentialBackOff options.
type ExponentialBackOffOpts func(*ExponentialBackOff)
// Default values for ExponentialBackOff.
const (
DefaultInitialInterval = 500 * time.Millisecond
@ -84,7 +81,7 @@ const (
)
// NewExponentialBackOff creates an instance of ExponentialBackOff using default values.
func NewExponentialBackOff(opts ...ExponentialBackOffOpts) *ExponentialBackOff {
func NewExponentialBackOff() *ExponentialBackOff {
b := &ExponentialBackOff{
InitialInterval: DefaultInitialInterval,
RandomizationFactor: DefaultRandomizationFactor,
@ -94,62 +91,10 @@ func NewExponentialBackOff(opts ...ExponentialBackOffOpts) *ExponentialBackOff {
Stop: Stop,
Clock: SystemClock,
}
for _, fn := range opts {
fn(b)
}
b.Reset()
return b
}
// WithInitialInterval sets the initial interval between retries.
func WithInitialInterval(duration time.Duration) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.InitialInterval = duration
}
}
// WithRandomizationFactor sets the randomization factor to add jitter to intervals.
func WithRandomizationFactor(randomizationFactor float64) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.RandomizationFactor = randomizationFactor
}
}
// WithMultiplier sets the multiplier for increasing the interval after each retry.
func WithMultiplier(multiplier float64) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.Multiplier = multiplier
}
}
// WithMaxInterval sets the maximum interval between retries.
func WithMaxInterval(duration time.Duration) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.MaxInterval = duration
}
}
// WithMaxElapsedTime sets the maximum total time for retries.
func WithMaxElapsedTime(duration time.Duration) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.MaxElapsedTime = duration
}
}
// WithRetryStopDuration sets the duration after which retries should stop.
func WithRetryStopDuration(duration time.Duration) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.Stop = duration
}
}
// WithClockProvider sets the clock used to measure time.
func WithClockProvider(clock Clock) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.Clock = clock
}
}
type systemClock struct{}
func (t systemClock) Now() time.Time {
@ -202,9 +147,6 @@ func (b *ExponentialBackOff) incrementCurrentInterval() {
// Returns a random value from the following interval:
// [currentInterval - randomizationFactor * currentInterval, currentInterval + randomizationFactor * currentInterval].
func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration {
if randomizationFactor == 0 {
return currentInterval // make sure no randomness is used when randomizationFactor is 0.
}
var delta = randomizationFactor * float64(currentInterval)
var minInterval = float64(currentInterval) - delta
var maxInterval = float64(currentInterval) + delta

View File

@ -5,20 +5,10 @@ import (
"time"
)
// An OperationWithData is executing by RetryWithData() or RetryNotifyWithData().
// The operation will be retried using a backoff policy if it returns an error.
type OperationWithData[T any] func() (T, error)
// An Operation is executing by Retry() or RetryNotify().
// The operation will be retried using a backoff policy if it returns an error.
type Operation func() error
func (o Operation) withEmptyData() OperationWithData[struct{}] {
return func() (struct{}, error) {
return struct{}{}, o()
}
}
// Notify is a notify-on-error function. It receives an operation error and
// backoff delay if the operation failed (with an error).
//
@ -38,41 +28,18 @@ func Retry(o Operation, b BackOff) error {
return RetryNotify(o, b, nil)
}
// RetryWithData is like Retry but returns data in the response too.
func RetryWithData[T any](o OperationWithData[T], b BackOff) (T, error) {
return RetryNotifyWithData(o, b, nil)
}
// RetryNotify calls notify function with the error and wait duration
// for each failed attempt before sleep.
func RetryNotify(operation Operation, b BackOff, notify Notify) error {
return RetryNotifyWithTimer(operation, b, notify, nil)
}
// RetryNotifyWithData is like RetryNotify but returns data in the response too.
func RetryNotifyWithData[T any](operation OperationWithData[T], b BackOff, notify Notify) (T, error) {
return doRetryNotify(operation, b, notify, nil)
}
// RetryNotifyWithTimer calls notify function with the error and wait duration using the given Timer
// for each failed attempt before sleep.
// A default timer that uses system timer is used when nil is passed.
func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer) error {
_, err := doRetryNotify(operation.withEmptyData(), b, notify, t)
return err
}
// RetryNotifyWithTimerAndData is like RetryNotifyWithTimer but returns data in the response too.
func RetryNotifyWithTimerAndData[T any](operation OperationWithData[T], b BackOff, notify Notify, t Timer) (T, error) {
return doRetryNotify(operation, b, notify, t)
}
func doRetryNotify[T any](operation OperationWithData[T], b BackOff, notify Notify, t Timer) (T, error) {
var (
err error
next time.Duration
res T
)
var err error
var next time.Duration
if t == nil {
t = &defaultTimer{}
}
@ -85,22 +52,21 @@ func doRetryNotify[T any](operation OperationWithData[T], b BackOff, notify Noti
b.Reset()
for {
res, err = operation()
if err == nil {
return res, nil
if err = operation(); err == nil {
return nil
}
var permanent *PermanentError
if errors.As(err, &permanent) {
return res, permanent.Err
return permanent.Err
}
if next = b.NextBackOff(); next == Stop {
if cerr := ctx.Err(); cerr != nil {
return res, cerr
return cerr
}
return res, err
return err
}
if notify != nil {
@ -111,7 +77,7 @@ func doRetryNotify[T any](operation OperationWithData[T], b BackOff, notify Noti
select {
case <-ctx.Done():
return res, ctx.Err()
return ctx.Err()
case <-t.C():
}
}

View File

@ -16,7 +16,7 @@ func (a *AccountService) New(req acme.Account) (acme.ExtendedAccount, error) {
resp, err := a.core.post(a.core.GetDirectory().NewAccountURL, req, &account)
location := getLocation(resp)
if location != "" {
if len(location) > 0 {
a.core.jws.SetKid(location)
}

View File

@ -2,6 +2,7 @@ package api
import (
"bytes"
"context"
"crypto"
"encoding/json"
"errors"
@ -70,7 +71,7 @@ func (a *Core) post(uri string, reqBody, response interface{}) (*http.Response,
}
// postAsGet performs an HTTP POST ("POST-as-GET") request.
// https://www.rfc-editor.org/rfc/rfc8555.html#section-6.3
// https://tools.ietf.org/html/rfc8555#section-6.3
func (a *Core) postAsGet(uri string, response interface{}) (*http.Response, error) {
return a.retrievablePost(uri, []byte{}, response)
}
@ -82,6 +83,8 @@ func (a *Core) retrievablePost(uri string, content []byte, response interface{})
bo.MaxInterval = 5 * time.Second
bo.MaxElapsedTime = 20 * time.Second
ctx, cancel := context.WithCancel(context.Background())
var resp *http.Response
operation := func() error {
var err error
@ -93,7 +96,8 @@ func (a *Core) retrievablePost(uri string, content []byte, response interface{})
return err
}
return backoff.Permanent(err)
cancel()
return err
}
return nil
@ -103,7 +107,7 @@ func (a *Core) retrievablePost(uri string, content []byte, response interface{})
log.Infof("retry due to: %v", err)
}
err := backoff.RetryNotify(operation, bo, notify)
err := backoff.RetryNotify(operation, backoff.WithContext(bo, ctx), notify)
if err != nil {
return resp, err
}
@ -117,7 +121,7 @@ func (a *Core) signedPost(uri string, content []byte, response interface{}) (*ht
return nil, fmt.Errorf("failed to post JWS message: failed to sign content: %w", err)
}
signedBody := bytes.NewBufferString(signedContent.FullSerialize())
signedBody := bytes.NewBuffer([]byte(signedContent.FullSerialize()))
resp, err := a.doer.Post(uri, signedBody, "application/jose+json", response)

View File

@ -1,11 +1,10 @@
package api
import (
"bytes"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"io/ioutil"
"net/http"
"github.com/go-acme/lego/v4/acme"
@ -40,7 +39,7 @@ func (c *CertificateService) GetAll(certURL string, bundle bool) (map[string]*ac
certs := map[string]*acme.RawCertificate{certURL: cert}
// URLs of "alternate" link relation
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.4.2
// - https://tools.ietf.org/html/rfc8555#section-7.4.2
alts := getLinks(headers, "alternate")
for _, alt := range alts {
@ -72,7 +71,7 @@ func (c *CertificateService) get(certURL string, bundle bool) (*acme.RawCertific
return nil, nil, err
}
data, err := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
data, err := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if err != nil {
return nil, resp.Header, err
}
@ -88,17 +87,12 @@ func (c *CertificateService) getCertificateChain(cert []byte, headers http.Heade
// See https://community.letsencrypt.org/t/acme-v2-no-up-link-in-response/64962
_, issuer := pem.Decode(cert)
if issuer != nil {
// If bundle is false, we want to return a single certificate.
// To do this, we remove the issuer cert(s) from the issued cert.
if !bundle {
cert = bytes.TrimSuffix(cert, issuer)
}
return &acme.RawCertificate{Cert: cert, Issuer: issuer}
}
// The issuer certificate link may be supplied via an "up" link
// in the response headers of a new certificate.
// See https://www.rfc-editor.org/rfc/rfc8555.html#section-7.4.2
// See https://tools.ietf.org/html/rfc8555#section-7.4.2
up := getLink(headers, "up")
issuer, err := c.getIssuerFromLink(up)

View File

@ -63,7 +63,7 @@ func (n *Manager) getNonce() (string, error) {
return GetFromResponse(resp)
}
// GetFromResponse Extracts a nonce from an HTTP response.
// GetFromResponse Extracts a nonce from a HTTP response.
func GetFromResponse(resp *http.Response) (string, error) {
if resp == nil {
return "", errors.New("nil response")

View File

@ -9,7 +9,7 @@ import (
"fmt"
"github.com/go-acme/lego/v4/acme/api/internal/nonces"
jose "github.com/go-jose/go-jose/v4"
jose "gopkg.in/square/go-jose.v2"
)
// JWS Represents a JWS.

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"runtime"
"strings"
@ -95,7 +96,7 @@ func (d *Doer) do(req *http.Request, response interface{}) (*http.Response, erro
}
if response != nil {
raw, err := io.ReadAll(resp.Body)
raw, err := ioutil.ReadAll(resp.Body)
if err != nil {
return resp, err
}
@ -119,7 +120,7 @@ func (d *Doer) formatUserAgent() string {
func checkError(req *http.Request, resp *http.Response) error {
if resp.StatusCode >= http.StatusBadRequest {
body, err := io.ReadAll(resp.Body)
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("%d :: %s :: %s :: %w", resp.StatusCode, req.Method, req.URL, err)
}
@ -133,10 +134,6 @@ func checkError(req *http.Request, resp *http.Response) error {
errorDetails.Method = req.Method
errorDetails.URL = req.URL.String()
if errorDetails.HTTPStatus == 0 {
errorDetails.HTTPStatus = resp.StatusCode
}
// Check for errors we handle specifically
if errorDetails.HTTPStatus == http.StatusBadRequest && errorDetails.Type == acme.BadNonceErr {
return &acme.NonceError{ProblemDetails: errorDetails}

View File

@ -5,7 +5,7 @@ package sender
const (
// ourUserAgent is the User-Agent of this underlying library package.
ourUserAgent = "xenolf-acme/4.17.4"
ourUserAgent = "xenolf-acme/4.4.0"
// ourUserAgentComment is part of the UA comment linked to the version status of this underlying library package.
// values: detach|release

View File

@ -3,58 +3,21 @@ package api
import (
"encoding/base64"
"errors"
"net"
"time"
"github.com/go-acme/lego/v4/acme"
)
// OrderOptions used to create an order (optional).
type OrderOptions struct {
NotBefore time.Time
NotAfter time.Time
// A string uniquely identifying a previously-issued certificate which this
// order is intended to replace.
// - https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5
ReplacesCertID string
}
type OrderService service
// New Creates a new order.
func (o *OrderService) New(domains []string) (acme.ExtendedOrder, error) {
return o.NewWithOptions(domains, nil)
}
// NewWithOptions Creates a new order.
func (o *OrderService) NewWithOptions(domains []string, opts *OrderOptions) (acme.ExtendedOrder, error) {
var identifiers []acme.Identifier
for _, domain := range domains {
ident := acme.Identifier{Value: domain, Type: "dns"}
if net.ParseIP(domain) != nil {
ident.Type = "ip"
}
identifiers = append(identifiers, ident)
identifiers = append(identifiers, acme.Identifier{Type: "dns", Value: domain})
}
orderReq := acme.Order{Identifiers: identifiers}
if opts != nil {
if !opts.NotAfter.IsZero() {
orderReq.NotAfter = opts.NotAfter.Format(time.RFC3339)
}
if !opts.NotBefore.IsZero() {
orderReq.NotBefore = opts.NotBefore.Format(time.RFC3339)
}
if o.core.GetDirectory().RenewalInfo != "" {
orderReq.Replaces = opts.ReplacesCertID
}
}
var order acme.Order
resp, err := o.core.post(o.core.GetDirectory().NewOrderURL, orderReq, &order)
if err != nil {

View File

@ -1,28 +0,0 @@
package api
import (
"errors"
"net/http"
)
// ErrNoARI is returned when the server does not advertise a renewal info endpoint.
var ErrNoARI = errors.New("renewalInfo[get/post]: server does not advertise a renewal info endpoint")
// GetRenewalInfo GETs renewal information for a certificate from the renewalInfo endpoint.
// This is used to determine if a certificate needs to be renewed.
//
// Note: this endpoint is part of a draft specification, not all ACME servers will implement it.
// This method will return api.ErrNoARI if the server does not advertise a renewal info endpoint.
//
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari
func (c *CertificateService) GetRenewalInfo(certID string) (*http.Response, error) {
if c.core.GetDirectory().RenewalInfo == "" {
return nil, ErrNoARI
}
if certID == "" {
return nil, errors.New("renewalInfo[get]: 'certID' cannot be empty")
}
return c.core.HTTPClient.Get(c.core.GetDirectory().RenewalInfo + "/" + certID)
}

View File

@ -1,5 +1,5 @@
// Package acme contains all objects related the ACME endpoints.
// https://www.rfc-editor.org/rfc/rfc8555.html
// https://tools.ietf.org/html/rfc8555
package acme
import (
@ -7,38 +7,20 @@ import (
"time"
)
// ACME status values of Account, Order, Authorization and Challenge objects.
// See https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.6 for details.
// Challenge statuses.
// https://tools.ietf.org/html/rfc8555#section-7.1.6
const (
StatusPending = "pending"
StatusInvalid = "invalid"
StatusValid = "valid"
StatusProcessing = "processing"
StatusDeactivated = "deactivated"
StatusExpired = "expired"
StatusInvalid = "invalid"
StatusPending = "pending"
StatusProcessing = "processing"
StatusReady = "ready"
StatusRevoked = "revoked"
StatusUnknown = "unknown"
StatusValid = "valid"
)
// CRL reason codes as defined in RFC 5280.
// https://datatracker.ietf.org/doc/html/rfc5280#section-5.3.1
const (
CRLReasonUnspecified uint = 0
CRLReasonKeyCompromise uint = 1
CRLReasonCACompromise uint = 2
CRLReasonAffiliationChanged uint = 3
CRLReasonSuperseded uint = 4
CRLReasonCessationOfOperation uint = 5
CRLReasonCertificateHold uint = 6
CRLReasonRemoveFromCRL uint = 8
CRLReasonPrivilegeWithdrawn uint = 9
CRLReasonAACompromise uint = 10
)
// Directory the ACME directory object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.1
// - https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
// - https://tools.ietf.org/html/rfc8555#section-7.1.1
type Directory struct {
NewNonceURL string `json:"newNonce"`
NewAccountURL string `json:"newAccount"`
@ -47,11 +29,10 @@ type Directory struct {
RevokeCertURL string `json:"revokeCert"`
KeyChangeURL string `json:"keyChange"`
Meta Meta `json:"meta"`
RenewalInfo string `json:"renewalInfo"`
}
// Meta the ACME meta object (related to Directory).
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.1
// - https://tools.ietf.org/html/rfc8555#section-7.1.1
type Meta struct {
// termsOfService (optional, string):
// A URL identifying the current terms of service.
@ -76,7 +57,7 @@ type Meta struct {
ExternalAccountRequired bool `json:"externalAccountRequired"`
}
// ExtendedAccount an extended Account.
// ExtendedAccount a extended Account.
type ExtendedAccount struct {
Account
// Contains the value of the response header `Location`
@ -84,8 +65,8 @@ type ExtendedAccount struct {
}
// Account the ACME account Object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.2
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.3
// - https://tools.ietf.org/html/rfc8555#section-7.1.2
// - https://tools.ietf.org/html/rfc8555#section-7.3
type Account struct {
// status (required, string):
// The status of this account.
@ -131,7 +112,7 @@ type ExtendedOrder struct {
}
// Order the ACME order Object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.3
// - https://tools.ietf.org/html/rfc8555#section-7.1.3
type Order struct {
// status (required, string):
// The status of this order.
@ -181,16 +162,10 @@ type Order struct {
// certificate (optional, string):
// A URL for the certificate that has been issued in response to this order
Certificate string `json:"certificate,omitempty"`
// replaces (optional, string):
// replaces (string, optional): A string uniquely identifying a
// previously-issued certificate which this order is intended to replace.
// - https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5
Replaces string `json:"replaces,omitempty"`
}
// Authorization the ACME authorization object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.4
// - https://tools.ietf.org/html/rfc8555#section-7.1.4
type Authorization struct {
// status (required, string):
// The status of this authorization.
@ -232,8 +207,8 @@ type ExtendedChallenge struct {
}
// Challenge the ACME challenge object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.5
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-8
// - https://tools.ietf.org/html/rfc8555#section-7.1.5
// - https://tools.ietf.org/html/rfc8555#section-8
type Challenge struct {
// type (required, string):
// The type of challenge encoded in the object.
@ -266,23 +241,23 @@ type Challenge struct {
// It MUST NOT contain any characters outside the base64url alphabet,
// and MUST NOT include base64 padding characters ("=").
// See [RFC4086] for additional information on randomness requirements.
// https://www.rfc-editor.org/rfc/rfc8555.html#section-8.3
// https://www.rfc-editor.org/rfc/rfc8555.html#section-8.4
// https://tools.ietf.org/html/rfc8555#section-8.3
// https://tools.ietf.org/html/rfc8555#section-8.4
Token string `json:"token"`
// https://www.rfc-editor.org/rfc/rfc8555.html#section-8.1
// https://tools.ietf.org/html/rfc8555#section-8.1
KeyAuthorization string `json:"keyAuthorization"`
}
// Identifier the ACME identifier object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-9.7.7
// - https://tools.ietf.org/html/rfc8555#section-9.7.7
type Identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
// CSRMessage Certificate Signing Request.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.4
// - https://tools.ietf.org/html/rfc8555#section-7.4
type CSRMessage struct {
// csr (required, string):
// A CSR encoding the parameters for the certificate being requested [RFC2986].
@ -292,8 +267,8 @@ type CSRMessage struct {
}
// RevokeCertMessage a certificate revocation message.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.6
// - https://www.rfc-editor.org/rfc/rfc5280.html#section-5.3.1
// - https://tools.ietf.org/html/rfc8555#section-7.6
// - https://tools.ietf.org/html/rfc5280#section-5.3.1
type RevokeCertMessage struct {
// certificate (required, string):
// The certificate to be revoked, in the base64url-encoded version of the DER format.
@ -314,36 +289,3 @@ type RawCertificate struct {
Cert []byte
Issuer []byte
}
// Window is a window of time.
type Window struct {
Start time.Time `json:"start"`
End time.Time `json:"end"`
}
// RenewalInfoResponse is the response to GET requests made the renewalInfo endpoint.
// - (4.1. Getting Renewal Information) https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
type RenewalInfoResponse struct {
// SuggestedWindow contains two fields, start and end,
// whose values are timestamps which bound the window of time in which the CA recommends renewing the certificate.
SuggestedWindow Window `json:"suggestedWindow"`
// ExplanationURL is an optional URL pointing to a page which may explain why the suggested renewal window is what it is.
// For example, it may be a page explaining the CA's dynamic load-balancing strategy,
// or a page documenting which certificates are affected by a mass revocation event.
// Callers SHOULD provide this URL to their operator, if present.
ExplanationURL string `json:"explanationURL"`
}
// RenewalInfoUpdateRequest is the JWS payload for POST requests made to the renewalInfo endpoint.
// - (4.2. RenewalInfo Objects) https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-4.2
type RenewalInfoUpdateRequest struct {
// CertID is a composite string in the format: base64url(AKI) || '.' || base64url(Serial), where AKI is the
// certificate's authority key identifier and Serial is the certificate's serial number. For details, see:
// https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-4.1
CertID string `json:"certID"`
// Replaced is required and indicates whether or not the client considers the certificate to have been replaced.
// A certificate is considered replaced when its revocation would not disrupt any ongoing services,
// for instance because it has been renewed and the new certificate is in use, or because it is no longer in use.
// Clients SHOULD NOT send a request where this value is false.
Replaced bool `json:"replaced"`
}

View File

@ -11,8 +11,8 @@ const (
)
// ProblemDetails the problem details object.
// - https://www.rfc-editor.org/rfc/rfc7807.html#section-3.1
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.3.3
// - https://tools.ietf.org/html/rfc7807#section-3.1
// - https://tools.ietf.org/html/rfc8555#section-7.3.3
type ProblemDetails struct {
Type string `json:"type,omitempty"`
Detail string `json:"detail,omitempty"`
@ -26,7 +26,7 @@ type ProblemDetails struct {
}
// SubProblem a "subproblems".
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-6.7.1
// - https://tools.ietf.org/html/rfc8555#section-6.7.1
type SubProblem struct {
Type string `json:"type,omitempty"`
Detail string `json:"detail,omitempty"`

View File

@ -14,8 +14,6 @@ import (
"errors"
"fmt"
"math/big"
"net"
"slices"
"strings"
"time"
@ -27,7 +25,6 @@ const (
EC256 = KeyType("P256")
EC384 = KeyType("P384")
RSA2048 = KeyType("2048")
RSA3072 = KeyType("3072")
RSA4096 = KeyType("4096")
RSA8192 = KeyType("8192")
)
@ -85,12 +82,9 @@ func ParsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
// ParsePEMPrivateKey parses a private key from key, which is a PEM block.
// Borrowed from Go standard library, to handle various private key and PEM block types.
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L291-L308
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238)
func ParsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) {
keyBlockDER, _ := pem.Decode(key)
if keyBlockDER == nil {
return nil, errors.New("invalid PEM block")
}
if keyBlockDER.Type != "PRIVATE KEY" && !strings.HasSuffix(keyBlockDER.Type, " PRIVATE KEY") {
return nil, fmt.Errorf("unknown PEM header %q", keyBlockDER.Type)
@ -124,8 +118,6 @@ func GeneratePrivateKey(keyType KeyType) (crypto.PrivateKey, error) {
return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case RSA2048:
return rsa.GenerateKey(rand.Reader, 2048)
case RSA3072:
return rsa.GenerateKey(rand.Reader, 3072)
case RSA4096:
return rsa.GenerateKey(rand.Reader, 4096)
case RSA8192:
@ -136,20 +128,9 @@ func GeneratePrivateKey(keyType KeyType) (crypto.PrivateKey, error) {
}
func GenerateCSR(privateKey crypto.PrivateKey, domain string, san []string, mustStaple bool) ([]byte, error) {
var dnsNames []string
var ipAddresses []net.IP
for _, altname := range san {
if ip := net.ParseIP(altname); ip != nil {
ipAddresses = append(ipAddresses, ip)
} else {
dnsNames = append(dnsNames, altname)
}
}
template := x509.CertificateRequest{
Subject: pkix.Name{CommonName: domain},
DNSNames: dnsNames,
IPAddresses: ipAddresses,
DNSNames: san,
}
if mustStaple {
@ -217,26 +198,6 @@ func ParsePEMCertificate(cert []byte) (*x509.Certificate, error) {
return x509.ParseCertificate(pemBlock.Bytes)
}
func GetCertificateMainDomain(cert *x509.Certificate) (string, error) {
return getMainDomain(cert.Subject, cert.DNSNames)
}
func GetCSRMainDomain(cert *x509.CertificateRequest) (string, error) {
return getMainDomain(cert.Subject, cert.DNSNames)
}
func getMainDomain(subject pkix.Name, dnsNames []string) (string, error) {
if subject.CommonName == "" && len(dnsNames) == 0 {
return "", errors.New("missing domain")
}
if subject.CommonName != "" {
return subject.CommonName, nil
}
return dnsNames[0], nil
}
func ExtractDomains(cert *x509.Certificate) []string {
var domains []string
if cert.Subject.CommonName != "" {
@ -251,13 +212,6 @@ func ExtractDomains(cert *x509.Certificate) []string {
domains = append(domains, sanDomain)
}
commonNameIP := net.ParseIP(cert.Subject.CommonName)
for _, sanIP := range cert.IPAddresses {
if !commonNameIP.Equal(sanIP) {
domains = append(domains, sanIP.String())
}
}
return domains
}
@ -269,7 +223,7 @@ func ExtractDomainsCSR(csr *x509.CertificateRequest) []string {
// loop over the SubjectAltName DNS names
for _, sanName := range csr.DNSNames {
if slices.Contains(domains, sanName) {
if containsSAN(domains, sanName) {
// Duplicate; skip this name
continue
}
@ -278,14 +232,16 @@ func ExtractDomainsCSR(csr *x509.CertificateRequest) []string {
domains = append(domains, sanName)
}
cnip := net.ParseIP(csr.Subject.CommonName)
for _, sanIP := range csr.IPAddresses {
if !cnip.Equal(sanIP) {
domains = append(domains, sanIP.String())
}
return domains
}
return domains
func containsSAN(domains []string, sanName string) bool {
for _, existingName := range domains {
if existingName == sanName {
return true
}
}
return false
}
func GeneratePemCert(privateKey *rsa.PrivateKey, domain string, extensions []pkix.Extension) ([]byte, error) {
@ -305,7 +261,7 @@ func generateDerCert(privateKey *rsa.PrivateKey, expiration time.Time, domain st
}
if expiration.IsZero() {
expiration = time.Now().AddDate(1, 0, 0)
expiration = time.Now().Add(365)
}
template := x509.Certificate{
@ -318,15 +274,9 @@ func generateDerCert(privateKey *rsa.PrivateKey, expiration time.Time, domain st
KeyUsage: x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
DNSNames: []string{domain},
ExtraExtensions: extensions,
}
// handling SAN filling as type suspected
if ip := net.ParseIP(domain); ip != nil {
template.IPAddresses = []net.IP{ip}
} else {
template.DNSNames = []string{domain}
}
return x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
}

View File

@ -12,7 +12,6 @@ const (
// limited on the "new-reg", "new-authz" and "new-cert" endpoints.
// From the documentation the limitation is 20 requests per second,
// but using 20 as value doesn't work but 18 do.
// https://letsencrypt.org/docs/rate-limits/
overallRequestLimit = 18
)
@ -36,14 +35,13 @@ func (c *Certifier) getAuthorizations(order acme.ExtendedOrder) ([]acme.Authoriz
}
var responses []acme.Authorization
failures := newObtainError()
for range len(order.Authorizations) {
failures := make(obtainError)
for i := 0; i < len(order.Authorizations); i++ {
select {
case res := <-resc:
responses = append(responses, res)
case err := <-errc:
failures.Add(err.Domain, err.Error)
failures[err.Domain] = err.Error
}
}
@ -54,10 +52,15 @@ func (c *Certifier) getAuthorizations(order acme.ExtendedOrder) ([]acme.Authoriz
close(resc)
close(errc)
return responses, failures.Join()
// be careful to not return an empty failures map;
// even if empty, they become non-nil error values
if len(failures) > 0 {
return responses, failures
}
return responses, nil
}
func (c *Certifier) deactivateAuthorizations(order acme.ExtendedOrder, force bool) {
func (c *Certifier) deactivateAuthorizations(order acme.ExtendedOrder) {
for _, authzURL := range order.Authorizations {
auth, err := c.core.Authorizations.Get(authzURL)
if err != nil {
@ -65,7 +68,7 @@ func (c *Certifier) deactivateAuthorizations(order acme.ExtendedOrder, force boo
continue
}
if auth.Status == acme.StatusValid && !force {
if auth.Status == acme.StatusValid {
log.Infof("Skipping deactivating of valid auth: %s", authzURL)
continue
}

View File

@ -7,7 +7,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"time"
@ -49,44 +49,22 @@ type Resource struct {
// If you do not want that you can supply your own private key in the privateKey parameter.
// If this parameter is non-nil it will be used instead of generating a new one.
//
// If `Bundle` is true, the `[]byte` contains both the issuer certificate and your issued certificate as a bundle.
//
// If `AlwaysDeactivateAuthorizations` is true, the authorizations are also relinquished if the obtain request was successful.
// See https://datatracker.ietf.org/doc/html/rfc8555#section-7.5.2.
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
type ObtainRequest struct {
Domains []string
Bundle bool
PrivateKey crypto.PrivateKey
MustStaple bool
NotBefore time.Time
NotAfter time.Time
Bundle bool
PreferredChain string
AlwaysDeactivateAuthorizations bool
// A string uniquely identifying a previously-issued certificate which this
// order is intended to replace.
// - https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5
ReplacesCertID string
}
// ObtainForCSRRequest The request to obtain a certificate matching the CSR passed into it.
//
// If `Bundle` is true, the `[]byte` contains both the issuer certificate and your issued certificate as a bundle.
//
// If `AlwaysDeactivateAuthorizations` is true, the authorizations are also relinquished if the obtain request was successful.
// See https://datatracker.ietf.org/doc/html/rfc8555#section-7.5.2.
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
type ObtainForCSRRequest struct {
CSR *x509.CertificateRequest
NotBefore time.Time
NotAfter time.Time
Bundle bool
PreferredChain string
AlwaysDeactivateAuthorizations bool
// A string uniquely identifying a previously-issued certificate which this
// order is intended to replace.
// - https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5
ReplacesCertID string
}
type resolver interface {
@ -131,13 +109,7 @@ func (c *Certifier) Obtain(request ObtainRequest) (*Resource, error) {
log.Infof("[%s] acme: Obtaining SAN certificate", strings.Join(domains, ", "))
}
orderOpts := &api.OrderOptions{
NotBefore: request.NotBefore,
NotAfter: request.NotAfter,
ReplacesCertID: request.ReplacesCertID,
}
order, err := c.core.Orders.NewWithOptions(domains, orderOpts)
order, err := c.core.Orders.New(domains)
if err != nil {
return nil, err
}
@ -145,32 +117,33 @@ func (c *Certifier) Obtain(request ObtainRequest) (*Resource, error) {
authz, err := c.getAuthorizations(order)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order, request.AlwaysDeactivateAuthorizations)
c.deactivateAuthorizations(order)
return nil, err
}
err = c.resolver.Solve(authz)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order, request.AlwaysDeactivateAuthorizations)
c.deactivateAuthorizations(order)
return nil, err
}
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := newObtainError()
failures := make(obtainError)
cert, err := c.getForOrder(domains, order, request.Bundle, request.PrivateKey, request.MustStaple, request.PreferredChain)
if err != nil {
for _, auth := range authz {
failures.Add(challenge.GetTargetedDomain(auth), err)
failures[challenge.GetTargetedDomain(auth)] = err
}
}
if request.AlwaysDeactivateAuthorizations {
c.deactivateAuthorizations(order, true)
// Do not return an empty failures map, because
// it would still be a non-nil error value
if len(failures) > 0 {
return cert, failures
}
return cert, failures.Join()
return cert, nil
}
// ObtainForCSR tries to obtain a certificate matching the CSR passed into it.
@ -197,13 +170,7 @@ func (c *Certifier) ObtainForCSR(request ObtainForCSRRequest) (*Resource, error)
log.Infof("[%s] acme: Obtaining SAN certificate given a CSR", strings.Join(domains, ", "))
}
orderOpts := &api.OrderOptions{
NotBefore: request.NotBefore,
NotAfter: request.NotAfter,
ReplacesCertID: request.ReplacesCertID,
}
order, err := c.core.Orders.NewWithOptions(domains, orderOpts)
order, err := c.core.Orders.New(domains)
if err != nil {
return nil, err
}
@ -211,37 +178,38 @@ func (c *Certifier) ObtainForCSR(request ObtainForCSRRequest) (*Resource, error)
authz, err := c.getAuthorizations(order)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order, request.AlwaysDeactivateAuthorizations)
c.deactivateAuthorizations(order)
return nil, err
}
err = c.resolver.Solve(authz)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order, request.AlwaysDeactivateAuthorizations)
c.deactivateAuthorizations(order)
return nil, err
}
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := newObtainError()
failures := make(obtainError)
cert, err := c.getForCSR(domains, order, request.Bundle, request.CSR.Raw, nil, request.PreferredChain)
if err != nil {
for _, auth := range authz {
failures.Add(challenge.GetTargetedDomain(auth), err)
failures[challenge.GetTargetedDomain(auth)] = err
}
}
if request.AlwaysDeactivateAuthorizations {
c.deactivateAuthorizations(order, true)
}
if cert != nil {
// Add the CSR to the certificate so that it can be used for renewals.
cert.CSR = certcrypto.PEMEncode(request.CSR)
}
return cert, failures.Join()
// Do not return an empty failures map,
// because it would still be a non-nil error value
if len(failures) > 0 {
return cert, failures
}
return cert, nil
}
func (c *Certifier) getForOrder(domains []string, order acme.ExtendedOrder, bundle bool, privateKey crypto.PrivateKey, mustStaple bool, preferredChain string) (*Resource, error) {
@ -253,23 +221,16 @@ func (c *Certifier) getForOrder(domains []string, order acme.ExtendedOrder, bund
}
}
commonName := ""
if len(domains[0]) <= 64 {
commonName = domains[0]
}
// Determine certificate name(s) based on the authorization resources
commonName := domains[0]
// RFC8555 Section 7.4 "Applying for Certificate Issuance"
// https://www.rfc-editor.org/rfc/rfc8555.html#section-7.4
// https://tools.ietf.org/html/rfc8555#section-7.4
// says:
// Clients SHOULD NOT make any assumptions about the sort order of
// "identifiers" or "authorizations" elements in the returned order
// object.
var san []string
if commonName != "" {
san = append(san, commonName)
}
san := []string{commonName}
for _, auth := range order.Identifiers {
if auth.Value != commonName {
san = append(san, auth.Value)
@ -291,8 +252,9 @@ func (c *Certifier) getForCSR(domains []string, order acme.ExtendedOrder, bundle
return nil, err
}
commonName := domains[0]
certRes := &Resource{
Domain: domains[0],
Domain: commonName,
CertURL: respOrder.Certificate,
PrivateKey: privateKeyPem,
}
@ -387,11 +349,6 @@ func (c *Certifier) checkResponse(order acme.ExtendedOrder, certRes *Resource, b
// Revoke takes a PEM encoded certificate or bundle and tries to revoke it at the CA.
func (c *Certifier) Revoke(cert []byte) error {
return c.RevokeWithReason(cert, nil)
}
// RevokeWithReason takes a PEM encoded certificate or bundle and tries to revoke it at the CA.
func (c *Certifier) RevokeWithReason(cert []byte, reason *uint) error {
certificates, err := certcrypto.ParsePEMBundle(cert)
if err != nil {
return err
@ -404,24 +361,11 @@ func (c *Certifier) RevokeWithReason(cert []byte, reason *uint) error {
revokeMsg := acme.RevokeCertMessage{
Certificate: base64.RawURLEncoding.EncodeToString(x509Cert.Raw),
Reason: reason,
}
return c.core.Certificates.Revoke(revokeMsg)
}
// RenewOptions options used by Certifier.RenewWithOptions.
type RenewOptions struct {
NotBefore time.Time
NotAfter time.Time
// If true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
Bundle bool
PreferredChain string
AlwaysDeactivateAuthorizations bool
// Not supported for CSR request.
MustStaple bool
}
// Renew takes a Resource and tries to renew the certificate.
//
// If the renewal process succeeds, the new certificate will be returned in a new CertResource.
@ -432,26 +376,7 @@ type RenewOptions struct {
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
//
// For private key reuse the PrivateKey property of the passed in Resource should be non-nil.
// Deprecated: use RenewWithOptions instead.
func (c *Certifier) Renew(certRes Resource, bundle, mustStaple bool, preferredChain string) (*Resource, error) {
return c.RenewWithOptions(certRes, &RenewOptions{
Bundle: bundle,
PreferredChain: preferredChain,
MustStaple: mustStaple,
})
}
// RenewWithOptions takes a Resource and tries to renew the certificate.
//
// If the renewal process succeeds, the new certificate will be returned in a new CertResource.
// Please be aware that this function will return a new certificate in ANY case that is not an error.
// If the server does not provide us with a new cert on a GET request to the CertURL
// this function will start a new-cert flow where a new certificate gets generated.
//
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
//
// For private key reuse the PrivateKey property of the passed in Resource should be non-nil.
func (c *Certifier) RenewWithOptions(certRes Resource, options *RenewOptions) (*Resource, error) {
// Input certificate is PEM encoded.
// Decode it here as we may need the decoded cert later on in the renewal process.
// The input may be a bundle or a single certificate.
@ -478,17 +403,11 @@ func (c *Certifier) RenewWithOptions(certRes Resource, options *RenewOptions) (*
return nil, errP
}
request := ObtainForCSRRequest{CSR: csr}
if options != nil {
request.NotBefore = options.NotBefore
request.NotAfter = options.NotAfter
request.Bundle = options.Bundle
request.PreferredChain = options.PreferredChain
request.AlwaysDeactivateAuthorizations = options.AlwaysDeactivateAuthorizations
}
return c.ObtainForCSR(request)
return c.ObtainForCSR(ObtainForCSRRequest{
CSR: csr,
Bundle: bundle,
PreferredChain: preferredChain,
})
}
var privateKey crypto.PrivateKey
@ -499,21 +418,13 @@ func (c *Certifier) RenewWithOptions(certRes Resource, options *RenewOptions) (*
}
}
request := ObtainRequest{
query := ObtainRequest{
Domains: certcrypto.ExtractDomains(x509Cert),
Bundle: bundle,
PrivateKey: privateKey,
MustStaple: mustStaple,
}
if options != nil {
request.MustStaple = options.MustStaple
request.NotBefore = options.NotBefore
request.NotAfter = options.NotAfter
request.Bundle = options.Bundle
request.PreferredChain = options.PreferredChain
request.AlwaysDeactivateAuthorizations = options.AlwaysDeactivateAuthorizations
}
return c.Obtain(request)
return c.Obtain(query)
}
// GetOCSP takes a PEM encoded cert or cert bundle returning the raw OCSP response,
@ -554,7 +465,7 @@ func (c *Certifier) GetOCSP(bundle []byte) ([]byte, *ocsp.Response, error) {
}
defer resp.Body.Close()
issuerBytes, errC := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
issuerBytes, errC := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if errC != nil {
return nil, nil, errC
}
@ -583,7 +494,7 @@ func (c *Certifier) GetOCSP(bundle []byte) ([]byte, *ocsp.Response, error) {
}
defer resp.Body.Close()
ocspResBytes, err := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
ocspResBytes, err := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if err != nil {
return nil, nil, err
}
@ -614,13 +525,8 @@ func (c *Certifier) Get(url string, bundle bool) (*Resource, error) {
return nil, err
}
domain, err := certcrypto.GetCertificateMainDomain(x509Certs[0])
if err != nil {
return nil, err
}
return &Resource{
Domain: domain,
Domain: x509Certs[0].Subject.CommonName,
Certificate: cert,
IssuerCertificate: issuer,
CertURL: url,
@ -654,11 +560,11 @@ func checkOrderStatus(order acme.ExtendedOrder) (bool, error) {
}
}
// https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.4
// https://tools.ietf.org/html/rfc8555#section-7.1.4
// The domain name MUST be encoded in the form in which it would appear in a certificate.
// That is, it MUST be encoded according to the rules in Section 7 of [RFC5280].
//
// https://www.rfc-editor.org/rfc/rfc5280.html#section-7
// https://tools.ietf.org/html/rfc5280#section-7
func sanitizeDomain(domains []string) []string {
var sanitizedDomains []string
for _, domain := range domains {

View File

@ -1,37 +1,27 @@
package certificate
import (
"errors"
"bytes"
"fmt"
"sort"
)
type obtainError struct {
data map[string]error
}
// obtainError is returned when there are specific errors available per domain.
type obtainError map[string]error
func newObtainError() *obtainError {
return &obtainError{data: make(map[string]error)}
}
func (e obtainError) Error() string {
buffer := bytes.NewBufferString("error: one or more domains had a problem:\n")
func (e *obtainError) Add(domain string, err error) {
e.data[domain] = err
var domains []string
for domain := range e {
domains = append(domains, domain)
}
sort.Strings(domains)
func (e *obtainError) Join() error {
if e == nil {
return nil
for _, domain := range domains {
buffer.WriteString(fmt.Sprintf("[%s] %s\n", domain, e[domain]))
}
if len(e.data) == 0 {
return nil
}
var err error
for d, e := range e.data {
err = errors.Join(err, fmt.Errorf("%s: %w", d, e))
}
return fmt.Errorf("error: one or more domains had a problem:\n%w", err)
return buffer.String()
}
type domainError struct {

View File

@ -1,129 +0,0 @@
package certificate
import (
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/rand"
"time"
"github.com/go-acme/lego/v4/acme"
)
// RenewalInfoRequest contains the necessary renewal information.
type RenewalInfoRequest struct {
Cert *x509.Certificate
}
// RenewalInfoResponse is a wrapper around acme.RenewalInfoResponse that provides a method for determining when to renew a certificate.
type RenewalInfoResponse struct {
acme.RenewalInfoResponse
// RetryAfter header indicating the polling interval that the ACME server recommends.
// Conforming clients SHOULD query the renewalInfo URL again after the RetryAfter period has passed,
// as the server may provide a different suggestedWindow.
// https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-4.2
RetryAfter time.Duration
}
// ShouldRenewAt determines the optimal renewal time based on the current time (UTC),renewal window suggest by ARI, and the client's willingness to sleep.
// It returns a pointer to a time.Time value indicating when the renewal should be attempted or nil if deferred until the next normal wake time.
// This method implements the RECOMMENDED algorithm described in draft-ietf-acme-ari.
//
// - (4.1-11. Getting Renewal Information) https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
func (r *RenewalInfoResponse) ShouldRenewAt(now time.Time, willingToSleep time.Duration) *time.Time {
// Explicitly convert all times to UTC.
now = now.UTC()
start := r.SuggestedWindow.Start.UTC()
end := r.SuggestedWindow.End.UTC()
// Select a uniform random time within the suggested window.
window := end.Sub(start)
randomDuration := time.Duration(rand.Int63n(int64(window)))
rt := start.Add(randomDuration)
// If the selected time is in the past, attempt renewal immediately.
if rt.Before(now) {
return &now
}
// Otherwise, if the client can schedule itself to attempt renewal at exactly the selected time, do so.
willingToSleepUntil := now.Add(willingToSleep)
if willingToSleepUntil.After(rt) || willingToSleepUntil.Equal(rt) {
return &rt
}
// TODO: Otherwise, if the selected time is before the next time that the client would wake up normally, attempt renewal immediately.
// Otherwise, sleep until the next normal wake time, re-check ARI, and return to Step 1.
return nil
}
// GetRenewalInfo sends a request to the ACME server's renewalInfo endpoint to obtain a suggested renewal window.
// The caller MUST provide the certificate and issuer certificate for the certificate they wish to renew.
// The caller should attempt to renew the certificate at the time indicated by the ShouldRenewAt method of the returned RenewalInfoResponse object.
//
// Note: this endpoint is part of a draft specification, not all ACME servers will implement it.
// This method will return api.ErrNoARI if the server does not advertise a renewal info endpoint.
//
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari
func (c *Certifier) GetRenewalInfo(req RenewalInfoRequest) (*RenewalInfoResponse, error) {
certID, err := MakeARICertID(req.Cert)
if err != nil {
return nil, fmt.Errorf("error making certID: %w", err)
}
resp, err := c.core.Certificates.GetRenewalInfo(certID)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var info RenewalInfoResponse
err = json.NewDecoder(resp.Body).Decode(&info)
if err != nil {
return nil, err
}
if retry := resp.Header.Get("Retry-After"); retry != "" {
info.RetryAfter, err = time.ParseDuration(retry + "s")
if err != nil {
return nil, err
}
}
return &info, nil
}
// MakeARICertID constructs a certificate identifier as described in draft-ietf-acme-ari-03, section 4.1.
func MakeARICertID(leaf *x509.Certificate) (string, error) {
if leaf == nil {
return "", errors.New("leaf certificate is nil")
}
// Marshal the Serial Number into DER.
der, err := asn1.Marshal(leaf.SerialNumber)
if err != nil {
return "", err
}
// Check if the DER encoded bytes are sufficient (at least 3 bytes: tag,
// length, and value).
if len(der) < 3 {
return "", errors.New("invalid DER encoding of serial number")
}
// Extract only the integer bytes from the DER encoded Serial Number
// Skipping the first 2 bytes (tag and length).
serial := base64.RawURLEncoding.EncodeToString(der[2:])
// Convert the Authority Key Identifier to base64url encoding without
// padding.
aki := base64.RawURLEncoding.EncodeToString(leaf.AuthorityKeyId)
// Construct the final identifier by concatenating AKI and Serial Number.
return fmt.Sprintf("%s.%s", aki, serial), nil
}

View File

@ -10,15 +10,15 @@ import (
type Type string
const (
// HTTP01 is the "http-01" ACME challenge https://www.rfc-editor.org/rfc/rfc8555.html#section-8.3
// HTTP01 is the "http-01" ACME challenge https://tools.ietf.org/html/rfc8555#section-8.3
// Note: ChallengePath returns the URL path to fulfill this challenge.
HTTP01 = Type("http-01")
// DNS01 is the "dns-01" ACME challenge https://www.rfc-editor.org/rfc/rfc8555.html#section-8.4
// DNS01 is the "dns-01" ACME challenge https://tools.ietf.org/html/rfc8555#section-8.4
// Note: GetRecord returns a DNS record which will fulfill this challenge.
DNS01 = Type("dns-01")
// TLSALPN01 is the "tls-alpn-01" ACME challenge https://www.rfc-editor.org/rfc/rfc8737.html
// TLSALPN01 is the "tls-alpn-01" ACME challenge https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-07
TLSALPN01 = Type("tls-alpn-01")
)

View File

@ -1,16 +1,12 @@
package dns01
import (
"strings"
"github.com/miekg/dns"
)
import "github.com/miekg/dns"
// Update FQDN with CNAME if any.
func updateDomainWithCName(r *dns.Msg, fqdn string) string {
for _, rr := range r.Answer {
if cn, ok := rr.(*dns.CNAME); ok {
if strings.EqualFold(cn.Hdr.Name, fqdn) {
if cn.Hdr.Name == fqdn {
return cn.Target
}
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/go-acme/lego/v4/acme"
@ -115,7 +114,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
return err
}
info := GetChallengeInfo(authz.Identifier.Value, keyAuth)
fqdn, value := GetRecord(authz.Identifier.Value, keyAuth)
var timeout, interval time.Duration
switch provider := c.provider.(type) {
@ -125,12 +124,12 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval
}
log.Infof("[%s] acme: Checking DNS record propagation. [nameservers=%s]", domain, strings.Join(recursiveNameservers, ","))
log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers)
time.Sleep(interval)
err = wait.For("propagation", timeout, interval, func() (bool, error) {
stop, errP := c.preCheck.call(domain, info.EffectiveFQDN, info.Value)
stop, errP := c.preCheck.call(domain, fqdn, value)
if !stop || errP != nil {
log.Infof("[%s] acme: Waiting for DNS record propagation.", domain)
}
@ -173,67 +172,19 @@ type sequential interface {
}
// GetRecord returns a DNS record which will fulfill the `dns-01` challenge.
// Deprecated: use GetChallengeInfo instead.
func GetRecord(domain, keyAuth string) (fqdn, value string) {
info := GetChallengeInfo(domain, keyAuth)
return info.EffectiveFQDN, info.Value
}
// ChallengeInfo contains the information use to create the TXT record.
type ChallengeInfo struct {
// FQDN is the full-qualified challenge domain (i.e. `_acme-challenge.[domain].`)
FQDN string
// EffectiveFQDN contains the resulting FQDN after the CNAMEs resolutions.
EffectiveFQDN string
// Value contains the value for the TXT record.
Value string
}
// GetChallengeInfo returns information used to create a DNS record which will fulfill the `dns-01` challenge.
func GetChallengeInfo(domain, keyAuth string) ChallengeInfo {
keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
// base64URL encoding without padding
value := base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
value = base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
fqdn = fmt.Sprintf("_acme-challenge.%s.", domain)
ok, _ := strconv.ParseBool(os.Getenv("LEGO_DISABLE_CNAME_SUPPORT"))
return ChallengeInfo{
Value: value,
FQDN: getChallengeFQDN(domain, false),
EffectiveFQDN: getChallengeFQDN(domain, !ok),
}
}
func getChallengeFQDN(domain string, followCNAME bool) string {
fqdn := fmt.Sprintf("_acme-challenge.%s.", domain)
if !followCNAME {
return fqdn
}
// recursion counter so it doesn't spin out of control
for range 50 {
// Keep following CNAMEs
if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_CNAME_SUPPORT")); ok {
r, err := dnsQuery(fqdn, dns.TypeCNAME, recursiveNameservers, true)
if err != nil || r.Rcode != dns.RcodeSuccess {
// No more CNAME records to follow, exit
break
// Check if the domain has CNAME then return that
if err == nil && r.Rcode == dns.RcodeSuccess {
fqdn = updateDomainWithCName(r, fqdn)
}
}
// Check if the domain has CNAME then use that
cname := updateDomainWithCName(r, fqdn)
if cname == fqdn {
break
}
log.Infof("Found CNAME entry for %q: %q", fqdn, cname)
fqdn = cname
}
return fqdn
return
}

View File

@ -8,7 +8,7 @@ import (
)
const (
dnsTemplate = `%s %d IN TXT %q`
dnsTemplate = `%s %d IN TXT "%s"`
)
// DNSProviderManual is an implementation of the ChallengeProvider interface.
@ -21,36 +21,33 @@ func NewDNSProviderManual() (*DNSProviderManual, error) {
// Present prints instructions for manually creating the TXT record.
func (*DNSProviderManual) Present(domain, token, keyAuth string) error {
info := GetChallengeInfo(domain, keyAuth)
fqdn, value := GetRecord(domain, keyAuth)
authZone, err := FindZoneByFqdn(info.EffectiveFQDN)
authZone, err := FindZoneByFqdn(fqdn)
if err != nil {
return fmt.Errorf("manual: could not find zone: %w", err)
return err
}
fmt.Printf("lego: Please create the following TXT record in your %s zone:\n", authZone)
fmt.Printf(dnsTemplate+"\n", info.EffectiveFQDN, DefaultTTL, info.Value)
fmt.Printf(dnsTemplate+"\n", fqdn, DefaultTTL, value)
fmt.Printf("lego: Press 'Enter' when you are done\n")
_, err = bufio.NewReader(os.Stdin).ReadBytes('\n')
if err != nil {
return fmt.Errorf("manual: %w", err)
}
return nil
return err
}
// CleanUp prints instructions for manually removing the TXT record.
func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error {
info := GetChallengeInfo(domain, keyAuth)
fqdn, _ := GetRecord(domain, keyAuth)
authZone, err := FindZoneByFqdn(info.EffectiveFQDN)
authZone, err := FindZoneByFqdn(fqdn)
if err != nil {
return fmt.Errorf("manual: could not find zone: %w", err)
return err
}
fmt.Printf("lego: You can now remove this TXT record from your %s zone:\n", authZone)
fmt.Printf(dnsTemplate+"\n", info.EffectiveFQDN, DefaultTTL, "...")
fmt.Printf(dnsTemplate+"\n", fqdn, DefaultTTL, "...")
return nil
}

View File

@ -1,24 +0,0 @@
package dns01
import (
"fmt"
"strings"
"github.com/miekg/dns"
)
// ExtractSubDomain extracts the subdomain part from a domain and a zone.
func ExtractSubDomain(domain, zone string) (string, error) {
canonDomain := dns.Fqdn(domain)
canonZone := dns.Fqdn(zone)
if canonDomain == canonZone {
return "", fmt.Errorf("no subdomain because the domain and the zone are identical: %s", canonDomain)
}
if !dns.IsSubDomain(canonZone, canonDomain) {
return "", fmt.Errorf("%s is not a subdomain of %s", canonDomain, canonZone)
}
return strings.TrimSuffix(canonDomain, "."+canonZone), nil
}

View File

@ -4,9 +4,6 @@ import (
"errors"
"fmt"
"net"
"os"
"slices"
"strconv"
"strings"
"sync"
"time"
@ -16,6 +13,9 @@ import (
const defaultResolvConf = "/etc/resolv.conf"
// dnsTimeout is used to override the default DNS timeout of 10 seconds.
var dnsTimeout = 10 * time.Second
var (
fqdnSoaCache = map[string]*soaCacheEntry{}
muFqdnSoaCache sync.Mutex
@ -99,12 +99,12 @@ func lookupNameservers(fqdn string) ([]string, error) {
zone, err := FindZoneByFqdn(fqdn)
if err != nil {
return nil, fmt.Errorf("could not find zone: %w", err)
return nil, fmt.Errorf("could not determine the zone: %w", err)
}
r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true)
if err != nil {
return nil, fmt.Errorf("NS call failed: %w", err)
return nil, err
}
for _, rr := range r.Answer {
@ -116,8 +116,7 @@ func lookupNameservers(fqdn string) ([]string, error) {
if len(authoritativeNss) > 0 {
return authoritativeNss, nil
}
return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
return nil, errors.New("could not determine authoritative nameservers")
}
// FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn
@ -131,7 +130,7 @@ func FindPrimaryNsByFqdn(fqdn string) (string, error) {
func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) {
soa, err := lookupSoaByFqdn(fqdn, nameservers)
if err != nil {
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
return "", err
}
return soa.primaryNs, nil
}
@ -147,7 +146,7 @@ func FindZoneByFqdn(fqdn string) (string, error) {
func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) {
soa, err := lookupSoaByFqdn(fqdn, nameservers)
if err != nil {
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err)
return "", err
}
return soa.zone, nil
}
@ -172,35 +171,35 @@ func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error)
func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
var err error
var r *dns.Msg
var in *dns.Msg
labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes {
domain := fqdn[index:]
r, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil {
continue
}
if r == nil {
if in == nil {
continue
}
switch r.Rcode {
switch in.Rcode {
case dns.RcodeSuccess:
// Check if we got a SOA RR in the answer section
if len(r.Answer) == 0 {
if len(in.Answer) == 0 {
continue
}
// CNAME records cannot/should not exist at the root of a zone.
// So we skip a domain when a CNAME is found.
if dnsMsgContainsCNAME(r) {
if dnsMsgContainsCNAME(in) {
continue
}
for _, ans := range r.Answer {
for _, ans := range in.Answer {
if soa, ok := ans.(*dns.SOA); ok {
return newSoaCacheEntry(soa), nil
}
@ -209,46 +208,36 @@ func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
// NXDOMAIN
default:
// Any response code other than NOERROR and NXDOMAIN is treated as error
return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r}
return nil, fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain)
}
}
return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err}
return nil, fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err))
}
// dnsMsgContainsCNAME checks for a CNAME answer in msg.
func dnsMsgContainsCNAME(msg *dns.Msg) bool {
return slices.ContainsFunc(msg.Answer, func(rr dns.RR) bool {
_, ok := rr.(*dns.CNAME)
return ok
})
for _, ans := range msg.Answer {
if _, ok := ans.(*dns.CNAME); ok {
return true
}
}
return false
}
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive)
if len(nameservers) == 0 {
return nil, &DNSError{Message: "empty list of nameservers"}
}
var r *dns.Msg
var in *dns.Msg
var err error
var errAll error
for _, ns := range nameservers {
r, err = sendDNSQuery(m, ns)
if err == nil && len(r.Answer) > 0 {
in, err = sendDNSQuery(m, ns)
if err == nil && len(in.Answer) > 0 {
break
}
errAll = errors.Join(errAll, err)
}
if err != nil {
return r, errAll
}
return r, nil
return in, err
}
func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
@ -264,84 +253,32 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
}
func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
r, _, err := tcp.Exchange(m, ns)
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}
return r, nil
}
udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
r, _, err := udp.Exchange(m, ns)
in, _, err := udp.Exchange(m, ns)
if r != nil && r.Truncated {
if in != nil && in.Truncated {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
// If the TCP request succeeds, the "err" will reset to nil
r, _, err = tcp.Exchange(m, ns)
// If the TCP request succeeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
return in, err
}
func formatDNSError(msg *dns.Msg, err error) string {
var parts []string
if msg != nil {
parts = append(parts, dns.RcodeToString[msg.Rcode])
}
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
parts = append(parts, err.Error())
}
return r, nil
if len(parts) > 0 {
return ": " + strings.Join(parts, " ")
}
// DNSError error related to DNS calls.
type DNSError struct {
Message string
NS string
MsgIn *dns.Msg
MsgOut *dns.Msg
Err error
}
func (d *DNSError) Error() string {
var details []string
if d.NS != "" {
details = append(details, "ns="+d.NS)
}
if d.MsgIn != nil && len(d.MsgIn.Question) > 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question)))
}
if d.MsgOut != nil {
if d.MsgIn == nil || len(d.MsgIn.Question) == 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question)))
}
details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode])
}
msg := "DNS error"
if d.Message != "" {
msg = d.Message
}
if d.Err != nil {
msg += ": " + d.Err.Error()
}
if len(details) > 0 {
msg += " [" + strings.Join(details, ", ") + "]"
}
return msg
}
func (d *DNSError) Unwrap() error {
return d.Err
}
func formatQuestions(questions []dns.Question) string {
var parts []string
for _, question := range questions {
parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " "))
}
return strings.Join(parts, ";")
return ""
}

View File

@ -1,8 +0,0 @@
//go:build !windows
package dns01
import "time"
// dnsTimeout is used to override the default DNS timeout of 10 seconds.
var dnsTimeout = 10 * time.Second

View File

@ -1,8 +0,0 @@
//go:build windows
package dns01
import "time"
// dnsTimeout is used to override the default DNS timeout of 20 seconds.
var dnsTimeout = 20 * time.Second

View File

@ -23,16 +23,13 @@ import (
// RFC7239 has standardized the different forwarding headers into a single header named Forwarded.
// The header value has a different format, so you should use forwardedMatcher
// when the http01.ProviderServer operates behind a RFC7239 compatible proxy.
// https://www.rfc-editor.org/rfc/rfc7239.html
// https://tools.ietf.org/html/rfc7239
//
// Note: RFC7239 also reminds us, "that an HTTP list [...] may be split over multiple header fields" (section 7.1),
// meaning that
//
// X-Header: a
// X-Header: b
//
// is equal to
//
// X-Header: a, b
//
// All matcher implementations (explicitly not excluding arbitraryMatcher!)
@ -69,7 +66,7 @@ func (m arbitraryMatcher) matches(r *http.Request, domain string) bool {
}
// forwardedMatcher checks whether the Forwarded header contains a "host" element starting with a domain name.
// See https://www.rfc-editor.org/rfc/rfc7239.html for details.
// See https://tools.ietf.org/html/rfc7239 for details.
type forwardedMatcher struct{}
func (m *forwardedMatcher) name() string {

View File

@ -2,11 +2,9 @@ package http01
import (
"fmt"
"io/fs"
"net"
"net/http"
"net/textproto"
"os"
"strings"
"github.com/go-acme/lego/v4/log"
@ -16,11 +14,8 @@ import (
// It may be instantiated without using the NewProviderServer function if
// you want only to use the default values.
type ProviderServer struct {
address string
network string // must be valid argument to net.Listen
socketMode fs.FileMode
iface string
port string
matcher domainMatcher
done chan bool
listener net.Listener
@ -34,34 +29,24 @@ func NewProviderServer(iface, port string) *ProviderServer {
port = "80"
}
return &ProviderServer{network: "tcp", address: net.JoinHostPort(iface, port), matcher: &hostMatcher{}}
}
func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer {
return &ProviderServer{network: "unix", address: socketPath, socketMode: mode, matcher: &hostMatcher{}}
return &ProviderServer{iface: iface, port: port, matcher: &hostMatcher{}}
}
// Present starts a web server and makes the token available at `ChallengePath(token)` for web requests.
func (s *ProviderServer) Present(domain, token, keyAuth string) error {
var err error
s.listener, err = net.Listen(s.network, s.GetAddress())
s.listener, err = net.Listen("tcp", s.GetAddress())
if err != nil {
return fmt.Errorf("could not start HTTP server for challenge: %w", err)
}
if s.network == "unix" {
if err = os.Chmod(s.address, s.socketMode); err != nil {
return fmt.Errorf("chmod %s: %w", s.address, err)
}
}
s.done = make(chan bool)
go s.serve(domain, token, keyAuth)
return nil
}
func (s *ProviderServer) GetAddress() string {
return s.address
return net.JoinHostPort(s.iface, s.port)
}
// CleanUp closes the HTTP server and removes the token from `ChallengePath(token)`.
@ -84,7 +69,7 @@ func (s *ProviderServer) CleanUp(domain, token, keyAuth string) error {
//
// The exact behavior depends on the value of headerName:
// - "" (the empty string) and "Host" will restore the default and only check the Host header
// - "Forwarded" will look for a Forwarded header, and inspect it according to https://www.rfc-editor.org/rfc/rfc7239.html
// - "Forwarded" will look for a Forwarded header, and inspect it according to https://tools.ietf.org/html/rfc7239
// - any other value will check the header value with the same name.
func (s *ProviderServer) SetProxyHeader(headerName string) {
switch h := textproto.CanonicalMIMEHeaderKey(headerName); h {
@ -100,7 +85,7 @@ func (s *ProviderServer) SetProxyHeader(headerName string) {
func (s *ProviderServer) serve(domain, token, keyAuth string) {
path := ChallengePath(token)
// The incoming request will be validated to prevent DNS rebind attacks.
// The incoming request must will be validated to prevent DNS rebind attacks.
// We only respond with the keyAuth, when we're receiving a GET requests with
// the "Host" header matching the domain (the latter is configurable though SetProxyHeader).
mux := http.NewServeMux()
@ -114,7 +99,7 @@ func (s *ProviderServer) serve(domain, token, keyAuth string) {
}
log.Infof("[%s] Served key authentication", domain)
} else {
log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure you are passing the %s header properly.", r.Host, r.Method, s.matcher.name())
log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure your are passing the %s header properly.", r.Host, r.Method, s.matcher.name())
_, err := w.Write([]byte("TEST"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)

View File

@ -19,7 +19,7 @@ func (e obtainError) Error() string {
sort.Strings(domains)
for _, domain := range domains {
_, _ = fmt.Fprintf(buffer, "[%s] %s\n", domain, e[domain])
buffer.WriteString(fmt.Sprintf("[%s] %s\n", domain, e[domain]))
}
return buffer.String()
}

View File

@ -128,7 +128,7 @@ func sequentialSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
}
func parallelSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
// For all valid preSolvers, first submit the challenges, so they have max time to propagate
// For all valid preSolvers, first submit the challenges so they have max time to propagate
for _, authSolver := range authSolvers {
authz := authSolver.authz
if solvr, ok := authSolver.solver.(preSolver); ok {

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"errors"
"fmt"
"sort"
@ -53,7 +54,7 @@ func (c *SolverManager) SetDNS01Provider(p challenge.Provider, opts ...dns01.Cha
return nil
}
// Remove removes a challenge type from the available solvers.
// Remove Remove a challenge type from the available solvers.
func (c *SolverManager) Remove(chlgType challenge.Type) {
delete(c.solvers, chlgType)
}
@ -106,17 +107,21 @@ func validate(core *api.Core, domain string, chlg acme.Challenge) error {
bo.MaxInterval = 10 * initialInterval
bo.MaxElapsedTime = 100 * initialInterval
ctx, cancel := context.WithCancel(context.Background())
// After the path is sent, the ACME server will access our server.
// Repeatedly check the server for an updated status on our request.
operation := func() error {
authz, err := core.Authorizations.Get(chlng.AuthorizationURL)
if err != nil {
return backoff.Permanent(err)
cancel()
return err
}
valid, err := checkAuthorizationStatus(authz)
if err != nil {
return backoff.Permanent(err)
cancel()
return err
}
if valid {
@ -127,7 +132,7 @@ func validate(core *api.Core, domain string, chlg acme.Challenge) error {
return errors.New("the server didn't respond to our request")
}
return backoff.Retry(operation, bo)
return backoff.Retry(operation, backoff.WithContext(bo, ctx))
}
func checkChallengeStatus(chlng acme.ExtendedChallenge) (bool, error) {

View File

@ -16,7 +16,7 @@ import (
)
// idPeAcmeIdentifierV1 is the SMI Security for PKIX Certification Extension OID referencing the ACME extension.
// Reference: https://www.rfc-editor.org/rfc/rfc8737.html#section-6.1
// Reference: https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-07#section-6.1
var idPeAcmeIdentifierV1 = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
type ValidateFunc func(core *api.Core, domain string, chlng acme.Challenge) error
@ -83,7 +83,7 @@ func ChallengeBlocks(domain, keyAuth string) ([]byte, []byte, error) {
// Add the keyAuth digest as the acmeValidation-v1 extension
// (marked as critical such that it won't be used by non-ACME software).
// Reference: https://www.rfc-editor.org/rfc/rfc8737.html#section-3
// Reference: https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-07#section-3
extensions := []pkix.Extension{
{
Id: idPeAcmeIdentifierV1,

View File

@ -40,7 +40,7 @@ func (s *ProviderServer) GetAddress() string {
return net.JoinHostPort(s.iface, s.port)
}
// Present generates a certificate with an SHA-256 digest of the keyAuth provided
// Present generates a certificate with a SHA-256 digest of the keyAuth provided
// as the acmeValidation-v1 extension value to conform to the ACME-TLS-ALPN spec.
func (s *ProviderServer) Present(domain, token, keyAuth string) error {
if s.port == "" {
@ -61,7 +61,7 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error {
// We must set that the `acme-tls/1` application level protocol is supported
// so that the protocol negotiation can succeed. Reference:
// https://www.rfc-editor.org/rfc/rfc8737.html#section-6.2
// https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-07#section-6.2
tlsConf.NextProtos = []string{ACMETLS1Protocol}
// Create the listener with the created tls.Config.

View File

@ -4,11 +4,10 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/go-acme/lego/v4/certcrypto"
@ -18,18 +17,13 @@ import (
const (
// caCertificatesEnvVar is the environment variable name that can be used to
// specify the path to PEM encoded CA Certificates that can be used to
// authenticate an ACME server with an HTTPS certificate not issued by a CA in
// authenticate an ACME server with a HTTPS certificate not issued by a CA in
// the system-wide trusted root list.
// Multiple file paths can be added by using os.PathListSeparator as a separator.
caCertificatesEnvVar = "LEGO_CA_CERTIFICATES"
// caSystemCertPool is the environment variable name that can be used to define
// if the certificates pool must use a copy of the system cert pool.
caSystemCertPool = "LEGO_CA_SYSTEM_CERT_POOL"
// caServerNameEnvVar is the environment variable name that can be used to
// specify the CA server name that can be used to
// authenticate an ACME server with an HTTPS certificate not issued by a CA in
// authenticate an ACME server with a HTTPS certificate not issued by a CA in
// the system-wide trusted root list.
caServerNameEnvVar = "LEGO_CA_SERVER_NAME"
@ -88,44 +82,23 @@ func createDefaultHTTPClient() *http.Client {
}
// initCertPool creates a *x509.CertPool populated with the PEM certificates
// found in the filepath specified in the caCertificatesEnvVar OS environment variable.
// If the caCertificatesEnvVar is not set then initCertPool will return nil.
// If there is an error creating a *x509.CertPool from the provided caCertificatesEnvVar value then initCertPool will panic.
// If the caSystemCertPool is set to a "truthy value" (`1`, `t`, `T`, `TRUE`, `true`, `True`) then a copy of system cert pool will be used.
// caSystemCertPool requires caCertificatesEnvVar to be set.
// found in the filepath specified in the caCertificatesEnvVar OS environment
// variable. If the caCertificatesEnvVar is not set then initCertPool will
// return nil. If there is an error creating a *x509.CertPool from the provided
// caCertificatesEnvVar value then initCertPool will panic.
func initCertPool() *x509.CertPool {
customCACertsPath := os.Getenv(caCertificatesEnvVar)
if customCACertsPath == "" {
return nil
}
certPool := getCertPool()
for _, customPath := range strings.Split(customCACertsPath, string(os.PathListSeparator)) {
customCAs, err := os.ReadFile(customPath)
if customCACertsPath := os.Getenv(caCertificatesEnvVar); customCACertsPath != "" {
customCAs, err := ioutil.ReadFile(customCACertsPath)
if err != nil {
panic(fmt.Sprintf("error reading %s=%q: %v",
caCertificatesEnvVar, customPath, err))
caCertificatesEnvVar, customCACertsPath, err))
}
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(customCAs); !ok {
panic(fmt.Sprintf("error creating x509 cert pool from %s=%q: %v",
caCertificatesEnvVar, customPath, err))
caCertificatesEnvVar, customCACertsPath, err))
}
}
return certPool
}
func getCertPool() *x509.CertPool {
useSystemCertPool, _ := strconv.ParseBool(os.Getenv(caSystemCertPool))
if !useSystemCertPool {
return x509.NewCertPool()
}
pool, err := x509.SystemCertPool()
if err == nil {
return pool
}
return x509.NewCertPool()
return nil
}

View File

@ -3,6 +3,7 @@ package env
import (
"errors"
"fmt"
"io/ioutil"
"os"
"strconv"
"strings"
@ -54,6 +55,7 @@ func Get(names ...string) (map[string]string, error) {
// // LEGO_TWO=""
// env.GetWithFallback([]string{"LEGO_ONE", "LEGO_TWO"})
// // => error
//
func GetWithFallback(groups ...[]string) (map[string]string, error) {
values := map[string]string{}
@ -78,26 +80,15 @@ func GetWithFallback(groups ...[]string) (map[string]string, error) {
return values, nil
}
func GetOneWithFallback[T any](main string, defaultValue T, fn func(string) (T, error), names ...string) T {
v, _ := getOneWithFallback(main, names...)
value, err := fn(v)
if err != nil {
return defaultValue
}
return value
}
func getOneWithFallback(main string, names ...string) (string, string) {
value := GetOrFile(main)
if value != "" {
if len(value) > 0 {
return value, main
}
for _, name := range names {
value := GetOrFile(name)
if value != "" {
if len(value) > 0 {
return value, main
}
}
@ -105,32 +96,43 @@ func getOneWithFallback(main string, names ...string) (string, string) {
return "", main
}
// GetOrDefaultInt returns the given environment variable value as an integer.
// Returns the default if the envvar cannot be coopered to an int, or is not found.
func GetOrDefaultInt(envVar string, defaultValue int) int {
v, err := strconv.Atoi(GetOrFile(envVar))
if err != nil {
return defaultValue
}
return v
}
// GetOrDefaultSecond returns the given environment variable value as an time.Duration (second).
// Returns the default if the envvar cannot be coopered to an int, or is not found.
func GetOrDefaultSecond(envVar string, defaultValue time.Duration) time.Duration {
v := GetOrDefaultInt(envVar, -1)
if v < 0 {
return defaultValue
}
return time.Duration(v) * time.Second
}
// GetOrDefaultString returns the given environment variable value as a string.
// Returns the default if the env var cannot be found.
func GetOrDefaultString(envVar string, defaultValue string) string {
return getOrDefault(envVar, defaultValue, ParseString)
// Returns the default if the envvar cannot be find.
func GetOrDefaultString(envVar, defaultValue string) string {
v := GetOrFile(envVar)
if v == "" {
return defaultValue
}
return v
}
// GetOrDefaultBool returns the given environment variable value as a boolean.
// Returns the default if the envvar cannot be coopered to a boolean, or is not found.
func GetOrDefaultBool(envVar string, defaultValue bool) bool {
return getOrDefault(envVar, defaultValue, strconv.ParseBool)
}
// GetOrDefaultInt returns the given environment variable value as an integer.
// Returns the default if the env var cannot be coopered to an int, or is not found.
func GetOrDefaultInt(envVar string, defaultValue int) int {
return getOrDefault(envVar, defaultValue, strconv.Atoi)
}
// GetOrDefaultSecond returns the given environment variable value as a time.Duration (second).
// Returns the default if the env var cannot be coopered to an int, or is not found.
func GetOrDefaultSecond(envVar string, defaultValue time.Duration) time.Duration {
return getOrDefault(envVar, defaultValue, ParseSecond)
}
func getOrDefault[T any](envVar string, defaultValue T, fn func(string) (T, error)) T {
v, err := fn(GetOrFile(envVar))
v, err := strconv.ParseBool(GetOrFile(envVar))
if err != nil {
return defaultValue
}
@ -153,7 +155,7 @@ func GetOrFile(envVar string) string {
return envVarValue
}
fileContents, err := os.ReadFile(fileVarValue)
fileContents, err := ioutil.ReadFile(fileVarValue)
if err != nil {
log.Printf("Failed to read the file %s (defined by env var %s): %s", fileVarValue, fileVar, err)
return ""
@ -161,26 +163,3 @@ func GetOrFile(envVar string) string {
return strings.TrimSuffix(string(fileContents), "\n")
}
// ParseSecond parses env var value (string) to a second (time.Duration).
func ParseSecond(s string) (time.Duration, error) {
v, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
if v < 0 {
return 0, fmt.Errorf("unsupported value: %d", v)
}
return time.Duration(v) * time.Second, nil
}
// ParseString parses env var value (string) to a string but throws an error when the string is empty.
func ParseString(s string) (string, error) {
if s == "" {
return "", errors.New("empty string")
}
return s, nil
}

View File

@ -1,6 +1,7 @@
package wait
import (
"errors"
"fmt"
"time"
@ -17,9 +18,9 @@ func For(msg string, timeout, interval time.Duration, f func() (bool, error)) er
select {
case <-timeUp:
if lastErr == nil {
return fmt.Errorf("%s: time limit exceeded", msg)
return errors.New("time limit exceeded")
}
return fmt.Errorf("%s: time limit exceeded: last error: %w", msg, lastErr)
return fmt.Errorf("time limit exceeded: last error: %w", lastErr)
default:
}

View File

@ -1,133 +0,0 @@
package errutils
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"strconv"
)
const legoDebugClientVerboseError = "LEGO_DEBUG_CLIENT_VERBOSE_ERROR"
// HTTPDoError uses with `(http.Client).Do` error.
type HTTPDoError struct {
req *http.Request
err error
}
// NewHTTPDoError creates a new HTTPDoError.
func NewHTTPDoError(req *http.Request, err error) *HTTPDoError {
return &HTTPDoError{req: req, err: err}
}
func (h HTTPDoError) Error() string {
msg := "unable to communicate with the API server:"
if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok {
msg += fmt.Sprintf(" [request: %s %s]", h.req.Method, h.req.URL)
}
if h.err == nil {
return msg
}
return msg + fmt.Sprintf(" error: %v", h.err)
}
func (h HTTPDoError) Unwrap() error {
return h.err
}
// ReadResponseError use with `io.ReadAll` when reading response body.
type ReadResponseError struct {
req *http.Request
StatusCode int
err error
}
// NewReadResponseError creates a new ReadResponseError.
func NewReadResponseError(req *http.Request, statusCode int, err error) *ReadResponseError {
return &ReadResponseError{req: req, StatusCode: statusCode, err: err}
}
func (r ReadResponseError) Error() string {
msg := "unable to read response body:"
if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok {
msg += fmt.Sprintf(" [request: %s %s]", r.req.Method, r.req.URL)
}
msg += fmt.Sprintf(" [status code: %d]", r.StatusCode)
if r.err == nil {
return msg
}
return msg + fmt.Sprintf(" error: %v", r.err)
}
func (r ReadResponseError) Unwrap() error {
return r.err
}
// UnmarshalError uses with `json.Unmarshal` or `xml.Unmarshal` when reading response body.
type UnmarshalError struct {
req *http.Request
StatusCode int
Body []byte
err error
}
// NewUnmarshalError creates a new UnmarshalError.
func NewUnmarshalError(req *http.Request, statusCode int, body []byte, err error) *UnmarshalError {
return &UnmarshalError{req: req, StatusCode: statusCode, Body: bytes.TrimSpace(body), err: err}
}
func (u UnmarshalError) Error() string {
msg := "unable to unmarshal response:"
if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok {
msg += fmt.Sprintf(" [request: %s %s]", u.req.Method, u.req.URL)
}
msg += fmt.Sprintf(" [status code: %d] body: %s", u.StatusCode, string(u.Body))
if u.err == nil {
return msg
}
return msg + fmt.Sprintf(" error: %v", u.err)
}
func (u UnmarshalError) Unwrap() error {
return u.err
}
// UnexpectedStatusCodeError use when the status of the response is unexpected but there is no API error type.
type UnexpectedStatusCodeError struct {
req *http.Request
StatusCode int
Body []byte
}
// NewUnexpectedStatusCodeError creates a new UnexpectedStatusCodeError.
func NewUnexpectedStatusCodeError(req *http.Request, statusCode int, body []byte) *UnexpectedStatusCodeError {
return &UnexpectedStatusCodeError{req: req, StatusCode: statusCode, Body: bytes.TrimSpace(body)}
}
func NewUnexpectedResponseStatusCodeError(req *http.Request, resp *http.Response) *UnexpectedStatusCodeError {
raw, _ := io.ReadAll(resp.Body)
return &UnexpectedStatusCodeError{req: req, StatusCode: resp.StatusCode, Body: bytes.TrimSpace(raw)}
}
func (u UnexpectedStatusCodeError) Error() string {
msg := "unexpected status code:"
if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok {
msg += fmt.Sprintf(" [request: %s %s]", u.req.Method, u.req.URL)
}
return msg + fmt.Sprintf(" [status code: %d] body: %s", u.StatusCode, string(u.Body))
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
@ -15,13 +16,15 @@ import (
// OVH API reference: https://eu.api.ovh.com/
// Create a Token: https://eu.api.ovh.com/createToken/
// Create a OAuth2 client: https://eu.api.ovh.com/console-preview/?section=%2Fme&branch=v1#post-/me/api/oauth2/client
// Environment variables names.
const (
envNamespace = "OVH_"
EnvEndpoint = envNamespace + "ENDPOINT"
EnvApplicationKey = envNamespace + "APPLICATION_KEY"
EnvApplicationSecret = envNamespace + "APPLICATION_SECRET"
EnvConsumerKey = envNamespace + "CONSUMER_KEY"
EnvTTL = envNamespace + "TTL"
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
@ -29,19 +32,6 @@ const (
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
)
// Authenticate using application key.
const (
EnvApplicationKey = envNamespace + "APPLICATION_KEY"
EnvApplicationSecret = envNamespace + "APPLICATION_SECRET"
EnvConsumerKey = envNamespace + "CONSUMER_KEY"
)
// Authenticate using OAuth2 client.
const (
EnvClientID = envNamespace + "CLIENT_ID"
EnvClientSecret = envNamespace + "CLIENT_SECRET"
)
// Record a DNS record.
type Record struct {
ID int64 `json:"id,omitempty"`
@ -52,32 +42,18 @@ type Record struct {
Zone string `json:"zone,omitempty"`
}
// OAuth2Config the OAuth2 specific configuration.
type OAuth2Config struct {
ClientID string
ClientSecret string
}
// Config is used to configure the creation of the DNSProvider.
type Config struct {
APIEndpoint string
ApplicationKey string
ApplicationSecret string
ConsumerKey string
OAuth2Config *OAuth2Config
PropagationTimeout time.Duration
PollingInterval time.Duration
TTL int
HTTPClient *http.Client
}
func (c *Config) hasAppKeyAuth() bool {
return c.ApplicationKey != "" || c.ApplicationSecret != "" || c.ConsumerKey != ""
}
// NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config {
return &Config{
@ -102,11 +78,17 @@ type DNSProvider struct {
// Credentials must be passed in the environment variables:
// OVH_ENDPOINT (must be either "ovh-eu" or "ovh-ca"), OVH_APPLICATION_KEY, OVH_APPLICATION_SECRET, OVH_CONSUMER_KEY.
func NewDNSProvider() (*DNSProvider, error) {
config, err := createConfigFromEnvVars()
values, err := env.Get(EnvEndpoint, EnvApplicationKey, EnvApplicationSecret, EnvConsumerKey)
if err != nil {
return nil, fmt.Errorf("ovh: %w", err)
}
config := NewDefaultConfig()
config.APIEndpoint = values[EnvEndpoint]
config.ApplicationKey = values[EnvApplicationKey]
config.ApplicationSecret = values[EnvApplicationSecret]
config.ConsumerKey = values[EnvConsumerKey]
return NewDNSProviderConfig(config)
}
@ -116,11 +98,16 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("ovh: the configuration of the DNS provider is nil")
}
if config.OAuth2Config != nil && config.hasAppKeyAuth() {
return nil, errors.New("ovh: can't use both authentication systems (ApplicationKey and OAuth2)")
if config.APIEndpoint == "" || config.ApplicationKey == "" || config.ApplicationSecret == "" || config.ConsumerKey == "" {
return nil, errors.New("ovh: credentials missing")
}
client, err := newClient(config)
client, err := ovh.NewClient(
config.APIEndpoint,
config.ApplicationKey,
config.ApplicationSecret,
config.ConsumerKey,
)
if err != nil {
return nil, fmt.Errorf("ovh: %w", err)
}
@ -136,23 +123,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
// Present creates a TXT record to fulfill the dns-01 challenge.
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
fqdn, value := dns01.GetRecord(domain, keyAuth)
// Parse domain name
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(domain))
if err != nil {
return fmt.Errorf("ovh: could not find zone for domain %q: %w", domain, err)
return fmt.Errorf("ovh: could not determine zone for domain %q: %w", domain, err)
}
authZone = dns01.UnFqdn(authZone)
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
if err != nil {
return fmt.Errorf("ovh: %w", err)
}
subDomain := extractRecordName(fqdn, authZone)
reqURL := fmt.Sprintf("/domain/zone/%s/record", authZone)
reqData := Record{FieldType: "TXT", SubDomain: subDomain, Target: info.Value, TTL: d.config.TTL}
reqData := Record{FieldType: "TXT", SubDomain: subDomain, Target: value, TTL: d.config.TTL}
// Create TXT record
var respData Record
@ -177,19 +160,19 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
// CleanUp removes the TXT record matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
fqdn, _ := dns01.GetRecord(domain, keyAuth)
// get the record's unique ID from when we created it
d.recordIDsMu.Lock()
recordID, ok := d.recordIDs[token]
d.recordIDsMu.Unlock()
if !ok {
return fmt.Errorf("ovh: unknown record ID for '%s'", info.EffectiveFQDN)
return fmt.Errorf("ovh: unknown record ID for '%s'", fqdn)
}
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(domain))
if err != nil {
return fmt.Errorf("ovh: could not find zone for domain %q: %w", domain, err)
return fmt.Errorf("ovh: could not determine zone for domain %q: %w", domain, err)
}
authZone = dns01.UnFqdn(authZone)
@ -222,94 +205,10 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
}
func createConfigFromEnvVars() (*Config, error) {
firstAppKeyEnvVar := findFirstValuedEnvVar(EnvApplicationKey, EnvApplicationSecret, EnvConsumerKey)
firstOAuth2EnvVar := findFirstValuedEnvVar(EnvClientID, EnvClientSecret)
if firstAppKeyEnvVar != "" && firstOAuth2EnvVar != "" {
return nil, fmt.Errorf("can't use both %s and %s at the same time", firstAppKeyEnvVar, firstOAuth2EnvVar)
func extractRecordName(fqdn, zone string) string {
name := dns01.UnFqdn(fqdn)
if idx := strings.Index(name, "."+zone); idx != -1 {
return name[:idx]
}
config := NewDefaultConfig()
if firstOAuth2EnvVar != "" {
values, err := env.Get(EnvEndpoint, EnvClientID, EnvClientSecret)
if err != nil {
return nil, err
}
config.APIEndpoint = values[EnvEndpoint]
config.OAuth2Config = &OAuth2Config{
ClientID: values[EnvClientID],
ClientSecret: values[EnvClientSecret],
}
return config, nil
}
values, err := env.Get(EnvEndpoint, EnvApplicationKey, EnvApplicationSecret, EnvConsumerKey)
if err != nil {
return nil, err
}
config.APIEndpoint = values[EnvEndpoint]
config.ApplicationKey = values[EnvApplicationKey]
config.ApplicationSecret = values[EnvApplicationSecret]
config.ConsumerKey = values[EnvConsumerKey]
return config, nil
}
func findFirstValuedEnvVar(envVars ...string) string {
for _, envVar := range envVars {
if env.GetOrFile(envVar) != "" {
return envVar
}
}
return ""
}
func newClient(config *Config) (*ovh.Client, error) {
if config.OAuth2Config == nil {
return newClientApplicationKey(config)
}
return newClientOAuth2(config)
}
func newClientApplicationKey(config *Config) (*ovh.Client, error) {
if config.APIEndpoint == "" || config.ApplicationKey == "" || config.ApplicationSecret == "" || config.ConsumerKey == "" {
return nil, errors.New("credentials are missing")
}
client, err := ovh.NewClient(
config.APIEndpoint,
config.ApplicationKey,
config.ApplicationSecret,
config.ConsumerKey,
)
if err != nil {
return nil, fmt.Errorf("new client: %w", err)
}
return client, nil
}
func newClientOAuth2(config *Config) (*ovh.Client, error) {
if config.APIEndpoint == "" || config.OAuth2Config.ClientID == "" || config.OAuth2Config.ClientSecret == "" {
return nil, errors.New("credentials are missing")
}
client, err := ovh.NewOAuth2Client(
config.APIEndpoint,
config.OAuth2Config.ClientID,
config.OAuth2Config.ClientSecret,
)
if err != nil {
return nil, fmt.Errorf("new OAuth2 client: %w", err)
}
return client, nil
return name
}

View File

@ -5,20 +5,11 @@ Code = "ovh"
Since = "v0.4.0"
Example = '''
# Application Key authentication:
OVH_APPLICATION_KEY=1234567898765432 \
OVH_APPLICATION_SECRET=b9841238feb177a84330febba8a832089 \
OVH_CONSUMER_KEY=256vfsd347245sdfg \
OVH_ENDPOINT=ovh-eu \
lego --email you@example.com --dns ovh --domains my.example.org run
# Or OAuth2:
OVH_CLIENT_ID=yyy \
OVH_CLIENT_SECRET=xxx \
OVH_ENDPOINT=ovh-eu \
lego --email you@example.com --dns ovh --domains my.example.org run
lego --email myemail@example.com --dns ovh --domains my.example.org run
'''
Additional = '''
@ -26,7 +17,7 @@ Additional = '''
Application key and secret can be created by following the [OVH guide](https://docs.ovh.com/gb/en/customer/first-steps-with-ovh-api/).
When requesting the consumer key, the following configuration can be used to define access rights:
When requesting the consumer key, the following configuration can be use to define access rights:
```json
{
@ -42,32 +33,14 @@ When requesting the consumer key, the following configuration can be used to def
]
}
```
## OAuth2 Client Credentials
Another method for authentication is by using OAuth2 client credentials.
An IAM policy and service account can be created by following the [OVH guide](https://help.ovhcloud.com/csm/en-manage-service-account?id=kb_article_view&sysparm_article=KB0059343).
Following IAM policies need to be authorized for the affected domain:
* dnsZone:apiovh:record/create
* dnsZone:apiovh:record/delete
* dnsZone:apiovh:refresh
## Important Note
Both authentication methods cannot be used at the same time.
'''
[Configuration]
[Configuration.Credentials]
OVH_ENDPOINT = "Endpoint URL (ovh-eu or ovh-ca)"
OVH_APPLICATION_KEY = "Application key (Application Key authentication)"
OVH_APPLICATION_SECRET = "Application secret (Application Key authentication)"
OVH_CONSUMER_KEY = "Consumer key (Application Key authentication)"
OVH_CLIENT_ID = "Client ID (OAuth2)"
OVH_CLIENT_SECRET = "Client secret (OAuth2)"
OVH_APPLICATION_KEY = "Application key"
OVH_APPLICATION_SECRET = "Application secret"
OVH_CONSUMER_KEY = "Consumer key"
[Configuration.Additional]
OVH_POLLING_INTERVAL = "Time between DNS propagation check"
OVH_PROPAGATION_TIMEOUT = "Maximum waiting time for DNS propagation"

View File

@ -1,228 +0,0 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"github.com/miekg/dns"
)
// Client the PowerDNS API client.
type Client struct {
serverName string
apiKey string
apiVersion int
Host *url.URL
HTTPClient *http.Client
}
// NewClient creates a new Client.
func NewClient(host *url.URL, serverName string, apiVersion int, apiKey string) *Client {
return &Client{
serverName: serverName,
apiKey: apiKey,
apiVersion: apiVersion,
Host: host,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}
}
func (c *Client) APIVersion() int {
return c.apiVersion
}
func (c *Client) SetAPIVersion(ctx context.Context) error {
var err error
c.apiVersion, err = c.getAPIVersion(ctx)
return err
}
func (c *Client) getAPIVersion(ctx context.Context) (int, error) {
endpoint := c.joinPath("/", "api")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return 0, err
}
result, err := c.do(req)
if err != nil {
return 0, err
}
var versions []apiVersion
err = json.Unmarshal(result, &versions)
if err != nil {
return 0, err
}
latestVersion := 0
for _, v := range versions {
if v.Version > latestVersion {
latestVersion = v.Version
}
}
return latestVersion, err
}
func (c *Client) GetHostedZone(ctx context.Context, authZone string) (*HostedZone, error) {
endpoint := c.joinPath("/", "servers", c.serverName, "zones", dns.Fqdn(authZone))
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
result, err := c.do(req)
if err != nil {
return nil, err
}
var zone HostedZone
err = json.Unmarshal(result, &zone)
if err != nil {
return nil, err
}
// convert pre-v1 API result
if len(zone.Records) > 0 {
zone.RRSets = []RRSet{}
for _, record := range zone.Records {
set := RRSet{
Name: record.Name,
Type: record.Type,
Records: []Record{record},
}
zone.RRSets = append(zone.RRSets, set)
}
}
return &zone, nil
}
func (c *Client) UpdateRecords(ctx context.Context, zone *HostedZone, sets RRSets) error {
endpoint := c.joinPath("/", "servers", c.serverName, "zones", zone.ID)
req, err := newJSONRequest(ctx, http.MethodPatch, endpoint, sets)
if err != nil {
return err
}
_, err = c.do(req)
if err != nil {
return err
}
return nil
}
func (c *Client) Notify(ctx context.Context, zone *HostedZone) error {
if c.apiVersion < 1 || zone.Kind != "Master" && zone.Kind != "Slave" {
return nil
}
endpoint := c.joinPath("/", "servers", c.serverName, "zones", zone.ID, "notify")
req, err := newJSONRequest(ctx, http.MethodPut, endpoint, nil)
if err != nil {
return err
}
_, err = c.do(req)
if err != nil {
return err
}
return nil
}
func (c *Client) joinPath(elem ...string) *url.URL {
p := path.Join(elem...)
if p != "/api" && c.apiVersion > 0 && !strings.HasPrefix(p, "/api/v") {
p = path.Join("/api", "v"+strconv.Itoa(c.apiVersion), p)
}
return c.Host.JoinPath(p)
}
func (c *Client) do(req *http.Request) (json.RawMessage, error) {
req.Header.Set("X-API-Key", c.apiKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusUnprocessableEntity && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
var msg json.RawMessage
err = json.NewDecoder(resp.Body).Decode(&msg)
if err != nil {
if errors.Is(err, io.EOF) {
// empty body
return nil, nil
}
// other error
return nil, err
}
// check for PowerDNS error message
if len(msg) > 0 && msg[0] == '{' {
var errInfo apiError
err = json.Unmarshal(msg, &errInfo)
if err != nil {
return nil, errutils.NewUnmarshalError(req, resp.StatusCode, msg, err)
}
if errInfo.ShortMsg != "" {
return nil, fmt.Errorf("error talking to PDNS API: %w", errInfo)
}
}
return msg, nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, strings.TrimSuffix(endpoint.String(), "/"), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
// PowerDNS doesn't follow HTTP convention about the "Content-Type" header.
if method != http.MethodGet && method != http.MethodDelete {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}

View File

@ -1,48 +0,0 @@
package internal
type Record struct {
Content string `json:"content"`
Disabled bool `json:"disabled"`
// pre-v1 API
Name string `json:"name"`
Type string `json:"type"`
TTL int `json:"ttl,omitempty"`
}
type HostedZone struct {
ID string `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
Kind string `json:"kind"`
RRSets []RRSet `json:"rrsets"`
// pre-v1 API
Records []Record `json:"records"`
}
type RRSet struct {
Name string `json:"name"`
Type string `json:"type"`
Kind string `json:"kind"`
ChangeType string `json:"changetype"`
Records []Record `json:"records,omitempty"`
TTL int `json:"ttl,omitempty"`
}
type RRSets struct {
RRSets []RRSet `json:"rrsets"`
}
type apiError struct {
ShortMsg string `json:"error"`
}
func (a apiError) Error() string {
return a.ShortMsg
}
type apiVersion struct {
URL string `json:"url"`
Version int `json:"version"`
}

View File

@ -1,228 +0,0 @@
// Package pdns implements a DNS provider for solving the DNS-01 challenge using PowerDNS nameserver.
package pdns
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/log"
"github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/pdns/internal"
)
// Environment variables names.
const (
envNamespace = "PDNS_"
EnvAPIKey = envNamespace + "API_KEY"
EnvAPIURL = envNamespace + "API_URL"
EnvTTL = envNamespace + "TTL"
EnvAPIVersion = envNamespace + "API_VERSION"
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
EnvPollingInterval = envNamespace + "POLLING_INTERVAL"
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
EnvServerName = envNamespace + "SERVER_NAME"
)
// Config is used to configure the creation of the DNSProvider.
type Config struct {
APIKey string
Host *url.URL
ServerName string
APIVersion int
PropagationTimeout time.Duration
PollingInterval time.Duration
TTL int
HTTPClient *http.Client
}
// NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config {
return &Config{
ServerName: env.GetOrDefaultString(EnvServerName, "localhost"),
APIVersion: env.GetOrDefaultInt(EnvAPIVersion, 0),
TTL: env.GetOrDefaultInt(EnvTTL, dns01.DefaultTTL),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 120*time.Second),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 2*time.Second),
HTTPClient: &http.Client{
Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30*time.Second),
},
}
}
// DNSProvider implements the challenge.Provider interface.
type DNSProvider struct {
config *Config
client *internal.Client
}
// NewDNSProvider returns a DNSProvider instance configured for pdns.
// Credentials must be passed in the environment variable:
// PDNS_API_URL and PDNS_API_KEY.
func NewDNSProvider() (*DNSProvider, error) {
values, err := env.Get(EnvAPIKey, EnvAPIURL)
if err != nil {
return nil, fmt.Errorf("pdns: %w", err)
}
hostURL, err := url.Parse(values[EnvAPIURL])
if err != nil {
return nil, fmt.Errorf("pdns: %w", err)
}
config := NewDefaultConfig()
config.Host = hostURL
config.APIKey = values[EnvAPIKey]
return NewDNSProviderConfig(config)
}
// NewDNSProviderConfig return a DNSProvider instance configured for pdns.
func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
if config == nil {
return nil, errors.New("pdns: the configuration of the DNS provider is nil")
}
if config.APIKey == "" {
return nil, errors.New("pdns: API key missing")
}
if config.Host == nil || config.Host.Host == "" {
return nil, errors.New("pdns: API URL missing")
}
client := internal.NewClient(config.Host, config.ServerName, config.APIVersion, config.APIKey)
if config.APIVersion <= 0 {
err := client.SetAPIVersion(context.Background())
if err != nil {
log.Warnf("pdns: failed to get API version %v", err)
}
}
return &DNSProvider{config: config, client: client}, nil
}
// Timeout returns the timeout and interval to use when checking for DNS propagation.
// Adjusting here to cope with spikes in propagation times.
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
}
// Present creates a TXT record to fulfill the dns-01 challenge.
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("pdns: could not find zone for domain %q: %w", domain, err)
}
ctx := context.Background()
zone, err := d.client.GetHostedZone(ctx, authZone)
if err != nil {
return fmt.Errorf("pdns: %w", err)
}
name := info.EffectiveFQDN
if d.client.APIVersion() == 0 {
// pre-v1 API wants non-fqdn
name = dns01.UnFqdn(info.EffectiveFQDN)
}
// Look for existing records.
existingRRSet := findTxtRecord(zone, info.EffectiveFQDN)
// merge the existing and new records
var records []internal.Record
if existingRRSet != nil {
records = existingRRSet.Records
}
rec := internal.Record{
Content: "\"" + info.Value + "\"",
Disabled: false,
// pre-v1 API
Type: "TXT",
Name: name,
TTL: d.config.TTL,
}
rrSets := internal.RRSets{
RRSets: []internal.RRSet{
{
Name: name,
ChangeType: "REPLACE",
Type: "TXT",
Kind: "Master",
TTL: d.config.TTL,
Records: append(records, rec),
},
},
}
err = d.client.UpdateRecords(ctx, zone, rrSets)
if err != nil {
return fmt.Errorf("pdns: %w", err)
}
return d.client.Notify(ctx, zone)
}
// CleanUp removes the TXT record matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("pdns: could not find zone for domain %q: %w", domain, err)
}
ctx := context.Background()
zone, err := d.client.GetHostedZone(ctx, authZone)
if err != nil {
return fmt.Errorf("pdns: %w", err)
}
set := findTxtRecord(zone, info.EffectiveFQDN)
if set == nil {
return fmt.Errorf("pdns: no existing record found for %s", info.EffectiveFQDN)
}
rrSets := internal.RRSets{
RRSets: []internal.RRSet{
{
Name: set.Name,
Type: set.Type,
ChangeType: "DELETE",
},
},
}
err = d.client.UpdateRecords(ctx, zone, rrSets)
if err != nil {
return fmt.Errorf("pdns: %w", err)
}
return d.client.Notify(ctx, zone)
}
func findTxtRecord(zone *internal.HostedZone, fqdn string) *internal.RRSet {
for _, set := range zone.RRSets {
if set.Type == "TXT" && (set.Name == dns01.UnFqdn(fqdn) || set.Name == fqdn) {
return &set
}
}
return nil
}

View File

@ -1,37 +0,0 @@
Name = "PowerDNS"
Description = ''''''
URL = "https://www.powerdns.com/"
Code = "pdns"
Since = "v0.4.0"
Example = '''
PDNS_API_URL=http://pdns-server:80/ \
PDNS_API_KEY=xxxx \
lego --email you@example.com --dns pdns --domains my.example.org run
'''
Additional = '''
## Information
Tested and confirmed to work with PowerDNS authoritative server 3.4.8 and 4.0.1. Refer to [PowerDNS documentation](https://doc.powerdns.com/md/httpapi/README/) instructions on how to enable the built-in API interface.
PowerDNS Notes:
- PowerDNS API does not currently support SSL, therefore you should take care to ensure that traffic between lego and the PowerDNS API is over a trusted network, VPN etc.
- In order to have the SOA serial automatically increment each time the `_acme-challenge` record is added/modified via the API, set `SOA-EDIT-API` to `INCEPTION-INCREMENT` for the zone in the `domainmetadata` table
- Some PowerDNS servers doesn't have root API endpoints enabled and API version autodetection will not work. In that case version number can be defined using `PDNS_API_VERSION`.
'''
[Configuration]
[Configuration.Credentials]
PDNS_API_KEY = "API key"
PDNS_API_URL = "API URL"
[Configuration.Additional]
PDNS_SERVER_NAME = "Name of the server in the URL, 'localhost' by default"
PDNS_API_VERSION = "Skip API version autodetection and use the provided version number."
PDNS_POLLING_INTERVAL = "Time between DNS propagation check"
PDNS_PROPAGATION_TIMEOUT = "Maximum waiting time for DNS propagation"
PDNS_TTL = "The TTL of the TXT record used for the DNS challenge"
PDNS_HTTP_TIMEOUT = "API request timeout"
[Links]
API = "https://doc.powerdns.com/md/httpapi/README/"

View File

@ -9,11 +9,9 @@ import (
"github.com/go-acme/lego/v4/log"
)
const mailTo = "mailto:"
// Resource represents all important information about a registration
// of which the client needs to keep track itself.
// WARNING: will be removed in the future (acme.ExtendedAccount), https://github.com/go-acme/lego/issues/855.
// WARNING: will be remove in the future (acme.ExtendedAccount), https://github.com/go-acme/lego/issues/855.
type Resource struct {
Body acme.Account `json:"body,omitempty"`
URI string `json:"uri,omitempty"`
@ -54,7 +52,7 @@ func (r *Registrar) Register(options RegisterOptions) (*Resource, error) {
if r.user.GetEmail() != "" {
log.Infof("acme: Registering account for %s", r.user.GetEmail())
accMsg.Contact = []string{mailTo + r.user.GetEmail()}
accMsg.Contact = []string{"mailto:" + r.user.GetEmail()}
}
account, err := r.core.Accounts.New(accMsg)
@ -78,7 +76,7 @@ func (r *Registrar) RegisterWithExternalAccountBinding(options RegisterEABOption
if r.user.GetEmail() != "" {
log.Infof("acme: Registering account for %s", r.user.GetEmail())
accMsg.Contact = []string{mailTo + r.user.GetEmail()}
accMsg.Contact = []string{"mailto:" + r.user.GetEmail()}
}
account, err := r.core.Accounts.NewEAB(accMsg, options.Kid, options.HmacEncoded)
@ -130,7 +128,7 @@ func (r *Registrar) UpdateRegistration(options RegisterOptions) (*Resource, erro
if r.user.GetEmail() != "" {
log.Infof("acme: Registering account for %s", r.user.GetEmail())
accMsg.Contact = []string{mailTo + r.user.GetEmail()}
accMsg.Contact = []string{"mailto:" + r.user.GetEmail()}
}
accountURL := r.user.GetRegistration().URI

View File

@ -1,2 +0,0 @@
jose-util/jose-util
jose-util.t.err

View File

@ -1,53 +0,0 @@
# https://github.com/golangci/golangci-lint
run:
skip-files:
- doc_test.go
modules-download-mode: readonly
linters:
enable-all: true
disable:
- gochecknoglobals
- goconst
- lll
- maligned
- nakedret
- scopelint
- unparam
- funlen # added in 1.18 (requires go-jose changes before it can be enabled)
linters-settings:
gocyclo:
min-complexity: 35
issues:
exclude-rules:
- text: "don't use ALL_CAPS in Go names"
linters:
- golint
- text: "hardcoded credentials"
linters:
- gosec
- text: "weak cryptographic primitive"
linters:
- gosec
- path: json/
linters:
- dupl
- errcheck
- gocritic
- gocyclo
- golint
- govet
- ineffassign
- staticcheck
- structcheck
- stylecheck
- unused
- path: _test\.go
linters:
- scopelint
- path: jwk.go
linters:
- gocyclo

View File

@ -1,33 +0,0 @@
language: go
matrix:
fast_finish: true
allow_failures:
- go: tip
go:
- "1.13.x"
- "1.14.x"
- tip
before_script:
- export PATH=$HOME/.local/bin:$PATH
before_install:
- go get -u github.com/mattn/goveralls github.com/wadey/gocovmerge
- curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.18.0
- pip install cram --user
script:
- go test -v -covermode=count -coverprofile=profile.cov .
- go test -v -covermode=count -coverprofile=cryptosigner/profile.cov ./cryptosigner
- go test -v -covermode=count -coverprofile=cipher/profile.cov ./cipher
- go test -v -covermode=count -coverprofile=jwt/profile.cov ./jwt
- go test -v ./json # no coverage for forked encoding/json package
- golangci-lint run
- cd jose-util && go build && PATH=$PWD:$PATH cram -v jose-util.t # cram tests jose-util
- cd ..
after_success:
- gocovmerge *.cov */*.cov > merged.coverprofile
- goveralls -coverprofile merged.coverprofile -service=travis-ci

View File

@ -1,96 +0,0 @@
# v4.0.4
## Fixed
- Reverted "Allow unmarshalling JSONWebKeySets with unsupported key types" as a
breaking change. See #136 / #137.
# v4.0.3
## Changed
- Allow unmarshalling JSONWebKeySets with unsupported key types (#130)
- Document that OpaqueKeyEncrypter can't be implemented (for now) (#129)
- Dependency updates
# v4.0.2
## Changed
- Improved documentation of Verify() to note that JSONWebKeySet is a supported
argument type (#104)
- Defined exported error values for missing x5c header and unsupported elliptic
curves error cases (#117)
# v4.0.1
## Fixed
- An attacker could send a JWE containing compressed data that used large
amounts of memory and CPU when decompressed by `Decrypt` or `DecryptMulti`.
Those functions now return an error if the decompressed data would exceed
250kB or 10x the compressed size (whichever is larger). Thanks to
Enze Wang@Alioth and Jianjun Chen@Zhongguancun Lab (@zer0yu and @chenjj)
for reporting.
# v4.0.0
This release makes some breaking changes in order to more thoroughly
address the vulnerabilities discussed in [Three New Attacks Against JSON Web
Tokens][1], "Sign/encrypt confusion", "Billion hash attack", and "Polyglot
token".
## Changed
- Limit JWT encryption types (exclude password or public key types) (#78)
- Enforce minimum length for HMAC keys (#85)
- jwt: match any audience in a list, rather than requiring all audiences (#81)
- jwt: accept only Compact Serialization (#75)
- jws: Add expected algorithms for signatures (#74)
- Require specifying expected algorithms for ParseEncrypted,
ParseSigned, ParseDetached, jwt.ParseEncrypted, jwt.ParseSigned,
jwt.ParseSignedAndEncrypted (#69, #74)
- Usually there is a small, known set of appropriate algorithms for a program
to use and it's a mistake to allow unexpected algorithms. For instance the
"billion hash attack" relies in part on programs accepting the PBES2
encryption algorithm and doing the necessary work even if they weren't
specifically configured to allow PBES2.
- Revert "Strip padding off base64 strings" (#82)
- The specs require base64url encoding without padding.
- Minimum supported Go version is now 1.21
## Added
- ParseSignedCompact, ParseSignedJSON, ParseEncryptedCompact, ParseEncryptedJSON.
- These allow parsing a specific serialization, as opposed to ParseSigned and
ParseEncrypted, which try to automatically detect which serialization was
provided. It's common to require a specific serialization for a specific
protocol - for instance JWT requires Compact serialization.
[1]: https://i.blackhat.com/BH-US-23/Presentations/US-23-Tervoort-Three-New-Attacks-Against-JSON-Web-Tokens.pdf
# v3.0.2
## Fixed
- DecryptMulti: handle decompression error (#19)
## Changed
- jwe/CompactSerialize: improve performance (#67)
- Increase the default number of PBKDF2 iterations to 600k (#48)
- Return the proper algorithm for ECDSA keys (#45)
## Added
- Add Thumbprint support for opaque signers (#38)
# v3.0.1
## Fixed
- Security issue: an attacker specifying a large "p2c" value can cause
JSONWebEncryption.Decrypt and JSONWebEncryption.DecryptMulti to consume large
amounts of CPU, causing a DoS. Thanks to Matt Schwager (@mschwager) for the
disclosure and to Tom Tervoort for originally publishing the category of attack.
https://i.blackhat.com/BH-US-23/Presentations/US-23-Tervoort-Three-New-Attacks-Against-JSON-Web-Tokens.pdf

View File

@ -1,13 +0,0 @@
# Security Policy
This document explains how to contact the Let's Encrypt security team to report security vulnerabilities.
## Supported Versions
| Version | Supported |
| ------- | ----------|
| >= v3 | &check; |
| v2 | &cross; |
| v1 | &cross; |
## Reporting a vulnerability
Please see [https://letsencrypt.org/contact/#security](https://letsencrypt.org/contact/#security) for the email address to report a vulnerability. Ensure that the subject line for your report contains the word `vulnerability` and is descriptive. Your email should be acknowledged within 24 hours. If you do not receive a response within 24 hours, please follow-up again with another email.

View File

@ -48,17 +48,6 @@ linters:
- nlreturn
- testpackage
- wsl
- varnamelen
- nilnil
- ireturn
- govet
- forcetypeassert
- cyclop
- containedctx
- revive
- nosnakecase
- exhaustruct
- depguard
issues:
exclude-rules:

View File

@ -1,168 +1,3 @@
# v0.10.2 - 2023/03/20
### New features
* Support DebugDOT option for debugging encoder ( #440 )
### Fix bugs
* Fix combination of embedding structure and omitempty option ( #442 )
# v0.10.1 - 2023/03/13
### Fix bugs
* Fix checkptr error for array decoder ( #415 )
* Fix added buffer size check when decoding key ( #430 )
* Fix handling of anonymous fields other than struct ( #431 )
* Fix to not optimize when lower conversion can't handle byte-by-byte ( #432 )
* Fix a problem that MarshalIndent does not work when UnorderedMap is specified ( #435 )
* Fix mapDecoder.DecodeStream() for empty objects containing whitespace ( #425 )
* Fix an issue that could not set the correct NextField for fields in the embedded structure ( #438 )
# v0.10.0 - 2022/11/29
### New features
* Support JSON Path ( #250 )
### Fix bugs
* Fix marshaler for map's key ( #409 )
# v0.9.11 - 2022/08/18
### Fix bugs
* Fix unexpected behavior when buffer ends with backslash ( #383 )
* Fix stream decoding of escaped character ( #387 )
# v0.9.10 - 2022/07/15
### Fix bugs
* Fix boundary exception of type caching ( #382 )
# v0.9.9 - 2022/07/15
### Fix bugs
* Fix encoding of directed interface with typed nil ( #377 )
* Fix embedded primitive type encoding using alias ( #378 )
* Fix slice/array type encoding with types implementing MarshalJSON ( #379 )
* Fix unicode decoding when the expected buffer state is not met after reading ( #380 )
# v0.9.8 - 2022/06/30
### Fix bugs
* Fix decoding of surrogate-pair ( #365 )
* Fix handling of embedded primitive type ( #366 )
* Add validation of escape sequence for decoder ( #367 )
* Fix stream tokenizing respecting UseNumber ( #369 )
* Fix encoding when struct pointer type that implements Marshal JSON is embedded ( #375 )
### Improve performance
* Improve performance of linkRecursiveCode ( #368 )
# v0.9.7 - 2022/04/22
### Fix bugs
#### Encoder
* Add filtering process for encoding on slow path ( #355 )
* Fix encoding of interface{} with pointer type ( #363 )
#### Decoder
* Fix map key decoder that implements UnmarshalJSON ( #353 )
* Fix decoding of []uint8 type ( #361 )
### New features
* Add DebugWith option for encoder ( #356 )
# v0.9.6 - 2022/03/22
### Fix bugs
* Correct the handling of the minimum value of int type for decoder ( #344 )
* Fix bugs of stream decoder's bufferSize ( #349 )
* Add a guard to use typeptr more safely ( #351 )
### Improve decoder performance
* Improve escapeString's performance ( #345 )
### Others
* Update go version for CI ( #347 )
# v0.9.5 - 2022/03/04
### Fix bugs
* Fix panic when decoding time.Time with context ( #328 )
* Fix reading the next character in buffer to nul consideration ( #338 )
* Fix incorrect handling on skipValue ( #341 )
### Improve decoder performance
* Improve performance when a payload contains escape sequence ( #334 )
# v0.9.4 - 2022/01/21
* Fix IsNilForMarshaler for string type with omitempty ( #323 )
* Fix the case where the embedded field is at the end ( #326 )
# v0.9.3 - 2022/01/14
* Fix logic of removing struct field for decoder ( #322 )
# v0.9.2 - 2022/01/14
* Add invalid decoder to delay type error judgment at decode ( #321 )
# v0.9.1 - 2022/01/11
* Fix encoding of MarshalText/MarshalJSON operation with head offset ( #319 )
# v0.9.0 - 2022/01/05
### New feature
* Supports dynamic filtering of struct fields ( #314 )
### Improve encoding performance
* Improve map encoding performance ( #310 )
* Optimize encoding path for escaped string ( #311 )
* Add encoding option for performance ( #312 )
### Fix bugs
* Fix panic at encoding map value on 1.18 ( #310 )
* Fix MarshalIndent for interface type ( #317 )
# v0.8.1 - 2021/12/05
* Fix operation conversion from PtrHead to Head in Recursive type ( #305 )
# v0.8.0 - 2021/12/02
* Fix embedded field conflict behavior ( #300 )
* Refactor compiler for encoder ( #301 #302 )
# v0.7.10 - 2021/10/16
* Fix conversion from pointer to uint64 ( #294 )
# v0.7.9 - 2021/09/28
* Fix encoding of nil value about interface type that has method ( #291 )
# v0.7.8 - 2021/09/01
* Fix mapassign_faststr for indirect struct type ( #283 )

View File

@ -22,7 +22,7 @@ cover-html: cover
.PHONY: lint
lint: golangci-lint
$(BIN_DIR)/golangci-lint run
golangci-lint run
golangci-lint: | $(BIN_DIR)
@{ \
@ -30,7 +30,7 @@ golangci-lint: | $(BIN_DIR)
GOLANGCI_LINT_TMP_DIR=$$(mktemp -d); \
cd $$GOLANGCI_LINT_TMP_DIR; \
go mod init tmp; \
GOBIN=$(BIN_DIR) go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.54.2; \
GOBIN=$(BIN_DIR) go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.36.0; \
rm -rf $$GOLANGCI_LINT_TMP_DIR; \
}

View File

@ -13,7 +13,7 @@ Fast JSON encoder/decoder compatible with encoding/json for Go
```
* version ( expected release date )
* v0.9.0
* v0.7.0
|
| while maintaining compatibility with encoding/json, we will add convenient APIs
|
@ -21,8 +21,9 @@ Fast JSON encoder/decoder compatible with encoding/json for Go
* v1.0.0
```
We are accepting requests for features that will be implemented between v0.9.0 and v.1.0.0.
We are accepting requests for features that will be implemented between v0.7.0 and v.1.0.0.
If you have the API you need, please submit your issue [here](https://github.com/goccy/go-json/issues).
For example, I'm thinking of supporting `context.Context` of `json.Marshaler` and decoding using JSON Path.
# Features
@ -31,7 +32,6 @@ If you have the API you need, please submit your issue [here](https://github.com
- Flexible customization with options
- Coloring the encoded string
- Can propagate context.Context to `MarshalJSON` or `UnmarshalJSON`
- Can dynamically filter the fields of the structure type-safely
# Installation
@ -184,7 +184,7 @@ func Marshal(v interface{}) ([]byte, error) {
`json.Marshal` and `json.Unmarshal` receive `interface{}` value and they perform type determination dynamically to process.
In normal case, you need to use the `reflect` library to determine the type dynamically, but since `reflect.Type` is defined as `interface`, when you call the method of `reflect.Type`, The reflect's argument is escaped.
Therefore, the arguments for `Marshal` and `Unmarshal` are always escaped to the heap.
Therefore, the arguments for `Marshal` and `Unmarshal` are always escape to the heap.
However, `go-json` can use the feature of `reflect.Type` while avoiding escaping.
`reflect.Type` is defined as `interface`, but in reality `reflect.Type` is implemented only by the structure `rtype` defined in the `reflect` package.

View File

@ -83,37 +83,6 @@ func unmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs
return validateEndBuf(src, cursor)
}
var (
pathDecoder = decoder.NewPathDecoder()
)
func extractFromPath(path *Path, data []byte, optFuncs ...DecodeOptionFunc) ([][]byte, error) {
if path.path.RootSelectorOnly {
return [][]byte{data}, nil
}
src := make([]byte, len(data)+1) // append nul byte to the end
copy(src, data)
ctx := decoder.TakeRuntimeContext()
ctx.Buf = src
ctx.Option.Flags = 0
ctx.Option.Flags |= decoder.PathOption
ctx.Option.Path = path.path
for _, optFunc := range optFuncs {
optFunc(ctx.Option)
}
paths, cursor, err := pathDecoder.DecodePath(ctx, 0, 0)
if err != nil {
decoder.ReleaseRuntimeContext(ctx)
return nil, err
}
decoder.ReleaseRuntimeContext(ctx)
if err := validateEndBuf(src, cursor); err != nil {
return nil, err
}
return paths, nil
}
func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
src := make([]byte, len(data)+1) // append nul byte to the end
copy(src, data)

View File

@ -1,7 +1,7 @@
version: '2'
services:
go-json:
image: golang:1.18
image: golang:1.16
volumes:
- '.:/go/src/go-json'
deploy:

View File

@ -3,7 +3,6 @@ package json
import (
"context"
"io"
"os"
"unsafe"
"github.com/goccy/go-json/internal/encoder"
@ -52,7 +51,7 @@ func (e *Encoder) EncodeContext(ctx context.Context, v interface{}, optFuncs ...
rctx.Option.Flag |= encoder.ContextOption
rctx.Option.Context = ctx
err := e.encodeWithOption(rctx, v, optFuncs...) //nolint: contextcheck
err := e.encodeWithOption(rctx, v, optFuncs...)
encoder.ReleaseRuntimeContext(rctx)
return err
@ -62,8 +61,6 @@ func (e *Encoder) encodeWithOption(ctx *encoder.RuntimeContext, v interface{}, o
if e.enabledHTMLEscape {
ctx.Option.Flag |= encoder.HTMLEscapeOption
}
ctx.Option.Flag |= encoder.NormalizeUTF8Option
ctx.Option.DebugOut = os.Stdout
for _, optFunc := range optFuncs {
optFunc(ctx.Option)
}
@ -114,13 +111,13 @@ func (e *Encoder) SetIndent(prefix, indent string) {
func marshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
rctx := encoder.TakeRuntimeContext()
rctx.Option.Flag = 0
rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.NormalizeUTF8Option | encoder.ContextOption
rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.ContextOption
rctx.Option.Context = ctx
for _, optFunc := range optFuncs {
optFunc(rctx.Option)
}
buf, err := encode(rctx, v) //nolint: contextcheck
buf, err := encode(rctx, v)
if err != nil {
encoder.ReleaseRuntimeContext(rctx)
return nil, err
@ -142,7 +139,7 @@ func marshal(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
ctx := encoder.TakeRuntimeContext()
ctx.Option.Flag = 0
ctx.Option.Flag |= (encoder.HTMLEscapeOption | encoder.NormalizeUTF8Option)
ctx.Option.Flag |= encoder.HTMLEscapeOption
for _, optFunc := range optFuncs {
optFunc(ctx.Option)
}
@ -169,7 +166,7 @@ func marshalNoEscape(v interface{}) ([]byte, error) {
ctx := encoder.TakeRuntimeContext()
ctx.Option.Flag = 0
ctx.Option.Flag |= (encoder.HTMLEscapeOption | encoder.NormalizeUTF8Option)
ctx.Option.Flag |= encoder.HTMLEscapeOption
buf, err := encodeNoEscape(ctx, v)
if err != nil {
@ -193,7 +190,7 @@ func marshalIndent(v interface{}, prefix, indent string, optFuncs ...EncodeOptio
ctx := encoder.TakeRuntimeContext()
ctx.Option.Flag = 0
ctx.Option.Flag |= (encoder.HTMLEscapeOption | encoder.NormalizeUTF8Option | encoder.IndentOption)
ctx.Option.Flag |= (encoder.HTMLEscapeOption | encoder.IndentOption)
for _, optFunc := range optFuncs {
optFunc(ctx.Option)
}
@ -223,7 +220,7 @@ func encode(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error) {
typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
codeSet, err := encoder.CompileToGetCodeSet(typeptr)
if err != nil {
return nil, err
}
@ -251,7 +248,7 @@ func encodeNoEscape(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error)
typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
codeSet, err := encoder.CompileToGetCodeSet(typeptr)
if err != nil {
return nil, err
}
@ -278,7 +275,7 @@ func encodeIndent(ctx *encoder.RuntimeContext, v interface{}, prefix, indent str
typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
codeSet, err := encoder.CompileToGetCodeSet(typeptr)
if err != nil {
return nil, err
}

View File

@ -37,5 +37,3 @@ type UnmarshalTypeError = errors.UnmarshalTypeError
type UnsupportedTypeError = errors.UnsupportedTypeError
type UnsupportedValueError = errors.UnsupportedValueError
type PathError = errors.PathError

View File

@ -35,7 +35,3 @@ func (d *anonymousFieldDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
p = *(*unsafe.Pointer)(p)
return d.dec.Decode(ctx, cursor, depth, unsafe.Pointer(uintptr(p)+d.offset))
}
func (d *anonymousFieldDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return d.dec.DecodePath(ctx, cursor, depth)
}

View File

@ -1,7 +1,6 @@
package decoder
import (
"fmt"
"unsafe"
"github.com/goccy/go-json/internal/errors"
@ -19,9 +18,7 @@ type arrayDecoder struct {
}
func newArrayDecoder(dec Decoder, elemType *runtime.Type, alen int, structName, fieldName string) *arrayDecoder {
// workaround to avoid checkptr errors. cannot use `*(*unsafe.Pointer)(unsafe_New(elemType))` directly.
zeroValuePtr := unsafe_New(elemType)
zeroValue := **(**unsafe.Pointer)(unsafe.Pointer(&zeroValuePtr))
zeroValue := *(*unsafe.Pointer)(unsafe_New(elemType))
return &arrayDecoder{
valueDecoder: dec,
elemType: elemType,
@ -170,7 +167,3 @@ func (d *arrayDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
}
}
}
func (d *arrayDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: array decoder does not support decode path")
}

View File

@ -1,438 +0,0 @@
package decoder
import (
"fmt"
"reflect"
"strconv"
)
var (
nilValue = reflect.ValueOf(nil)
)
func AssignValue(src, dst reflect.Value) error {
if dst.Type().Kind() != reflect.Ptr {
return fmt.Errorf("invalid dst type. required pointer type: %T", dst.Type())
}
casted, err := castValue(dst.Elem().Type(), src)
if err != nil {
return err
}
dst.Elem().Set(casted)
return nil
}
func castValue(t reflect.Type, v reflect.Value) (reflect.Value, error) {
switch t.Kind() {
case reflect.Int:
vv, err := castInt(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(int(vv.Int())), nil
case reflect.Int8:
vv, err := castInt(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(int8(vv.Int())), nil
case reflect.Int16:
vv, err := castInt(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(int16(vv.Int())), nil
case reflect.Int32:
vv, err := castInt(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(int32(vv.Int())), nil
case reflect.Int64:
return castInt(v)
case reflect.Uint:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uint(vv.Uint())), nil
case reflect.Uint8:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uint8(vv.Uint())), nil
case reflect.Uint16:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uint16(vv.Uint())), nil
case reflect.Uint32:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uint32(vv.Uint())), nil
case reflect.Uint64:
return castUint(v)
case reflect.Uintptr:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uintptr(vv.Uint())), nil
case reflect.String:
return castString(v)
case reflect.Bool:
return castBool(v)
case reflect.Float32:
vv, err := castFloat(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(float32(vv.Float())), nil
case reflect.Float64:
return castFloat(v)
case reflect.Array:
return castArray(t, v)
case reflect.Slice:
return castSlice(t, v)
case reflect.Map:
return castMap(t, v)
case reflect.Struct:
return castStruct(t, v)
}
return v, nil
}
func castInt(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return reflect.ValueOf(int64(v.Uint())), nil
case reflect.String:
i64, err := strconv.ParseInt(v.String(), 10, 64)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(i64), nil
case reflect.Bool:
if v.Bool() {
return reflect.ValueOf(int64(1)), nil
}
return reflect.ValueOf(int64(0)), nil
case reflect.Float32, reflect.Float64:
return reflect.ValueOf(int64(v.Float())), nil
case reflect.Array:
if v.Len() > 0 {
return castInt(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to int64 from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castInt(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to int64 from empty slice")
case reflect.Interface:
return castInt(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to int64 from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to int64 from struct")
case reflect.Ptr:
return castInt(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to int64 from %s", v.Type().Kind())
}
func castUint(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.ValueOf(uint64(v.Int())), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v, nil
case reflect.String:
u64, err := strconv.ParseUint(v.String(), 10, 64)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(u64), nil
case reflect.Bool:
if v.Bool() {
return reflect.ValueOf(uint64(1)), nil
}
return reflect.ValueOf(uint64(0)), nil
case reflect.Float32, reflect.Float64:
return reflect.ValueOf(uint64(v.Float())), nil
case reflect.Array:
if v.Len() > 0 {
return castUint(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to uint64 from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castUint(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to uint64 from empty slice")
case reflect.Interface:
return castUint(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to uint64 from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to uint64 from struct")
case reflect.Ptr:
return castUint(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to uint64 from %s", v.Type().Kind())
}
func castString(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.ValueOf(fmt.Sprint(v.Int())), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return reflect.ValueOf(fmt.Sprint(v.Uint())), nil
case reflect.String:
return v, nil
case reflect.Bool:
if v.Bool() {
return reflect.ValueOf("true"), nil
}
return reflect.ValueOf("false"), nil
case reflect.Float32, reflect.Float64:
return reflect.ValueOf(fmt.Sprint(v.Float())), nil
case reflect.Array:
if v.Len() > 0 {
return castString(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to string from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castString(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to string from empty slice")
case reflect.Interface:
return castString(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to string from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to string from struct")
case reflect.Ptr:
return castString(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to string from %s", v.Type().Kind())
}
func castBool(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch v.Int() {
case 0:
return reflect.ValueOf(false), nil
case 1:
return reflect.ValueOf(true), nil
}
return nilValue, fmt.Errorf("failed to cast to bool from %d", v.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch v.Uint() {
case 0:
return reflect.ValueOf(false), nil
case 1:
return reflect.ValueOf(true), nil
}
return nilValue, fmt.Errorf("failed to cast to bool from %d", v.Uint())
case reflect.String:
b, err := strconv.ParseBool(v.String())
if err != nil {
return nilValue, err
}
return reflect.ValueOf(b), nil
case reflect.Bool:
return v, nil
case reflect.Float32, reflect.Float64:
switch v.Float() {
case 0:
return reflect.ValueOf(false), nil
case 1:
return reflect.ValueOf(true), nil
}
return nilValue, fmt.Errorf("failed to cast to bool from %f", v.Float())
case reflect.Array:
if v.Len() > 0 {
return castBool(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to string from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castBool(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to string from empty slice")
case reflect.Interface:
return castBool(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to string from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to string from struct")
case reflect.Ptr:
return castBool(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to bool from %s", v.Type().Kind())
}
func castFloat(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.ValueOf(float64(v.Int())), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return reflect.ValueOf(float64(v.Uint())), nil
case reflect.String:
f64, err := strconv.ParseFloat(v.String(), 64)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(f64), nil
case reflect.Bool:
if v.Bool() {
return reflect.ValueOf(float64(1)), nil
}
return reflect.ValueOf(float64(0)), nil
case reflect.Float32, reflect.Float64:
return v, nil
case reflect.Array:
if v.Len() > 0 {
return castFloat(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to float64 from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castFloat(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to float64 from empty slice")
case reflect.Interface:
return castFloat(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to float64 from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to float64 from struct")
case reflect.Ptr:
return castFloat(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to float64 from %s", v.Type().Kind())
}
func castArray(t reflect.Type, v reflect.Value) (reflect.Value, error) {
kind := v.Type().Kind()
if kind == reflect.Interface {
return castArray(t, reflect.ValueOf(v.Interface()))
}
if kind != reflect.Slice && kind != reflect.Array {
return nilValue, fmt.Errorf("failed to cast to array from %s", kind)
}
if t.Elem() == v.Type().Elem() {
return v, nil
}
if t.Len() != v.Len() {
return nilValue, fmt.Errorf("failed to cast [%d]array from slice of %d length", t.Len(), v.Len())
}
ret := reflect.New(t).Elem()
for i := 0; i < v.Len(); i++ {
vv, err := castValue(t.Elem(), v.Index(i))
if err != nil {
return nilValue, err
}
ret.Index(i).Set(vv)
}
return ret, nil
}
func castSlice(t reflect.Type, v reflect.Value) (reflect.Value, error) {
kind := v.Type().Kind()
if kind == reflect.Interface {
return castSlice(t, reflect.ValueOf(v.Interface()))
}
if kind != reflect.Slice && kind != reflect.Array {
return nilValue, fmt.Errorf("failed to cast to slice from %s", kind)
}
if t.Elem() == v.Type().Elem() {
return v, nil
}
ret := reflect.MakeSlice(t, v.Len(), v.Len())
for i := 0; i < v.Len(); i++ {
vv, err := castValue(t.Elem(), v.Index(i))
if err != nil {
return nilValue, err
}
ret.Index(i).Set(vv)
}
return ret, nil
}
func castMap(t reflect.Type, v reflect.Value) (reflect.Value, error) {
ret := reflect.MakeMap(t)
switch v.Type().Kind() {
case reflect.Map:
iter := v.MapRange()
for iter.Next() {
key, err := castValue(t.Key(), iter.Key())
if err != nil {
return nilValue, err
}
value, err := castValue(t.Elem(), iter.Value())
if err != nil {
return nilValue, err
}
ret.SetMapIndex(key, value)
}
return ret, nil
case reflect.Interface:
return castMap(t, reflect.ValueOf(v.Interface()))
case reflect.Slice:
if v.Len() > 0 {
return castMap(t, v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to map from empty slice")
}
return nilValue, fmt.Errorf("failed to cast to map from %s", v.Type().Kind())
}
func castStruct(t reflect.Type, v reflect.Value) (reflect.Value, error) {
ret := reflect.New(t).Elem()
switch v.Type().Kind() {
case reflect.Map:
iter := v.MapRange()
for iter.Next() {
key := iter.Key()
k, err := castString(key)
if err != nil {
return nilValue, err
}
fieldName := k.String()
field, ok := t.FieldByName(fieldName)
if ok {
value, err := castValue(field.Type, iter.Value())
if err != nil {
return nilValue, err
}
ret.FieldByName(fieldName).Set(value)
}
}
return ret, nil
case reflect.Struct:
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
ret.FieldByName(name).Set(v.FieldByName(name))
}
return ret, nil
case reflect.Interface:
return castStruct(t, reflect.ValueOf(v.Interface()))
case reflect.Slice:
if v.Len() > 0 {
return castStruct(t, v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to struct from empty slice")
default:
return nilValue, fmt.Errorf("failed to cast to struct from %s", v.Type().Kind())
}
}

View File

@ -1,7 +1,6 @@
package decoder
import (
"fmt"
"unsafe"
"github.com/goccy/go-json/internal/errors"
@ -77,7 +76,3 @@ func (d *boolDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.
}
return 0, errors.ErrUnexpectedEndOfJSON("bool", cursor)
}
func (d *boolDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: bool decoder does not support decode path")
}

View File

@ -2,7 +2,6 @@ package decoder
import (
"encoding/base64"
"fmt"
"unsafe"
"github.com/goccy/go-json/internal/errors"
@ -24,8 +23,9 @@ func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName
unmarshalDecoder = newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName)
case runtime.PtrTo(typ).Implements(unmarshalTextType):
unmarshalDecoder = newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName)
default:
unmarshalDecoder, _ = compileUint8(typ, structName, fieldName)
}
if unmarshalDecoder == nil {
return nil
}
return newSliceDecoder(unmarshalDecoder, typ, 1, structName, fieldName)
}
@ -79,10 +79,6 @@ func (d *bytesDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
return cursor, nil
}
func (d *bytesDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: []byte decoder does not support decode path")
}
func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) {
c := s.skipWhiteSpace()
if c == '[' {

View File

@ -9,6 +9,7 @@ import (
"unicode"
"unsafe"
"github.com/goccy/go-json/internal/errors"
"github.com/goccy/go-json/internal/runtime"
)
@ -24,7 +25,7 @@ func init() {
if typeAddr == nil {
typeAddr = &runtime.TypeAddr{}
}
cachedDecoder = make([]Decoder, typeAddr.AddrRange>>typeAddr.AddrShift+1)
cachedDecoder = make([]Decoder, typeAddr.AddrRange>>typeAddr.AddrShift)
}
func loadDecoderMap() map[uintptr]Decoder {
@ -125,7 +126,13 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode
case reflect.Func:
return compileFunc(typ, structName, fieldName)
}
return newInvalidDecoder(typ, structName, fieldName), nil
return nil, &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(typ),
Offset: 0,
Struct: structName,
Field: fieldName,
}
}
func isStringTagSupportedType(typ *runtime.Type) bool {
@ -154,9 +161,6 @@ func compileMapKey(typ *runtime.Type, structName, fieldName string, structTypeTo
if runtime.PtrTo(typ).Implements(unmarshalTextType) {
return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
}
if typ.Kind() == reflect.String {
return newStringDecoder(structName, fieldName), nil
}
dec, err := compile(typ, structName, fieldName, structTypeToDecoder)
if err != nil {
return nil, err
@ -170,9 +174,17 @@ func compileMapKey(typ *runtime.Type, structName, fieldName string, structTypeTo
case *ptrDecoder:
dec = t.dec
default:
return newInvalidDecoder(typ, structName, fieldName), nil
goto ERROR
}
}
ERROR:
return nil, &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(typ),
Offset: 0,
Struct: structName,
Field: fieldName,
}
}
func compilePtr(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
@ -310,21 +322,64 @@ func compileFunc(typ *runtime.Type, strutName, fieldName string) (Decoder, error
return newFuncDecoder(typ, strutName, fieldName), nil
}
func typeToStructTags(typ *runtime.Type) runtime.StructTags {
tags := runtime.StructTags{}
fieldNum := typ.NumField()
for i := 0; i < fieldNum; i++ {
field := typ.Field(i)
if runtime.IsIgnoredStructField(field) {
func removeConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, field reflect.StructField) {
for k, v := range dec.fieldMap {
if _, exists := conflictedMap[k]; exists {
// already conflicted key
continue
}
tags = append(tags, runtime.StructTagFromField(field))
set, exists := fieldMap[k]
if !exists {
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
continue
}
if set.isTaggedKey {
if v.isTaggedKey {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} else {
if v.isTaggedKey {
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
} else {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
}
}
return tags
}
func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
fieldNum := typ.NumField()
conflictedMap := map[string]struct{}{}
fieldMap := map[string]*structFieldSet{}
typeptr := uintptr(unsafe.Pointer(typ))
if dec, exists := structTypeToDecoder[typeptr]; exists {
@ -333,8 +388,6 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
structDec := newStructDecoder(structName, fieldName, fieldMap)
structTypeToDecoder[typeptr] = structDec
structName = typ.Name()
tags := typeToStructTags(typ)
allFields := []*structFieldSet{}
for i := 0; i < fieldNum; i++ {
field := typ.Field(i)
if runtime.IsIgnoredStructField(field) {
@ -352,19 +405,7 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
// recursive definition
continue
}
for k, v := range stDec.fieldMap {
if tags.ExistsKey(k) {
continue
}
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
allFields = append(allFields, fieldSet)
}
removeConflictFields(fieldMap, conflictedMap, stDec, field)
} else if pdec, ok := dec.(*ptrDecoder); ok {
contentDec := pdec.contentDecoder()
if pdec.typ == typ {
@ -380,9 +421,12 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
}
if dec, ok := contentDec.(*structDecoder); ok {
for k, v := range dec.fieldMap {
if tags.ExistsKey(k) {
if _, exists := conflictedMap[k]; exists {
// already conflicted key
continue
}
set, exists := fieldMap[k]
if !exists {
fieldSet := &structFieldSet{
dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
offset: field.Offset,
@ -391,27 +435,46 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
keyLen: int64(len(k)),
err: fieldSetErr,
}
allFields = append(allFields, fieldSet)
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
continue
}
if set.isTaggedKey {
if v.isTaggedKey {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} else {
if v.isTaggedKey {
fieldSet := &structFieldSet{
dec: pdec,
dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
offset: field.Offset,
isTaggedKey: tag.IsTaggedKey,
key: field.Name,
keyLen: int64(len(field.Name)),
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
err: fieldSetErr,
}
allFields = append(allFields, fieldSet)
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
} else {
fieldSet := &structFieldSet{
dec: dec,
offset: field.Offset,
isTaggedKey: tag.IsTaggedKey,
key: field.Name,
keyLen: int64(len(field.Name)),
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
}
}
}
allFields = append(allFields, fieldSet)
}
} else {
if tag.IsString && isStringTagSupportedType(runtime.Type2RType(field.Type)) {
@ -430,15 +493,11 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
key: key,
keyLen: int64(len(key)),
}
allFields = append(allFields, fieldSet)
}
}
for _, set := range filterDuplicatedFields(allFields) {
fieldMap[set.key] = set
lower := strings.ToLower(set.key)
fieldMap[key] = fieldSet
lower := strings.ToLower(key)
if _, exists := fieldMap[lower]; !exists {
// first win
fieldMap[lower] = set
fieldMap[lower] = fieldSet
}
}
}
delete(structTypeToDecoder, typeptr)
@ -446,42 +505,6 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
return structDec, nil
}
func filterDuplicatedFields(allFields []*structFieldSet) []*structFieldSet {
fieldMap := map[string][]*structFieldSet{}
for _, field := range allFields {
fieldMap[field.key] = append(fieldMap[field.key], field)
}
duplicatedFieldMap := map[string]struct{}{}
for k, sets := range fieldMap {
sets = filterFieldSets(sets)
if len(sets) != 1 {
duplicatedFieldMap[k] = struct{}{}
}
}
filtered := make([]*structFieldSet, 0, len(allFields))
for _, field := range allFields {
if _, exists := duplicatedFieldMap[field.key]; exists {
continue
}
filtered = append(filtered, field)
}
return filtered
}
func filterFieldSets(sets []*structFieldSet) []*structFieldSet {
if len(sets) == 1 {
return sets
}
filtered := make([]*structFieldSet, 0, len(sets))
for _, set := range sets {
if set.isTaggedKey {
filtered = append(filtered, set)
}
}
return filtered
}
func implementsUnmarshalJSONType(typ *runtime.Type) bool {
return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType)
}

View File

@ -1,4 +1,3 @@
//go:build !race
// +build !race
package decoder

View File

@ -1,4 +1,3 @@
//go:build race
// +build race
package decoder

View File

@ -156,15 +156,3 @@ func (d *floatDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
d.op(p, f64)
return cursor, nil
}
func (d *floatDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
buf := ctx.Buf
bytes, c, err := d.decodeByte(buf, cursor)
if err != nil {
return nil, 0, err
}
if bytes == nil {
return [][]byte{nullbytes}, c, nil
}
return [][]byte{bytes}, c, nil
}

View File

@ -2,7 +2,6 @@ package decoder
import (
"bytes"
"fmt"
"unsafe"
"github.com/goccy/go-json/internal/errors"
@ -140,7 +139,3 @@ func (d *funcDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.
}
return cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor)
}
func (d *funcDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: func decoder does not support decode path")
}

View File

@ -192,15 +192,15 @@ func (d *intDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) erro
}
switch d.kind {
case reflect.Int8:
if i64 < -1*(1<<7) || (1<<7) <= i64 {
if i64 <= -1*(1<<7) || (1<<7) <= i64 {
return d.typeError(bytes, s.totalOffset())
}
case reflect.Int16:
if i64 < -1*(1<<15) || (1<<15) <= i64 {
if i64 <= -1*(1<<15) || (1<<15) <= i64 {
return d.typeError(bytes, s.totalOffset())
}
case reflect.Int32:
if i64 < -1*(1<<31) || (1<<31) <= i64 {
if i64 <= -1*(1<<31) || (1<<31) <= i64 {
return d.typeError(bytes, s.totalOffset())
}
}
@ -225,22 +225,18 @@ func (d *intDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P
}
switch d.kind {
case reflect.Int8:
if i64 < -1*(1<<7) || (1<<7) <= i64 {
if i64 <= -1*(1<<7) || (1<<7) <= i64 {
return 0, d.typeError(bytes, cursor)
}
case reflect.Int16:
if i64 < -1*(1<<15) || (1<<15) <= i64 {
if i64 <= -1*(1<<15) || (1<<15) <= i64 {
return 0, d.typeError(bytes, cursor)
}
case reflect.Int32:
if i64 < -1*(1<<31) || (1<<31) <= i64 {
if i64 <= -1*(1<<31) || (1<<31) <= i64 {
return 0, d.typeError(bytes, cursor)
}
}
d.op(p, i64)
return cursor, nil
}
func (d *intDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: int decoder does not support decode path")
}

View File

@ -94,7 +94,6 @@ func (d *interfaceDecoder) numDecoder(s *Stream) Decoder {
var (
emptyInterfaceType = runtime.Type2RType(reflect.TypeOf((*interface{})(nil)).Elem())
EmptyInterfaceType = emptyInterfaceType
interfaceMapType = runtime.Type2RType(
reflect.TypeOf((*map[string]interface{})(nil)).Elem(),
)
@ -457,72 +456,3 @@ func (d *interfaceDecoder) decodeEmptyInterface(ctx *RuntimeContext, cursor, dep
}
return cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor)
}
func NewPathDecoder() Decoder {
ifaceDecoder := &interfaceDecoder{
typ: emptyInterfaceType,
structName: "",
fieldName: "",
floatDecoder: newFloatDecoder("", "", func(p unsafe.Pointer, v float64) {
*(*interface{})(p) = v
}),
numberDecoder: newNumberDecoder("", "", func(p unsafe.Pointer, v json.Number) {
*(*interface{})(p) = v
}),
stringDecoder: newStringDecoder("", ""),
}
ifaceDecoder.sliceDecoder = newSliceDecoder(
ifaceDecoder,
emptyInterfaceType,
emptyInterfaceType.Size(),
"", "",
)
ifaceDecoder.mapDecoder = newMapDecoder(
interfaceMapType,
stringType,
ifaceDecoder.stringDecoder,
interfaceMapType.Elem(),
ifaceDecoder,
"", "",
)
return ifaceDecoder
}
var (
truebytes = []byte("true")
falsebytes = []byte("false")
)
func (d *interfaceDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
buf := ctx.Buf
cursor = skipWhiteSpace(buf, cursor)
switch buf[cursor] {
case '{':
return d.mapDecoder.DecodePath(ctx, cursor, depth)
case '[':
return d.sliceDecoder.DecodePath(ctx, cursor, depth)
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
return d.floatDecoder.DecodePath(ctx, cursor, depth)
case '"':
return d.stringDecoder.DecodePath(ctx, cursor, depth)
case 't':
if err := validateTrue(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return [][]byte{truebytes}, cursor, nil
case 'f':
if err := validateFalse(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 5
return [][]byte{falsebytes}, cursor, nil
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return [][]byte{nullbytes}, cursor, nil
}
return nil, cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor)
}

View File

@ -1,55 +0,0 @@
package decoder
import (
"reflect"
"unsafe"
"github.com/goccy/go-json/internal/errors"
"github.com/goccy/go-json/internal/runtime"
)
type invalidDecoder struct {
typ *runtime.Type
kind reflect.Kind
structName string
fieldName string
}
func newInvalidDecoder(typ *runtime.Type, structName, fieldName string) *invalidDecoder {
return &invalidDecoder{
typ: typ,
kind: typ.Kind(),
structName: structName,
fieldName: fieldName,
}
}
func (d *invalidDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error {
return &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(d.typ),
Offset: s.totalOffset(),
Struct: d.structName,
Field: d.fieldName,
}
}
func (d *invalidDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) (int64, error) {
return 0, &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(d.typ),
Offset: cursor,
Struct: d.structName,
Field: d.fieldName,
}
}
func (d *invalidDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(d.typ),
Offset: cursor,
Struct: d.structName,
Field: d.fieldName,
}
}

View File

@ -87,13 +87,13 @@ func (d *mapDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) erro
if mapValue == nil {
mapValue = makemap(d.mapType, 0)
}
s.cursor++
if s.skipWhiteSpace() == '}' {
if s.buf[s.cursor+1] == '}' {
*(*unsafe.Pointer)(p) = mapValue
s.cursor++
s.cursor += 2
return nil
}
for {
s.cursor++
k := unsafe_New(d.keyType)
if err := d.keyDecoder.DecodeStream(s, depth, k); err != nil {
return err
@ -117,7 +117,6 @@ func (d *mapDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) erro
if !s.equalChar(',') {
return errors.ErrExpected("comma after object value", s.totalOffset())
}
s.cursor++
}
}
@ -185,96 +184,3 @@ func (d *mapDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P
cursor++
}
}
func (d *mapDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
buf := ctx.Buf
depth++
if depth > maxDecodeNestingDepth {
return nil, 0, errors.ErrExceededMaxDepth(buf[cursor], cursor)
}
cursor = skipWhiteSpace(buf, cursor)
buflen := int64(len(buf))
if buflen < 2 {
return nil, 0, errors.ErrExpected("{} for map", cursor)
}
switch buf[cursor] {
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return [][]byte{nullbytes}, cursor, nil
case '{':
default:
return nil, 0, errors.ErrExpected("{ character for map value", cursor)
}
cursor++
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == '}' {
cursor++
return nil, cursor, nil
}
keyDecoder, ok := d.keyDecoder.(*stringDecoder)
if !ok {
return nil, 0, &errors.UnmarshalTypeError{
Value: "string",
Type: reflect.TypeOf(""),
Offset: cursor,
Struct: d.structName,
Field: d.fieldName,
}
}
ret := [][]byte{}
for {
key, keyCursor, err := keyDecoder.decodeByte(buf, cursor)
if err != nil {
return nil, 0, err
}
cursor = skipWhiteSpace(buf, keyCursor)
if buf[cursor] != ':' {
return nil, 0, errors.ErrExpected("colon after object key", cursor)
}
cursor++
child, found, err := ctx.Option.Path.Field(string(key))
if err != nil {
return nil, 0, err
}
if found {
if child != nil {
oldPath := ctx.Option.Path.node
ctx.Option.Path.node = child
paths, c, err := d.valueDecoder.DecodePath(ctx, cursor, depth)
if err != nil {
return nil, 0, err
}
ctx.Option.Path.node = oldPath
ret = append(ret, paths...)
cursor = c
} else {
start := cursor
end, err := skipValue(buf, cursor, depth)
if err != nil {
return nil, 0, err
}
ret = append(ret, buf[start:end])
cursor = end
}
} else {
c, err := skipValue(buf, cursor, depth)
if err != nil {
return nil, 0, err
}
cursor = c
}
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == '}' {
cursor++
return ret, cursor, nil
}
if buf[cursor] != ',' {
return nil, 0, errors.ErrExpected("comma after object value", cursor)
}
cursor++
}
}

View File

@ -51,17 +51,6 @@ func (d *numberDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
return cursor, nil
}
func (d *numberDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
bytes, c, err := d.decodeByte(ctx.Buf, cursor)
if err != nil {
return nil, 0, err
}
if bytes == nil {
return [][]byte{nullbytes}, c, nil
}
return [][]byte{bytes}, c, nil
}
func (d *numberDecoder) decodeStreamByte(s *Stream) ([]byte, error) {
start := s.cursor
for {

View File

@ -7,11 +7,9 @@ type OptionFlags uint8
const (
FirstWinOption OptionFlags = 1 << iota
ContextOption
PathOption
)
type Option struct {
Flags OptionFlags
Context context.Context
Path *Path
}

View File

@ -1,670 +0,0 @@
package decoder
import (
"fmt"
"reflect"
"strconv"
"github.com/goccy/go-json/internal/errors"
"github.com/goccy/go-json/internal/runtime"
)
type PathString string
func (s PathString) Build() (*Path, error) {
builder := new(PathBuilder)
return builder.Build([]rune(s))
}
type PathBuilder struct {
root PathNode
node PathNode
singleQuotePathSelector bool
doubleQuotePathSelector bool
}
func (b *PathBuilder) Build(buf []rune) (*Path, error) {
node, err := b.build(buf)
if err != nil {
return nil, err
}
return &Path{
node: node,
RootSelectorOnly: node == nil,
SingleQuotePathSelector: b.singleQuotePathSelector,
DoubleQuotePathSelector: b.doubleQuotePathSelector,
}, nil
}
func (b *PathBuilder) build(buf []rune) (PathNode, error) {
if len(buf) == 0 {
return nil, errors.ErrEmptyPath()
}
if buf[0] != '$' {
return nil, errors.ErrInvalidPath("JSON Path must start with a $ character")
}
if len(buf) == 1 {
return nil, nil
}
buf = buf[1:]
offset, err := b.buildNext(buf)
if err != nil {
return nil, err
}
if len(buf) > offset {
return nil, errors.ErrInvalidPath("remain invalid path %q", buf[offset:])
}
return b.root, nil
}
func (b *PathBuilder) buildNextCharIfExists(buf []rune, cursor int) (int, error) {
if len(buf) > cursor {
offset, err := b.buildNext(buf[cursor:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
}
return cursor, nil
}
func (b *PathBuilder) buildNext(buf []rune) (int, error) {
switch buf[0] {
case '.':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with dot character")
}
offset, err := b.buildSelector(buf[1:])
if err != nil {
return 0, err
}
return offset + 1, nil
case '[':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character")
}
offset, err := b.buildIndex(buf[1:])
if err != nil {
return 0, err
}
return offset + 1, nil
default:
return 0, errors.ErrInvalidPath("expect dot or left bracket character. but found %c character", buf[0])
}
}
func (b *PathBuilder) buildSelector(buf []rune) (int, error) {
switch buf[0] {
case '.':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with double dot character")
}
offset, err := b.buildPathRecursive(buf[1:])
if err != nil {
return 0, err
}
return 1 + offset, nil
case '[', ']', '$', '*':
return 0, errors.ErrInvalidPath("found invalid path character %c after dot", buf[0])
}
for cursor := 0; cursor < len(buf); cursor++ {
switch buf[cursor] {
case '$', '*', ']':
return 0, errors.ErrInvalidPath("found %c character in field selector context", buf[cursor])
case '.':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with dot character")
}
selector := buf[:cursor]
b.addSelectorNode(string(selector))
offset, err := b.buildSelector(buf[cursor+1:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
case '[':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character")
}
selector := buf[:cursor]
b.addSelectorNode(string(selector))
offset, err := b.buildIndex(buf[cursor+1:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
case '"':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with double quote character")
}
offset, err := b.buildQuoteSelector(buf[cursor+1:], DoubleQuotePathSelector)
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
}
}
b.addSelectorNode(string(buf))
return len(buf), nil
}
func (b *PathBuilder) buildQuoteSelector(buf []rune, sel QuotePathSelector) (int, error) {
switch buf[0] {
case '[', ']', '$', '.', '*', '\'', '"':
return 0, errors.ErrInvalidPath("found invalid path character %c after quote", buf[0])
}
for cursor := 0; cursor < len(buf); cursor++ {
switch buf[cursor] {
case '\'':
if sel != SingleQuotePathSelector {
return 0, errors.ErrInvalidPath("found double quote character in field selector with single quote context")
}
if len(buf) <= cursor+1 {
return 0, errors.ErrInvalidPath("JSON Path ends with single quote character in field selector context")
}
if buf[cursor+1] != ']' {
return 0, errors.ErrInvalidPath("expect right bracket for field selector with single quote but found %c", buf[cursor+1])
}
selector := buf[:cursor]
b.addSelectorNode(string(selector))
b.singleQuotePathSelector = true
return b.buildNextCharIfExists(buf, cursor+2)
case '"':
if sel != DoubleQuotePathSelector {
return 0, errors.ErrInvalidPath("found single quote character in field selector with double quote context")
}
selector := buf[:cursor]
b.addSelectorNode(string(selector))
b.doubleQuotePathSelector = true
return b.buildNextCharIfExists(buf, cursor+1)
}
}
return 0, errors.ErrInvalidPath("couldn't find quote character in selector quote path context")
}
func (b *PathBuilder) buildPathRecursive(buf []rune) (int, error) {
switch buf[0] {
case '.', '[', ']', '$', '*':
return 0, errors.ErrInvalidPath("found invalid path character %c after double dot", buf[0])
}
for cursor := 0; cursor < len(buf); cursor++ {
switch buf[cursor] {
case '$', '*', ']':
return 0, errors.ErrInvalidPath("found %c character in field selector context", buf[cursor])
case '.':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with dot character")
}
selector := buf[:cursor]
b.addRecursiveNode(string(selector))
offset, err := b.buildSelector(buf[cursor+1:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
case '[':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character")
}
selector := buf[:cursor]
b.addRecursiveNode(string(selector))
offset, err := b.buildIndex(buf[cursor+1:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
}
}
b.addRecursiveNode(string(buf))
return len(buf), nil
}
func (b *PathBuilder) buildIndex(buf []rune) (int, error) {
switch buf[0] {
case '.', '[', ']', '$':
return 0, errors.ErrInvalidPath("found invalid path character %c after left bracket", buf[0])
case '\'':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with single quote character")
}
offset, err := b.buildQuoteSelector(buf[1:], SingleQuotePathSelector)
if err != nil {
return 0, err
}
return 1 + offset, nil
case '*':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with star character")
}
if buf[1] != ']' {
return 0, errors.ErrInvalidPath("expect right bracket character for index all path but found %c character", buf[1])
}
b.addIndexAllNode()
offset := len("*]")
if len(buf) > 2 {
buildOffset, err := b.buildNext(buf[2:])
if err != nil {
return 0, err
}
return offset + buildOffset, nil
}
return offset, nil
}
for cursor := 0; cursor < len(buf); cursor++ {
switch buf[cursor] {
case ']':
index, err := strconv.ParseInt(string(buf[:cursor]), 10, 64)
if err != nil {
return 0, errors.ErrInvalidPath("%q is unexpected index path", buf[:cursor])
}
b.addIndexNode(int(index))
return b.buildNextCharIfExists(buf, cursor+1)
}
}
return 0, errors.ErrInvalidPath("couldn't find right bracket character in index path context")
}
func (b *PathBuilder) addIndexAllNode() {
node := newPathIndexAllNode()
if b.root == nil {
b.root = node
b.node = node
} else {
b.node = b.node.chain(node)
}
}
func (b *PathBuilder) addRecursiveNode(selector string) {
node := newPathRecursiveNode(selector)
if b.root == nil {
b.root = node
b.node = node
} else {
b.node = b.node.chain(node)
}
}
func (b *PathBuilder) addSelectorNode(name string) {
node := newPathSelectorNode(name)
if b.root == nil {
b.root = node
b.node = node
} else {
b.node = b.node.chain(node)
}
}
func (b *PathBuilder) addIndexNode(idx int) {
node := newPathIndexNode(idx)
if b.root == nil {
b.root = node
b.node = node
} else {
b.node = b.node.chain(node)
}
}
type QuotePathSelector int
const (
SingleQuotePathSelector QuotePathSelector = 1
DoubleQuotePathSelector QuotePathSelector = 2
)
type Path struct {
node PathNode
RootSelectorOnly bool
SingleQuotePathSelector bool
DoubleQuotePathSelector bool
}
func (p *Path) Field(sel string) (PathNode, bool, error) {
if p.node == nil {
return nil, false, nil
}
return p.node.Field(sel)
}
func (p *Path) Get(src, dst reflect.Value) error {
if p.node == nil {
return nil
}
return p.node.Get(src, dst)
}
func (p *Path) String() string {
if p.node == nil {
return "$"
}
return p.node.String()
}
type PathNode interface {
fmt.Stringer
Index(idx int) (PathNode, bool, error)
Field(fieldName string) (PathNode, bool, error)
Get(src, dst reflect.Value) error
chain(PathNode) PathNode
target() bool
single() bool
}
type BasePathNode struct {
child PathNode
}
func (n *BasePathNode) chain(node PathNode) PathNode {
n.child = node
return node
}
func (n *BasePathNode) target() bool {
return n.child == nil
}
func (n *BasePathNode) single() bool {
return true
}
type PathSelectorNode struct {
*BasePathNode
selector string
}
func newPathSelectorNode(selector string) *PathSelectorNode {
return &PathSelectorNode{
BasePathNode: &BasePathNode{},
selector: selector,
}
}
func (n *PathSelectorNode) Index(idx int) (PathNode, bool, error) {
return nil, false, &errors.PathError{}
}
func (n *PathSelectorNode) Field(fieldName string) (PathNode, bool, error) {
if n.selector == fieldName {
return n.child, true, nil
}
return nil, false, nil
}
func (n *PathSelectorNode) Get(src, dst reflect.Value) error {
switch src.Type().Kind() {
case reflect.Map:
iter := src.MapRange()
for iter.Next() {
key, ok := iter.Key().Interface().(string)
if !ok {
return fmt.Errorf("invalid map key type %T", src.Type().Key())
}
child, found, err := n.Field(key)
if err != nil {
return err
}
if found {
if child != nil {
return child.Get(iter.Value(), dst)
}
return AssignValue(iter.Value(), dst)
}
}
case reflect.Struct:
typ := src.Type()
for i := 0; i < typ.Len(); i++ {
tag := runtime.StructTagFromField(typ.Field(i))
child, found, err := n.Field(tag.Key)
if err != nil {
return err
}
if found {
if child != nil {
return child.Get(src.Field(i), dst)
}
return AssignValue(src.Field(i), dst)
}
}
case reflect.Ptr:
return n.Get(src.Elem(), dst)
case reflect.Interface:
return n.Get(reflect.ValueOf(src.Interface()), dst)
case reflect.Float64, reflect.String, reflect.Bool:
return AssignValue(src, dst)
}
return fmt.Errorf("failed to get %s value from %s", n.selector, src.Type())
}
func (n *PathSelectorNode) String() string {
s := fmt.Sprintf(".%s", n.selector)
if n.child != nil {
s += n.child.String()
}
return s
}
type PathIndexNode struct {
*BasePathNode
selector int
}
func newPathIndexNode(selector int) *PathIndexNode {
return &PathIndexNode{
BasePathNode: &BasePathNode{},
selector: selector,
}
}
func (n *PathIndexNode) Index(idx int) (PathNode, bool, error) {
if n.selector == idx {
return n.child, true, nil
}
return nil, false, nil
}
func (n *PathIndexNode) Field(fieldName string) (PathNode, bool, error) {
return nil, false, &errors.PathError{}
}
func (n *PathIndexNode) Get(src, dst reflect.Value) error {
switch src.Type().Kind() {
case reflect.Array, reflect.Slice:
if src.Len() > n.selector {
if n.child != nil {
return n.child.Get(src.Index(n.selector), dst)
}
return AssignValue(src.Index(n.selector), dst)
}
case reflect.Ptr:
return n.Get(src.Elem(), dst)
case reflect.Interface:
return n.Get(reflect.ValueOf(src.Interface()), dst)
}
return fmt.Errorf("failed to get [%d] value from %s", n.selector, src.Type())
}
func (n *PathIndexNode) String() string {
s := fmt.Sprintf("[%d]", n.selector)
if n.child != nil {
s += n.child.String()
}
return s
}
type PathIndexAllNode struct {
*BasePathNode
}
func newPathIndexAllNode() *PathIndexAllNode {
return &PathIndexAllNode{
BasePathNode: &BasePathNode{},
}
}
func (n *PathIndexAllNode) Index(idx int) (PathNode, bool, error) {
return n.child, true, nil
}
func (n *PathIndexAllNode) Field(fieldName string) (PathNode, bool, error) {
return nil, false, &errors.PathError{}
}
func (n *PathIndexAllNode) Get(src, dst reflect.Value) error {
switch src.Type().Kind() {
case reflect.Array, reflect.Slice:
var arr []interface{}
for i := 0; i < src.Len(); i++ {
var v interface{}
rv := reflect.ValueOf(&v)
if n.child != nil {
if err := n.child.Get(src.Index(i), rv); err != nil {
return err
}
} else {
if err := AssignValue(src.Index(i), rv); err != nil {
return err
}
}
arr = append(arr, v)
}
if err := AssignValue(reflect.ValueOf(arr), dst); err != nil {
return err
}
return nil
case reflect.Ptr:
return n.Get(src.Elem(), dst)
case reflect.Interface:
return n.Get(reflect.ValueOf(src.Interface()), dst)
}
return fmt.Errorf("failed to get all value from %s", src.Type())
}
func (n *PathIndexAllNode) String() string {
s := "[*]"
if n.child != nil {
s += n.child.String()
}
return s
}
type PathRecursiveNode struct {
*BasePathNode
selector string
}
func newPathRecursiveNode(selector string) *PathRecursiveNode {
node := newPathSelectorNode(selector)
return &PathRecursiveNode{
BasePathNode: &BasePathNode{
child: node,
},
selector: selector,
}
}
func (n *PathRecursiveNode) Field(fieldName string) (PathNode, bool, error) {
if n.selector == fieldName {
return n.child, true, nil
}
return nil, false, nil
}
func (n *PathRecursiveNode) Index(_ int) (PathNode, bool, error) {
return n, true, nil
}
func valueToSliceValue(v interface{}) []interface{} {
rv := reflect.ValueOf(v)
ret := []interface{}{}
if rv.Type().Kind() == reflect.Slice || rv.Type().Kind() == reflect.Array {
for i := 0; i < rv.Len(); i++ {
ret = append(ret, rv.Index(i).Interface())
}
return ret
}
return []interface{}{v}
}
func (n *PathRecursiveNode) Get(src, dst reflect.Value) error {
if n.child == nil {
return fmt.Errorf("failed to get by recursive path ..%s", n.selector)
}
var arr []interface{}
switch src.Type().Kind() {
case reflect.Map:
iter := src.MapRange()
for iter.Next() {
key, ok := iter.Key().Interface().(string)
if !ok {
return fmt.Errorf("invalid map key type %T", src.Type().Key())
}
child, found, err := n.Field(key)
if err != nil {
return err
}
if found {
var v interface{}
rv := reflect.ValueOf(&v)
_ = child.Get(iter.Value(), rv)
arr = append(arr, valueToSliceValue(v)...)
} else {
var v interface{}
rv := reflect.ValueOf(&v)
_ = n.Get(iter.Value(), rv)
if v != nil {
arr = append(arr, valueToSliceValue(v)...)
}
}
}
_ = AssignValue(reflect.ValueOf(arr), dst)
return nil
case reflect.Struct:
typ := src.Type()
for i := 0; i < typ.Len(); i++ {
tag := runtime.StructTagFromField(typ.Field(i))
child, found, err := n.Field(tag.Key)
if err != nil {
return err
}
if found {
var v interface{}
rv := reflect.ValueOf(&v)
_ = child.Get(src.Field(i), rv)
arr = append(arr, valueToSliceValue(v)...)
} else {
var v interface{}
rv := reflect.ValueOf(&v)
_ = n.Get(src.Field(i), rv)
if v != nil {
arr = append(arr, valueToSliceValue(v)...)
}
}
}
_ = AssignValue(reflect.ValueOf(arr), dst)
return nil
case reflect.Array, reflect.Slice:
for i := 0; i < src.Len(); i++ {
var v interface{}
rv := reflect.ValueOf(&v)
_ = n.Get(src.Index(i), rv)
if v != nil {
arr = append(arr, valueToSliceValue(v)...)
}
}
_ = AssignValue(reflect.ValueOf(arr), dst)
return nil
case reflect.Ptr:
return n.Get(src.Elem(), dst)
case reflect.Interface:
return n.Get(reflect.ValueOf(src.Interface()), dst)
}
return fmt.Errorf("failed to get %s value from %s", n.selector, src.Type())
}
func (n *PathRecursiveNode) String() string {
s := fmt.Sprintf("..%s", n.selector)
if n.child != nil {
s += n.child.String()
}
return s
}

View File

@ -1,7 +1,6 @@
package decoder
import (
"fmt"
"unsafe"
"github.com/goccy/go-json/internal/runtime"
@ -35,10 +34,6 @@ func (d *ptrDecoder) contentDecoder() Decoder {
//go:linkname unsafe_New reflect.unsafe_New
func unsafe_New(*runtime.Type) unsafe.Pointer
func UnsafeNew(t *runtime.Type) unsafe.Pointer {
return unsafe_New(t)
}
func (d *ptrDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error {
if s.skipWhiteSpace() == nul {
s.read()
@ -85,13 +80,8 @@ func (d *ptrDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P
}
c, err := d.dec.Decode(ctx, cursor, depth, newptr)
if err != nil {
*(*unsafe.Pointer)(p) = nil
return 0, err
}
cursor = c
return cursor, nil
}
func (d *ptrDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: ptr decoder does not support decode path")
}

View File

@ -299,82 +299,3 @@ func (d *sliceDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
}
}
}
func (d *sliceDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
buf := ctx.Buf
depth++
if depth > maxDecodeNestingDepth {
return nil, 0, errors.ErrExceededMaxDepth(buf[cursor], cursor)
}
ret := [][]byte{}
for {
switch buf[cursor] {
case ' ', '\n', '\t', '\r':
cursor++
continue
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return [][]byte{nullbytes}, cursor, nil
case '[':
cursor++
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == ']' {
cursor++
return ret, cursor, nil
}
idx := 0
for {
child, found, err := ctx.Option.Path.node.Index(idx)
if err != nil {
return nil, 0, err
}
if found {
if child != nil {
oldPath := ctx.Option.Path.node
ctx.Option.Path.node = child
paths, c, err := d.valueDecoder.DecodePath(ctx, cursor, depth)
if err != nil {
return nil, 0, err
}
ctx.Option.Path.node = oldPath
ret = append(ret, paths...)
cursor = c
} else {
start := cursor
end, err := skipValue(buf, cursor, depth)
if err != nil {
return nil, 0, err
}
ret = append(ret, buf[start:end])
cursor = end
}
} else {
c, err := skipValue(buf, cursor, depth)
if err != nil {
return nil, 0, err
}
cursor = c
}
cursor = skipWhiteSpace(buf, cursor)
switch buf[cursor] {
case ']':
cursor++
return ret, cursor, nil
case ',':
idx++
default:
return nil, 0, errors.ErrInvalidCharacter(buf[cursor], "slice", cursor)
}
cursor++
}
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
return nil, 0, d.errNumber(cursor)
default:
return nil, 0, errors.ErrUnexpectedEndOfJSON("slice", cursor)
}
}
}

View File

@ -103,7 +103,7 @@ func (s *Stream) statForRetry() ([]byte, int64, unsafe.Pointer) {
func (s *Stream) Reset() {
s.reset()
s.bufSize = int64(len(s.buf))
s.bufSize = initBufSize
}
func (s *Stream) More() bool {
@ -138,11 +138,8 @@ func (s *Stream) Token() (interface{}, error) {
s.cursor++
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
bytes := floatBytes(s)
str := *(*string)(unsafe.Pointer(&bytes))
if s.UseNumber {
return json.Number(str), nil
}
f64, err := strconv.ParseFloat(str, 64)
s := *(*string)(unsafe.Pointer(&bytes))
f64, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, err
}
@ -280,7 +277,7 @@ func (s *Stream) skipObject(depth int64) error {
if char(p, cursor) == nul {
s.cursor = cursor
if s.read() {
_, cursor, p = s.stat()
_, cursor, p = s.statForRetry()
continue
}
return errors.ErrUnexpectedEndOfJSON("string of object", cursor)
@ -343,7 +340,7 @@ func (s *Stream) skipArray(depth int64) error {
if char(p, cursor) == nul {
s.cursor = cursor
if s.read() {
_, cursor, p = s.stat()
_, cursor, p = s.statForRetry()
continue
}
return errors.ErrUnexpectedEndOfJSON("string of object", cursor)
@ -401,7 +398,7 @@ func (s *Stream) skipValue(depth int64) error {
if char(p, cursor) == nul {
s.cursor = cursor
if s.read() {
_, cursor, p = s.stat()
_, cursor, p = s.statForRetry()
continue
}
return errors.ErrUnexpectedEndOfJSON("value of string", s.totalOffset())
@ -426,6 +423,7 @@ func (s *Stream) skipValue(depth int64) error {
continue
} else if c == nul {
if s.read() {
s.cursor-- // for retry current character
_, cursor, p = s.stat()
continue
}

View File

@ -1,8 +1,6 @@
package decoder
import (
"bytes"
"fmt"
"reflect"
"unicode"
"unicode/utf16"
@ -60,17 +58,6 @@ func (d *stringDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
return cursor, nil
}
func (d *stringDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
bytes, c, err := d.decodeByte(ctx.Buf, cursor)
if err != nil {
return nil, 0, err
}
if bytes == nil {
return [][]byte{nullbytes}, c, nil
}
return [][]byte{bytes}, c, nil
}
var (
hexToInt = [256]int{
'0': 0,
@ -106,30 +93,24 @@ func unicodeToRune(code []byte) rune {
return r
}
func readAtLeast(s *Stream, n int64, p *unsafe.Pointer) bool {
for s.cursor+n >= s.length {
if !s.read() {
return false
}
*p = s.bufptr()
}
return true
}
func decodeUnicodeRune(s *Stream, p unsafe.Pointer) (rune, int64, unsafe.Pointer, error) {
const defaultOffset = 5
const surrogateOffset = 11
if !readAtLeast(s, defaultOffset, &p) {
if s.cursor+defaultOffset >= s.length {
if !s.read() {
return rune(0), 0, nil, errors.ErrInvalidCharacter(s.char(), "escaped string", s.totalOffset())
}
p = s.bufptr()
}
r := unicodeToRune(s.buf[s.cursor+1 : s.cursor+defaultOffset])
if utf16.IsSurrogate(r) {
if !readAtLeast(s, surrogateOffset, &p) {
return unicode.ReplacementChar, defaultOffset, p, nil
if s.cursor+surrogateOffset >= s.length {
s.read()
p = s.bufptr()
}
if s.buf[s.cursor+defaultOffset] != '\\' || s.buf[s.cursor+defaultOffset+1] != 'u' {
if s.cursor+surrogateOffset >= s.length || s.buf[s.cursor+defaultOffset] != '\\' || s.buf[s.cursor+defaultOffset+1] != 'u' {
return unicode.ReplacementChar, defaultOffset, p, nil
}
r2 := unicodeToRune(s.buf[s.cursor+defaultOffset+2 : s.cursor+surrogateOffset])
@ -182,7 +163,6 @@ RETRY:
if !s.read() {
return nil, errors.ErrInvalidCharacter(s.char(), "escaped string", s.totalOffset())
}
p = s.bufptr()
goto RETRY
default:
return nil, errors.ErrUnexpectedEndOfJSON("string", s.totalOffset())
@ -328,36 +308,49 @@ func (d *stringDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, err
cursor++
start := cursor
b := (*sliceHeader)(unsafe.Pointer(&buf)).data
escaped := 0
for {
switch char(b, cursor) {
case '\\':
escaped++
cursor++
switch char(b, cursor) {
case '"', '\\', '/', 'b', 'f', 'n', 'r', 't':
cursor++
case '"':
buf[cursor] = '"'
buf = append(buf[:cursor-1], buf[cursor:]...)
case '\\':
buf[cursor] = '\\'
buf = append(buf[:cursor-1], buf[cursor:]...)
case '/':
buf[cursor] = '/'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'b':
buf[cursor] = '\b'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'f':
buf[cursor] = '\f'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'n':
buf[cursor] = '\n'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'r':
buf[cursor] = '\r'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 't':
buf[cursor] = '\t'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'u':
buflen := int64(len(buf))
if cursor+5 >= buflen {
return nil, 0, errors.ErrUnexpectedEndOfJSON("escaped string", cursor)
}
for i := int64(1); i <= 4; i++ {
c := char(b, cursor+i)
if !(('0' <= c && c <= '9') || ('a' <= c && c <= 'f') || ('A' <= c && c <= 'F')) {
return nil, 0, errors.ErrSyntax(fmt.Sprintf("json: invalid character %c in \\u hexadecimal character escape", c), cursor+i)
}
}
cursor += 5
code := unicodeToRune(buf[cursor+1 : cursor+5])
unicode := []byte(string(code))
buf = append(append(buf[:cursor-1], unicode...), buf[cursor+5:]...)
default:
return nil, 0, errors.ErrUnexpectedEndOfJSON("escaped string", cursor)
}
continue
case '"':
literal := buf[start:cursor]
if escaped > 0 {
literal = literal[:unescapeString(literal)]
}
cursor++
return literal, cursor, nil
case nul:
@ -376,77 +369,3 @@ func (d *stringDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, err
}
}
}
var unescapeMap = [256]byte{
'"': '"',
'\\': '\\',
'/': '/',
'b': '\b',
'f': '\f',
'n': '\n',
'r': '\r',
't': '\t',
}
func unsafeAdd(ptr unsafe.Pointer, offset int) unsafe.Pointer {
return unsafe.Pointer(uintptr(ptr) + uintptr(offset))
}
func unescapeString(buf []byte) int {
p := (*sliceHeader)(unsafe.Pointer(&buf)).data
end := unsafeAdd(p, len(buf))
src := unsafeAdd(p, bytes.IndexByte(buf, '\\'))
dst := src
for src != end {
c := char(src, 0)
if c == '\\' {
escapeChar := char(src, 1)
if escapeChar != 'u' {
*(*byte)(dst) = unescapeMap[escapeChar]
src = unsafeAdd(src, 2)
dst = unsafeAdd(dst, 1)
} else {
v1 := hexToInt[char(src, 2)]
v2 := hexToInt[char(src, 3)]
v3 := hexToInt[char(src, 4)]
v4 := hexToInt[char(src, 5)]
code := rune((v1 << 12) | (v2 << 8) | (v3 << 4) | v4)
if code >= 0xd800 && code < 0xdc00 && uintptr(unsafeAdd(src, 11)) < uintptr(end) {
if char(src, 6) == '\\' && char(src, 7) == 'u' {
v1 := hexToInt[char(src, 8)]
v2 := hexToInt[char(src, 9)]
v3 := hexToInt[char(src, 10)]
v4 := hexToInt[char(src, 11)]
lo := rune((v1 << 12) | (v2 << 8) | (v3 << 4) | v4)
if lo >= 0xdc00 && lo < 0xe000 {
code = (code-0xd800)<<10 | (lo - 0xdc00) + 0x10000
src = unsafeAdd(src, 6)
}
}
}
var b [utf8.UTFMax]byte
n := utf8.EncodeRune(b[:], code)
switch n {
case 4:
*(*byte)(unsafeAdd(dst, 3)) = b[3]
fallthrough
case 3:
*(*byte)(unsafeAdd(dst, 2)) = b[2]
fallthrough
case 2:
*(*byte)(unsafeAdd(dst, 1)) = b[1]
fallthrough
case 1:
*(*byte)(unsafeAdd(dst, 0)) = b[0]
}
src = unsafeAdd(src, 6)
dst = unsafeAdd(dst, n)
}
} else {
*(*byte)(dst) = c
src = unsafeAdd(src, 1)
dst = unsafeAdd(dst, 1)
}
}
return int(uintptr(dst) - uintptr(p))
}

View File

@ -51,14 +51,6 @@ func init() {
}
}
func toASCIILower(s string) string {
b := []byte(s)
for i := range b {
b[i] = largeToSmallTable[b[i]]
}
return string(b)
}
func newStructDecoder(structName, fieldName string, fieldMap map[string]*structFieldSet) *structDecoder {
return &structDecoder{
fieldMap: fieldMap,
@ -99,10 +91,6 @@ func (d *structDecoder) tryOptimize() {
for k, v := range d.fieldMap {
key := strings.ToLower(k)
if key != k {
if key != toASCIILower(k) {
d.isTriedOptimize = true
return
}
// already exists same key (e.g. Hello and HELLO has same lower case key
if _, exists := conflicted[key]; exists {
d.isTriedOptimize = true
@ -170,53 +158,49 @@ func (d *structDecoder) tryOptimize() {
}
// decode from '\uXXXX'
func decodeKeyCharByUnicodeRune(buf []byte, cursor int64) ([]byte, int64, error) {
func decodeKeyCharByUnicodeRune(buf []byte, cursor int64) ([]byte, int64) {
const defaultOffset = 4
const surrogateOffset = 6
if cursor+defaultOffset >= int64(len(buf)) {
return nil, 0, errors.ErrUnexpectedEndOfJSON("escaped string", cursor)
}
r := unicodeToRune(buf[cursor : cursor+defaultOffset])
if utf16.IsSurrogate(r) {
cursor += defaultOffset
if cursor+surrogateOffset >= int64(len(buf)) || buf[cursor] != '\\' || buf[cursor+1] != 'u' {
return []byte(string(unicode.ReplacementChar)), cursor + defaultOffset - 1, nil
return []byte(string(unicode.ReplacementChar)), cursor + defaultOffset - 1
}
cursor += 2
r2 := unicodeToRune(buf[cursor : cursor+defaultOffset])
if r := utf16.DecodeRune(r, r2); r != unicode.ReplacementChar {
return []byte(string(r)), cursor + defaultOffset - 1, nil
return []byte(string(r)), cursor + defaultOffset - 1
}
}
return []byte(string(r)), cursor + defaultOffset - 1, nil
return []byte(string(r)), cursor + defaultOffset - 1
}
func decodeKeyCharByEscapedChar(buf []byte, cursor int64) ([]byte, int64, error) {
func decodeKeyCharByEscapedChar(buf []byte, cursor int64) ([]byte, int64) {
c := buf[cursor]
cursor++
switch c {
case '"':
return []byte{'"'}, cursor, nil
return []byte{'"'}, cursor
case '\\':
return []byte{'\\'}, cursor, nil
return []byte{'\\'}, cursor
case '/':
return []byte{'/'}, cursor, nil
return []byte{'/'}, cursor
case 'b':
return []byte{'\b'}, cursor, nil
return []byte{'\b'}, cursor
case 'f':
return []byte{'\f'}, cursor, nil
return []byte{'\f'}, cursor
case 'n':
return []byte{'\n'}, cursor, nil
return []byte{'\n'}, cursor
case 'r':
return []byte{'\r'}, cursor, nil
return []byte{'\r'}, cursor
case 't':
return []byte{'\t'}, cursor, nil
return []byte{'\t'}, cursor
case 'u':
return decodeKeyCharByUnicodeRune(buf, cursor)
}
return nil, cursor, nil
return nil, cursor
}
func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64, *structFieldSet, error) {
@ -258,10 +242,7 @@ func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64,
return 0, nil, errors.ErrUnexpectedEndOfJSON("string", cursor)
case '\\':
cursor++
chars, nextCursor, err := decodeKeyCharByEscapedChar(buf, cursor)
if err != nil {
return 0, nil, err
}
chars, nextCursor := decodeKeyCharByEscapedChar(buf, cursor)
for _, c := range chars {
curBit &= bitmap[keyIdx][largeToSmallTable[c]]
if curBit == 0 {
@ -324,10 +305,7 @@ func decodeKeyByBitmapUint16(d *structDecoder, buf []byte, cursor int64) (int64,
return 0, nil, errors.ErrUnexpectedEndOfJSON("string", cursor)
case '\\':
cursor++
chars, nextCursor, err := decodeKeyCharByEscapedChar(buf, cursor)
if err != nil {
return 0, nil, err
}
chars, nextCursor := decodeKeyCharByEscapedChar(buf, cursor)
for _, c := range chars {
curBit &= bitmap[keyIdx][largeToSmallTable[c]]
if curBit == 0 {
@ -839,7 +817,3 @@ func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
cursor++
}
}
func (d *structDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: struct decoder does not support decode path")
}

View File

@ -10,7 +10,6 @@ import (
type Decoder interface {
Decode(*RuntimeContext, int64, int64, unsafe.Pointer) (int64, error)
DecodePath(*RuntimeContext, int64, int64) ([][]byte, int64, error)
DecodeStream(*Stream, int64, unsafe.Pointer) error
}

View File

@ -188,7 +188,3 @@ func (d *uintDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.
d.op(p, u64)
return cursor, nil
}
func (d *uintDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: uint decoder does not support decode path")
}

View File

@ -1,9 +1,7 @@
package decoder
import (
"context"
"encoding/json"
"fmt"
"unsafe"
"github.com/goccy/go-json/internal/errors"
@ -48,20 +46,13 @@ func (d *unmarshalJSONDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Poi
typ: d.typ,
ptr: p,
}))
switch v := v.(type) {
case unmarshalerContext:
var ctx context.Context
if (s.Option.Flags & ContextOption) != 0 {
ctx = s.Option.Context
} else {
ctx = context.Background()
}
if err := v.UnmarshalJSON(ctx, dst); err != nil {
if err := v.(unmarshalerContext).UnmarshalJSON(s.Option.Context, dst); err != nil {
d.annotateError(s.cursor, err)
return err
}
case json.Unmarshaler:
if err := v.UnmarshalJSON(dst); err != nil {
} else {
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil {
d.annotateError(s.cursor, err)
return err
}
@ -98,7 +89,3 @@ func (d *unmarshalJSONDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
}
return end, nil
}
func (d *unmarshalJSONDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: unmarshal json decoder does not support decode path")
}

View File

@ -3,7 +3,6 @@ package decoder
import (
"bytes"
"encoding"
"fmt"
"unicode"
"unicode/utf16"
"unicode/utf8"
@ -143,11 +142,7 @@ func (d *unmarshalTextDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
return end, nil
}
func (d *unmarshalTextDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: unmarshal text decoder does not support decode path")
}
func unquoteBytes(s []byte) (t []byte, ok bool) { //nolint: nonamedreturns
func unquoteBytes(s []byte) (t []byte, ok bool) {
length := len(s)
if length < 2 || s[0] != '"' || s[length-1] != '"' {
return

View File

@ -1,7 +1,6 @@
package decoder
import (
"fmt"
"reflect"
"unsafe"
@ -67,7 +66,3 @@ func (d *wrappedStringDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
ctx.Buf = oldBuf
return c, nil
}
func (d *wrappedStringDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: wrapped string decoder does not support decode path")
}

Some files were not shown because too many files have changed in this diff Show More