Compare commits

..

No commits in common. "master" and "0.2.7" have entirely different histories.

1044 changed files with 38612 additions and 133458 deletions

View File

@ -1,70 +1,89 @@
--- ---
kind: pipeline kind: pipeline
type: docker type: docker
name: build-linux name: cleanup-before
environment:
GOOS: linux
GOOPTIONS: -mod=vendor
SRCFILES: cmd/pki/*.go
PROJECTNAME: pki
steps: steps:
- name: build-linux-amd64 - name: clean
image: golang image: alpine
commands: commands:
- go build -o $PROJECTNAME $GOOPTIONS $SRCFILES - rm -rf /build/*
environment: volumes:
GOARCH: amd64 - name: build
path: /build
when: when:
event: event: tag
exclude:
- tag volumes:
- name: build-linux-arm64 - name: build
image: golang host:
commands: path: /tmp/pki/build
- go build -o $PROJECTNAME $GOOPTIONS $SRCFILES
environment:
GOARCH: arm64
when:
event:
exclude:
- tag
--- ---
kind: pipeline kind: pipeline
type: docker type: docker
name: gitea-release-linux name: default-linux-amd64
environment:
GOOS: linux
GOOPTIONS: -mod=vendor
SRCFILES: cmd/pki/*.go
PROJECTNAME: pki
steps: steps:
- name: build-linux-amd64 - name: build
image: golang image: golang
commands: commands:
- go build -o $PROJECTNAME $GOOPTIONS $SRCFILES - ./ci-build.sh build
- tar -czvf $PROJECTNAME-$DRONE_TAG-$GOOS-$GOARCH.tar.gz $PROJECTNAME
- echo $PROJECTNAME $DRONE_TAG > VERSION
environment: environment:
GOOS: linux
GOARCH: amd64 GOARCH: amd64
when: volumes:
event: - name: build
- tag path: /build
- name: build-linux-arm64
volumes:
- name: build
host:
path: /tmp/pki/build
depends_on:
- cleanup-before
---
kind: pipeline
type: docker
name: default-linux-arm64
steps:
- name: build
image: golang image: golang
commands: commands:
- go build -o $PROJECTNAME $GOOPTIONS $SRCFILES - ./ci-build.sh build
- tar -czvf $PROJECTNAME-$DRONE_TAG-$GOOS-$GOARCH.tar.gz $PROJECTNAME
- echo $PROJECTNAME $DRONE_TAG > VERSION
environment: environment:
GOOS: linux
GOARCH: arm64 GOARCH: arm64
volumes:
- name: build
path: /build
volumes:
- name: build
host:
path: /tmp/pki/build
depends_on:
- cleanup-before
---
kind: pipeline
type: docker
name: gitea-release
steps:
- name: move
image: alpine
commands:
- mv build/* ./
volumes:
- name: build
path: /drone/src/build
when: when:
event: event: tag
- tag
- name: release - name: release
image: plugins/gitea-release image: plugins/gitea-release
settings: settings:
@ -76,6 +95,50 @@ steps:
- sha256 - sha256
- sha512 - sha512
title: VERSION title: VERSION
volumes:
- name: build
path: /drone/src/build
when: when:
event: event: tag
- tag - name: ls
image: alpine
commands:
- find .
volumes:
- name: build
path: /drone/src/build
when:
event: tag
volumes:
- name: build
host:
path: /tmp/pki/build
depends_on:
- default-linux-amd64
- default-linux-arm64
---
kind: pipeline
type: docker
name: cleanup-after
steps:
- name: clean
image: alpine
commands:
- rm -rf /build/*
volumes:
- name: build
path: /build
when:
event: tag
volumes:
- name: build
host:
path: /tmp/pki/build
depends_on:
- gitea-release

18
Makefile Normal file
View File

@ -0,0 +1,18 @@
# pki Makefile
GOCMD=go
GOBUILDCMD=${GOCMD} build
GOOPTIONS=-mod=vendor -ldflags="-s -w"
RMCMD=rm
BINNAME=pki
SRCFILES=cmd/pki/*.go
all: build
build:
${GOBUILDCMD} ${GOOPTIONS} ${SRCFILES}
clean:
${RMCMD} -f ${BINNAME}

View File

@ -10,7 +10,7 @@ PKI is a centralized Letsencrypt database server and renewer for certificate man
### Build ### Build
```bash ```bash
go build cmd/pki/pki.go make
``` ```
### Sample config in pki.ini ### Sample config in pki.ini
@ -40,7 +40,7 @@ ovhck=
## License ## License
```text ```text
Copyright (c) 2020, 2021, 2022 PaulBSD Copyright (c) 2020, 2021 PaulBSD
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without

62
ci-build.sh Executable file
View File

@ -0,0 +1,62 @@
#!/bin/bash
set -e
PROJECTNAME=pki
RELEASENAME=${PROJECTNAME}
VERSION="0"
GOOPTIONS="-mod=vendor"
SRCFILES=cmd/${PROJECTNAME}/*.go
build() {
echo "Begin of build"
if [[ ! -z $DRONE_TAG ]]
then
echo "Drone tag set, let's do a release"
VERSION=$DRONE_TAG
echo "${PROJECTNAME} ${VERSION}" > /build/VERSION
elif [[ ! -z $DRONE_TAG ]]
then
echo "Drone not set, let's only do a build"
VERSION=$DRONE_COMMIT
fi
if [[ ! -z $VERSION && ! -z $GOOS && ! -z $GOARCH ]]
then
echo "Let's set a release name"
RELEASENAME=${PROJECTNAME}-${VERSION}-${GOOS}-${GOARCH}
fi
echo "Building project"
go build -o ${PROJECTNAME} ${GOOPTIONS} ${SRCFILES}
if [[ ! -z $DRONE_TAG ]]
then
echo "Let's make archives"
mkdir -p /build
tar -czvf /build/${RELEASENAME}.tar.gz ${PROJECTNAME}
fi
echo "Removing binary file"
rm ${PROJECTNAME}
echo "End of build"
}
clean() {
rm -rf $RELEASEDIR
}
case $1 in
"build")
build
;;
"clean")
clean
;;
*)
echo "No options choosen"
exit 1
;;
esac

53
go.mod
View File

@ -1,42 +1,41 @@
module git.paulbsd.com/paulbsd/pki module git.paulbsd.com/paulbsd/pki
go 1.23 go 1.17
require ( require (
github.com/go-acme/lego/v4 v4.17.4 github.com/go-acme/lego/v4 v4.4.0
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/labstack/echo/v4 v4.12.0 github.com/google/go-cmp v0.5.5 // indirect
github.com/lib/pq v1.10.9 github.com/gopherjs/gopherjs v0.0.0-20210406100015-1e088ea4ee04 // indirect
github.com/miekg/dns v1.1.62 // indirect github.com/labstack/echo/v4 v4.5.0
github.com/lib/pq v1.10.3
github.com/miekg/dns v1.1.43 // indirect
github.com/onsi/ginkgo v1.16.0 // indirect github.com/onsi/ginkgo v1.16.0 // indirect
github.com/onsi/gomega v1.11.0 // indirect github.com/onsi/gomega v1.11.0 // indirect
golang.org/x/crypto v0.26.0 // indirect github.com/smartystreets/assertions v1.2.0 // indirect
golang.org/x/net v0.28.0 // indirect golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 // indirect
golang.org/x/sys v0.24.0 // indirect golang.org/x/net v0.0.0-20210903162142-ad29c8ab022f // indirect
golang.org/x/text v0.17.0 // indirect golang.org/x/sys v0.0.0-20210903071746-97244b99971b // indirect
golang.org/x/time v0.6.0 // indirect golang.org/x/text v0.3.7 // indirect
gopkg.in/ini.v1 v1.67.0 golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
xorm.io/builder v0.3.13 // indirect gopkg.in/ini.v1 v1.62.1
xorm.io/xorm v1.3.9 xorm.io/builder v0.3.9 // indirect
xorm.io/xorm v1.2.3
) )
require ( require (
github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v4 v4.1.1 // indirect
github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/goccy/go-json v0.7.8 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.11 // indirect
github.com/labstack/gommon v0.4.2 // indirect github.com/labstack/gommon v0.3.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.13 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/ovh/go-ovh v1.6.0 // indirect github.com/ovh/go-ovh v1.1.0 // indirect
github.com/syndtr/goleveldb v1.0.0 // indirect github.com/syndtr/goleveldb v1.0.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect github.com/valyala/fasttemplate v1.2.1 // indirect
golang.org/x/mod v0.20.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
golang.org/x/oauth2 v0.22.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/tools v0.24.0 // indirect
) )

1007
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -2,13 +2,10 @@ package cert
import "time" import "time"
func (e *Entry) Z() {
}
// Entry is the main struct for stored certificates // Entry is the main struct for stored certificates
type Entry struct { type Entry struct {
ID int `xorm:"pk autoincr"` ID int `xorm:"pk autoincr"`
Domain string `xorm:"notnull"` Domains string `xorm:"notnull"`
Certificate string `xorm:"text notnull"` Certificate string `xorm:"text notnull"`
PrivateKey string `xorm:"text notnull"` PrivateKey string `xorm:"text notnull"`
AuthURL string `xorm:"notnull"` AuthURL string `xorm:"notnull"`

View File

@ -60,13 +60,10 @@ func (cfg *Config) GetConfig() error {
options["ovhas"] = pkisection.Key("ovhas").MustString("") options["ovhas"] = pkisection.Key("ovhas").MustString("")
options["ovhck"] = pkisection.Key("ovhck").MustString("") options["ovhck"] = pkisection.Key("ovhck").MustString("")
options["pdnsapiurl"] = pkisection.Key("pdnsapiurl").MustString("")
options["pdnsapikey"] = pkisection.Key("pdnsapikey").MustString("")
cfg.ACME.ProviderOptions = options cfg.ACME.ProviderOptions = options
for key, value := range options { for k, v := range options {
if value == "" { if v == "" {
utils.Advice(fmt.Sprintf("Provider parameter %s not set", key)) utils.Advice(fmt.Sprintf("OVH provider parameter %s not set", k))
} }
} }
@ -75,8 +72,6 @@ func (cfg *Config) GetConfig() error {
cfg.ACME.AuthURL = lego.LEDirectoryProduction cfg.ACME.AuthURL = lego.LEDirectoryProduction
case "staging": case "staging":
cfg.ACME.AuthURL = lego.LEDirectoryStaging cfg.ACME.AuthURL = lego.LEDirectoryStaging
default:
cfg.ACME.AuthURL = lego.LEDirectoryStaging
} }
return nil return nil

View File

@ -7,7 +7,6 @@ import (
"git.paulbsd.com/paulbsd/pki/src/cert" "git.paulbsd.com/paulbsd/pki/src/cert"
"git.paulbsd.com/paulbsd/pki/src/config" "git.paulbsd.com/paulbsd/pki/src/config"
"git.paulbsd.com/paulbsd/pki/src/domain"
"git.paulbsd.com/paulbsd/pki/src/pki" "git.paulbsd.com/paulbsd/pki/src/pki"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"xorm.io/xorm" "xorm.io/xorm"
@ -18,7 +17,7 @@ import (
func Init(cfg *config.Config) (err error) { func Init(cfg *config.Config) (err error) {
var databaseEngine = "postgres" var databaseEngine = "postgres"
tables := []interface{}{cert.Entry{}, tables := []interface{}{cert.Entry{},
pki.User{}, domain.Domain{}} pki.User{}}
cfg.Db, err = xorm.NewEngine(databaseEngine, cfg.Db, err = xorm.NewEngine(databaseEngine,
fmt.Sprintf("%s://%s:%s@%s/%s", fmt.Sprintf("%s://%s:%s@%s/%s",

View File

@ -1,12 +0,0 @@
package domain
import "time"
// Domain describes a domain
type Domain struct {
ID int `xorm:"pk autoincr"`
Domain string `xorm:"text notnull unique(domain_provider)"`
Provider string `xorm:"text notnull unique(domain_provider)"`
Created time.Time `xorm:"created notnull"`
Updated time.Time `xorm:"updated notnull"`
}

View File

@ -9,13 +9,12 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"log" "log"
"strings"
"git.paulbsd.com/paulbsd/pki/src/cert" "git.paulbsd.com/paulbsd/pki/src/cert"
"git.paulbsd.com/paulbsd/pki/src/config" "git.paulbsd.com/paulbsd/pki/src/config"
"git.paulbsd.com/paulbsd/pki/src/domain"
"github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
) )
@ -30,12 +29,15 @@ func (u *User) Init(cfg *config.Config) (err error) {
} }
// GetEntry returns requested acme ressource in database relative to domain // GetEntry returns requested acme ressource in database relative to domain
func (u *User) GetEntry(cfg *config.Config, domain *string) (Entry cert.Entry, err error) { func (u *User) GetEntry(cfg *config.Config, domains []string) (Entry cert.Entry, err error) {
has, err := cfg.Db.Where("domain = ?", domain).And(
has, err := cfg.Db.Where("domains = ?", strings.Join(domains, ",")).And(
"auth_url = ?", cfg.ACME.AuthURL).And( "auth_url = ?", cfg.ACME.AuthURL).And(
fmt.Sprintf("validity_end::timestamp-'%d DAY'::INTERVAL >= now()", cfg.ACME.MaxDaysBefore)).Desc( fmt.Sprintf("validity_end::timestamp-'%d DAY'::INTERVAL >= now()", cfg.ACME.MaxDaysBefore)).Desc(
"id").Get(&Entry) "id").Get(&Entry)
fmt.Println(has, err)
if !has { if !has {
err = fmt.Errorf("entry doesn't exists") err = fmt.Errorf("entry doesn't exists")
} }
@ -66,38 +68,12 @@ func (u *User) HandleRegistration(cfg *config.Config, client *lego.Client) (err
} }
// RequestNewCert returns a newly requested certificate to letsencrypt // RequestNewCert returns a newly requested certificate to letsencrypt
func (u *User) RequestNewCert(cfg *config.Config, domainnames *[]string) (certs *certificate.Resource, err error) { func (u *User) RequestNewCert(cfg *config.Config, domains []string) (certificates *certificate.Resource, err error) {
legoconfig := lego.NewConfig(u) legoconfig := lego.NewConfig(u)
legoconfig.CADirURL = cfg.ACME.AuthURL legoconfig.CADirURL = cfg.ACME.AuthURL
legoconfig.Certificate.KeyType = certcrypto.RSA2048 legoconfig.Certificate.KeyType = certcrypto.RSA2048
var dom domain.Domain ovhprovider, err := initProvider(cfg)
var has bool
for _, d := range *domainnames {
dom = domain.Domain{Domain: d}
if has, err = cfg.Db.Get(&dom); has {
break
}
if err != nil {
log.Println(err)
}
}
if !has {
err = fmt.Errorf("supplied domain not in allowed domains")
return
}
var provider challenge.Provider
switch dom.Provider {
case "ovh":
provider, err = initOVHProvider(cfg)
case "pdns":
provider, err = initPowerDNSProvider(cfg)
default:
return
}
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
@ -107,7 +83,7 @@ func (u *User) RequestNewCert(cfg *config.Config, domainnames *[]string) (certs
log.Println(err) log.Println(err)
} }
err = client.Challenge.SetDNS01Provider(provider) err = client.Challenge.SetDNS01Provider(ovhprovider)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
@ -121,15 +97,14 @@ func (u *User) RequestNewCert(cfg *config.Config, domainnames *[]string) (certs
} }
request := certificate.ObtainRequest{ request := certificate.ObtainRequest{
Domains: *domainnames, Domains: domains,
Bundle: true, Bundle: true,
} }
certs, err = client.Certificate.Obtain(request) certificates, err = client.Certificate.Obtain(request)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
return return
} }

View File

@ -1,16 +1,12 @@
package pki package pki
import ( import (
"log"
"net/url"
"git.paulbsd.com/paulbsd/pki/src/config" "git.paulbsd.com/paulbsd/pki/src/config"
"github.com/go-acme/lego/v4/providers/dns/ovh" "github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/go-acme/lego/v4/providers/dns/pdns"
) )
// initOVHProvider initialize DNS provider configuration // initProvider initialize DNS provider configuration
func initOVHProvider(cfg *config.Config) (ovhprovider *ovh.DNSProvider, err error) { func initProvider(cfg *config.Config) (ovhprovider *ovh.DNSProvider, err error) {
ovhconfig := ovh.NewDefaultConfig() ovhconfig := ovh.NewDefaultConfig()
ovhconfig.APIEndpoint = cfg.ACME.ProviderOptions["ovhendpoint"] ovhconfig.APIEndpoint = cfg.ACME.ProviderOptions["ovhendpoint"]
@ -19,24 +15,6 @@ func initOVHProvider(cfg *config.Config) (ovhprovider *ovh.DNSProvider, err erro
ovhconfig.ConsumerKey = cfg.ACME.ProviderOptions["ovhck"] ovhconfig.ConsumerKey = cfg.ACME.ProviderOptions["ovhck"]
ovhprovider, err = ovh.NewDNSProviderConfig(ovhconfig) ovhprovider, err = ovh.NewDNSProviderConfig(ovhconfig)
if err != nil {
log.Println(err)
}
return
}
// initPowerDNSProvider initialize DNS provider configuration
func initPowerDNSProvider(cfg *config.Config) (pdnsprovider *pdns.DNSProvider, err error) {
pdnsconfig := pdns.NewDefaultConfig()
pdnsconfig.Host, err = url.Parse(cfg.ACME.ProviderOptions["pdnsapiurl"])
pdnsconfig.APIKey = cfg.ACME.ProviderOptions["pdnsapikey"]
pdnsprovider, err = pdns.NewDNSProviderConfig(pdnsconfig)
if err != nil {
log.Println(err)
}
return return
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"strings"
"git.paulbsd.com/paulbsd/pki/src/config" "git.paulbsd.com/paulbsd/pki/src/config"
"git.paulbsd.com/paulbsd/pki/src/pki" "git.paulbsd.com/paulbsd/pki/src/pki"
@ -29,18 +30,13 @@ func RunServer(cfg *config.Config) (err error) {
e.GET("/", func(c echo.Context) error { e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "Welcome to PKI software (https://git.paulbsd.com/paulbsd/pki)") return c.String(http.StatusOK, "Welcome to PKI software (https://git.paulbsd.com/paulbsd/pki)")
}) })
e.POST("/cert", func(c echo.Context) (err error) { e.GET("/domain/:domains", func(c echo.Context) (err error) {
var request = new(EntryRequest) var result EntryResponse
var result = make(map[string]EntryResponse) var domains = strings.Split(c.Param("domains"), ",")
err = c.Bind(&request)
if err != nil {
log.Println(err)
return c.JSON(http.StatusInternalServerError, "error parsing request")
}
log.Printf("Providing %s to user %s at %s\n", request.Domains, c.Get("username"), c.RealIP()) log.Println(fmt.Sprintf("Providing %s to user %s at %s", domains, c.Get("username"), c.RealIP()))
result, err = GetCertificate(cfg, c.Get("user").(*pki.User), &request.Domains) result, err = GetCertificate(cfg, c.Get("user").(*pki.User), domains)
if err != nil { if err != nil {
return c.String(http.StatusInternalServerError, fmt.Sprintf("%s", err)) return c.String(http.StatusInternalServerError, fmt.Sprintf("%s", err))
} }

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"log" "log"
"regexp" "regexp"
"strings"
"time" "time"
"git.paulbsd.com/paulbsd/pki/src/cert" "git.paulbsd.com/paulbsd/pki/src/cert"
@ -13,24 +14,18 @@ import (
"git.paulbsd.com/paulbsd/pki/src/pki" "git.paulbsd.com/paulbsd/pki/src/pki"
) )
const timeformatstring string = "2006-01-02 15:04:05"
var domainRegex, err = regexp.Compile(`^[a-z0-9\*]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,6}$`)
// GetCertificate get certificate from database if exists, of request it from ACME // GetCertificate get certificate from database if exists, of request it from ACME
func GetCertificate(cfg *config.Config, user *pki.User, domains *[]string) (result map[string]EntryResponse, err error) { func GetCertificate(cfg *config.Config, user *pki.User, domains []string) (result EntryResponse, err error) {
err = CheckDomains(domains) err = CheckDomains(domains)
if err != nil { if err != nil {
return result, err return result, err
} }
result = make(map[string]EntryResponse)
firstdomain := (*domains)[0] entry, err := user.GetEntry(cfg, domains)
entry, err := user.GetEntry(cfg, &firstdomain)
if err != nil { if err != nil {
certs, err := user.RequestNewCert(cfg, domains) certs, err := user.RequestNewCert(cfg, domains)
if err != nil { if err != nil {
log.Printf("Error fetching new certificate %s\n", err) log.Println(fmt.Sprintf("Error fetching new certificate %s", err))
return result, err return result, err
} }
NotBefore, NotAfter, err := GetDates(certs.Certificate) NotBefore, NotAfter, err := GetDates(certs.Certificate)
@ -38,36 +33,33 @@ func GetCertificate(cfg *config.Config, user *pki.User, domains *[]string) (resu
log.Println("Error where parsing dates") log.Println("Error where parsing dates")
return result, err return result, err
} }
entry := cert.Entry{Domain: certs.Domain, entry := cert.Entry{Domains: strings.Join(domains, ","),
Certificate: string(certs.Certificate), Certificate: string(certs.Certificate),
PrivateKey: string(certs.PrivateKey), PrivateKey: string(certs.PrivateKey),
ValidityBegin: NotBefore, ValidityBegin: NotBefore,
ValidityEnd: NotAfter, ValidityEnd: NotAfter,
AuthURL: cfg.ACME.AuthURL} AuthURL: cfg.ACME.AuthURL}
cfg.Db.Insert(&entry) cfg.Db.Insert(&entry)
result[firstdomain] = convertEntryToResponse(entry) result = convertEntryToResponse(entry)
return result, err return result, err
} }
result[firstdomain] = convertEntryToResponse(entry) result = convertEntryToResponse(entry)
return return
} }
// CheckDomains check if requested domains are valid // CheckDomains check if requested domains are valid
func CheckDomains(domains *[]string) (err error) { func CheckDomains(domains []string) (err error) {
for _, domain := range *domains { domainRegex, err := regexp.Compile(`^[a-z0-9\*]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,6}$`)
err = CheckDomain(&domain)
if err != nil {
return
}
}
return
}
// CheckDomain check if requested domain are valid if err != nil {
func CheckDomain(domain *string) (err error) { return
res := domainRegex.Match([]byte(*domain)) }
if !res {
return fmt.Errorf("Domain %s has not a valid syntax %s, please verify", *domain, err) for _, d := range domains {
res := domainRegex.Match([]byte(d))
if !res {
return fmt.Errorf(fmt.Sprintf("Domain %s has not a valid syntax %s, please verify", d, err))
}
} }
return return
} }
@ -88,7 +80,9 @@ func GetDates(cert []byte) (NotBefore time.Time, NotAfter time.Time, err error)
// convertEntryToResponse converts database ACME entry to JSON ACME entry // convertEntryToResponse converts database ACME entry to JSON ACME entry
func convertEntryToResponse(in cert.Entry) (out EntryResponse) { func convertEntryToResponse(in cert.Entry) (out EntryResponse) {
out.Domains = append(out.Domains, in.Domain) timeformatstring := "2006-01-02 15:04:05"
out.Domains = in.Domains
out.Certificate = in.Certificate out.Certificate = in.Certificate
out.PrivateKey = in.PrivateKey out.PrivateKey = in.PrivateKey
out.ValidityBegin = in.ValidityBegin.Format(timeformatstring) out.ValidityBegin = in.ValidityBegin.Format(timeformatstring)
@ -97,16 +91,11 @@ func convertEntryToResponse(in cert.Entry) (out EntryResponse) {
return return
} }
// EntryRequest
type EntryRequest struct {
Domains []string `json:"domains"`
}
// EntryResponse is the struct defining JSON response from webservice // EntryResponse is the struct defining JSON response from webservice
type EntryResponse struct { type EntryResponse struct {
Domains []string `json:"domains"` Domains string `json:"domains"`
Certificate string `json:"certificate"` Certificate string `json:"certificate"`
PrivateKey string `json:"privatekey"` PrivateKey string `json:"privatekey"`
ValidityBegin string `json:"validitybegin"` ValidityBegin string `json:"validitybegin"`
ValidityEnd string `json:"validityend"` ValidityEnd string `json:"validityend"`
} }

10
vendor/github.com/cenkalti/backoff/v4/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,10 @@
language: go
go:
- 1.13
- 1.x
- tip
before_install:
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
script:
- $HOME/gopath/bin/goveralls -service=travis-ci

View File

@ -1,4 +1,4 @@
# Exponential Backoff [![GoDoc][godoc image]][godoc] [![Coverage Status][coveralls image]][coveralls] # Exponential Backoff [![GoDoc][godoc image]][godoc] [![Build Status][travis image]][travis] [![Coverage Status][coveralls image]][coveralls]
This is a Go port of the exponential backoff algorithm from [Google's HTTP Client Library for Java][google-http-java-client]. This is a Go port of the exponential backoff algorithm from [Google's HTTP Client Library for Java][google-http-java-client].
@ -21,6 +21,8 @@ Use https://pkg.go.dev/github.com/cenkalti/backoff/v4 to view the documentation.
[godoc]: https://pkg.go.dev/github.com/cenkalti/backoff/v4 [godoc]: https://pkg.go.dev/github.com/cenkalti/backoff/v4
[godoc image]: https://godoc.org/github.com/cenkalti/backoff?status.png [godoc image]: https://godoc.org/github.com/cenkalti/backoff?status.png
[travis]: https://travis-ci.org/cenkalti/backoff
[travis image]: https://travis-ci.org/cenkalti/backoff.png?branch=master
[coveralls]: https://coveralls.io/github/cenkalti/backoff?branch=master [coveralls]: https://coveralls.io/github/cenkalti/backoff?branch=master
[coveralls image]: https://coveralls.io/repos/github/cenkalti/backoff/badge.svg?branch=master [coveralls image]: https://coveralls.io/repos/github/cenkalti/backoff/badge.svg?branch=master

View File

@ -71,9 +71,6 @@ type Clock interface {
Now() time.Time Now() time.Time
} }
// ExponentialBackOffOpts is a function type used to configure ExponentialBackOff options.
type ExponentialBackOffOpts func(*ExponentialBackOff)
// Default values for ExponentialBackOff. // Default values for ExponentialBackOff.
const ( const (
DefaultInitialInterval = 500 * time.Millisecond DefaultInitialInterval = 500 * time.Millisecond
@ -84,7 +81,7 @@ const (
) )
// NewExponentialBackOff creates an instance of ExponentialBackOff using default values. // NewExponentialBackOff creates an instance of ExponentialBackOff using default values.
func NewExponentialBackOff(opts ...ExponentialBackOffOpts) *ExponentialBackOff { func NewExponentialBackOff() *ExponentialBackOff {
b := &ExponentialBackOff{ b := &ExponentialBackOff{
InitialInterval: DefaultInitialInterval, InitialInterval: DefaultInitialInterval,
RandomizationFactor: DefaultRandomizationFactor, RandomizationFactor: DefaultRandomizationFactor,
@ -94,62 +91,10 @@ func NewExponentialBackOff(opts ...ExponentialBackOffOpts) *ExponentialBackOff {
Stop: Stop, Stop: Stop,
Clock: SystemClock, Clock: SystemClock,
} }
for _, fn := range opts {
fn(b)
}
b.Reset() b.Reset()
return b return b
} }
// WithInitialInterval sets the initial interval between retries.
func WithInitialInterval(duration time.Duration) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.InitialInterval = duration
}
}
// WithRandomizationFactor sets the randomization factor to add jitter to intervals.
func WithRandomizationFactor(randomizationFactor float64) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.RandomizationFactor = randomizationFactor
}
}
// WithMultiplier sets the multiplier for increasing the interval after each retry.
func WithMultiplier(multiplier float64) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.Multiplier = multiplier
}
}
// WithMaxInterval sets the maximum interval between retries.
func WithMaxInterval(duration time.Duration) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.MaxInterval = duration
}
}
// WithMaxElapsedTime sets the maximum total time for retries.
func WithMaxElapsedTime(duration time.Duration) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.MaxElapsedTime = duration
}
}
// WithRetryStopDuration sets the duration after which retries should stop.
func WithRetryStopDuration(duration time.Duration) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.Stop = duration
}
}
// WithClockProvider sets the clock used to measure time.
func WithClockProvider(clock Clock) ExponentialBackOffOpts {
return func(ebo *ExponentialBackOff) {
ebo.Clock = clock
}
}
type systemClock struct{} type systemClock struct{}
func (t systemClock) Now() time.Time { func (t systemClock) Now() time.Time {
@ -202,9 +147,6 @@ func (b *ExponentialBackOff) incrementCurrentInterval() {
// Returns a random value from the following interval: // Returns a random value from the following interval:
// [currentInterval - randomizationFactor * currentInterval, currentInterval + randomizationFactor * currentInterval]. // [currentInterval - randomizationFactor * currentInterval, currentInterval + randomizationFactor * currentInterval].
func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration { func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration {
if randomizationFactor == 0 {
return currentInterval // make sure no randomness is used when randomizationFactor is 0.
}
var delta = randomizationFactor * float64(currentInterval) var delta = randomizationFactor * float64(currentInterval)
var minInterval = float64(currentInterval) - delta var minInterval = float64(currentInterval) - delta
var maxInterval = float64(currentInterval) + delta var maxInterval = float64(currentInterval) + delta

View File

@ -5,20 +5,10 @@ import (
"time" "time"
) )
// An OperationWithData is executing by RetryWithData() or RetryNotifyWithData().
// The operation will be retried using a backoff policy if it returns an error.
type OperationWithData[T any] func() (T, error)
// An Operation is executing by Retry() or RetryNotify(). // An Operation is executing by Retry() or RetryNotify().
// The operation will be retried using a backoff policy if it returns an error. // The operation will be retried using a backoff policy if it returns an error.
type Operation func() error type Operation func() error
func (o Operation) withEmptyData() OperationWithData[struct{}] {
return func() (struct{}, error) {
return struct{}{}, o()
}
}
// Notify is a notify-on-error function. It receives an operation error and // Notify is a notify-on-error function. It receives an operation error and
// backoff delay if the operation failed (with an error). // backoff delay if the operation failed (with an error).
// //
@ -38,41 +28,18 @@ func Retry(o Operation, b BackOff) error {
return RetryNotify(o, b, nil) return RetryNotify(o, b, nil)
} }
// RetryWithData is like Retry but returns data in the response too.
func RetryWithData[T any](o OperationWithData[T], b BackOff) (T, error) {
return RetryNotifyWithData(o, b, nil)
}
// RetryNotify calls notify function with the error and wait duration // RetryNotify calls notify function with the error and wait duration
// for each failed attempt before sleep. // for each failed attempt before sleep.
func RetryNotify(operation Operation, b BackOff, notify Notify) error { func RetryNotify(operation Operation, b BackOff, notify Notify) error {
return RetryNotifyWithTimer(operation, b, notify, nil) return RetryNotifyWithTimer(operation, b, notify, nil)
} }
// RetryNotifyWithData is like RetryNotify but returns data in the response too.
func RetryNotifyWithData[T any](operation OperationWithData[T], b BackOff, notify Notify) (T, error) {
return doRetryNotify(operation, b, notify, nil)
}
// RetryNotifyWithTimer calls notify function with the error and wait duration using the given Timer // RetryNotifyWithTimer calls notify function with the error and wait duration using the given Timer
// for each failed attempt before sleep. // for each failed attempt before sleep.
// A default timer that uses system timer is used when nil is passed. // A default timer that uses system timer is used when nil is passed.
func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer) error { func RetryNotifyWithTimer(operation Operation, b BackOff, notify Notify, t Timer) error {
_, err := doRetryNotify(operation.withEmptyData(), b, notify, t) var err error
return err var next time.Duration
}
// RetryNotifyWithTimerAndData is like RetryNotifyWithTimer but returns data in the response too.
func RetryNotifyWithTimerAndData[T any](operation OperationWithData[T], b BackOff, notify Notify, t Timer) (T, error) {
return doRetryNotify(operation, b, notify, t)
}
func doRetryNotify[T any](operation OperationWithData[T], b BackOff, notify Notify, t Timer) (T, error) {
var (
err error
next time.Duration
res T
)
if t == nil { if t == nil {
t = &defaultTimer{} t = &defaultTimer{}
} }
@ -85,22 +52,21 @@ func doRetryNotify[T any](operation OperationWithData[T], b BackOff, notify Noti
b.Reset() b.Reset()
for { for {
res, err = operation() if err = operation(); err == nil {
if err == nil { return nil
return res, nil
} }
var permanent *PermanentError var permanent *PermanentError
if errors.As(err, &permanent) { if errors.As(err, &permanent) {
return res, permanent.Err return permanent.Err
} }
if next = b.NextBackOff(); next == Stop { if next = b.NextBackOff(); next == Stop {
if cerr := ctx.Err(); cerr != nil { if cerr := ctx.Err(); cerr != nil {
return res, cerr return cerr
} }
return res, err return err
} }
if notify != nil { if notify != nil {
@ -111,7 +77,7 @@ func doRetryNotify[T any](operation OperationWithData[T], b BackOff, notify Noti
select { select {
case <-ctx.Done(): case <-ctx.Done():
return res, ctx.Err() return ctx.Err()
case <-t.C(): case <-t.C():
} }
} }

View File

@ -16,7 +16,7 @@ func (a *AccountService) New(req acme.Account) (acme.ExtendedAccount, error) {
resp, err := a.core.post(a.core.GetDirectory().NewAccountURL, req, &account) resp, err := a.core.post(a.core.GetDirectory().NewAccountURL, req, &account)
location := getLocation(resp) location := getLocation(resp)
if location != "" { if len(location) > 0 {
a.core.jws.SetKid(location) a.core.jws.SetKid(location)
} }

View File

@ -2,6 +2,7 @@ package api
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"encoding/json" "encoding/json"
"errors" "errors"
@ -70,7 +71,7 @@ func (a *Core) post(uri string, reqBody, response interface{}) (*http.Response,
} }
// postAsGet performs an HTTP POST ("POST-as-GET") request. // postAsGet performs an HTTP POST ("POST-as-GET") request.
// https://www.rfc-editor.org/rfc/rfc8555.html#section-6.3 // https://tools.ietf.org/html/rfc8555#section-6.3
func (a *Core) postAsGet(uri string, response interface{}) (*http.Response, error) { func (a *Core) postAsGet(uri string, response interface{}) (*http.Response, error) {
return a.retrievablePost(uri, []byte{}, response) return a.retrievablePost(uri, []byte{}, response)
} }
@ -82,6 +83,8 @@ func (a *Core) retrievablePost(uri string, content []byte, response interface{})
bo.MaxInterval = 5 * time.Second bo.MaxInterval = 5 * time.Second
bo.MaxElapsedTime = 20 * time.Second bo.MaxElapsedTime = 20 * time.Second
ctx, cancel := context.WithCancel(context.Background())
var resp *http.Response var resp *http.Response
operation := func() error { operation := func() error {
var err error var err error
@ -93,7 +96,8 @@ func (a *Core) retrievablePost(uri string, content []byte, response interface{})
return err return err
} }
return backoff.Permanent(err) cancel()
return err
} }
return nil return nil
@ -103,7 +107,7 @@ func (a *Core) retrievablePost(uri string, content []byte, response interface{})
log.Infof("retry due to: %v", err) log.Infof("retry due to: %v", err)
} }
err := backoff.RetryNotify(operation, bo, notify) err := backoff.RetryNotify(operation, backoff.WithContext(bo, ctx), notify)
if err != nil { if err != nil {
return resp, err return resp, err
} }
@ -117,7 +121,7 @@ func (a *Core) signedPost(uri string, content []byte, response interface{}) (*ht
return nil, fmt.Errorf("failed to post JWS message: failed to sign content: %w", err) return nil, fmt.Errorf("failed to post JWS message: failed to sign content: %w", err)
} }
signedBody := bytes.NewBufferString(signedContent.FullSerialize()) signedBody := bytes.NewBuffer([]byte(signedContent.FullSerialize()))
resp, err := a.doer.Post(uri, signedBody, "application/jose+json", response) resp, err := a.doer.Post(uri, signedBody, "application/jose+json", response)

View File

@ -1,11 +1,10 @@
package api package api
import ( import (
"bytes"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors" "errors"
"io" "io/ioutil"
"net/http" "net/http"
"github.com/go-acme/lego/v4/acme" "github.com/go-acme/lego/v4/acme"
@ -40,7 +39,7 @@ func (c *CertificateService) GetAll(certURL string, bundle bool) (map[string]*ac
certs := map[string]*acme.RawCertificate{certURL: cert} certs := map[string]*acme.RawCertificate{certURL: cert}
// URLs of "alternate" link relation // URLs of "alternate" link relation
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.4.2 // - https://tools.ietf.org/html/rfc8555#section-7.4.2
alts := getLinks(headers, "alternate") alts := getLinks(headers, "alternate")
for _, alt := range alts { for _, alt := range alts {
@ -72,7 +71,7 @@ func (c *CertificateService) get(certURL string, bundle bool) (*acme.RawCertific
return nil, nil, err return nil, nil, err
} }
data, err := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize)) data, err := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if err != nil { if err != nil {
return nil, resp.Header, err return nil, resp.Header, err
} }
@ -88,17 +87,12 @@ func (c *CertificateService) getCertificateChain(cert []byte, headers http.Heade
// See https://community.letsencrypt.org/t/acme-v2-no-up-link-in-response/64962 // See https://community.letsencrypt.org/t/acme-v2-no-up-link-in-response/64962
_, issuer := pem.Decode(cert) _, issuer := pem.Decode(cert)
if issuer != nil { if issuer != nil {
// If bundle is false, we want to return a single certificate.
// To do this, we remove the issuer cert(s) from the issued cert.
if !bundle {
cert = bytes.TrimSuffix(cert, issuer)
}
return &acme.RawCertificate{Cert: cert, Issuer: issuer} return &acme.RawCertificate{Cert: cert, Issuer: issuer}
} }
// The issuer certificate link may be supplied via an "up" link // The issuer certificate link may be supplied via an "up" link
// in the response headers of a new certificate. // in the response headers of a new certificate.
// See https://www.rfc-editor.org/rfc/rfc8555.html#section-7.4.2 // See https://tools.ietf.org/html/rfc8555#section-7.4.2
up := getLink(headers, "up") up := getLink(headers, "up")
issuer, err := c.getIssuerFromLink(up) issuer, err := c.getIssuerFromLink(up)

View File

@ -63,7 +63,7 @@ func (n *Manager) getNonce() (string, error) {
return GetFromResponse(resp) return GetFromResponse(resp)
} }
// GetFromResponse Extracts a nonce from an HTTP response. // GetFromResponse Extracts a nonce from a HTTP response.
func GetFromResponse(resp *http.Response) (string, error) { func GetFromResponse(resp *http.Response) (string, error) {
if resp == nil { if resp == nil {
return "", errors.New("nil response") return "", errors.New("nil response")

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"github.com/go-acme/lego/v4/acme/api/internal/nonces" "github.com/go-acme/lego/v4/acme/api/internal/nonces"
jose "github.com/go-jose/go-jose/v4" jose "gopkg.in/square/go-jose.v2"
) )
// JWS Represents a JWS. // JWS Represents a JWS.

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"runtime" "runtime"
"strings" "strings"
@ -95,7 +96,7 @@ func (d *Doer) do(req *http.Request, response interface{}) (*http.Response, erro
} }
if response != nil { if response != nil {
raw, err := io.ReadAll(resp.Body) raw, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return resp, err return resp, err
} }
@ -119,7 +120,7 @@ func (d *Doer) formatUserAgent() string {
func checkError(req *http.Request, resp *http.Response) error { func checkError(req *http.Request, resp *http.Response) error {
if resp.StatusCode >= http.StatusBadRequest { if resp.StatusCode >= http.StatusBadRequest {
body, err := io.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("%d :: %s :: %s :: %w", resp.StatusCode, req.Method, req.URL, err) return fmt.Errorf("%d :: %s :: %s :: %w", resp.StatusCode, req.Method, req.URL, err)
} }
@ -133,10 +134,6 @@ func checkError(req *http.Request, resp *http.Response) error {
errorDetails.Method = req.Method errorDetails.Method = req.Method
errorDetails.URL = req.URL.String() errorDetails.URL = req.URL.String()
if errorDetails.HTTPStatus == 0 {
errorDetails.HTTPStatus = resp.StatusCode
}
// Check for errors we handle specifically // Check for errors we handle specifically
if errorDetails.HTTPStatus == http.StatusBadRequest && errorDetails.Type == acme.BadNonceErr { if errorDetails.HTTPStatus == http.StatusBadRequest && errorDetails.Type == acme.BadNonceErr {
return &acme.NonceError{ProblemDetails: errorDetails} return &acme.NonceError{ProblemDetails: errorDetails}

View File

@ -5,7 +5,7 @@ package sender
const ( const (
// ourUserAgent is the User-Agent of this underlying library package. // ourUserAgent is the User-Agent of this underlying library package.
ourUserAgent = "xenolf-acme/4.17.4" ourUserAgent = "xenolf-acme/4.4.0"
// ourUserAgentComment is part of the UA comment linked to the version status of this underlying library package. // ourUserAgentComment is part of the UA comment linked to the version status of this underlying library package.
// values: detach|release // values: detach|release

View File

@ -3,58 +3,21 @@ package api
import ( import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"net"
"time"
"github.com/go-acme/lego/v4/acme" "github.com/go-acme/lego/v4/acme"
) )
// OrderOptions used to create an order (optional).
type OrderOptions struct {
NotBefore time.Time
NotAfter time.Time
// A string uniquely identifying a previously-issued certificate which this
// order is intended to replace.
// - https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5
ReplacesCertID string
}
type OrderService service type OrderService service
// New Creates a new order. // New Creates a new order.
func (o *OrderService) New(domains []string) (acme.ExtendedOrder, error) { func (o *OrderService) New(domains []string) (acme.ExtendedOrder, error) {
return o.NewWithOptions(domains, nil)
}
// NewWithOptions Creates a new order.
func (o *OrderService) NewWithOptions(domains []string, opts *OrderOptions) (acme.ExtendedOrder, error) {
var identifiers []acme.Identifier var identifiers []acme.Identifier
for _, domain := range domains { for _, domain := range domains {
ident := acme.Identifier{Value: domain, Type: "dns"} identifiers = append(identifiers, acme.Identifier{Type: "dns", Value: domain})
if net.ParseIP(domain) != nil {
ident.Type = "ip"
}
identifiers = append(identifiers, ident)
} }
orderReq := acme.Order{Identifiers: identifiers} orderReq := acme.Order{Identifiers: identifiers}
if opts != nil {
if !opts.NotAfter.IsZero() {
orderReq.NotAfter = opts.NotAfter.Format(time.RFC3339)
}
if !opts.NotBefore.IsZero() {
orderReq.NotBefore = opts.NotBefore.Format(time.RFC3339)
}
if o.core.GetDirectory().RenewalInfo != "" {
orderReq.Replaces = opts.ReplacesCertID
}
}
var order acme.Order var order acme.Order
resp, err := o.core.post(o.core.GetDirectory().NewOrderURL, orderReq, &order) resp, err := o.core.post(o.core.GetDirectory().NewOrderURL, orderReq, &order)
if err != nil { if err != nil {

View File

@ -1,28 +0,0 @@
package api
import (
"errors"
"net/http"
)
// ErrNoARI is returned when the server does not advertise a renewal info endpoint.
var ErrNoARI = errors.New("renewalInfo[get/post]: server does not advertise a renewal info endpoint")
// GetRenewalInfo GETs renewal information for a certificate from the renewalInfo endpoint.
// This is used to determine if a certificate needs to be renewed.
//
// Note: this endpoint is part of a draft specification, not all ACME servers will implement it.
// This method will return api.ErrNoARI if the server does not advertise a renewal info endpoint.
//
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari
func (c *CertificateService) GetRenewalInfo(certID string) (*http.Response, error) {
if c.core.GetDirectory().RenewalInfo == "" {
return nil, ErrNoARI
}
if certID == "" {
return nil, errors.New("renewalInfo[get]: 'certID' cannot be empty")
}
return c.core.HTTPClient.Get(c.core.GetDirectory().RenewalInfo + "/" + certID)
}

View File

@ -1,5 +1,5 @@
// Package acme contains all objects related the ACME endpoints. // Package acme contains all objects related the ACME endpoints.
// https://www.rfc-editor.org/rfc/rfc8555.html // https://tools.ietf.org/html/rfc8555
package acme package acme
import ( import (
@ -7,38 +7,20 @@ import (
"time" "time"
) )
// ACME status values of Account, Order, Authorization and Challenge objects. // Challenge statuses.
// See https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.6 for details. // https://tools.ietf.org/html/rfc8555#section-7.1.6
const ( const (
StatusPending = "pending"
StatusInvalid = "invalid"
StatusValid = "valid"
StatusProcessing = "processing"
StatusDeactivated = "deactivated" StatusDeactivated = "deactivated"
StatusExpired = "expired" StatusExpired = "expired"
StatusInvalid = "invalid"
StatusPending = "pending"
StatusProcessing = "processing"
StatusReady = "ready"
StatusRevoked = "revoked" StatusRevoked = "revoked"
StatusUnknown = "unknown"
StatusValid = "valid"
)
// CRL reason codes as defined in RFC 5280.
// https://datatracker.ietf.org/doc/html/rfc5280#section-5.3.1
const (
CRLReasonUnspecified uint = 0
CRLReasonKeyCompromise uint = 1
CRLReasonCACompromise uint = 2
CRLReasonAffiliationChanged uint = 3
CRLReasonSuperseded uint = 4
CRLReasonCessationOfOperation uint = 5
CRLReasonCertificateHold uint = 6
CRLReasonRemoveFromCRL uint = 8
CRLReasonPrivilegeWithdrawn uint = 9
CRLReasonAACompromise uint = 10
) )
// Directory the ACME directory object. // Directory the ACME directory object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.1 // - https://tools.ietf.org/html/rfc8555#section-7.1.1
// - https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
type Directory struct { type Directory struct {
NewNonceURL string `json:"newNonce"` NewNonceURL string `json:"newNonce"`
NewAccountURL string `json:"newAccount"` NewAccountURL string `json:"newAccount"`
@ -47,11 +29,10 @@ type Directory struct {
RevokeCertURL string `json:"revokeCert"` RevokeCertURL string `json:"revokeCert"`
KeyChangeURL string `json:"keyChange"` KeyChangeURL string `json:"keyChange"`
Meta Meta `json:"meta"` Meta Meta `json:"meta"`
RenewalInfo string `json:"renewalInfo"`
} }
// Meta the ACME meta object (related to Directory). // Meta the ACME meta object (related to Directory).
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.1 // - https://tools.ietf.org/html/rfc8555#section-7.1.1
type Meta struct { type Meta struct {
// termsOfService (optional, string): // termsOfService (optional, string):
// A URL identifying the current terms of service. // A URL identifying the current terms of service.
@ -71,12 +52,12 @@ type Meta struct {
// externalAccountRequired (optional, boolean): // externalAccountRequired (optional, boolean):
// If this field is present and set to "true", // If this field is present and set to "true",
// then the CA requires that all new-account requests include an "externalAccountBinding" field // then the CA requires that all new- account requests include an "externalAccountBinding" field
// associating the new account with an external account. // associating the new account with an external account.
ExternalAccountRequired bool `json:"externalAccountRequired"` ExternalAccountRequired bool `json:"externalAccountRequired"`
} }
// ExtendedAccount an extended Account. // ExtendedAccount a extended Account.
type ExtendedAccount struct { type ExtendedAccount struct {
Account Account
// Contains the value of the response header `Location` // Contains the value of the response header `Location`
@ -84,14 +65,14 @@ type ExtendedAccount struct {
} }
// Account the ACME account Object. // Account the ACME account Object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.2 // - https://tools.ietf.org/html/rfc8555#section-7.1.2
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.3 // - https://tools.ietf.org/html/rfc8555#section-7.3
type Account struct { type Account struct {
// status (required, string): // status (required, string):
// The status of this account. // The status of this account.
// Possible values are: "valid", "deactivated", and "revoked". // Possible values are: "valid", "deactivated", and "revoked".
// The value "deactivated" should be used to indicate client-initiated deactivation // The value "deactivated" should be used to indicate client-initiated deactivation
// whereas "revoked" should be used to indicate server-initiated deactivation. (See Section 7.1.6) // whereas "revoked" should be used to indicate server- initiated deactivation. (See Section 7.1.6)
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
// contact (optional, array of string): // contact (optional, array of string):
@ -131,7 +112,7 @@ type ExtendedOrder struct {
} }
// Order the ACME order Object. // Order the ACME order Object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.3 // - https://tools.ietf.org/html/rfc8555#section-7.1.3
type Order struct { type Order struct {
// status (required, string): // status (required, string):
// The status of this order. // The status of this order.
@ -181,16 +162,10 @@ type Order struct {
// certificate (optional, string): // certificate (optional, string):
// A URL for the certificate that has been issued in response to this order // A URL for the certificate that has been issued in response to this order
Certificate string `json:"certificate,omitempty"` Certificate string `json:"certificate,omitempty"`
// replaces (optional, string):
// replaces (string, optional): A string uniquely identifying a
// previously-issued certificate which this order is intended to replace.
// - https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5
Replaces string `json:"replaces,omitempty"`
} }
// Authorization the ACME authorization object. // Authorization the ACME authorization object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.4 // - https://tools.ietf.org/html/rfc8555#section-7.1.4
type Authorization struct { type Authorization struct {
// status (required, string): // status (required, string):
// The status of this authorization. // The status of this authorization.
@ -232,8 +207,8 @@ type ExtendedChallenge struct {
} }
// Challenge the ACME challenge object. // Challenge the ACME challenge object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.5 // - https://tools.ietf.org/html/rfc8555#section-7.1.5
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-8 // - https://tools.ietf.org/html/rfc8555#section-8
type Challenge struct { type Challenge struct {
// type (required, string): // type (required, string):
// The type of challenge encoded in the object. // The type of challenge encoded in the object.
@ -266,23 +241,23 @@ type Challenge struct {
// It MUST NOT contain any characters outside the base64url alphabet, // It MUST NOT contain any characters outside the base64url alphabet,
// and MUST NOT include base64 padding characters ("="). // and MUST NOT include base64 padding characters ("=").
// See [RFC4086] for additional information on randomness requirements. // See [RFC4086] for additional information on randomness requirements.
// https://www.rfc-editor.org/rfc/rfc8555.html#section-8.3 // https://tools.ietf.org/html/rfc8555#section-8.3
// https://www.rfc-editor.org/rfc/rfc8555.html#section-8.4 // https://tools.ietf.org/html/rfc8555#section-8.4
Token string `json:"token"` Token string `json:"token"`
// https://www.rfc-editor.org/rfc/rfc8555.html#section-8.1 // https://tools.ietf.org/html/rfc8555#section-8.1
KeyAuthorization string `json:"keyAuthorization"` KeyAuthorization string `json:"keyAuthorization"`
} }
// Identifier the ACME identifier object. // Identifier the ACME identifier object.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-9.7.7 // - https://tools.ietf.org/html/rfc8555#section-9.7.7
type Identifier struct { type Identifier struct {
Type string `json:"type"` Type string `json:"type"`
Value string `json:"value"` Value string `json:"value"`
} }
// CSRMessage Certificate Signing Request. // CSRMessage Certificate Signing Request.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.4 // - https://tools.ietf.org/html/rfc8555#section-7.4
type CSRMessage struct { type CSRMessage struct {
// csr (required, string): // csr (required, string):
// A CSR encoding the parameters for the certificate being requested [RFC2986]. // A CSR encoding the parameters for the certificate being requested [RFC2986].
@ -292,8 +267,8 @@ type CSRMessage struct {
} }
// RevokeCertMessage a certificate revocation message. // RevokeCertMessage a certificate revocation message.
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.6 // - https://tools.ietf.org/html/rfc8555#section-7.6
// - https://www.rfc-editor.org/rfc/rfc5280.html#section-5.3.1 // - https://tools.ietf.org/html/rfc5280#section-5.3.1
type RevokeCertMessage struct { type RevokeCertMessage struct {
// certificate (required, string): // certificate (required, string):
// The certificate to be revoked, in the base64url-encoded version of the DER format. // The certificate to be revoked, in the base64url-encoded version of the DER format.
@ -314,36 +289,3 @@ type RawCertificate struct {
Cert []byte Cert []byte
Issuer []byte Issuer []byte
} }
// Window is a window of time.
type Window struct {
Start time.Time `json:"start"`
End time.Time `json:"end"`
}
// RenewalInfoResponse is the response to GET requests made the renewalInfo endpoint.
// - (4.1. Getting Renewal Information) https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
type RenewalInfoResponse struct {
// SuggestedWindow contains two fields, start and end,
// whose values are timestamps which bound the window of time in which the CA recommends renewing the certificate.
SuggestedWindow Window `json:"suggestedWindow"`
// ExplanationURL is an optional URL pointing to a page which may explain why the suggested renewal window is what it is.
// For example, it may be a page explaining the CA's dynamic load-balancing strategy,
// or a page documenting which certificates are affected by a mass revocation event.
// Callers SHOULD provide this URL to their operator, if present.
ExplanationURL string `json:"explanationURL"`
}
// RenewalInfoUpdateRequest is the JWS payload for POST requests made to the renewalInfo endpoint.
// - (4.2. RenewalInfo Objects) https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-4.2
type RenewalInfoUpdateRequest struct {
// CertID is a composite string in the format: base64url(AKI) || '.' || base64url(Serial), where AKI is the
// certificate's authority key identifier and Serial is the certificate's serial number. For details, see:
// https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-4.1
CertID string `json:"certID"`
// Replaced is required and indicates whether or not the client considers the certificate to have been replaced.
// A certificate is considered replaced when its revocation would not disrupt any ongoing services,
// for instance because it has been renewed and the new certificate is in use, or because it is no longer in use.
// Clients SHOULD NOT send a request where this value is false.
Replaced bool `json:"replaced"`
}

View File

@ -11,8 +11,8 @@ const (
) )
// ProblemDetails the problem details object. // ProblemDetails the problem details object.
// - https://www.rfc-editor.org/rfc/rfc7807.html#section-3.1 // - https://tools.ietf.org/html/rfc7807#section-3.1
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-7.3.3 // - https://tools.ietf.org/html/rfc8555#section-7.3.3
type ProblemDetails struct { type ProblemDetails struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
Detail string `json:"detail,omitempty"` Detail string `json:"detail,omitempty"`
@ -26,7 +26,7 @@ type ProblemDetails struct {
} }
// SubProblem a "subproblems". // SubProblem a "subproblems".
// - https://www.rfc-editor.org/rfc/rfc8555.html#section-6.7.1 // - https://tools.ietf.org/html/rfc8555#section-6.7.1
type SubProblem struct { type SubProblem struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
Detail string `json:"detail,omitempty"` Detail string `json:"detail,omitempty"`

View File

@ -14,8 +14,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"net"
"slices"
"strings" "strings"
"time" "time"
@ -27,7 +25,6 @@ const (
EC256 = KeyType("P256") EC256 = KeyType("P256")
EC384 = KeyType("P384") EC384 = KeyType("P384")
RSA2048 = KeyType("2048") RSA2048 = KeyType("2048")
RSA3072 = KeyType("3072")
RSA4096 = KeyType("4096") RSA4096 = KeyType("4096")
RSA8192 = KeyType("8192") RSA8192 = KeyType("8192")
) )
@ -85,12 +82,9 @@ func ParsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
// ParsePEMPrivateKey parses a private key from key, which is a PEM block. // ParsePEMPrivateKey parses a private key from key, which is a PEM block.
// Borrowed from Go standard library, to handle various private key and PEM block types. // Borrowed from Go standard library, to handle various private key and PEM block types.
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L291-L308 // https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L291-L308
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238 // https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238)
func ParsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) { func ParsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) {
keyBlockDER, _ := pem.Decode(key) keyBlockDER, _ := pem.Decode(key)
if keyBlockDER == nil {
return nil, errors.New("invalid PEM block")
}
if keyBlockDER.Type != "PRIVATE KEY" && !strings.HasSuffix(keyBlockDER.Type, " PRIVATE KEY") { if keyBlockDER.Type != "PRIVATE KEY" && !strings.HasSuffix(keyBlockDER.Type, " PRIVATE KEY") {
return nil, fmt.Errorf("unknown PEM header %q", keyBlockDER.Type) return nil, fmt.Errorf("unknown PEM header %q", keyBlockDER.Type)
@ -124,8 +118,6 @@ func GeneratePrivateKey(keyType KeyType) (crypto.PrivateKey, error) {
return ecdsa.GenerateKey(elliptic.P384(), rand.Reader) return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case RSA2048: case RSA2048:
return rsa.GenerateKey(rand.Reader, 2048) return rsa.GenerateKey(rand.Reader, 2048)
case RSA3072:
return rsa.GenerateKey(rand.Reader, 3072)
case RSA4096: case RSA4096:
return rsa.GenerateKey(rand.Reader, 4096) return rsa.GenerateKey(rand.Reader, 4096)
case RSA8192: case RSA8192:
@ -136,20 +128,9 @@ func GeneratePrivateKey(keyType KeyType) (crypto.PrivateKey, error) {
} }
func GenerateCSR(privateKey crypto.PrivateKey, domain string, san []string, mustStaple bool) ([]byte, error) { func GenerateCSR(privateKey crypto.PrivateKey, domain string, san []string, mustStaple bool) ([]byte, error) {
var dnsNames []string
var ipAddresses []net.IP
for _, altname := range san {
if ip := net.ParseIP(altname); ip != nil {
ipAddresses = append(ipAddresses, ip)
} else {
dnsNames = append(dnsNames, altname)
}
}
template := x509.CertificateRequest{ template := x509.CertificateRequest{
Subject: pkix.Name{CommonName: domain}, Subject: pkix.Name{CommonName: domain},
DNSNames: dnsNames, DNSNames: san,
IPAddresses: ipAddresses,
} }
if mustStaple { if mustStaple {
@ -217,26 +198,6 @@ func ParsePEMCertificate(cert []byte) (*x509.Certificate, error) {
return x509.ParseCertificate(pemBlock.Bytes) return x509.ParseCertificate(pemBlock.Bytes)
} }
func GetCertificateMainDomain(cert *x509.Certificate) (string, error) {
return getMainDomain(cert.Subject, cert.DNSNames)
}
func GetCSRMainDomain(cert *x509.CertificateRequest) (string, error) {
return getMainDomain(cert.Subject, cert.DNSNames)
}
func getMainDomain(subject pkix.Name, dnsNames []string) (string, error) {
if subject.CommonName == "" && len(dnsNames) == 0 {
return "", errors.New("missing domain")
}
if subject.CommonName != "" {
return subject.CommonName, nil
}
return dnsNames[0], nil
}
func ExtractDomains(cert *x509.Certificate) []string { func ExtractDomains(cert *x509.Certificate) []string {
var domains []string var domains []string
if cert.Subject.CommonName != "" { if cert.Subject.CommonName != "" {
@ -251,13 +212,6 @@ func ExtractDomains(cert *x509.Certificate) []string {
domains = append(domains, sanDomain) domains = append(domains, sanDomain)
} }
commonNameIP := net.ParseIP(cert.Subject.CommonName)
for _, sanIP := range cert.IPAddresses {
if !commonNameIP.Equal(sanIP) {
domains = append(domains, sanIP.String())
}
}
return domains return domains
} }
@ -269,7 +223,7 @@ func ExtractDomainsCSR(csr *x509.CertificateRequest) []string {
// loop over the SubjectAltName DNS names // loop over the SubjectAltName DNS names
for _, sanName := range csr.DNSNames { for _, sanName := range csr.DNSNames {
if slices.Contains(domains, sanName) { if containsSAN(domains, sanName) {
// Duplicate; skip this name // Duplicate; skip this name
continue continue
} }
@ -278,14 +232,16 @@ func ExtractDomainsCSR(csr *x509.CertificateRequest) []string {
domains = append(domains, sanName) domains = append(domains, sanName)
} }
cnip := net.ParseIP(csr.Subject.CommonName) return domains
for _, sanIP := range csr.IPAddresses { }
if !cnip.Equal(sanIP) {
domains = append(domains, sanIP.String()) func containsSAN(domains []string, sanName string) bool {
for _, existingName := range domains {
if existingName == sanName {
return true
} }
} }
return false
return domains
} }
func GeneratePemCert(privateKey *rsa.PrivateKey, domain string, extensions []pkix.Extension) ([]byte, error) { func GeneratePemCert(privateKey *rsa.PrivateKey, domain string, extensions []pkix.Extension) ([]byte, error) {
@ -305,7 +261,7 @@ func generateDerCert(privateKey *rsa.PrivateKey, expiration time.Time, domain st
} }
if expiration.IsZero() { if expiration.IsZero() {
expiration = time.Now().AddDate(1, 0, 0) expiration = time.Now().Add(365)
} }
template := x509.Certificate{ template := x509.Certificate{
@ -318,15 +274,9 @@ func generateDerCert(privateKey *rsa.PrivateKey, expiration time.Time, domain st
KeyUsage: x509.KeyUsageKeyEncipherment, KeyUsage: x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true, BasicConstraintsValid: true,
DNSNames: []string{domain},
ExtraExtensions: extensions, ExtraExtensions: extensions,
} }
// handling SAN filling as type suspected
if ip := net.ParseIP(domain); ip != nil {
template.IPAddresses = []net.IP{ip}
} else {
template.DNSNames = []string{domain}
}
return x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) return x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
} }

View File

@ -12,7 +12,6 @@ const (
// limited on the "new-reg", "new-authz" and "new-cert" endpoints. // limited on the "new-reg", "new-authz" and "new-cert" endpoints.
// From the documentation the limitation is 20 requests per second, // From the documentation the limitation is 20 requests per second,
// but using 20 as value doesn't work but 18 do. // but using 20 as value doesn't work but 18 do.
// https://letsencrypt.org/docs/rate-limits/
overallRequestLimit = 18 overallRequestLimit = 18
) )
@ -36,14 +35,13 @@ func (c *Certifier) getAuthorizations(order acme.ExtendedOrder) ([]acme.Authoriz
} }
var responses []acme.Authorization var responses []acme.Authorization
failures := make(obtainError)
failures := newObtainError() for i := 0; i < len(order.Authorizations); i++ {
for range len(order.Authorizations) {
select { select {
case res := <-resc: case res := <-resc:
responses = append(responses, res) responses = append(responses, res)
case err := <-errc: case err := <-errc:
failures.Add(err.Domain, err.Error) failures[err.Domain] = err.Error
} }
} }
@ -54,10 +52,15 @@ func (c *Certifier) getAuthorizations(order acme.ExtendedOrder) ([]acme.Authoriz
close(resc) close(resc)
close(errc) close(errc)
return responses, failures.Join() // be careful to not return an empty failures map;
// even if empty, they become non-nil error values
if len(failures) > 0 {
return responses, failures
}
return responses, nil
} }
func (c *Certifier) deactivateAuthorizations(order acme.ExtendedOrder, force bool) { func (c *Certifier) deactivateAuthorizations(order acme.ExtendedOrder) {
for _, authzURL := range order.Authorizations { for _, authzURL := range order.Authorizations {
auth, err := c.core.Authorizations.Get(authzURL) auth, err := c.core.Authorizations.Get(authzURL)
if err != nil { if err != nil {
@ -65,7 +68,7 @@ func (c *Certifier) deactivateAuthorizations(order acme.ExtendedOrder, force boo
continue continue
} }
if auth.Status == acme.StatusValid && !force { if auth.Status == acme.StatusValid {
log.Infof("Skipping deactivating of valid auth: %s", authzURL) log.Infof("Skipping deactivating of valid auth: %s", authzURL)
continue continue
} }

View File

@ -7,7 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -49,44 +49,22 @@ type Resource struct {
// If you do not want that you can supply your own private key in the privateKey parameter. // If you do not want that you can supply your own private key in the privateKey parameter.
// If this parameter is non-nil it will be used instead of generating a new one. // If this parameter is non-nil it will be used instead of generating a new one.
// //
// If `Bundle` is true, the `[]byte` contains both the issuer certificate and your issued certificate as a bundle. // If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
//
// If `AlwaysDeactivateAuthorizations` is true, the authorizations are also relinquished if the obtain request was successful.
// See https://datatracker.ietf.org/doc/html/rfc8555#section-7.5.2.
type ObtainRequest struct { type ObtainRequest struct {
Domains []string Domains []string
PrivateKey crypto.PrivateKey Bundle bool
MustStaple bool PrivateKey crypto.PrivateKey
MustStaple bool
NotBefore time.Time PreferredChain string
NotAfter time.Time
Bundle bool
PreferredChain string
AlwaysDeactivateAuthorizations bool
// A string uniquely identifying a previously-issued certificate which this
// order is intended to replace.
// - https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5
ReplacesCertID string
} }
// ObtainForCSRRequest The request to obtain a certificate matching the CSR passed into it. // ObtainForCSRRequest The request to obtain a certificate matching the CSR passed into it.
// //
// If `Bundle` is true, the `[]byte` contains both the issuer certificate and your issued certificate as a bundle. // If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
//
// If `AlwaysDeactivateAuthorizations` is true, the authorizations are also relinquished if the obtain request was successful.
// See https://datatracker.ietf.org/doc/html/rfc8555#section-7.5.2.
type ObtainForCSRRequest struct { type ObtainForCSRRequest struct {
CSR *x509.CertificateRequest CSR *x509.CertificateRequest
Bundle bool
NotBefore time.Time PreferredChain string
NotAfter time.Time
Bundle bool
PreferredChain string
AlwaysDeactivateAuthorizations bool
// A string uniquely identifying a previously-issued certificate which this
// order is intended to replace.
// - https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5
ReplacesCertID string
} }
type resolver interface { type resolver interface {
@ -131,13 +109,7 @@ func (c *Certifier) Obtain(request ObtainRequest) (*Resource, error) {
log.Infof("[%s] acme: Obtaining SAN certificate", strings.Join(domains, ", ")) log.Infof("[%s] acme: Obtaining SAN certificate", strings.Join(domains, ", "))
} }
orderOpts := &api.OrderOptions{ order, err := c.core.Orders.New(domains)
NotBefore: request.NotBefore,
NotAfter: request.NotAfter,
ReplacesCertID: request.ReplacesCertID,
}
order, err := c.core.Orders.NewWithOptions(domains, orderOpts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,32 +117,33 @@ func (c *Certifier) Obtain(request ObtainRequest) (*Resource, error) {
authz, err := c.getAuthorizations(order) authz, err := c.getAuthorizations(order)
if err != nil { if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates. // If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order, request.AlwaysDeactivateAuthorizations) c.deactivateAuthorizations(order)
return nil, err return nil, err
} }
err = c.resolver.Solve(authz) err = c.resolver.Solve(authz)
if err != nil { if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates. // If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order, request.AlwaysDeactivateAuthorizations) c.deactivateAuthorizations(order)
return nil, err return nil, err
} }
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", ")) log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := newObtainError() failures := make(obtainError)
cert, err := c.getForOrder(domains, order, request.Bundle, request.PrivateKey, request.MustStaple, request.PreferredChain) cert, err := c.getForOrder(domains, order, request.Bundle, request.PrivateKey, request.MustStaple, request.PreferredChain)
if err != nil { if err != nil {
for _, auth := range authz { for _, auth := range authz {
failures.Add(challenge.GetTargetedDomain(auth), err) failures[challenge.GetTargetedDomain(auth)] = err
} }
} }
if request.AlwaysDeactivateAuthorizations { // Do not return an empty failures map, because
c.deactivateAuthorizations(order, true) // it would still be a non-nil error value
if len(failures) > 0 {
return cert, failures
} }
return cert, nil
return cert, failures.Join()
} }
// ObtainForCSR tries to obtain a certificate matching the CSR passed into it. // ObtainForCSR tries to obtain a certificate matching the CSR passed into it.
@ -197,13 +170,7 @@ func (c *Certifier) ObtainForCSR(request ObtainForCSRRequest) (*Resource, error)
log.Infof("[%s] acme: Obtaining SAN certificate given a CSR", strings.Join(domains, ", ")) log.Infof("[%s] acme: Obtaining SAN certificate given a CSR", strings.Join(domains, ", "))
} }
orderOpts := &api.OrderOptions{ order, err := c.core.Orders.New(domains)
NotBefore: request.NotBefore,
NotAfter: request.NotAfter,
ReplacesCertID: request.ReplacesCertID,
}
order, err := c.core.Orders.NewWithOptions(domains, orderOpts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -211,37 +178,38 @@ func (c *Certifier) ObtainForCSR(request ObtainForCSRRequest) (*Resource, error)
authz, err := c.getAuthorizations(order) authz, err := c.getAuthorizations(order)
if err != nil { if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates. // If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order, request.AlwaysDeactivateAuthorizations) c.deactivateAuthorizations(order)
return nil, err return nil, err
} }
err = c.resolver.Solve(authz) err = c.resolver.Solve(authz)
if err != nil { if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates. // If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order, request.AlwaysDeactivateAuthorizations) c.deactivateAuthorizations(order)
return nil, err return nil, err
} }
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", ")) log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := newObtainError() failures := make(obtainError)
cert, err := c.getForCSR(domains, order, request.Bundle, request.CSR.Raw, nil, request.PreferredChain) cert, err := c.getForCSR(domains, order, request.Bundle, request.CSR.Raw, nil, request.PreferredChain)
if err != nil { if err != nil {
for _, auth := range authz { for _, auth := range authz {
failures.Add(challenge.GetTargetedDomain(auth), err) failures[challenge.GetTargetedDomain(auth)] = err
} }
} }
if request.AlwaysDeactivateAuthorizations {
c.deactivateAuthorizations(order, true)
}
if cert != nil { if cert != nil {
// Add the CSR to the certificate so that it can be used for renewals. // Add the CSR to the certificate so that it can be used for renewals.
cert.CSR = certcrypto.PEMEncode(request.CSR) cert.CSR = certcrypto.PEMEncode(request.CSR)
} }
return cert, failures.Join() // Do not return an empty failures map,
// because it would still be a non-nil error value
if len(failures) > 0 {
return cert, failures
}
return cert, nil
} }
func (c *Certifier) getForOrder(domains []string, order acme.ExtendedOrder, bundle bool, privateKey crypto.PrivateKey, mustStaple bool, preferredChain string) (*Resource, error) { func (c *Certifier) getForOrder(domains []string, order acme.ExtendedOrder, bundle bool, privateKey crypto.PrivateKey, mustStaple bool, preferredChain string) (*Resource, error) {
@ -253,23 +221,16 @@ func (c *Certifier) getForOrder(domains []string, order acme.ExtendedOrder, bund
} }
} }
commonName := "" // Determine certificate name(s) based on the authorization resources
if len(domains[0]) <= 64 { commonName := domains[0]
commonName = domains[0]
}
// RFC8555 Section 7.4 "Applying for Certificate Issuance" // RFC8555 Section 7.4 "Applying for Certificate Issuance"
// https://www.rfc-editor.org/rfc/rfc8555.html#section-7.4 // https://tools.ietf.org/html/rfc8555#section-7.4
// says: // says:
// Clients SHOULD NOT make any assumptions about the sort order of // Clients SHOULD NOT make any assumptions about the sort order of
// "identifiers" or "authorizations" elements in the returned order // "identifiers" or "authorizations" elements in the returned order
// object. // object.
san := []string{commonName}
var san []string
if commonName != "" {
san = append(san, commonName)
}
for _, auth := range order.Identifiers { for _, auth := range order.Identifiers {
if auth.Value != commonName { if auth.Value != commonName {
san = append(san, auth.Value) san = append(san, auth.Value)
@ -291,14 +252,15 @@ func (c *Certifier) getForCSR(domains []string, order acme.ExtendedOrder, bundle
return nil, err return nil, err
} }
commonName := domains[0]
certRes := &Resource{ certRes := &Resource{
Domain: domains[0], Domain: commonName,
CertURL: respOrder.Certificate, CertURL: respOrder.Certificate,
PrivateKey: privateKeyPem, PrivateKey: privateKeyPem,
} }
if respOrder.Status == acme.StatusValid { if respOrder.Status == acme.StatusValid {
// if the certificate is available right away, shortcut! // if the certificate is available right away, short cut!
ok, errR := c.checkResponse(respOrder, certRes, bundle, preferredChain) ok, errR := c.checkResponse(respOrder, certRes, bundle, preferredChain)
if errR != nil { if errR != nil {
return nil, errR return nil, errR
@ -387,11 +349,6 @@ func (c *Certifier) checkResponse(order acme.ExtendedOrder, certRes *Resource, b
// Revoke takes a PEM encoded certificate or bundle and tries to revoke it at the CA. // Revoke takes a PEM encoded certificate or bundle and tries to revoke it at the CA.
func (c *Certifier) Revoke(cert []byte) error { func (c *Certifier) Revoke(cert []byte) error {
return c.RevokeWithReason(cert, nil)
}
// RevokeWithReason takes a PEM encoded certificate or bundle and tries to revoke it at the CA.
func (c *Certifier) RevokeWithReason(cert []byte, reason *uint) error {
certificates, err := certcrypto.ParsePEMBundle(cert) certificates, err := certcrypto.ParsePEMBundle(cert)
if err != nil { if err != nil {
return err return err
@ -404,24 +361,11 @@ func (c *Certifier) RevokeWithReason(cert []byte, reason *uint) error {
revokeMsg := acme.RevokeCertMessage{ revokeMsg := acme.RevokeCertMessage{
Certificate: base64.RawURLEncoding.EncodeToString(x509Cert.Raw), Certificate: base64.RawURLEncoding.EncodeToString(x509Cert.Raw),
Reason: reason,
} }
return c.core.Certificates.Revoke(revokeMsg) return c.core.Certificates.Revoke(revokeMsg)
} }
// RenewOptions options used by Certifier.RenewWithOptions.
type RenewOptions struct {
NotBefore time.Time
NotAfter time.Time
// If true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
Bundle bool
PreferredChain string
AlwaysDeactivateAuthorizations bool
// Not supported for CSR request.
MustStaple bool
}
// Renew takes a Resource and tries to renew the certificate. // Renew takes a Resource and tries to renew the certificate.
// //
// If the renewal process succeeds, the new certificate will be returned in a new CertResource. // If the renewal process succeeds, the new certificate will be returned in a new CertResource.
@ -432,26 +376,7 @@ type RenewOptions struct {
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle. // If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
// //
// For private key reuse the PrivateKey property of the passed in Resource should be non-nil. // For private key reuse the PrivateKey property of the passed in Resource should be non-nil.
// Deprecated: use RenewWithOptions instead.
func (c *Certifier) Renew(certRes Resource, bundle, mustStaple bool, preferredChain string) (*Resource, error) { func (c *Certifier) Renew(certRes Resource, bundle, mustStaple bool, preferredChain string) (*Resource, error) {
return c.RenewWithOptions(certRes, &RenewOptions{
Bundle: bundle,
PreferredChain: preferredChain,
MustStaple: mustStaple,
})
}
// RenewWithOptions takes a Resource and tries to renew the certificate.
//
// If the renewal process succeeds, the new certificate will be returned in a new CertResource.
// Please be aware that this function will return a new certificate in ANY case that is not an error.
// If the server does not provide us with a new cert on a GET request to the CertURL
// this function will start a new-cert flow where a new certificate gets generated.
//
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
//
// For private key reuse the PrivateKey property of the passed in Resource should be non-nil.
func (c *Certifier) RenewWithOptions(certRes Resource, options *RenewOptions) (*Resource, error) {
// Input certificate is PEM encoded. // Input certificate is PEM encoded.
// Decode it here as we may need the decoded cert later on in the renewal process. // Decode it here as we may need the decoded cert later on in the renewal process.
// The input may be a bundle or a single certificate. // The input may be a bundle or a single certificate.
@ -478,17 +403,11 @@ func (c *Certifier) RenewWithOptions(certRes Resource, options *RenewOptions) (*
return nil, errP return nil, errP
} }
request := ObtainForCSRRequest{CSR: csr} return c.ObtainForCSR(ObtainForCSRRequest{
CSR: csr,
if options != nil { Bundle: bundle,
request.NotBefore = options.NotBefore PreferredChain: preferredChain,
request.NotAfter = options.NotAfter })
request.Bundle = options.Bundle
request.PreferredChain = options.PreferredChain
request.AlwaysDeactivateAuthorizations = options.AlwaysDeactivateAuthorizations
}
return c.ObtainForCSR(request)
} }
var privateKey crypto.PrivateKey var privateKey crypto.PrivateKey
@ -499,21 +418,13 @@ func (c *Certifier) RenewWithOptions(certRes Resource, options *RenewOptions) (*
} }
} }
request := ObtainRequest{ query := ObtainRequest{
Domains: certcrypto.ExtractDomains(x509Cert), Domains: certcrypto.ExtractDomains(x509Cert),
Bundle: bundle,
PrivateKey: privateKey, PrivateKey: privateKey,
MustStaple: mustStaple,
} }
return c.Obtain(query)
if options != nil {
request.MustStaple = options.MustStaple
request.NotBefore = options.NotBefore
request.NotAfter = options.NotAfter
request.Bundle = options.Bundle
request.PreferredChain = options.PreferredChain
request.AlwaysDeactivateAuthorizations = options.AlwaysDeactivateAuthorizations
}
return c.Obtain(request)
} }
// GetOCSP takes a PEM encoded cert or cert bundle returning the raw OCSP response, // GetOCSP takes a PEM encoded cert or cert bundle returning the raw OCSP response,
@ -554,7 +465,7 @@ func (c *Certifier) GetOCSP(bundle []byte) ([]byte, *ocsp.Response, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
issuerBytes, errC := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize)) issuerBytes, errC := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if errC != nil { if errC != nil {
return nil, nil, errC return nil, nil, errC
} }
@ -583,7 +494,7 @@ func (c *Certifier) GetOCSP(bundle []byte) ([]byte, *ocsp.Response, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
ocspResBytes, err := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize)) ocspResBytes, err := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -614,13 +525,8 @@ func (c *Certifier) Get(url string, bundle bool) (*Resource, error) {
return nil, err return nil, err
} }
domain, err := certcrypto.GetCertificateMainDomain(x509Certs[0])
if err != nil {
return nil, err
}
return &Resource{ return &Resource{
Domain: domain, Domain: x509Certs[0].Subject.CommonName,
Certificate: cert, Certificate: cert,
IssuerCertificate: issuer, IssuerCertificate: issuer,
CertURL: url, CertURL: url,
@ -654,11 +560,11 @@ func checkOrderStatus(order acme.ExtendedOrder) (bool, error) {
} }
} }
// https://www.rfc-editor.org/rfc/rfc8555.html#section-7.1.4 // https://tools.ietf.org/html/rfc8555#section-7.1.4
// The domain name MUST be encoded in the form in which it would appear in a certificate. // The domain name MUST be encoded in the form in which it would appear in a certificate.
// That is, it MUST be encoded according to the rules in Section 7 of [RFC5280]. // That is, it MUST be encoded according to the rules in Section 7 of [RFC5280].
// //
// https://www.rfc-editor.org/rfc/rfc5280.html#section-7 // https://tools.ietf.org/html/rfc5280#section-7
func sanitizeDomain(domains []string) []string { func sanitizeDomain(domains []string) []string {
var sanitizedDomains []string var sanitizedDomains []string
for _, domain := range domains { for _, domain := range domains {

View File

@ -1,37 +1,27 @@
package certificate package certificate
import ( import (
"errors" "bytes"
"fmt" "fmt"
"sort"
) )
type obtainError struct { // obtainError is returned when there are specific errors available per domain.
data map[string]error type obtainError map[string]error
}
func newObtainError() *obtainError { func (e obtainError) Error() string {
return &obtainError{data: make(map[string]error)} buffer := bytes.NewBufferString("error: one or more domains had a problem:\n")
}
func (e *obtainError) Add(domain string, err error) { var domains []string
e.data[domain] = err for domain := range e {
} domains = append(domains, domain)
func (e *obtainError) Join() error {
if e == nil {
return nil
} }
sort.Strings(domains)
if len(e.data) == 0 { for _, domain := range domains {
return nil buffer.WriteString(fmt.Sprintf("[%s] %s\n", domain, e[domain]))
} }
return buffer.String()
var err error
for d, e := range e.data {
err = errors.Join(err, fmt.Errorf("%s: %w", d, e))
}
return fmt.Errorf("error: one or more domains had a problem:\n%w", err)
} }
type domainError struct { type domainError struct {

View File

@ -1,129 +0,0 @@
package certificate
import (
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/rand"
"time"
"github.com/go-acme/lego/v4/acme"
)
// RenewalInfoRequest contains the necessary renewal information.
type RenewalInfoRequest struct {
Cert *x509.Certificate
}
// RenewalInfoResponse is a wrapper around acme.RenewalInfoResponse that provides a method for determining when to renew a certificate.
type RenewalInfoResponse struct {
acme.RenewalInfoResponse
// RetryAfter header indicating the polling interval that the ACME server recommends.
// Conforming clients SHOULD query the renewalInfo URL again after the RetryAfter period has passed,
// as the server may provide a different suggestedWindow.
// https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-4.2
RetryAfter time.Duration
}
// ShouldRenewAt determines the optimal renewal time based on the current time (UTC),renewal window suggest by ARI, and the client's willingness to sleep.
// It returns a pointer to a time.Time value indicating when the renewal should be attempted or nil if deferred until the next normal wake time.
// This method implements the RECOMMENDED algorithm described in draft-ietf-acme-ari.
//
// - (4.1-11. Getting Renewal Information) https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
func (r *RenewalInfoResponse) ShouldRenewAt(now time.Time, willingToSleep time.Duration) *time.Time {
// Explicitly convert all times to UTC.
now = now.UTC()
start := r.SuggestedWindow.Start.UTC()
end := r.SuggestedWindow.End.UTC()
// Select a uniform random time within the suggested window.
window := end.Sub(start)
randomDuration := time.Duration(rand.Int63n(int64(window)))
rt := start.Add(randomDuration)
// If the selected time is in the past, attempt renewal immediately.
if rt.Before(now) {
return &now
}
// Otherwise, if the client can schedule itself to attempt renewal at exactly the selected time, do so.
willingToSleepUntil := now.Add(willingToSleep)
if willingToSleepUntil.After(rt) || willingToSleepUntil.Equal(rt) {
return &rt
}
// TODO: Otherwise, if the selected time is before the next time that the client would wake up normally, attempt renewal immediately.
// Otherwise, sleep until the next normal wake time, re-check ARI, and return to Step 1.
return nil
}
// GetRenewalInfo sends a request to the ACME server's renewalInfo endpoint to obtain a suggested renewal window.
// The caller MUST provide the certificate and issuer certificate for the certificate they wish to renew.
// The caller should attempt to renew the certificate at the time indicated by the ShouldRenewAt method of the returned RenewalInfoResponse object.
//
// Note: this endpoint is part of a draft specification, not all ACME servers will implement it.
// This method will return api.ErrNoARI if the server does not advertise a renewal info endpoint.
//
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari
func (c *Certifier) GetRenewalInfo(req RenewalInfoRequest) (*RenewalInfoResponse, error) {
certID, err := MakeARICertID(req.Cert)
if err != nil {
return nil, fmt.Errorf("error making certID: %w", err)
}
resp, err := c.core.Certificates.GetRenewalInfo(certID)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var info RenewalInfoResponse
err = json.NewDecoder(resp.Body).Decode(&info)
if err != nil {
return nil, err
}
if retry := resp.Header.Get("Retry-After"); retry != "" {
info.RetryAfter, err = time.ParseDuration(retry + "s")
if err != nil {
return nil, err
}
}
return &info, nil
}
// MakeARICertID constructs a certificate identifier as described in draft-ietf-acme-ari-03, section 4.1.
func MakeARICertID(leaf *x509.Certificate) (string, error) {
if leaf == nil {
return "", errors.New("leaf certificate is nil")
}
// Marshal the Serial Number into DER.
der, err := asn1.Marshal(leaf.SerialNumber)
if err != nil {
return "", err
}
// Check if the DER encoded bytes are sufficient (at least 3 bytes: tag,
// length, and value).
if len(der) < 3 {
return "", errors.New("invalid DER encoding of serial number")
}
// Extract only the integer bytes from the DER encoded Serial Number
// Skipping the first 2 bytes (tag and length).
serial := base64.RawURLEncoding.EncodeToString(der[2:])
// Convert the Authority Key Identifier to base64url encoding without
// padding.
aki := base64.RawURLEncoding.EncodeToString(leaf.AuthorityKeyId)
// Construct the final identifier by concatenating AKI and Serial Number.
return fmt.Sprintf("%s.%s", aki, serial), nil
}

View File

@ -10,15 +10,15 @@ import (
type Type string type Type string
const ( const (
// HTTP01 is the "http-01" ACME challenge https://www.rfc-editor.org/rfc/rfc8555.html#section-8.3 // HTTP01 is the "http-01" ACME challenge https://tools.ietf.org/html/rfc8555#section-8.3
// Note: ChallengePath returns the URL path to fulfill this challenge. // Note: ChallengePath returns the URL path to fulfill this challenge.
HTTP01 = Type("http-01") HTTP01 = Type("http-01")
// DNS01 is the "dns-01" ACME challenge https://www.rfc-editor.org/rfc/rfc8555.html#section-8.4 // DNS01 is the "dns-01" ACME challenge https://tools.ietf.org/html/rfc8555#section-8.4
// Note: GetRecord returns a DNS record which will fulfill this challenge. // Note: GetRecord returns a DNS record which will fulfill this challenge.
DNS01 = Type("dns-01") DNS01 = Type("dns-01")
// TLSALPN01 is the "tls-alpn-01" ACME challenge https://www.rfc-editor.org/rfc/rfc8737.html // TLSALPN01 is the "tls-alpn-01" ACME challenge https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-07
TLSALPN01 = Type("tls-alpn-01") TLSALPN01 = Type("tls-alpn-01")
) )

View File

@ -1,16 +1,12 @@
package dns01 package dns01
import ( import "github.com/miekg/dns"
"strings"
"github.com/miekg/dns"
)
// Update FQDN with CNAME if any. // Update FQDN with CNAME if any.
func updateDomainWithCName(r *dns.Msg, fqdn string) string { func updateDomainWithCName(r *dns.Msg, fqdn string) string {
for _, rr := range r.Answer { for _, rr := range r.Answer {
if cn, ok := rr.(*dns.CNAME); ok { if cn, ok := rr.(*dns.CNAME); ok {
if strings.EqualFold(cn.Hdr.Name, fqdn) { if cn.Hdr.Name == fqdn {
return cn.Target return cn.Target
} }
} }

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/go-acme/lego/v4/acme" "github.com/go-acme/lego/v4/acme"
@ -115,7 +114,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
return err return err
} }
info := GetChallengeInfo(authz.Identifier.Value, keyAuth) fqdn, value := GetRecord(authz.Identifier.Value, keyAuth)
var timeout, interval time.Duration var timeout, interval time.Duration
switch provider := c.provider.(type) { switch provider := c.provider.(type) {
@ -125,12 +124,12 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval
} }
log.Infof("[%s] acme: Checking DNS record propagation. [nameservers=%s]", domain, strings.Join(recursiveNameservers, ",")) log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers)
time.Sleep(interval) time.Sleep(interval)
err = wait.For("propagation", timeout, interval, func() (bool, error) { err = wait.For("propagation", timeout, interval, func() (bool, error) {
stop, errP := c.preCheck.call(domain, info.EffectiveFQDN, info.Value) stop, errP := c.preCheck.call(domain, fqdn, value)
if !stop || errP != nil { if !stop || errP != nil {
log.Infof("[%s] acme: Waiting for DNS record propagation.", domain) log.Infof("[%s] acme: Waiting for DNS record propagation.", domain)
} }
@ -173,67 +172,19 @@ type sequential interface {
} }
// GetRecord returns a DNS record which will fulfill the `dns-01` challenge. // GetRecord returns a DNS record which will fulfill the `dns-01` challenge.
// Deprecated: use GetChallengeInfo instead.
func GetRecord(domain, keyAuth string) (fqdn, value string) { func GetRecord(domain, keyAuth string) (fqdn, value string) {
info := GetChallengeInfo(domain, keyAuth)
return info.EffectiveFQDN, info.Value
}
// ChallengeInfo contains the information use to create the TXT record.
type ChallengeInfo struct {
// FQDN is the full-qualified challenge domain (i.e. `_acme-challenge.[domain].`)
FQDN string
// EffectiveFQDN contains the resulting FQDN after the CNAMEs resolutions.
EffectiveFQDN string
// Value contains the value for the TXT record.
Value string
}
// GetChallengeInfo returns information used to create a DNS record which will fulfill the `dns-01` challenge.
func GetChallengeInfo(domain, keyAuth string) ChallengeInfo {
keyAuthShaBytes := sha256.Sum256([]byte(keyAuth)) keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
// base64URL encoding without padding // base64URL encoding without padding
value := base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size]) value = base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
fqdn = fmt.Sprintf("_acme-challenge.%s.", domain)
ok, _ := strconv.ParseBool(os.Getenv("LEGO_DISABLE_CNAME_SUPPORT")) if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_CNAME_SUPPORT")); ok {
return ChallengeInfo{
Value: value,
FQDN: getChallengeFQDN(domain, false),
EffectiveFQDN: getChallengeFQDN(domain, !ok),
}
}
func getChallengeFQDN(domain string, followCNAME bool) string {
fqdn := fmt.Sprintf("_acme-challenge.%s.", domain)
if !followCNAME {
return fqdn
}
// recursion counter so it doesn't spin out of control
for range 50 {
// Keep following CNAMEs
r, err := dnsQuery(fqdn, dns.TypeCNAME, recursiveNameservers, true) r, err := dnsQuery(fqdn, dns.TypeCNAME, recursiveNameservers, true)
// Check if the domain has CNAME then return that
if err != nil || r.Rcode != dns.RcodeSuccess { if err == nil && r.Rcode == dns.RcodeSuccess {
// No more CNAME records to follow, exit fqdn = updateDomainWithCName(r, fqdn)
break
} }
// Check if the domain has CNAME then use that
cname := updateDomainWithCName(r, fqdn)
if cname == fqdn {
break
}
log.Infof("Found CNAME entry for %q: %q", fqdn, cname)
fqdn = cname
} }
return fqdn return
} }

View File

@ -8,7 +8,7 @@ import (
) )
const ( const (
dnsTemplate = `%s %d IN TXT %q` dnsTemplate = `%s %d IN TXT "%s"`
) )
// DNSProviderManual is an implementation of the ChallengeProvider interface. // DNSProviderManual is an implementation of the ChallengeProvider interface.
@ -21,36 +21,33 @@ func NewDNSProviderManual() (*DNSProviderManual, error) {
// Present prints instructions for manually creating the TXT record. // Present prints instructions for manually creating the TXT record.
func (*DNSProviderManual) Present(domain, token, keyAuth string) error { func (*DNSProviderManual) Present(domain, token, keyAuth string) error {
info := GetChallengeInfo(domain, keyAuth) fqdn, value := GetRecord(domain, keyAuth)
authZone, err := FindZoneByFqdn(info.EffectiveFQDN) authZone, err := FindZoneByFqdn(fqdn)
if err != nil { if err != nil {
return fmt.Errorf("manual: could not find zone: %w", err) return err
} }
fmt.Printf("lego: Please create the following TXT record in your %s zone:\n", authZone) fmt.Printf("lego: Please create the following TXT record in your %s zone:\n", authZone)
fmt.Printf(dnsTemplate+"\n", info.EffectiveFQDN, DefaultTTL, info.Value) fmt.Printf(dnsTemplate+"\n", fqdn, DefaultTTL, value)
fmt.Printf("lego: Press 'Enter' when you are done\n") fmt.Printf("lego: Press 'Enter' when you are done\n")
_, err = bufio.NewReader(os.Stdin).ReadBytes('\n') _, err = bufio.NewReader(os.Stdin).ReadBytes('\n')
if err != nil {
return fmt.Errorf("manual: %w", err)
}
return nil return err
} }
// CleanUp prints instructions for manually removing the TXT record. // CleanUp prints instructions for manually removing the TXT record.
func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error { func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error {
info := GetChallengeInfo(domain, keyAuth) fqdn, _ := GetRecord(domain, keyAuth)
authZone, err := FindZoneByFqdn(info.EffectiveFQDN) authZone, err := FindZoneByFqdn(fqdn)
if err != nil { if err != nil {
return fmt.Errorf("manual: could not find zone: %w", err) return err
} }
fmt.Printf("lego: You can now remove this TXT record from your %s zone:\n", authZone) fmt.Printf("lego: You can now remove this TXT record from your %s zone:\n", authZone)
fmt.Printf(dnsTemplate+"\n", info.EffectiveFQDN, DefaultTTL, "...") fmt.Printf(dnsTemplate+"\n", fqdn, DefaultTTL, "...")
return nil return nil
} }

View File

@ -1,24 +0,0 @@
package dns01
import (
"fmt"
"strings"
"github.com/miekg/dns"
)
// ExtractSubDomain extracts the subdomain part from a domain and a zone.
func ExtractSubDomain(domain, zone string) (string, error) {
canonDomain := dns.Fqdn(domain)
canonZone := dns.Fqdn(zone)
if canonDomain == canonZone {
return "", fmt.Errorf("no subdomain because the domain and the zone are identical: %s", canonDomain)
}
if !dns.IsSubDomain(canonZone, canonDomain) {
return "", fmt.Errorf("%s is not a subdomain of %s", canonDomain, canonZone)
}
return strings.TrimSuffix(canonDomain, "."+canonZone), nil
}

View File

@ -4,9 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os"
"slices"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -16,6 +13,9 @@ import (
const defaultResolvConf = "/etc/resolv.conf" const defaultResolvConf = "/etc/resolv.conf"
// dnsTimeout is used to override the default DNS timeout of 10 seconds.
var dnsTimeout = 10 * time.Second
var ( var (
fqdnSoaCache = map[string]*soaCacheEntry{} fqdnSoaCache = map[string]*soaCacheEntry{}
muFqdnSoaCache sync.Mutex muFqdnSoaCache sync.Mutex
@ -99,12 +99,12 @@ func lookupNameservers(fqdn string) ([]string, error) {
zone, err := FindZoneByFqdn(fqdn) zone, err := FindZoneByFqdn(fqdn)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not find zone: %w", err) return nil, fmt.Errorf("could not determine the zone: %w", err)
} }
r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true) r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("NS call failed: %w", err) return nil, err
} }
for _, rr := range r.Answer { for _, rr := range r.Answer {
@ -116,8 +116,7 @@ func lookupNameservers(fqdn string) ([]string, error) {
if len(authoritativeNss) > 0 { if len(authoritativeNss) > 0 {
return authoritativeNss, nil return authoritativeNss, nil
} }
return nil, errors.New("could not determine authoritative nameservers")
return nil, fmt.Errorf("[zone=%s] could not determine authoritative nameservers", zone)
} }
// FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn // FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn
@ -131,7 +130,7 @@ func FindPrimaryNsByFqdn(fqdn string) (string, error) {
func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) { func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) {
soa, err := lookupSoaByFqdn(fqdn, nameservers) soa, err := lookupSoaByFqdn(fqdn, nameservers)
if err != nil { if err != nil {
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err) return "", err
} }
return soa.primaryNs, nil return soa.primaryNs, nil
} }
@ -147,7 +146,7 @@ func FindZoneByFqdn(fqdn string) (string, error) {
func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) { func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) {
soa, err := lookupSoaByFqdn(fqdn, nameservers) soa, err := lookupSoaByFqdn(fqdn, nameservers)
if err != nil { if err != nil {
return "", fmt.Errorf("[fqdn=%s] %w", fqdn, err) return "", err
} }
return soa.zone, nil return soa.zone, nil
} }
@ -172,35 +171,35 @@ func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error)
func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) { func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
var err error var err error
var r *dns.Msg var in *dns.Msg
labelIndexes := dns.Split(fqdn) labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes { for _, index := range labelIndexes {
domain := fqdn[index:] domain := fqdn[index:]
r, err = dnsQuery(domain, dns.TypeSOA, nameservers, true) in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil { if err != nil {
continue continue
} }
if r == nil { if in == nil {
continue continue
} }
switch r.Rcode { switch in.Rcode {
case dns.RcodeSuccess: case dns.RcodeSuccess:
// Check if we got a SOA RR in the answer section // Check if we got a SOA RR in the answer section
if len(r.Answer) == 0 { if len(in.Answer) == 0 {
continue continue
} }
// CNAME records cannot/should not exist at the root of a zone. // CNAME records cannot/should not exist at the root of a zone.
// So we skip a domain when a CNAME is found. // So we skip a domain when a CNAME is found.
if dnsMsgContainsCNAME(r) { if dnsMsgContainsCNAME(in) {
continue continue
} }
for _, ans := range r.Answer { for _, ans := range in.Answer {
if soa, ok := ans.(*dns.SOA); ok { if soa, ok := ans.(*dns.SOA); ok {
return newSoaCacheEntry(soa), nil return newSoaCacheEntry(soa), nil
} }
@ -209,46 +208,36 @@ func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
// NXDOMAIN // NXDOMAIN
default: default:
// Any response code other than NOERROR and NXDOMAIN is treated as error // Any response code other than NOERROR and NXDOMAIN is treated as error
return nil, &DNSError{Message: fmt.Sprintf("unexpected response for '%s'", domain), MsgOut: r} return nil, fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain)
} }
} }
return nil, &DNSError{Message: fmt.Sprintf("could not find the start of authority for '%s'", fqdn), MsgOut: r, Err: err} return nil, fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err))
} }
// dnsMsgContainsCNAME checks for a CNAME answer in msg. // dnsMsgContainsCNAME checks for a CNAME answer in msg.
func dnsMsgContainsCNAME(msg *dns.Msg) bool { func dnsMsgContainsCNAME(msg *dns.Msg) bool {
return slices.ContainsFunc(msg.Answer, func(rr dns.RR) bool { for _, ans := range msg.Answer {
_, ok := rr.(*dns.CNAME) if _, ok := ans.(*dns.CNAME); ok {
return ok return true
}) }
}
return false
} }
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) { func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive) m := createDNSMsg(fqdn, rtype, recursive)
if len(nameservers) == 0 { var in *dns.Msg
return nil, &DNSError{Message: "empty list of nameservers"}
}
var r *dns.Msg
var err error var err error
var errAll error
for _, ns := range nameservers { for _, ns := range nameservers {
r, err = sendDNSQuery(m, ns) in, err = sendDNSQuery(m, ns)
if err == nil && len(r.Answer) > 0 { if err == nil && len(in.Answer) > 0 {
break break
} }
errAll = errors.Join(errAll, err)
} }
return in, err
if err != nil {
return r, errAll
}
return r, nil
} }
func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg { func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
@ -264,84 +253,32 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
} }
func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) { func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok { udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout} in, _, err := udp.Exchange(m, ns)
r, _, err := tcp.Exchange(m, ns)
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
}
return r, nil if in != nil && in.Truncated {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
// If the TCP request succeeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
} }
udp := &dns.Client{Net: "udp", Timeout: dnsTimeout} return in, err
r, _, err := udp.Exchange(m, ns) }
if r != nil && r.Truncated { func formatDNSError(msg *dns.Msg, err error) string {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout} var parts []string
// If the TCP request succeeds, the "err" will reset to nil
r, _, err = tcp.Exchange(m, ns) if msg != nil {
parts = append(parts, dns.RcodeToString[msg.Rcode])
} }
if err != nil { if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err} parts = append(parts, err.Error())
} }
return r, nil if len(parts) > 0 {
} return ": " + strings.Join(parts, " ")
}
// DNSError error related to DNS calls.
type DNSError struct { return ""
Message string
NS string
MsgIn *dns.Msg
MsgOut *dns.Msg
Err error
}
func (d *DNSError) Error() string {
var details []string
if d.NS != "" {
details = append(details, "ns="+d.NS)
}
if d.MsgIn != nil && len(d.MsgIn.Question) > 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgIn.Question)))
}
if d.MsgOut != nil {
if d.MsgIn == nil || len(d.MsgIn.Question) == 0 {
details = append(details, fmt.Sprintf("question='%s'", formatQuestions(d.MsgOut.Question)))
}
details = append(details, "code="+dns.RcodeToString[d.MsgOut.Rcode])
}
msg := "DNS error"
if d.Message != "" {
msg = d.Message
}
if d.Err != nil {
msg += ": " + d.Err.Error()
}
if len(details) > 0 {
msg += " [" + strings.Join(details, ", ") + "]"
}
return msg
}
func (d *DNSError) Unwrap() error {
return d.Err
}
func formatQuestions(questions []dns.Question) string {
var parts []string
for _, question := range questions {
parts = append(parts, strings.ReplaceAll(strings.TrimPrefix(question.String(), ";"), "\t", " "))
}
return strings.Join(parts, ";")
} }

View File

@ -1,8 +0,0 @@
//go:build !windows
package dns01
import "time"
// dnsTimeout is used to override the default DNS timeout of 10 seconds.
var dnsTimeout = 10 * time.Second

View File

@ -1,8 +0,0 @@
//go:build windows
package dns01
import "time"
// dnsTimeout is used to override the default DNS timeout of 20 seconds.
var dnsTimeout = 20 * time.Second

View File

@ -23,17 +23,14 @@ import (
// RFC7239 has standardized the different forwarding headers into a single header named Forwarded. // RFC7239 has standardized the different forwarding headers into a single header named Forwarded.
// The header value has a different format, so you should use forwardedMatcher // The header value has a different format, so you should use forwardedMatcher
// when the http01.ProviderServer operates behind a RFC7239 compatible proxy. // when the http01.ProviderServer operates behind a RFC7239 compatible proxy.
// https://www.rfc-editor.org/rfc/rfc7239.html // https://tools.ietf.org/html/rfc7239
// //
// Note: RFC7239 also reminds us, "that an HTTP list [...] may be split over multiple header fields" (section 7.1), // Note: RFC7239 also reminds us, "that an HTTP list [...] may be split over multiple header fields" (section 7.1),
// meaning that // meaning that
// // X-Header: a
// X-Header: a // X-Header: b
// X-Header: b
//
// is equal to // is equal to
// // X-Header: a, b
// X-Header: a, b
// //
// All matcher implementations (explicitly not excluding arbitraryMatcher!) // All matcher implementations (explicitly not excluding arbitraryMatcher!)
// have in common that they only match against the first value in such lists. // have in common that they only match against the first value in such lists.
@ -69,7 +66,7 @@ func (m arbitraryMatcher) matches(r *http.Request, domain string) bool {
} }
// forwardedMatcher checks whether the Forwarded header contains a "host" element starting with a domain name. // forwardedMatcher checks whether the Forwarded header contains a "host" element starting with a domain name.
// See https://www.rfc-editor.org/rfc/rfc7239.html for details. // See https://tools.ietf.org/html/rfc7239 for details.
type forwardedMatcher struct{} type forwardedMatcher struct{}
func (m *forwardedMatcher) name() string { func (m *forwardedMatcher) name() string {

View File

@ -2,11 +2,9 @@ package http01
import ( import (
"fmt" "fmt"
"io/fs"
"net" "net"
"net/http" "net/http"
"net/textproto" "net/textproto"
"os"
"strings" "strings"
"github.com/go-acme/lego/v4/log" "github.com/go-acme/lego/v4/log"
@ -16,11 +14,8 @@ import (
// It may be instantiated without using the NewProviderServer function if // It may be instantiated without using the NewProviderServer function if
// you want only to use the default values. // you want only to use the default values.
type ProviderServer struct { type ProviderServer struct {
address string iface string
network string // must be valid argument to net.Listen port string
socketMode fs.FileMode
matcher domainMatcher matcher domainMatcher
done chan bool done chan bool
listener net.Listener listener net.Listener
@ -34,34 +29,24 @@ func NewProviderServer(iface, port string) *ProviderServer {
port = "80" port = "80"
} }
return &ProviderServer{network: "tcp", address: net.JoinHostPort(iface, port), matcher: &hostMatcher{}} return &ProviderServer{iface: iface, port: port, matcher: &hostMatcher{}}
}
func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer {
return &ProviderServer{network: "unix", address: socketPath, socketMode: mode, matcher: &hostMatcher{}}
} }
// Present starts a web server and makes the token available at `ChallengePath(token)` for web requests. // Present starts a web server and makes the token available at `ChallengePath(token)` for web requests.
func (s *ProviderServer) Present(domain, token, keyAuth string) error { func (s *ProviderServer) Present(domain, token, keyAuth string) error {
var err error var err error
s.listener, err = net.Listen(s.network, s.GetAddress()) s.listener, err = net.Listen("tcp", s.GetAddress())
if err != nil { if err != nil {
return fmt.Errorf("could not start HTTP server for challenge: %w", err) return fmt.Errorf("could not start HTTP server for challenge: %w", err)
} }
if s.network == "unix" {
if err = os.Chmod(s.address, s.socketMode); err != nil {
return fmt.Errorf("chmod %s: %w", s.address, err)
}
}
s.done = make(chan bool) s.done = make(chan bool)
go s.serve(domain, token, keyAuth) go s.serve(domain, token, keyAuth)
return nil return nil
} }
func (s *ProviderServer) GetAddress() string { func (s *ProviderServer) GetAddress() string {
return s.address return net.JoinHostPort(s.iface, s.port)
} }
// CleanUp closes the HTTP server and removes the token from `ChallengePath(token)`. // CleanUp closes the HTTP server and removes the token from `ChallengePath(token)`.
@ -84,7 +69,7 @@ func (s *ProviderServer) CleanUp(domain, token, keyAuth string) error {
// //
// The exact behavior depends on the value of headerName: // The exact behavior depends on the value of headerName:
// - "" (the empty string) and "Host" will restore the default and only check the Host header // - "" (the empty string) and "Host" will restore the default and only check the Host header
// - "Forwarded" will look for a Forwarded header, and inspect it according to https://www.rfc-editor.org/rfc/rfc7239.html // - "Forwarded" will look for a Forwarded header, and inspect it according to https://tools.ietf.org/html/rfc7239
// - any other value will check the header value with the same name. // - any other value will check the header value with the same name.
func (s *ProviderServer) SetProxyHeader(headerName string) { func (s *ProviderServer) SetProxyHeader(headerName string) {
switch h := textproto.CanonicalMIMEHeaderKey(headerName); h { switch h := textproto.CanonicalMIMEHeaderKey(headerName); h {
@ -100,7 +85,7 @@ func (s *ProviderServer) SetProxyHeader(headerName string) {
func (s *ProviderServer) serve(domain, token, keyAuth string) { func (s *ProviderServer) serve(domain, token, keyAuth string) {
path := ChallengePath(token) path := ChallengePath(token)
// The incoming request will be validated to prevent DNS rebind attacks. // The incoming request must will be validated to prevent DNS rebind attacks.
// We only respond with the keyAuth, when we're receiving a GET requests with // We only respond with the keyAuth, when we're receiving a GET requests with
// the "Host" header matching the domain (the latter is configurable though SetProxyHeader). // the "Host" header matching the domain (the latter is configurable though SetProxyHeader).
mux := http.NewServeMux() mux := http.NewServeMux()
@ -114,7 +99,7 @@ func (s *ProviderServer) serve(domain, token, keyAuth string) {
} }
log.Infof("[%s] Served key authentication", domain) log.Infof("[%s] Served key authentication", domain)
} else { } else {
log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure you are passing the %s header properly.", r.Host, r.Method, s.matcher.name()) log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure your are passing the %s header properly.", r.Host, r.Method, s.matcher.name())
_, err := w.Write([]byte("TEST")) _, err := w.Write([]byte("TEST"))
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)

View File

@ -19,7 +19,7 @@ func (e obtainError) Error() string {
sort.Strings(domains) sort.Strings(domains)
for _, domain := range domains { for _, domain := range domains {
_, _ = fmt.Fprintf(buffer, "[%s] %s\n", domain, e[domain]) buffer.WriteString(fmt.Sprintf("[%s] %s\n", domain, e[domain]))
} }
return buffer.String() return buffer.String()
} }

View File

@ -128,7 +128,7 @@ func sequentialSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
} }
func parallelSolve(authSolvers []*selectedAuthSolver, failures obtainError) { func parallelSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
// For all valid preSolvers, first submit the challenges, so they have max time to propagate // For all valid preSolvers, first submit the challenges so they have max time to propagate
for _, authSolver := range authSolvers { for _, authSolver := range authSolvers {
authz := authSolver.authz authz := authSolver.authz
if solvr, ok := authSolver.solver.(preSolver); ok { if solvr, ok := authSolver.solver.(preSolver); ok {

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"sort" "sort"
@ -53,7 +54,7 @@ func (c *SolverManager) SetDNS01Provider(p challenge.Provider, opts ...dns01.Cha
return nil return nil
} }
// Remove removes a challenge type from the available solvers. // Remove Remove a challenge type from the available solvers.
func (c *SolverManager) Remove(chlgType challenge.Type) { func (c *SolverManager) Remove(chlgType challenge.Type) {
delete(c.solvers, chlgType) delete(c.solvers, chlgType)
} }
@ -106,17 +107,21 @@ func validate(core *api.Core, domain string, chlg acme.Challenge) error {
bo.MaxInterval = 10 * initialInterval bo.MaxInterval = 10 * initialInterval
bo.MaxElapsedTime = 100 * initialInterval bo.MaxElapsedTime = 100 * initialInterval
ctx, cancel := context.WithCancel(context.Background())
// After the path is sent, the ACME server will access our server. // After the path is sent, the ACME server will access our server.
// Repeatedly check the server for an updated status on our request. // Repeatedly check the server for an updated status on our request.
operation := func() error { operation := func() error {
authz, err := core.Authorizations.Get(chlng.AuthorizationURL) authz, err := core.Authorizations.Get(chlng.AuthorizationURL)
if err != nil { if err != nil {
return backoff.Permanent(err) cancel()
return err
} }
valid, err := checkAuthorizationStatus(authz) valid, err := checkAuthorizationStatus(authz)
if err != nil { if err != nil {
return backoff.Permanent(err) cancel()
return err
} }
if valid { if valid {
@ -127,7 +132,7 @@ func validate(core *api.Core, domain string, chlg acme.Challenge) error {
return errors.New("the server didn't respond to our request") return errors.New("the server didn't respond to our request")
} }
return backoff.Retry(operation, bo) return backoff.Retry(operation, backoff.WithContext(bo, ctx))
} }
func checkChallengeStatus(chlng acme.ExtendedChallenge) (bool, error) { func checkChallengeStatus(chlng acme.ExtendedChallenge) (bool, error) {

View File

@ -16,7 +16,7 @@ import (
) )
// idPeAcmeIdentifierV1 is the SMI Security for PKIX Certification Extension OID referencing the ACME extension. // idPeAcmeIdentifierV1 is the SMI Security for PKIX Certification Extension OID referencing the ACME extension.
// Reference: https://www.rfc-editor.org/rfc/rfc8737.html#section-6.1 // Reference: https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-07#section-6.1
var idPeAcmeIdentifierV1 = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} var idPeAcmeIdentifierV1 = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
type ValidateFunc func(core *api.Core, domain string, chlng acme.Challenge) error type ValidateFunc func(core *api.Core, domain string, chlng acme.Challenge) error
@ -83,7 +83,7 @@ func ChallengeBlocks(domain, keyAuth string) ([]byte, []byte, error) {
// Add the keyAuth digest as the acmeValidation-v1 extension // Add the keyAuth digest as the acmeValidation-v1 extension
// (marked as critical such that it won't be used by non-ACME software). // (marked as critical such that it won't be used by non-ACME software).
// Reference: https://www.rfc-editor.org/rfc/rfc8737.html#section-3 // Reference: https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-07#section-3
extensions := []pkix.Extension{ extensions := []pkix.Extension{
{ {
Id: idPeAcmeIdentifierV1, Id: idPeAcmeIdentifierV1,

View File

@ -40,7 +40,7 @@ func (s *ProviderServer) GetAddress() string {
return net.JoinHostPort(s.iface, s.port) return net.JoinHostPort(s.iface, s.port)
} }
// Present generates a certificate with an SHA-256 digest of the keyAuth provided // Present generates a certificate with a SHA-256 digest of the keyAuth provided
// as the acmeValidation-v1 extension value to conform to the ACME-TLS-ALPN spec. // as the acmeValidation-v1 extension value to conform to the ACME-TLS-ALPN spec.
func (s *ProviderServer) Present(domain, token, keyAuth string) error { func (s *ProviderServer) Present(domain, token, keyAuth string) error {
if s.port == "" { if s.port == "" {
@ -61,7 +61,7 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error {
// We must set that the `acme-tls/1` application level protocol is supported // We must set that the `acme-tls/1` application level protocol is supported
// so that the protocol negotiation can succeed. Reference: // so that the protocol negotiation can succeed. Reference:
// https://www.rfc-editor.org/rfc/rfc8737.html#section-6.2 // https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-07#section-6.2
tlsConf.NextProtos = []string{ACMETLS1Protocol} tlsConf.NextProtos = []string{ACMETLS1Protocol}
// Create the listener with the created tls.Config. // Create the listener with the created tls.Config.

View File

@ -4,11 +4,10 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
"strconv"
"strings"
"time" "time"
"github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certcrypto"
@ -18,18 +17,13 @@ import (
const ( const (
// caCertificatesEnvVar is the environment variable name that can be used to // caCertificatesEnvVar is the environment variable name that can be used to
// specify the path to PEM encoded CA Certificates that can be used to // specify the path to PEM encoded CA Certificates that can be used to
// authenticate an ACME server with an HTTPS certificate not issued by a CA in // authenticate an ACME server with a HTTPS certificate not issued by a CA in
// the system-wide trusted root list. // the system-wide trusted root list.
// Multiple file paths can be added by using os.PathListSeparator as a separator.
caCertificatesEnvVar = "LEGO_CA_CERTIFICATES" caCertificatesEnvVar = "LEGO_CA_CERTIFICATES"
// caSystemCertPool is the environment variable name that can be used to define
// if the certificates pool must use a copy of the system cert pool.
caSystemCertPool = "LEGO_CA_SYSTEM_CERT_POOL"
// caServerNameEnvVar is the environment variable name that can be used to // caServerNameEnvVar is the environment variable name that can be used to
// specify the CA server name that can be used to // specify the CA server name that can be used to
// authenticate an ACME server with an HTTPS certificate not issued by a CA in // authenticate an ACME server with a HTTPS certificate not issued by a CA in
// the system-wide trusted root list. // the system-wide trusted root list.
caServerNameEnvVar = "LEGO_CA_SERVER_NAME" caServerNameEnvVar = "LEGO_CA_SERVER_NAME"
@ -88,44 +82,23 @@ func createDefaultHTTPClient() *http.Client {
} }
// initCertPool creates a *x509.CertPool populated with the PEM certificates // initCertPool creates a *x509.CertPool populated with the PEM certificates
// found in the filepath specified in the caCertificatesEnvVar OS environment variable. // found in the filepath specified in the caCertificatesEnvVar OS environment
// If the caCertificatesEnvVar is not set then initCertPool will return nil. // variable. If the caCertificatesEnvVar is not set then initCertPool will
// If there is an error creating a *x509.CertPool from the provided caCertificatesEnvVar value then initCertPool will panic. // return nil. If there is an error creating a *x509.CertPool from the provided
// If the caSystemCertPool is set to a "truthy value" (`1`, `t`, `T`, `TRUE`, `true`, `True`) then a copy of system cert pool will be used. // caCertificatesEnvVar value then initCertPool will panic.
// caSystemCertPool requires caCertificatesEnvVar to be set.
func initCertPool() *x509.CertPool { func initCertPool() *x509.CertPool {
customCACertsPath := os.Getenv(caCertificatesEnvVar) if customCACertsPath := os.Getenv(caCertificatesEnvVar); customCACertsPath != "" {
if customCACertsPath == "" { customCAs, err := ioutil.ReadFile(customCACertsPath)
return nil
}
certPool := getCertPool()
for _, customPath := range strings.Split(customCACertsPath, string(os.PathListSeparator)) {
customCAs, err := os.ReadFile(customPath)
if err != nil { if err != nil {
panic(fmt.Sprintf("error reading %s=%q: %v", panic(fmt.Sprintf("error reading %s=%q: %v",
caCertificatesEnvVar, customPath, err)) caCertificatesEnvVar, customCACertsPath, err))
} }
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(customCAs); !ok { if ok := certPool.AppendCertsFromPEM(customCAs); !ok {
panic(fmt.Sprintf("error creating x509 cert pool from %s=%q: %v", panic(fmt.Sprintf("error creating x509 cert pool from %s=%q: %v",
caCertificatesEnvVar, customPath, err)) caCertificatesEnvVar, customCACertsPath, err))
} }
return certPool
} }
return nil
return certPool
}
func getCertPool() *x509.CertPool {
useSystemCertPool, _ := strconv.ParseBool(os.Getenv(caSystemCertPool))
if !useSystemCertPool {
return x509.NewCertPool()
}
pool, err := x509.SystemCertPool()
if err == nil {
return pool
}
return x509.NewCertPool()
} }

View File

@ -3,6 +3,7 @@ package env
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -54,6 +55,7 @@ func Get(names ...string) (map[string]string, error) {
// // LEGO_TWO="" // // LEGO_TWO=""
// env.GetWithFallback([]string{"LEGO_ONE", "LEGO_TWO"}) // env.GetWithFallback([]string{"LEGO_ONE", "LEGO_TWO"})
// // => error // // => error
//
func GetWithFallback(groups ...[]string) (map[string]string, error) { func GetWithFallback(groups ...[]string) (map[string]string, error) {
values := map[string]string{} values := map[string]string{}
@ -78,26 +80,15 @@ func GetWithFallback(groups ...[]string) (map[string]string, error) {
return values, nil return values, nil
} }
func GetOneWithFallback[T any](main string, defaultValue T, fn func(string) (T, error), names ...string) T {
v, _ := getOneWithFallback(main, names...)
value, err := fn(v)
if err != nil {
return defaultValue
}
return value
}
func getOneWithFallback(main string, names ...string) (string, string) { func getOneWithFallback(main string, names ...string) (string, string) {
value := GetOrFile(main) value := GetOrFile(main)
if value != "" { if len(value) > 0 {
return value, main return value, main
} }
for _, name := range names { for _, name := range names {
value := GetOrFile(name) value := GetOrFile(name)
if value != "" { if len(value) > 0 {
return value, main return value, main
} }
} }
@ -105,32 +96,43 @@ func getOneWithFallback(main string, names ...string) (string, string) {
return "", main return "", main
} }
// GetOrDefaultInt returns the given environment variable value as an integer.
// Returns the default if the envvar cannot be coopered to an int, or is not found.
func GetOrDefaultInt(envVar string, defaultValue int) int {
v, err := strconv.Atoi(GetOrFile(envVar))
if err != nil {
return defaultValue
}
return v
}
// GetOrDefaultSecond returns the given environment variable value as an time.Duration (second).
// Returns the default if the envvar cannot be coopered to an int, or is not found.
func GetOrDefaultSecond(envVar string, defaultValue time.Duration) time.Duration {
v := GetOrDefaultInt(envVar, -1)
if v < 0 {
return defaultValue
}
return time.Duration(v) * time.Second
}
// GetOrDefaultString returns the given environment variable value as a string. // GetOrDefaultString returns the given environment variable value as a string.
// Returns the default if the env var cannot be found. // Returns the default if the envvar cannot be find.
func GetOrDefaultString(envVar string, defaultValue string) string { func GetOrDefaultString(envVar, defaultValue string) string {
return getOrDefault(envVar, defaultValue, ParseString) v := GetOrFile(envVar)
if v == "" {
return defaultValue
}
return v
} }
// GetOrDefaultBool returns the given environment variable value as a boolean. // GetOrDefaultBool returns the given environment variable value as a boolean.
// Returns the default if the env var cannot be coopered to a boolean, or is not found. // Returns the default if the envvar cannot be coopered to a boolean, or is not found.
func GetOrDefaultBool(envVar string, defaultValue bool) bool { func GetOrDefaultBool(envVar string, defaultValue bool) bool {
return getOrDefault(envVar, defaultValue, strconv.ParseBool) v, err := strconv.ParseBool(GetOrFile(envVar))
}
// GetOrDefaultInt returns the given environment variable value as an integer.
// Returns the default if the env var cannot be coopered to an int, or is not found.
func GetOrDefaultInt(envVar string, defaultValue int) int {
return getOrDefault(envVar, defaultValue, strconv.Atoi)
}
// GetOrDefaultSecond returns the given environment variable value as a time.Duration (second).
// Returns the default if the env var cannot be coopered to an int, or is not found.
func GetOrDefaultSecond(envVar string, defaultValue time.Duration) time.Duration {
return getOrDefault(envVar, defaultValue, ParseSecond)
}
func getOrDefault[T any](envVar string, defaultValue T, fn func(string) (T, error)) T {
v, err := fn(GetOrFile(envVar))
if err != nil { if err != nil {
return defaultValue return defaultValue
} }
@ -153,7 +155,7 @@ func GetOrFile(envVar string) string {
return envVarValue return envVarValue
} }
fileContents, err := os.ReadFile(fileVarValue) fileContents, err := ioutil.ReadFile(fileVarValue)
if err != nil { if err != nil {
log.Printf("Failed to read the file %s (defined by env var %s): %s", fileVarValue, fileVar, err) log.Printf("Failed to read the file %s (defined by env var %s): %s", fileVarValue, fileVar, err)
return "" return ""
@ -161,26 +163,3 @@ func GetOrFile(envVar string) string {
return strings.TrimSuffix(string(fileContents), "\n") return strings.TrimSuffix(string(fileContents), "\n")
} }
// ParseSecond parses env var value (string) to a second (time.Duration).
func ParseSecond(s string) (time.Duration, error) {
v, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
if v < 0 {
return 0, fmt.Errorf("unsupported value: %d", v)
}
return time.Duration(v) * time.Second, nil
}
// ParseString parses env var value (string) to a string but throws an error when the string is empty.
func ParseString(s string) (string, error) {
if s == "" {
return "", errors.New("empty string")
}
return s, nil
}

View File

@ -1,6 +1,7 @@
package wait package wait
import ( import (
"errors"
"fmt" "fmt"
"time" "time"
@ -17,9 +18,9 @@ func For(msg string, timeout, interval time.Duration, f func() (bool, error)) er
select { select {
case <-timeUp: case <-timeUp:
if lastErr == nil { if lastErr == nil {
return fmt.Errorf("%s: time limit exceeded", msg) return errors.New("time limit exceeded")
} }
return fmt.Errorf("%s: time limit exceeded: last error: %w", msg, lastErr) return fmt.Errorf("time limit exceeded: last error: %w", lastErr)
default: default:
} }

View File

@ -1,133 +0,0 @@
package errutils
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"strconv"
)
const legoDebugClientVerboseError = "LEGO_DEBUG_CLIENT_VERBOSE_ERROR"
// HTTPDoError uses with `(http.Client).Do` error.
type HTTPDoError struct {
req *http.Request
err error
}
// NewHTTPDoError creates a new HTTPDoError.
func NewHTTPDoError(req *http.Request, err error) *HTTPDoError {
return &HTTPDoError{req: req, err: err}
}
func (h HTTPDoError) Error() string {
msg := "unable to communicate with the API server:"
if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok {
msg += fmt.Sprintf(" [request: %s %s]", h.req.Method, h.req.URL)
}
if h.err == nil {
return msg
}
return msg + fmt.Sprintf(" error: %v", h.err)
}
func (h HTTPDoError) Unwrap() error {
return h.err
}
// ReadResponseError use with `io.ReadAll` when reading response body.
type ReadResponseError struct {
req *http.Request
StatusCode int
err error
}
// NewReadResponseError creates a new ReadResponseError.
func NewReadResponseError(req *http.Request, statusCode int, err error) *ReadResponseError {
return &ReadResponseError{req: req, StatusCode: statusCode, err: err}
}
func (r ReadResponseError) Error() string {
msg := "unable to read response body:"
if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok {
msg += fmt.Sprintf(" [request: %s %s]", r.req.Method, r.req.URL)
}
msg += fmt.Sprintf(" [status code: %d]", r.StatusCode)
if r.err == nil {
return msg
}
return msg + fmt.Sprintf(" error: %v", r.err)
}
func (r ReadResponseError) Unwrap() error {
return r.err
}
// UnmarshalError uses with `json.Unmarshal` or `xml.Unmarshal` when reading response body.
type UnmarshalError struct {
req *http.Request
StatusCode int
Body []byte
err error
}
// NewUnmarshalError creates a new UnmarshalError.
func NewUnmarshalError(req *http.Request, statusCode int, body []byte, err error) *UnmarshalError {
return &UnmarshalError{req: req, StatusCode: statusCode, Body: bytes.TrimSpace(body), err: err}
}
func (u UnmarshalError) Error() string {
msg := "unable to unmarshal response:"
if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok {
msg += fmt.Sprintf(" [request: %s %s]", u.req.Method, u.req.URL)
}
msg += fmt.Sprintf(" [status code: %d] body: %s", u.StatusCode, string(u.Body))
if u.err == nil {
return msg
}
return msg + fmt.Sprintf(" error: %v", u.err)
}
func (u UnmarshalError) Unwrap() error {
return u.err
}
// UnexpectedStatusCodeError use when the status of the response is unexpected but there is no API error type.
type UnexpectedStatusCodeError struct {
req *http.Request
StatusCode int
Body []byte
}
// NewUnexpectedStatusCodeError creates a new UnexpectedStatusCodeError.
func NewUnexpectedStatusCodeError(req *http.Request, statusCode int, body []byte) *UnexpectedStatusCodeError {
return &UnexpectedStatusCodeError{req: req, StatusCode: statusCode, Body: bytes.TrimSpace(body)}
}
func NewUnexpectedResponseStatusCodeError(req *http.Request, resp *http.Response) *UnexpectedStatusCodeError {
raw, _ := io.ReadAll(resp.Body)
return &UnexpectedStatusCodeError{req: req, StatusCode: resp.StatusCode, Body: bytes.TrimSpace(raw)}
}
func (u UnexpectedStatusCodeError) Error() string {
msg := "unexpected status code:"
if ok, _ := strconv.ParseBool(os.Getenv(legoDebugClientVerboseError)); ok {
msg += fmt.Sprintf(" [request: %s %s]", u.req.Method, u.req.URL)
}
return msg + fmt.Sprintf(" [status code: %d] body: %s", u.StatusCode, string(u.Body))
}

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@ -14,14 +15,16 @@ import (
) )
// OVH API reference: https://eu.api.ovh.com/ // OVH API reference: https://eu.api.ovh.com/
// Create a Token: https://eu.api.ovh.com/createToken/ // Create a Token: https://eu.api.ovh.com/createToken/
// Create a OAuth2 client: https://eu.api.ovh.com/console-preview/?section=%2Fme&branch=v1#post-/me/api/oauth2/client
// Environment variables names. // Environment variables names.
const ( const (
envNamespace = "OVH_" envNamespace = "OVH_"
EnvEndpoint = envNamespace + "ENDPOINT" EnvEndpoint = envNamespace + "ENDPOINT"
EnvApplicationKey = envNamespace + "APPLICATION_KEY"
EnvApplicationSecret = envNamespace + "APPLICATION_SECRET"
EnvConsumerKey = envNamespace + "CONSUMER_KEY"
EnvTTL = envNamespace + "TTL" EnvTTL = envNamespace + "TTL"
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT" EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
@ -29,19 +32,6 @@ const (
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT" EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
) )
// Authenticate using application key.
const (
EnvApplicationKey = envNamespace + "APPLICATION_KEY"
EnvApplicationSecret = envNamespace + "APPLICATION_SECRET"
EnvConsumerKey = envNamespace + "CONSUMER_KEY"
)
// Authenticate using OAuth2 client.
const (
EnvClientID = envNamespace + "CLIENT_ID"
EnvClientSecret = envNamespace + "CLIENT_SECRET"
)
// Record a DNS record. // Record a DNS record.
type Record struct { type Record struct {
ID int64 `json:"id,omitempty"` ID int64 `json:"id,omitempty"`
@ -52,32 +42,18 @@ type Record struct {
Zone string `json:"zone,omitempty"` Zone string `json:"zone,omitempty"`
} }
// OAuth2Config the OAuth2 specific configuration.
type OAuth2Config struct {
ClientID string
ClientSecret string
}
// Config is used to configure the creation of the DNSProvider. // Config is used to configure the creation of the DNSProvider.
type Config struct { type Config struct {
APIEndpoint string APIEndpoint string
ApplicationKey string
ApplicationKey string ApplicationSecret string
ApplicationSecret string ConsumerKey string
ConsumerKey string
OAuth2Config *OAuth2Config
PropagationTimeout time.Duration PropagationTimeout time.Duration
PollingInterval time.Duration PollingInterval time.Duration
TTL int TTL int
HTTPClient *http.Client HTTPClient *http.Client
} }
func (c *Config) hasAppKeyAuth() bool {
return c.ApplicationKey != "" || c.ApplicationSecret != "" || c.ConsumerKey != ""
}
// NewDefaultConfig returns a default configuration for the DNSProvider. // NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config { func NewDefaultConfig() *Config {
return &Config{ return &Config{
@ -102,11 +78,17 @@ type DNSProvider struct {
// Credentials must be passed in the environment variables: // Credentials must be passed in the environment variables:
// OVH_ENDPOINT (must be either "ovh-eu" or "ovh-ca"), OVH_APPLICATION_KEY, OVH_APPLICATION_SECRET, OVH_CONSUMER_KEY. // OVH_ENDPOINT (must be either "ovh-eu" or "ovh-ca"), OVH_APPLICATION_KEY, OVH_APPLICATION_SECRET, OVH_CONSUMER_KEY.
func NewDNSProvider() (*DNSProvider, error) { func NewDNSProvider() (*DNSProvider, error) {
config, err := createConfigFromEnvVars() values, err := env.Get(EnvEndpoint, EnvApplicationKey, EnvApplicationSecret, EnvConsumerKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("ovh: %w", err) return nil, fmt.Errorf("ovh: %w", err)
} }
config := NewDefaultConfig()
config.APIEndpoint = values[EnvEndpoint]
config.ApplicationKey = values[EnvApplicationKey]
config.ApplicationSecret = values[EnvApplicationSecret]
config.ConsumerKey = values[EnvConsumerKey]
return NewDNSProviderConfig(config) return NewDNSProviderConfig(config)
} }
@ -116,11 +98,16 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
return nil, errors.New("ovh: the configuration of the DNS provider is nil") return nil, errors.New("ovh: the configuration of the DNS provider is nil")
} }
if config.OAuth2Config != nil && config.hasAppKeyAuth() { if config.APIEndpoint == "" || config.ApplicationKey == "" || config.ApplicationSecret == "" || config.ConsumerKey == "" {
return nil, errors.New("ovh: can't use both authentication systems (ApplicationKey and OAuth2)") return nil, errors.New("ovh: credentials missing")
} }
client, err := newClient(config) client, err := ovh.NewClient(
config.APIEndpoint,
config.ApplicationKey,
config.ApplicationSecret,
config.ConsumerKey,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("ovh: %w", err) return nil, fmt.Errorf("ovh: %w", err)
} }
@ -136,23 +123,19 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
// Present creates a TXT record to fulfill the dns-01 challenge. // Present creates a TXT record to fulfill the dns-01 challenge.
func (d *DNSProvider) Present(domain, token, keyAuth string) error { func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) fqdn, value := dns01.GetRecord(domain, keyAuth)
// Parse domain name // Parse domain name
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(domain))
if err != nil { if err != nil {
return fmt.Errorf("ovh: could not find zone for domain %q: %w", domain, err) return fmt.Errorf("ovh: could not determine zone for domain %q: %w", domain, err)
} }
authZone = dns01.UnFqdn(authZone) authZone = dns01.UnFqdn(authZone)
subDomain := extractRecordName(fqdn, authZone)
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
if err != nil {
return fmt.Errorf("ovh: %w", err)
}
reqURL := fmt.Sprintf("/domain/zone/%s/record", authZone) reqURL := fmt.Sprintf("/domain/zone/%s/record", authZone)
reqData := Record{FieldType: "TXT", SubDomain: subDomain, Target: info.Value, TTL: d.config.TTL} reqData := Record{FieldType: "TXT", SubDomain: subDomain, Target: value, TTL: d.config.TTL}
// Create TXT record // Create TXT record
var respData Record var respData Record
@ -177,19 +160,19 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
// CleanUp removes the TXT record matching the specified parameters. // CleanUp removes the TXT record matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth) fqdn, _ := dns01.GetRecord(domain, keyAuth)
// get the record's unique ID from when we created it // get the record's unique ID from when we created it
d.recordIDsMu.Lock() d.recordIDsMu.Lock()
recordID, ok := d.recordIDs[token] recordID, ok := d.recordIDs[token]
d.recordIDsMu.Unlock() d.recordIDsMu.Unlock()
if !ok { if !ok {
return fmt.Errorf("ovh: unknown record ID for '%s'", info.EffectiveFQDN) return fmt.Errorf("ovh: unknown record ID for '%s'", fqdn)
} }
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN) authZone, err := dns01.FindZoneByFqdn(dns01.ToFqdn(domain))
if err != nil { if err != nil {
return fmt.Errorf("ovh: could not find zone for domain %q: %w", domain, err) return fmt.Errorf("ovh: could not determine zone for domain %q: %w", domain, err)
} }
authZone = dns01.UnFqdn(authZone) authZone = dns01.UnFqdn(authZone)
@ -222,94 +205,10 @@ func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval return d.config.PropagationTimeout, d.config.PollingInterval
} }
func createConfigFromEnvVars() (*Config, error) { func extractRecordName(fqdn, zone string) string {
firstAppKeyEnvVar := findFirstValuedEnvVar(EnvApplicationKey, EnvApplicationSecret, EnvConsumerKey) name := dns01.UnFqdn(fqdn)
firstOAuth2EnvVar := findFirstValuedEnvVar(EnvClientID, EnvClientSecret) if idx := strings.Index(name, "."+zone); idx != -1 {
return name[:idx]
if firstAppKeyEnvVar != "" && firstOAuth2EnvVar != "" {
return nil, fmt.Errorf("can't use both %s and %s at the same time", firstAppKeyEnvVar, firstOAuth2EnvVar)
} }
return name
config := NewDefaultConfig()
if firstOAuth2EnvVar != "" {
values, err := env.Get(EnvEndpoint, EnvClientID, EnvClientSecret)
if err != nil {
return nil, err
}
config.APIEndpoint = values[EnvEndpoint]
config.OAuth2Config = &OAuth2Config{
ClientID: values[EnvClientID],
ClientSecret: values[EnvClientSecret],
}
return config, nil
}
values, err := env.Get(EnvEndpoint, EnvApplicationKey, EnvApplicationSecret, EnvConsumerKey)
if err != nil {
return nil, err
}
config.APIEndpoint = values[EnvEndpoint]
config.ApplicationKey = values[EnvApplicationKey]
config.ApplicationSecret = values[EnvApplicationSecret]
config.ConsumerKey = values[EnvConsumerKey]
return config, nil
}
func findFirstValuedEnvVar(envVars ...string) string {
for _, envVar := range envVars {
if env.GetOrFile(envVar) != "" {
return envVar
}
}
return ""
}
func newClient(config *Config) (*ovh.Client, error) {
if config.OAuth2Config == nil {
return newClientApplicationKey(config)
}
return newClientOAuth2(config)
}
func newClientApplicationKey(config *Config) (*ovh.Client, error) {
if config.APIEndpoint == "" || config.ApplicationKey == "" || config.ApplicationSecret == "" || config.ConsumerKey == "" {
return nil, errors.New("credentials are missing")
}
client, err := ovh.NewClient(
config.APIEndpoint,
config.ApplicationKey,
config.ApplicationSecret,
config.ConsumerKey,
)
if err != nil {
return nil, fmt.Errorf("new client: %w", err)
}
return client, nil
}
func newClientOAuth2(config *Config) (*ovh.Client, error) {
if config.APIEndpoint == "" || config.OAuth2Config.ClientID == "" || config.OAuth2Config.ClientSecret == "" {
return nil, errors.New("credentials are missing")
}
client, err := ovh.NewOAuth2Client(
config.APIEndpoint,
config.OAuth2Config.ClientID,
config.OAuth2Config.ClientSecret,
)
if err != nil {
return nil, fmt.Errorf("new OAuth2 client: %w", err)
}
return client, nil
} }

View File

@ -5,20 +5,11 @@ Code = "ovh"
Since = "v0.4.0" Since = "v0.4.0"
Example = ''' Example = '''
# Application Key authentication:
OVH_APPLICATION_KEY=1234567898765432 \ OVH_APPLICATION_KEY=1234567898765432 \
OVH_APPLICATION_SECRET=b9841238feb177a84330febba8a832089 \ OVH_APPLICATION_SECRET=b9841238feb177a84330febba8a832089 \
OVH_CONSUMER_KEY=256vfsd347245sdfg \ OVH_CONSUMER_KEY=256vfsd347245sdfg \
OVH_ENDPOINT=ovh-eu \ OVH_ENDPOINT=ovh-eu \
lego --email you@example.com --dns ovh --domains my.example.org run lego --email myemail@example.com --dns ovh --domains my.example.org run
# Or OAuth2:
OVH_CLIENT_ID=yyy \
OVH_CLIENT_SECRET=xxx \
OVH_ENDPOINT=ovh-eu \
lego --email you@example.com --dns ovh --domains my.example.org run
''' '''
Additional = ''' Additional = '''
@ -26,7 +17,7 @@ Additional = '''
Application key and secret can be created by following the [OVH guide](https://docs.ovh.com/gb/en/customer/first-steps-with-ovh-api/). Application key and secret can be created by following the [OVH guide](https://docs.ovh.com/gb/en/customer/first-steps-with-ovh-api/).
When requesting the consumer key, the following configuration can be used to define access rights: When requesting the consumer key, the following configuration can be use to define access rights:
```json ```json
{ {
@ -42,32 +33,14 @@ When requesting the consumer key, the following configuration can be used to def
] ]
} }
``` ```
## OAuth2 Client Credentials
Another method for authentication is by using OAuth2 client credentials.
An IAM policy and service account can be created by following the [OVH guide](https://help.ovhcloud.com/csm/en-manage-service-account?id=kb_article_view&sysparm_article=KB0059343).
Following IAM policies need to be authorized for the affected domain:
* dnsZone:apiovh:record/create
* dnsZone:apiovh:record/delete
* dnsZone:apiovh:refresh
## Important Note
Both authentication methods cannot be used at the same time.
''' '''
[Configuration] [Configuration]
[Configuration.Credentials] [Configuration.Credentials]
OVH_ENDPOINT = "Endpoint URL (ovh-eu or ovh-ca)" OVH_ENDPOINT = "Endpoint URL (ovh-eu or ovh-ca)"
OVH_APPLICATION_KEY = "Application key (Application Key authentication)" OVH_APPLICATION_KEY = "Application key"
OVH_APPLICATION_SECRET = "Application secret (Application Key authentication)" OVH_APPLICATION_SECRET = "Application secret"
OVH_CONSUMER_KEY = "Consumer key (Application Key authentication)" OVH_CONSUMER_KEY = "Consumer key"
OVH_CLIENT_ID = "Client ID (OAuth2)"
OVH_CLIENT_SECRET = "Client secret (OAuth2)"
[Configuration.Additional] [Configuration.Additional]
OVH_POLLING_INTERVAL = "Time between DNS propagation check" OVH_POLLING_INTERVAL = "Time between DNS propagation check"
OVH_PROPAGATION_TIMEOUT = "Maximum waiting time for DNS propagation" OVH_PROPAGATION_TIMEOUT = "Maximum waiting time for DNS propagation"

View File

@ -1,228 +0,0 @@
package internal
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"github.com/miekg/dns"
)
// Client the PowerDNS API client.
type Client struct {
serverName string
apiKey string
apiVersion int
Host *url.URL
HTTPClient *http.Client
}
// NewClient creates a new Client.
func NewClient(host *url.URL, serverName string, apiVersion int, apiKey string) *Client {
return &Client{
serverName: serverName,
apiKey: apiKey,
apiVersion: apiVersion,
Host: host,
HTTPClient: &http.Client{Timeout: 5 * time.Second},
}
}
func (c *Client) APIVersion() int {
return c.apiVersion
}
func (c *Client) SetAPIVersion(ctx context.Context) error {
var err error
c.apiVersion, err = c.getAPIVersion(ctx)
return err
}
func (c *Client) getAPIVersion(ctx context.Context) (int, error) {
endpoint := c.joinPath("/", "api")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return 0, err
}
result, err := c.do(req)
if err != nil {
return 0, err
}
var versions []apiVersion
err = json.Unmarshal(result, &versions)
if err != nil {
return 0, err
}
latestVersion := 0
for _, v := range versions {
if v.Version > latestVersion {
latestVersion = v.Version
}
}
return latestVersion, err
}
func (c *Client) GetHostedZone(ctx context.Context, authZone string) (*HostedZone, error) {
endpoint := c.joinPath("/", "servers", c.serverName, "zones", dns.Fqdn(authZone))
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
result, err := c.do(req)
if err != nil {
return nil, err
}
var zone HostedZone
err = json.Unmarshal(result, &zone)
if err != nil {
return nil, err
}
// convert pre-v1 API result
if len(zone.Records) > 0 {
zone.RRSets = []RRSet{}
for _, record := range zone.Records {
set := RRSet{
Name: record.Name,
Type: record.Type,
Records: []Record{record},
}
zone.RRSets = append(zone.RRSets, set)
}
}
return &zone, nil
}
func (c *Client) UpdateRecords(ctx context.Context, zone *HostedZone, sets RRSets) error {
endpoint := c.joinPath("/", "servers", c.serverName, "zones", zone.ID)
req, err := newJSONRequest(ctx, http.MethodPatch, endpoint, sets)
if err != nil {
return err
}
_, err = c.do(req)
if err != nil {
return err
}
return nil
}
func (c *Client) Notify(ctx context.Context, zone *HostedZone) error {
if c.apiVersion < 1 || zone.Kind != "Master" && zone.Kind != "Slave" {
return nil
}
endpoint := c.joinPath("/", "servers", c.serverName, "zones", zone.ID, "notify")
req, err := newJSONRequest(ctx, http.MethodPut, endpoint, nil)
if err != nil {
return err
}
_, err = c.do(req)
if err != nil {
return err
}
return nil
}
func (c *Client) joinPath(elem ...string) *url.URL {
p := path.Join(elem...)
if p != "/api" && c.apiVersion > 0 && !strings.HasPrefix(p, "/api/v") {
p = path.Join("/api", "v"+strconv.Itoa(c.apiVersion), p)
}
return c.Host.JoinPath(p)
}
func (c *Client) do(req *http.Request) (json.RawMessage, error) {
req.Header.Set("X-API-Key", c.apiKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusUnprocessableEntity && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
return nil, errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
var msg json.RawMessage
err = json.NewDecoder(resp.Body).Decode(&msg)
if err != nil {
if errors.Is(err, io.EOF) {
// empty body
return nil, nil
}
// other error
return nil, err
}
// check for PowerDNS error message
if len(msg) > 0 && msg[0] == '{' {
var errInfo apiError
err = json.Unmarshal(msg, &errInfo)
if err != nil {
return nil, errutils.NewUnmarshalError(req, resp.StatusCode, msg, err)
}
if errInfo.ShortMsg != "" {
return nil, fmt.Errorf("error talking to PDNS API: %w", errInfo)
}
}
return msg, nil
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, strings.TrimSuffix(endpoint.String(), "/"), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
// PowerDNS doesn't follow HTTP convention about the "Content-Type" header.
if method != http.MethodGet && method != http.MethodDelete {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}

View File

@ -1,48 +0,0 @@
package internal
type Record struct {
Content string `json:"content"`
Disabled bool `json:"disabled"`
// pre-v1 API
Name string `json:"name"`
Type string `json:"type"`
TTL int `json:"ttl,omitempty"`
}
type HostedZone struct {
ID string `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
Kind string `json:"kind"`
RRSets []RRSet `json:"rrsets"`
// pre-v1 API
Records []Record `json:"records"`
}
type RRSet struct {
Name string `json:"name"`
Type string `json:"type"`
Kind string `json:"kind"`
ChangeType string `json:"changetype"`
Records []Record `json:"records,omitempty"`
TTL int `json:"ttl,omitempty"`
}
type RRSets struct {
RRSets []RRSet `json:"rrsets"`
}
type apiError struct {
ShortMsg string `json:"error"`
}
func (a apiError) Error() string {
return a.ShortMsg
}
type apiVersion struct {
URL string `json:"url"`
Version int `json:"version"`
}

View File

@ -1,228 +0,0 @@
// Package pdns implements a DNS provider for solving the DNS-01 challenge using PowerDNS nameserver.
package pdns
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/log"
"github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/pdns/internal"
)
// Environment variables names.
const (
envNamespace = "PDNS_"
EnvAPIKey = envNamespace + "API_KEY"
EnvAPIURL = envNamespace + "API_URL"
EnvTTL = envNamespace + "TTL"
EnvAPIVersion = envNamespace + "API_VERSION"
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
EnvPollingInterval = envNamespace + "POLLING_INTERVAL"
EnvHTTPTimeout = envNamespace + "HTTP_TIMEOUT"
EnvServerName = envNamespace + "SERVER_NAME"
)
// Config is used to configure the creation of the DNSProvider.
type Config struct {
APIKey string
Host *url.URL
ServerName string
APIVersion int
PropagationTimeout time.Duration
PollingInterval time.Duration
TTL int
HTTPClient *http.Client
}
// NewDefaultConfig returns a default configuration for the DNSProvider.
func NewDefaultConfig() *Config {
return &Config{
ServerName: env.GetOrDefaultString(EnvServerName, "localhost"),
APIVersion: env.GetOrDefaultInt(EnvAPIVersion, 0),
TTL: env.GetOrDefaultInt(EnvTTL, dns01.DefaultTTL),
PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 120*time.Second),
PollingInterval: env.GetOrDefaultSecond(EnvPollingInterval, 2*time.Second),
HTTPClient: &http.Client{
Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30*time.Second),
},
}
}
// DNSProvider implements the challenge.Provider interface.
type DNSProvider struct {
config *Config
client *internal.Client
}
// NewDNSProvider returns a DNSProvider instance configured for pdns.
// Credentials must be passed in the environment variable:
// PDNS_API_URL and PDNS_API_KEY.
func NewDNSProvider() (*DNSProvider, error) {
values, err := env.Get(EnvAPIKey, EnvAPIURL)
if err != nil {
return nil, fmt.Errorf("pdns: %w", err)
}
hostURL, err := url.Parse(values[EnvAPIURL])
if err != nil {
return nil, fmt.Errorf("pdns: %w", err)
}
config := NewDefaultConfig()
config.Host = hostURL
config.APIKey = values[EnvAPIKey]
return NewDNSProviderConfig(config)
}
// NewDNSProviderConfig return a DNSProvider instance configured for pdns.
func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
if config == nil {
return nil, errors.New("pdns: the configuration of the DNS provider is nil")
}
if config.APIKey == "" {
return nil, errors.New("pdns: API key missing")
}
if config.Host == nil || config.Host.Host == "" {
return nil, errors.New("pdns: API URL missing")
}
client := internal.NewClient(config.Host, config.ServerName, config.APIVersion, config.APIKey)
if config.APIVersion <= 0 {
err := client.SetAPIVersion(context.Background())
if err != nil {
log.Warnf("pdns: failed to get API version %v", err)
}
}
return &DNSProvider{config: config, client: client}, nil
}
// Timeout returns the timeout and interval to use when checking for DNS propagation.
// Adjusting here to cope with spikes in propagation times.
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
}
// Present creates a TXT record to fulfill the dns-01 challenge.
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("pdns: could not find zone for domain %q: %w", domain, err)
}
ctx := context.Background()
zone, err := d.client.GetHostedZone(ctx, authZone)
if err != nil {
return fmt.Errorf("pdns: %w", err)
}
name := info.EffectiveFQDN
if d.client.APIVersion() == 0 {
// pre-v1 API wants non-fqdn
name = dns01.UnFqdn(info.EffectiveFQDN)
}
// Look for existing records.
existingRRSet := findTxtRecord(zone, info.EffectiveFQDN)
// merge the existing and new records
var records []internal.Record
if existingRRSet != nil {
records = existingRRSet.Records
}
rec := internal.Record{
Content: "\"" + info.Value + "\"",
Disabled: false,
// pre-v1 API
Type: "TXT",
Name: name,
TTL: d.config.TTL,
}
rrSets := internal.RRSets{
RRSets: []internal.RRSet{
{
Name: name,
ChangeType: "REPLACE",
Type: "TXT",
Kind: "Master",
TTL: d.config.TTL,
Records: append(records, rec),
},
},
}
err = d.client.UpdateRecords(ctx, zone, rrSets)
if err != nil {
return fmt.Errorf("pdns: %w", err)
}
return d.client.Notify(ctx, zone)
}
// CleanUp removes the TXT record matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
info := dns01.GetChallengeInfo(domain, keyAuth)
authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
if err != nil {
return fmt.Errorf("pdns: could not find zone for domain %q: %w", domain, err)
}
ctx := context.Background()
zone, err := d.client.GetHostedZone(ctx, authZone)
if err != nil {
return fmt.Errorf("pdns: %w", err)
}
set := findTxtRecord(zone, info.EffectiveFQDN)
if set == nil {
return fmt.Errorf("pdns: no existing record found for %s", info.EffectiveFQDN)
}
rrSets := internal.RRSets{
RRSets: []internal.RRSet{
{
Name: set.Name,
Type: set.Type,
ChangeType: "DELETE",
},
},
}
err = d.client.UpdateRecords(ctx, zone, rrSets)
if err != nil {
return fmt.Errorf("pdns: %w", err)
}
return d.client.Notify(ctx, zone)
}
func findTxtRecord(zone *internal.HostedZone, fqdn string) *internal.RRSet {
for _, set := range zone.RRSets {
if set.Type == "TXT" && (set.Name == dns01.UnFqdn(fqdn) || set.Name == fqdn) {
return &set
}
}
return nil
}

View File

@ -1,37 +0,0 @@
Name = "PowerDNS"
Description = ''''''
URL = "https://www.powerdns.com/"
Code = "pdns"
Since = "v0.4.0"
Example = '''
PDNS_API_URL=http://pdns-server:80/ \
PDNS_API_KEY=xxxx \
lego --email you@example.com --dns pdns --domains my.example.org run
'''
Additional = '''
## Information
Tested and confirmed to work with PowerDNS authoritative server 3.4.8 and 4.0.1. Refer to [PowerDNS documentation](https://doc.powerdns.com/md/httpapi/README/) instructions on how to enable the built-in API interface.
PowerDNS Notes:
- PowerDNS API does not currently support SSL, therefore you should take care to ensure that traffic between lego and the PowerDNS API is over a trusted network, VPN etc.
- In order to have the SOA serial automatically increment each time the `_acme-challenge` record is added/modified via the API, set `SOA-EDIT-API` to `INCEPTION-INCREMENT` for the zone in the `domainmetadata` table
- Some PowerDNS servers doesn't have root API endpoints enabled and API version autodetection will not work. In that case version number can be defined using `PDNS_API_VERSION`.
'''
[Configuration]
[Configuration.Credentials]
PDNS_API_KEY = "API key"
PDNS_API_URL = "API URL"
[Configuration.Additional]
PDNS_SERVER_NAME = "Name of the server in the URL, 'localhost' by default"
PDNS_API_VERSION = "Skip API version autodetection and use the provided version number."
PDNS_POLLING_INTERVAL = "Time between DNS propagation check"
PDNS_PROPAGATION_TIMEOUT = "Maximum waiting time for DNS propagation"
PDNS_TTL = "The TTL of the TXT record used for the DNS challenge"
PDNS_HTTP_TIMEOUT = "API request timeout"
[Links]
API = "https://doc.powerdns.com/md/httpapi/README/"

View File

@ -9,11 +9,9 @@ import (
"github.com/go-acme/lego/v4/log" "github.com/go-acme/lego/v4/log"
) )
const mailTo = "mailto:"
// Resource represents all important information about a registration // Resource represents all important information about a registration
// of which the client needs to keep track itself. // of which the client needs to keep track itself.
// WARNING: will be removed in the future (acme.ExtendedAccount), https://github.com/go-acme/lego/issues/855. // WARNING: will be remove in the future (acme.ExtendedAccount), https://github.com/go-acme/lego/issues/855.
type Resource struct { type Resource struct {
Body acme.Account `json:"body,omitempty"` Body acme.Account `json:"body,omitempty"`
URI string `json:"uri,omitempty"` URI string `json:"uri,omitempty"`
@ -54,7 +52,7 @@ func (r *Registrar) Register(options RegisterOptions) (*Resource, error) {
if r.user.GetEmail() != "" { if r.user.GetEmail() != "" {
log.Infof("acme: Registering account for %s", r.user.GetEmail()) log.Infof("acme: Registering account for %s", r.user.GetEmail())
accMsg.Contact = []string{mailTo + r.user.GetEmail()} accMsg.Contact = []string{"mailto:" + r.user.GetEmail()}
} }
account, err := r.core.Accounts.New(accMsg) account, err := r.core.Accounts.New(accMsg)
@ -78,7 +76,7 @@ func (r *Registrar) RegisterWithExternalAccountBinding(options RegisterEABOption
if r.user.GetEmail() != "" { if r.user.GetEmail() != "" {
log.Infof("acme: Registering account for %s", r.user.GetEmail()) log.Infof("acme: Registering account for %s", r.user.GetEmail())
accMsg.Contact = []string{mailTo + r.user.GetEmail()} accMsg.Contact = []string{"mailto:" + r.user.GetEmail()}
} }
account, err := r.core.Accounts.NewEAB(accMsg, options.Kid, options.HmacEncoded) account, err := r.core.Accounts.NewEAB(accMsg, options.Kid, options.HmacEncoded)
@ -130,7 +128,7 @@ func (r *Registrar) UpdateRegistration(options RegisterOptions) (*Resource, erro
if r.user.GetEmail() != "" { if r.user.GetEmail() != "" {
log.Infof("acme: Registering account for %s", r.user.GetEmail()) log.Infof("acme: Registering account for %s", r.user.GetEmail())
accMsg.Contact = []string{mailTo + r.user.GetEmail()} accMsg.Contact = []string{"mailto:" + r.user.GetEmail()}
} }
accountURL := r.user.GetRegistration().URI accountURL := r.user.GetRegistration().URI

View File

@ -1,2 +0,0 @@
jose-util/jose-util
jose-util.t.err

View File

@ -1,53 +0,0 @@
# https://github.com/golangci/golangci-lint
run:
skip-files:
- doc_test.go
modules-download-mode: readonly
linters:
enable-all: true
disable:
- gochecknoglobals
- goconst
- lll
- maligned
- nakedret
- scopelint
- unparam
- funlen # added in 1.18 (requires go-jose changes before it can be enabled)
linters-settings:
gocyclo:
min-complexity: 35
issues:
exclude-rules:
- text: "don't use ALL_CAPS in Go names"
linters:
- golint
- text: "hardcoded credentials"
linters:
- gosec
- text: "weak cryptographic primitive"
linters:
- gosec
- path: json/
linters:
- dupl
- errcheck
- gocritic
- gocyclo
- golint
- govet
- ineffassign
- staticcheck
- structcheck
- stylecheck
- unused
- path: _test\.go
linters:
- scopelint
- path: jwk.go
linters:
- gocyclo

View File

@ -1,33 +0,0 @@
language: go
matrix:
fast_finish: true
allow_failures:
- go: tip
go:
- "1.13.x"
- "1.14.x"
- tip
before_script:
- export PATH=$HOME/.local/bin:$PATH
before_install:
- go get -u github.com/mattn/goveralls github.com/wadey/gocovmerge
- curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.18.0
- pip install cram --user
script:
- go test -v -covermode=count -coverprofile=profile.cov .
- go test -v -covermode=count -coverprofile=cryptosigner/profile.cov ./cryptosigner
- go test -v -covermode=count -coverprofile=cipher/profile.cov ./cipher
- go test -v -covermode=count -coverprofile=jwt/profile.cov ./jwt
- go test -v ./json # no coverage for forked encoding/json package
- golangci-lint run
- cd jose-util && go build && PATH=$PWD:$PATH cram -v jose-util.t # cram tests jose-util
- cd ..
after_success:
- gocovmerge *.cov */*.cov > merged.coverprofile
- goveralls -coverprofile merged.coverprofile -service=travis-ci

View File

@ -1,96 +0,0 @@
# v4.0.4
## Fixed
- Reverted "Allow unmarshalling JSONWebKeySets with unsupported key types" as a
breaking change. See #136 / #137.
# v4.0.3
## Changed
- Allow unmarshalling JSONWebKeySets with unsupported key types (#130)
- Document that OpaqueKeyEncrypter can't be implemented (for now) (#129)
- Dependency updates
# v4.0.2
## Changed
- Improved documentation of Verify() to note that JSONWebKeySet is a supported
argument type (#104)
- Defined exported error values for missing x5c header and unsupported elliptic
curves error cases (#117)
# v4.0.1
## Fixed
- An attacker could send a JWE containing compressed data that used large
amounts of memory and CPU when decompressed by `Decrypt` or `DecryptMulti`.
Those functions now return an error if the decompressed data would exceed
250kB or 10x the compressed size (whichever is larger). Thanks to
Enze Wang@Alioth and Jianjun Chen@Zhongguancun Lab (@zer0yu and @chenjj)
for reporting.
# v4.0.0
This release makes some breaking changes in order to more thoroughly
address the vulnerabilities discussed in [Three New Attacks Against JSON Web
Tokens][1], "Sign/encrypt confusion", "Billion hash attack", and "Polyglot
token".
## Changed
- Limit JWT encryption types (exclude password or public key types) (#78)
- Enforce minimum length for HMAC keys (#85)
- jwt: match any audience in a list, rather than requiring all audiences (#81)
- jwt: accept only Compact Serialization (#75)
- jws: Add expected algorithms for signatures (#74)
- Require specifying expected algorithms for ParseEncrypted,
ParseSigned, ParseDetached, jwt.ParseEncrypted, jwt.ParseSigned,
jwt.ParseSignedAndEncrypted (#69, #74)
- Usually there is a small, known set of appropriate algorithms for a program
to use and it's a mistake to allow unexpected algorithms. For instance the
"billion hash attack" relies in part on programs accepting the PBES2
encryption algorithm and doing the necessary work even if they weren't
specifically configured to allow PBES2.
- Revert "Strip padding off base64 strings" (#82)
- The specs require base64url encoding without padding.
- Minimum supported Go version is now 1.21
## Added
- ParseSignedCompact, ParseSignedJSON, ParseEncryptedCompact, ParseEncryptedJSON.
- These allow parsing a specific serialization, as opposed to ParseSigned and
ParseEncrypted, which try to automatically detect which serialization was
provided. It's common to require a specific serialization for a specific
protocol - for instance JWT requires Compact serialization.
[1]: https://i.blackhat.com/BH-US-23/Presentations/US-23-Tervoort-Three-New-Attacks-Against-JSON-Web-Tokens.pdf
# v3.0.2
## Fixed
- DecryptMulti: handle decompression error (#19)
## Changed
- jwe/CompactSerialize: improve performance (#67)
- Increase the default number of PBKDF2 iterations to 600k (#48)
- Return the proper algorithm for ECDSA keys (#45)
## Added
- Add Thumbprint support for opaque signers (#38)
# v3.0.1
## Fixed
- Security issue: an attacker specifying a large "p2c" value can cause
JSONWebEncryption.Decrypt and JSONWebEncryption.DecryptMulti to consume large
amounts of CPU, causing a DoS. Thanks to Matt Schwager (@mschwager) for the
disclosure and to Tom Tervoort for originally publishing the category of attack.
https://i.blackhat.com/BH-US-23/Presentations/US-23-Tervoort-Three-New-Attacks-Against-JSON-Web-Tokens.pdf

View File

@ -1,13 +0,0 @@
# Security Policy
This document explains how to contact the Let's Encrypt security team to report security vulnerabilities.
## Supported Versions
| Version | Supported |
| ------- | ----------|
| >= v3 | &check; |
| v2 | &cross; |
| v1 | &cross; |
## Reporting a vulnerability
Please see [https://letsencrypt.org/contact/#security](https://letsencrypt.org/contact/#security) for the email address to report a vulnerability. Ensure that the subject line for your report contains the word `vulnerability` and is descriptive. Your email should be acknowledged within 24 hours. If you do not receive a response within 24 hours, please follow-up again with another email.

View File

@ -48,17 +48,6 @@ linters:
- nlreturn - nlreturn
- testpackage - testpackage
- wsl - wsl
- varnamelen
- nilnil
- ireturn
- govet
- forcetypeassert
- cyclop
- containedctx
- revive
- nosnakecase
- exhaustruct
- depguard
issues: issues:
exclude-rules: exclude-rules:

View File

@ -1,168 +1,3 @@
# v0.10.2 - 2023/03/20
### New features
* Support DebugDOT option for debugging encoder ( #440 )
### Fix bugs
* Fix combination of embedding structure and omitempty option ( #442 )
# v0.10.1 - 2023/03/13
### Fix bugs
* Fix checkptr error for array decoder ( #415 )
* Fix added buffer size check when decoding key ( #430 )
* Fix handling of anonymous fields other than struct ( #431 )
* Fix to not optimize when lower conversion can't handle byte-by-byte ( #432 )
* Fix a problem that MarshalIndent does not work when UnorderedMap is specified ( #435 )
* Fix mapDecoder.DecodeStream() for empty objects containing whitespace ( #425 )
* Fix an issue that could not set the correct NextField for fields in the embedded structure ( #438 )
# v0.10.0 - 2022/11/29
### New features
* Support JSON Path ( #250 )
### Fix bugs
* Fix marshaler for map's key ( #409 )
# v0.9.11 - 2022/08/18
### Fix bugs
* Fix unexpected behavior when buffer ends with backslash ( #383 )
* Fix stream decoding of escaped character ( #387 )
# v0.9.10 - 2022/07/15
### Fix bugs
* Fix boundary exception of type caching ( #382 )
# v0.9.9 - 2022/07/15
### Fix bugs
* Fix encoding of directed interface with typed nil ( #377 )
* Fix embedded primitive type encoding using alias ( #378 )
* Fix slice/array type encoding with types implementing MarshalJSON ( #379 )
* Fix unicode decoding when the expected buffer state is not met after reading ( #380 )
# v0.9.8 - 2022/06/30
### Fix bugs
* Fix decoding of surrogate-pair ( #365 )
* Fix handling of embedded primitive type ( #366 )
* Add validation of escape sequence for decoder ( #367 )
* Fix stream tokenizing respecting UseNumber ( #369 )
* Fix encoding when struct pointer type that implements Marshal JSON is embedded ( #375 )
### Improve performance
* Improve performance of linkRecursiveCode ( #368 )
# v0.9.7 - 2022/04/22
### Fix bugs
#### Encoder
* Add filtering process for encoding on slow path ( #355 )
* Fix encoding of interface{} with pointer type ( #363 )
#### Decoder
* Fix map key decoder that implements UnmarshalJSON ( #353 )
* Fix decoding of []uint8 type ( #361 )
### New features
* Add DebugWith option for encoder ( #356 )
# v0.9.6 - 2022/03/22
### Fix bugs
* Correct the handling of the minimum value of int type for decoder ( #344 )
* Fix bugs of stream decoder's bufferSize ( #349 )
* Add a guard to use typeptr more safely ( #351 )
### Improve decoder performance
* Improve escapeString's performance ( #345 )
### Others
* Update go version for CI ( #347 )
# v0.9.5 - 2022/03/04
### Fix bugs
* Fix panic when decoding time.Time with context ( #328 )
* Fix reading the next character in buffer to nul consideration ( #338 )
* Fix incorrect handling on skipValue ( #341 )
### Improve decoder performance
* Improve performance when a payload contains escape sequence ( #334 )
# v0.9.4 - 2022/01/21
* Fix IsNilForMarshaler for string type with omitempty ( #323 )
* Fix the case where the embedded field is at the end ( #326 )
# v0.9.3 - 2022/01/14
* Fix logic of removing struct field for decoder ( #322 )
# v0.9.2 - 2022/01/14
* Add invalid decoder to delay type error judgment at decode ( #321 )
# v0.9.1 - 2022/01/11
* Fix encoding of MarshalText/MarshalJSON operation with head offset ( #319 )
# v0.9.0 - 2022/01/05
### New feature
* Supports dynamic filtering of struct fields ( #314 )
### Improve encoding performance
* Improve map encoding performance ( #310 )
* Optimize encoding path for escaped string ( #311 )
* Add encoding option for performance ( #312 )
### Fix bugs
* Fix panic at encoding map value on 1.18 ( #310 )
* Fix MarshalIndent for interface type ( #317 )
# v0.8.1 - 2021/12/05
* Fix operation conversion from PtrHead to Head in Recursive type ( #305 )
# v0.8.0 - 2021/12/02
* Fix embedded field conflict behavior ( #300 )
* Refactor compiler for encoder ( #301 #302 )
# v0.7.10 - 2021/10/16
* Fix conversion from pointer to uint64 ( #294 )
# v0.7.9 - 2021/09/28
* Fix encoding of nil value about interface type that has method ( #291 )
# v0.7.8 - 2021/09/01 # v0.7.8 - 2021/09/01
* Fix mapassign_faststr for indirect struct type ( #283 ) * Fix mapassign_faststr for indirect struct type ( #283 )

View File

@ -22,7 +22,7 @@ cover-html: cover
.PHONY: lint .PHONY: lint
lint: golangci-lint lint: golangci-lint
$(BIN_DIR)/golangci-lint run golangci-lint run
golangci-lint: | $(BIN_DIR) golangci-lint: | $(BIN_DIR)
@{ \ @{ \
@ -30,7 +30,7 @@ golangci-lint: | $(BIN_DIR)
GOLANGCI_LINT_TMP_DIR=$$(mktemp -d); \ GOLANGCI_LINT_TMP_DIR=$$(mktemp -d); \
cd $$GOLANGCI_LINT_TMP_DIR; \ cd $$GOLANGCI_LINT_TMP_DIR; \
go mod init tmp; \ go mod init tmp; \
GOBIN=$(BIN_DIR) go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.54.2; \ GOBIN=$(BIN_DIR) go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.36.0; \
rm -rf $$GOLANGCI_LINT_TMP_DIR; \ rm -rf $$GOLANGCI_LINT_TMP_DIR; \
} }

View File

@ -13,7 +13,7 @@ Fast JSON encoder/decoder compatible with encoding/json for Go
``` ```
* version ( expected release date ) * version ( expected release date )
* v0.9.0 * v0.7.0
| |
| while maintaining compatibility with encoding/json, we will add convenient APIs | while maintaining compatibility with encoding/json, we will add convenient APIs
| |
@ -21,8 +21,9 @@ Fast JSON encoder/decoder compatible with encoding/json for Go
* v1.0.0 * v1.0.0
``` ```
We are accepting requests for features that will be implemented between v0.9.0 and v.1.0.0. We are accepting requests for features that will be implemented between v0.7.0 and v.1.0.0.
If you have the API you need, please submit your issue [here](https://github.com/goccy/go-json/issues). If you have the API you need, please submit your issue [here](https://github.com/goccy/go-json/issues).
For example, I'm thinking of supporting `context.Context` of `json.Marshaler` and decoding using JSON Path.
# Features # Features
@ -31,7 +32,6 @@ If you have the API you need, please submit your issue [here](https://github.com
- Flexible customization with options - Flexible customization with options
- Coloring the encoded string - Coloring the encoded string
- Can propagate context.Context to `MarshalJSON` or `UnmarshalJSON` - Can propagate context.Context to `MarshalJSON` or `UnmarshalJSON`
- Can dynamically filter the fields of the structure type-safely
# Installation # Installation
@ -184,7 +184,7 @@ func Marshal(v interface{}) ([]byte, error) {
`json.Marshal` and `json.Unmarshal` receive `interface{}` value and they perform type determination dynamically to process. `json.Marshal` and `json.Unmarshal` receive `interface{}` value and they perform type determination dynamically to process.
In normal case, you need to use the `reflect` library to determine the type dynamically, but since `reflect.Type` is defined as `interface`, when you call the method of `reflect.Type`, The reflect's argument is escaped. In normal case, you need to use the `reflect` library to determine the type dynamically, but since `reflect.Type` is defined as `interface`, when you call the method of `reflect.Type`, The reflect's argument is escaped.
Therefore, the arguments for `Marshal` and `Unmarshal` are always escaped to the heap. Therefore, the arguments for `Marshal` and `Unmarshal` are always escape to the heap.
However, `go-json` can use the feature of `reflect.Type` while avoiding escaping. However, `go-json` can use the feature of `reflect.Type` while avoiding escaping.
`reflect.Type` is defined as `interface`, but in reality `reflect.Type` is implemented only by the structure `rtype` defined in the `reflect` package. `reflect.Type` is defined as `interface`, but in reality `reflect.Type` is implemented only by the structure `rtype` defined in the `reflect` package.

View File

@ -83,37 +83,6 @@ func unmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs
return validateEndBuf(src, cursor) return validateEndBuf(src, cursor)
} }
var (
pathDecoder = decoder.NewPathDecoder()
)
func extractFromPath(path *Path, data []byte, optFuncs ...DecodeOptionFunc) ([][]byte, error) {
if path.path.RootSelectorOnly {
return [][]byte{data}, nil
}
src := make([]byte, len(data)+1) // append nul byte to the end
copy(src, data)
ctx := decoder.TakeRuntimeContext()
ctx.Buf = src
ctx.Option.Flags = 0
ctx.Option.Flags |= decoder.PathOption
ctx.Option.Path = path.path
for _, optFunc := range optFuncs {
optFunc(ctx.Option)
}
paths, cursor, err := pathDecoder.DecodePath(ctx, 0, 0)
if err != nil {
decoder.ReleaseRuntimeContext(ctx)
return nil, err
}
decoder.ReleaseRuntimeContext(ctx)
if err := validateEndBuf(src, cursor); err != nil {
return nil, err
}
return paths, nil
}
func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
src := make([]byte, len(data)+1) // append nul byte to the end src := make([]byte, len(data)+1) // append nul byte to the end
copy(src, data) copy(src, data)

View File

@ -1,7 +1,7 @@
version: '2' version: '2'
services: services:
go-json: go-json:
image: golang:1.18 image: golang:1.16
volumes: volumes:
- '.:/go/src/go-json' - '.:/go/src/go-json'
deploy: deploy:

View File

@ -3,7 +3,6 @@ package json
import ( import (
"context" "context"
"io" "io"
"os"
"unsafe" "unsafe"
"github.com/goccy/go-json/internal/encoder" "github.com/goccy/go-json/internal/encoder"
@ -52,7 +51,7 @@ func (e *Encoder) EncodeContext(ctx context.Context, v interface{}, optFuncs ...
rctx.Option.Flag |= encoder.ContextOption rctx.Option.Flag |= encoder.ContextOption
rctx.Option.Context = ctx rctx.Option.Context = ctx
err := e.encodeWithOption(rctx, v, optFuncs...) //nolint: contextcheck err := e.encodeWithOption(rctx, v, optFuncs...)
encoder.ReleaseRuntimeContext(rctx) encoder.ReleaseRuntimeContext(rctx)
return err return err
@ -62,8 +61,6 @@ func (e *Encoder) encodeWithOption(ctx *encoder.RuntimeContext, v interface{}, o
if e.enabledHTMLEscape { if e.enabledHTMLEscape {
ctx.Option.Flag |= encoder.HTMLEscapeOption ctx.Option.Flag |= encoder.HTMLEscapeOption
} }
ctx.Option.Flag |= encoder.NormalizeUTF8Option
ctx.Option.DebugOut = os.Stdout
for _, optFunc := range optFuncs { for _, optFunc := range optFuncs {
optFunc(ctx.Option) optFunc(ctx.Option)
} }
@ -114,13 +111,13 @@ func (e *Encoder) SetIndent(prefix, indent string) {
func marshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) { func marshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
rctx := encoder.TakeRuntimeContext() rctx := encoder.TakeRuntimeContext()
rctx.Option.Flag = 0 rctx.Option.Flag = 0
rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.NormalizeUTF8Option | encoder.ContextOption rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.ContextOption
rctx.Option.Context = ctx rctx.Option.Context = ctx
for _, optFunc := range optFuncs { for _, optFunc := range optFuncs {
optFunc(rctx.Option) optFunc(rctx.Option)
} }
buf, err := encode(rctx, v) //nolint: contextcheck buf, err := encode(rctx, v)
if err != nil { if err != nil {
encoder.ReleaseRuntimeContext(rctx) encoder.ReleaseRuntimeContext(rctx)
return nil, err return nil, err
@ -142,7 +139,7 @@ func marshal(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
ctx := encoder.TakeRuntimeContext() ctx := encoder.TakeRuntimeContext()
ctx.Option.Flag = 0 ctx.Option.Flag = 0
ctx.Option.Flag |= (encoder.HTMLEscapeOption | encoder.NormalizeUTF8Option) ctx.Option.Flag |= encoder.HTMLEscapeOption
for _, optFunc := range optFuncs { for _, optFunc := range optFuncs {
optFunc(ctx.Option) optFunc(ctx.Option)
} }
@ -169,7 +166,7 @@ func marshalNoEscape(v interface{}) ([]byte, error) {
ctx := encoder.TakeRuntimeContext() ctx := encoder.TakeRuntimeContext()
ctx.Option.Flag = 0 ctx.Option.Flag = 0
ctx.Option.Flag |= (encoder.HTMLEscapeOption | encoder.NormalizeUTF8Option) ctx.Option.Flag |= encoder.HTMLEscapeOption
buf, err := encodeNoEscape(ctx, v) buf, err := encodeNoEscape(ctx, v)
if err != nil { if err != nil {
@ -193,7 +190,7 @@ func marshalIndent(v interface{}, prefix, indent string, optFuncs ...EncodeOptio
ctx := encoder.TakeRuntimeContext() ctx := encoder.TakeRuntimeContext()
ctx.Option.Flag = 0 ctx.Option.Flag = 0
ctx.Option.Flag |= (encoder.HTMLEscapeOption | encoder.NormalizeUTF8Option | encoder.IndentOption) ctx.Option.Flag |= (encoder.HTMLEscapeOption | encoder.IndentOption)
for _, optFunc := range optFuncs { for _, optFunc := range optFuncs {
optFunc(ctx.Option) optFunc(ctx.Option)
} }
@ -223,7 +220,7 @@ func encode(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error) {
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr) codeSet, err := encoder.CompileToGetCodeSet(typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -251,7 +248,7 @@ func encodeNoEscape(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error)
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr) codeSet, err := encoder.CompileToGetCodeSet(typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -278,7 +275,7 @@ func encodeIndent(ctx *encoder.RuntimeContext, v interface{}, prefix, indent str
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr) codeSet, err := encoder.CompileToGetCodeSet(typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -37,5 +37,3 @@ type UnmarshalTypeError = errors.UnmarshalTypeError
type UnsupportedTypeError = errors.UnsupportedTypeError type UnsupportedTypeError = errors.UnsupportedTypeError
type UnsupportedValueError = errors.UnsupportedValueError type UnsupportedValueError = errors.UnsupportedValueError
type PathError = errors.PathError

View File

@ -35,7 +35,3 @@ func (d *anonymousFieldDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
p = *(*unsafe.Pointer)(p) p = *(*unsafe.Pointer)(p)
return d.dec.Decode(ctx, cursor, depth, unsafe.Pointer(uintptr(p)+d.offset)) return d.dec.Decode(ctx, cursor, depth, unsafe.Pointer(uintptr(p)+d.offset))
} }
func (d *anonymousFieldDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return d.dec.DecodePath(ctx, cursor, depth)
}

View File

@ -1,7 +1,6 @@
package decoder package decoder
import ( import (
"fmt"
"unsafe" "unsafe"
"github.com/goccy/go-json/internal/errors" "github.com/goccy/go-json/internal/errors"
@ -19,9 +18,7 @@ type arrayDecoder struct {
} }
func newArrayDecoder(dec Decoder, elemType *runtime.Type, alen int, structName, fieldName string) *arrayDecoder { func newArrayDecoder(dec Decoder, elemType *runtime.Type, alen int, structName, fieldName string) *arrayDecoder {
// workaround to avoid checkptr errors. cannot use `*(*unsafe.Pointer)(unsafe_New(elemType))` directly. zeroValue := *(*unsafe.Pointer)(unsafe_New(elemType))
zeroValuePtr := unsafe_New(elemType)
zeroValue := **(**unsafe.Pointer)(unsafe.Pointer(&zeroValuePtr))
return &arrayDecoder{ return &arrayDecoder{
valueDecoder: dec, valueDecoder: dec,
elemType: elemType, elemType: elemType,
@ -170,7 +167,3 @@ func (d *arrayDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
} }
} }
} }
func (d *arrayDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: array decoder does not support decode path")
}

View File

@ -1,438 +0,0 @@
package decoder
import (
"fmt"
"reflect"
"strconv"
)
var (
nilValue = reflect.ValueOf(nil)
)
func AssignValue(src, dst reflect.Value) error {
if dst.Type().Kind() != reflect.Ptr {
return fmt.Errorf("invalid dst type. required pointer type: %T", dst.Type())
}
casted, err := castValue(dst.Elem().Type(), src)
if err != nil {
return err
}
dst.Elem().Set(casted)
return nil
}
func castValue(t reflect.Type, v reflect.Value) (reflect.Value, error) {
switch t.Kind() {
case reflect.Int:
vv, err := castInt(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(int(vv.Int())), nil
case reflect.Int8:
vv, err := castInt(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(int8(vv.Int())), nil
case reflect.Int16:
vv, err := castInt(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(int16(vv.Int())), nil
case reflect.Int32:
vv, err := castInt(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(int32(vv.Int())), nil
case reflect.Int64:
return castInt(v)
case reflect.Uint:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uint(vv.Uint())), nil
case reflect.Uint8:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uint8(vv.Uint())), nil
case reflect.Uint16:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uint16(vv.Uint())), nil
case reflect.Uint32:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uint32(vv.Uint())), nil
case reflect.Uint64:
return castUint(v)
case reflect.Uintptr:
vv, err := castUint(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(uintptr(vv.Uint())), nil
case reflect.String:
return castString(v)
case reflect.Bool:
return castBool(v)
case reflect.Float32:
vv, err := castFloat(v)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(float32(vv.Float())), nil
case reflect.Float64:
return castFloat(v)
case reflect.Array:
return castArray(t, v)
case reflect.Slice:
return castSlice(t, v)
case reflect.Map:
return castMap(t, v)
case reflect.Struct:
return castStruct(t, v)
}
return v, nil
}
func castInt(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return reflect.ValueOf(int64(v.Uint())), nil
case reflect.String:
i64, err := strconv.ParseInt(v.String(), 10, 64)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(i64), nil
case reflect.Bool:
if v.Bool() {
return reflect.ValueOf(int64(1)), nil
}
return reflect.ValueOf(int64(0)), nil
case reflect.Float32, reflect.Float64:
return reflect.ValueOf(int64(v.Float())), nil
case reflect.Array:
if v.Len() > 0 {
return castInt(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to int64 from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castInt(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to int64 from empty slice")
case reflect.Interface:
return castInt(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to int64 from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to int64 from struct")
case reflect.Ptr:
return castInt(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to int64 from %s", v.Type().Kind())
}
func castUint(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.ValueOf(uint64(v.Int())), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v, nil
case reflect.String:
u64, err := strconv.ParseUint(v.String(), 10, 64)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(u64), nil
case reflect.Bool:
if v.Bool() {
return reflect.ValueOf(uint64(1)), nil
}
return reflect.ValueOf(uint64(0)), nil
case reflect.Float32, reflect.Float64:
return reflect.ValueOf(uint64(v.Float())), nil
case reflect.Array:
if v.Len() > 0 {
return castUint(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to uint64 from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castUint(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to uint64 from empty slice")
case reflect.Interface:
return castUint(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to uint64 from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to uint64 from struct")
case reflect.Ptr:
return castUint(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to uint64 from %s", v.Type().Kind())
}
func castString(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.ValueOf(fmt.Sprint(v.Int())), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return reflect.ValueOf(fmt.Sprint(v.Uint())), nil
case reflect.String:
return v, nil
case reflect.Bool:
if v.Bool() {
return reflect.ValueOf("true"), nil
}
return reflect.ValueOf("false"), nil
case reflect.Float32, reflect.Float64:
return reflect.ValueOf(fmt.Sprint(v.Float())), nil
case reflect.Array:
if v.Len() > 0 {
return castString(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to string from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castString(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to string from empty slice")
case reflect.Interface:
return castString(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to string from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to string from struct")
case reflect.Ptr:
return castString(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to string from %s", v.Type().Kind())
}
func castBool(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch v.Int() {
case 0:
return reflect.ValueOf(false), nil
case 1:
return reflect.ValueOf(true), nil
}
return nilValue, fmt.Errorf("failed to cast to bool from %d", v.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch v.Uint() {
case 0:
return reflect.ValueOf(false), nil
case 1:
return reflect.ValueOf(true), nil
}
return nilValue, fmt.Errorf("failed to cast to bool from %d", v.Uint())
case reflect.String:
b, err := strconv.ParseBool(v.String())
if err != nil {
return nilValue, err
}
return reflect.ValueOf(b), nil
case reflect.Bool:
return v, nil
case reflect.Float32, reflect.Float64:
switch v.Float() {
case 0:
return reflect.ValueOf(false), nil
case 1:
return reflect.ValueOf(true), nil
}
return nilValue, fmt.Errorf("failed to cast to bool from %f", v.Float())
case reflect.Array:
if v.Len() > 0 {
return castBool(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to string from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castBool(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to string from empty slice")
case reflect.Interface:
return castBool(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to string from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to string from struct")
case reflect.Ptr:
return castBool(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to bool from %s", v.Type().Kind())
}
func castFloat(v reflect.Value) (reflect.Value, error) {
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.ValueOf(float64(v.Int())), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return reflect.ValueOf(float64(v.Uint())), nil
case reflect.String:
f64, err := strconv.ParseFloat(v.String(), 64)
if err != nil {
return nilValue, err
}
return reflect.ValueOf(f64), nil
case reflect.Bool:
if v.Bool() {
return reflect.ValueOf(float64(1)), nil
}
return reflect.ValueOf(float64(0)), nil
case reflect.Float32, reflect.Float64:
return v, nil
case reflect.Array:
if v.Len() > 0 {
return castFloat(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to float64 from empty array")
case reflect.Slice:
if v.Len() > 0 {
return castFloat(v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to float64 from empty slice")
case reflect.Interface:
return castFloat(reflect.ValueOf(v.Interface()))
case reflect.Map:
return nilValue, fmt.Errorf("failed to cast to float64 from map")
case reflect.Struct:
return nilValue, fmt.Errorf("failed to cast to float64 from struct")
case reflect.Ptr:
return castFloat(v.Elem())
}
return nilValue, fmt.Errorf("failed to cast to float64 from %s", v.Type().Kind())
}
func castArray(t reflect.Type, v reflect.Value) (reflect.Value, error) {
kind := v.Type().Kind()
if kind == reflect.Interface {
return castArray(t, reflect.ValueOf(v.Interface()))
}
if kind != reflect.Slice && kind != reflect.Array {
return nilValue, fmt.Errorf("failed to cast to array from %s", kind)
}
if t.Elem() == v.Type().Elem() {
return v, nil
}
if t.Len() != v.Len() {
return nilValue, fmt.Errorf("failed to cast [%d]array from slice of %d length", t.Len(), v.Len())
}
ret := reflect.New(t).Elem()
for i := 0; i < v.Len(); i++ {
vv, err := castValue(t.Elem(), v.Index(i))
if err != nil {
return nilValue, err
}
ret.Index(i).Set(vv)
}
return ret, nil
}
func castSlice(t reflect.Type, v reflect.Value) (reflect.Value, error) {
kind := v.Type().Kind()
if kind == reflect.Interface {
return castSlice(t, reflect.ValueOf(v.Interface()))
}
if kind != reflect.Slice && kind != reflect.Array {
return nilValue, fmt.Errorf("failed to cast to slice from %s", kind)
}
if t.Elem() == v.Type().Elem() {
return v, nil
}
ret := reflect.MakeSlice(t, v.Len(), v.Len())
for i := 0; i < v.Len(); i++ {
vv, err := castValue(t.Elem(), v.Index(i))
if err != nil {
return nilValue, err
}
ret.Index(i).Set(vv)
}
return ret, nil
}
func castMap(t reflect.Type, v reflect.Value) (reflect.Value, error) {
ret := reflect.MakeMap(t)
switch v.Type().Kind() {
case reflect.Map:
iter := v.MapRange()
for iter.Next() {
key, err := castValue(t.Key(), iter.Key())
if err != nil {
return nilValue, err
}
value, err := castValue(t.Elem(), iter.Value())
if err != nil {
return nilValue, err
}
ret.SetMapIndex(key, value)
}
return ret, nil
case reflect.Interface:
return castMap(t, reflect.ValueOf(v.Interface()))
case reflect.Slice:
if v.Len() > 0 {
return castMap(t, v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to map from empty slice")
}
return nilValue, fmt.Errorf("failed to cast to map from %s", v.Type().Kind())
}
func castStruct(t reflect.Type, v reflect.Value) (reflect.Value, error) {
ret := reflect.New(t).Elem()
switch v.Type().Kind() {
case reflect.Map:
iter := v.MapRange()
for iter.Next() {
key := iter.Key()
k, err := castString(key)
if err != nil {
return nilValue, err
}
fieldName := k.String()
field, ok := t.FieldByName(fieldName)
if ok {
value, err := castValue(field.Type, iter.Value())
if err != nil {
return nilValue, err
}
ret.FieldByName(fieldName).Set(value)
}
}
return ret, nil
case reflect.Struct:
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
ret.FieldByName(name).Set(v.FieldByName(name))
}
return ret, nil
case reflect.Interface:
return castStruct(t, reflect.ValueOf(v.Interface()))
case reflect.Slice:
if v.Len() > 0 {
return castStruct(t, v.Index(0))
}
return nilValue, fmt.Errorf("failed to cast to struct from empty slice")
default:
return nilValue, fmt.Errorf("failed to cast to struct from %s", v.Type().Kind())
}
}

View File

@ -1,7 +1,6 @@
package decoder package decoder
import ( import (
"fmt"
"unsafe" "unsafe"
"github.com/goccy/go-json/internal/errors" "github.com/goccy/go-json/internal/errors"
@ -77,7 +76,3 @@ func (d *boolDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.
} }
return 0, errors.ErrUnexpectedEndOfJSON("bool", cursor) return 0, errors.ErrUnexpectedEndOfJSON("bool", cursor)
} }
func (d *boolDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: bool decoder does not support decode path")
}

View File

@ -2,7 +2,6 @@ package decoder
import ( import (
"encoding/base64" "encoding/base64"
"fmt"
"unsafe" "unsafe"
"github.com/goccy/go-json/internal/errors" "github.com/goccy/go-json/internal/errors"
@ -24,8 +23,9 @@ func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName
unmarshalDecoder = newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName) unmarshalDecoder = newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName)
case runtime.PtrTo(typ).Implements(unmarshalTextType): case runtime.PtrTo(typ).Implements(unmarshalTextType):
unmarshalDecoder = newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName) unmarshalDecoder = newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName)
default: }
unmarshalDecoder, _ = compileUint8(typ, structName, fieldName) if unmarshalDecoder == nil {
return nil
} }
return newSliceDecoder(unmarshalDecoder, typ, 1, structName, fieldName) return newSliceDecoder(unmarshalDecoder, typ, 1, structName, fieldName)
} }
@ -79,10 +79,6 @@ func (d *bytesDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
return cursor, nil return cursor, nil
} }
func (d *bytesDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: []byte decoder does not support decode path")
}
func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) { func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) {
c := s.skipWhiteSpace() c := s.skipWhiteSpace()
if c == '[' { if c == '[' {

View File

@ -9,6 +9,7 @@ import (
"unicode" "unicode"
"unsafe" "unsafe"
"github.com/goccy/go-json/internal/errors"
"github.com/goccy/go-json/internal/runtime" "github.com/goccy/go-json/internal/runtime"
) )
@ -24,7 +25,7 @@ func init() {
if typeAddr == nil { if typeAddr == nil {
typeAddr = &runtime.TypeAddr{} typeAddr = &runtime.TypeAddr{}
} }
cachedDecoder = make([]Decoder, typeAddr.AddrRange>>typeAddr.AddrShift+1) cachedDecoder = make([]Decoder, typeAddr.AddrRange>>typeAddr.AddrShift)
} }
func loadDecoderMap() map[uintptr]Decoder { func loadDecoderMap() map[uintptr]Decoder {
@ -125,7 +126,13 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode
case reflect.Func: case reflect.Func:
return compileFunc(typ, structName, fieldName) return compileFunc(typ, structName, fieldName)
} }
return newInvalidDecoder(typ, structName, fieldName), nil return nil, &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(typ),
Offset: 0,
Struct: structName,
Field: fieldName,
}
} }
func isStringTagSupportedType(typ *runtime.Type) bool { func isStringTagSupportedType(typ *runtime.Type) bool {
@ -154,9 +161,6 @@ func compileMapKey(typ *runtime.Type, structName, fieldName string, structTypeTo
if runtime.PtrTo(typ).Implements(unmarshalTextType) { if runtime.PtrTo(typ).Implements(unmarshalTextType) {
return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
} }
if typ.Kind() == reflect.String {
return newStringDecoder(structName, fieldName), nil
}
dec, err := compile(typ, structName, fieldName, structTypeToDecoder) dec, err := compile(typ, structName, fieldName, structTypeToDecoder)
if err != nil { if err != nil {
return nil, err return nil, err
@ -170,9 +174,17 @@ func compileMapKey(typ *runtime.Type, structName, fieldName string, structTypeTo
case *ptrDecoder: case *ptrDecoder:
dec = t.dec dec = t.dec
default: default:
return newInvalidDecoder(typ, structName, fieldName), nil goto ERROR
} }
} }
ERROR:
return nil, &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(typ),
Offset: 0,
Struct: structName,
Field: fieldName,
}
} }
func compilePtr(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { func compilePtr(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
@ -310,21 +322,64 @@ func compileFunc(typ *runtime.Type, strutName, fieldName string) (Decoder, error
return newFuncDecoder(typ, strutName, fieldName), nil return newFuncDecoder(typ, strutName, fieldName), nil
} }
func typeToStructTags(typ *runtime.Type) runtime.StructTags { func removeConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, field reflect.StructField) {
tags := runtime.StructTags{} for k, v := range dec.fieldMap {
fieldNum := typ.NumField() if _, exists := conflictedMap[k]; exists {
for i := 0; i < fieldNum; i++ { // already conflicted key
field := typ.Field(i)
if runtime.IsIgnoredStructField(field) {
continue continue
} }
tags = append(tags, runtime.StructTagFromField(field)) set, exists := fieldMap[k]
if !exists {
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
continue
}
if set.isTaggedKey {
if v.isTaggedKey {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} else {
if v.isTaggedKey {
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
} else {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
}
} }
return tags
} }
func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
fieldNum := typ.NumField() fieldNum := typ.NumField()
conflictedMap := map[string]struct{}{}
fieldMap := map[string]*structFieldSet{} fieldMap := map[string]*structFieldSet{}
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
if dec, exists := structTypeToDecoder[typeptr]; exists { if dec, exists := structTypeToDecoder[typeptr]; exists {
@ -333,8 +388,6 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
structDec := newStructDecoder(structName, fieldName, fieldMap) structDec := newStructDecoder(structName, fieldName, fieldMap)
structTypeToDecoder[typeptr] = structDec structTypeToDecoder[typeptr] = structDec
structName = typ.Name() structName = typ.Name()
tags := typeToStructTags(typ)
allFields := []*structFieldSet{}
for i := 0; i < fieldNum; i++ { for i := 0; i < fieldNum; i++ {
field := typ.Field(i) field := typ.Field(i)
if runtime.IsIgnoredStructField(field) { if runtime.IsIgnoredStructField(field) {
@ -352,19 +405,7 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
// recursive definition // recursive definition
continue continue
} }
for k, v := range stDec.fieldMap { removeConflictFields(fieldMap, conflictedMap, stDec, field)
if tags.ExistsKey(k) {
continue
}
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
allFields = append(allFields, fieldSet)
}
} else if pdec, ok := dec.(*ptrDecoder); ok { } else if pdec, ok := dec.(*ptrDecoder); ok {
contentDec := pdec.contentDecoder() contentDec := pdec.contentDecoder()
if pdec.typ == typ { if pdec.typ == typ {
@ -380,38 +421,60 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
} }
if dec, ok := contentDec.(*structDecoder); ok { if dec, ok := contentDec.(*structDecoder); ok {
for k, v := range dec.fieldMap { for k, v := range dec.fieldMap {
if tags.ExistsKey(k) { if _, exists := conflictedMap[k]; exists {
// already conflicted key
continue continue
} }
fieldSet := &structFieldSet{ set, exists := fieldMap[k]
dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), if !exists {
offset: field.Offset, fieldSet := &structFieldSet{
isTaggedKey: v.isTaggedKey, dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
key: k, offset: field.Offset,
keyLen: int64(len(k)), isTaggedKey: v.isTaggedKey,
err: fieldSetErr, key: k,
keyLen: int64(len(k)),
err: fieldSetErr,
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
continue
}
if set.isTaggedKey {
if v.isTaggedKey {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} else {
if v.isTaggedKey {
fieldSet := &structFieldSet{
dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
offset: field.Offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
err: fieldSetErr,
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
} else {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} }
allFields = append(allFields, fieldSet)
} }
} else {
fieldSet := &structFieldSet{
dec: pdec,
offset: field.Offset,
isTaggedKey: tag.IsTaggedKey,
key: field.Name,
keyLen: int64(len(field.Name)),
}
allFields = append(allFields, fieldSet)
} }
} else {
fieldSet := &structFieldSet{
dec: dec,
offset: field.Offset,
isTaggedKey: tag.IsTaggedKey,
key: field.Name,
keyLen: int64(len(field.Name)),
}
allFields = append(allFields, fieldSet)
} }
} else { } else {
if tag.IsString && isStringTagSupportedType(runtime.Type2RType(field.Type)) { if tag.IsString && isStringTagSupportedType(runtime.Type2RType(field.Type)) {
@ -430,15 +493,11 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
key: key, key: key,
keyLen: int64(len(key)), keyLen: int64(len(key)),
} }
allFields = append(allFields, fieldSet) fieldMap[key] = fieldSet
} lower := strings.ToLower(key)
} if _, exists := fieldMap[lower]; !exists {
for _, set := range filterDuplicatedFields(allFields) { fieldMap[lower] = fieldSet
fieldMap[set.key] = set }
lower := strings.ToLower(set.key)
if _, exists := fieldMap[lower]; !exists {
// first win
fieldMap[lower] = set
} }
} }
delete(structTypeToDecoder, typeptr) delete(structTypeToDecoder, typeptr)
@ -446,42 +505,6 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
return structDec, nil return structDec, nil
} }
func filterDuplicatedFields(allFields []*structFieldSet) []*structFieldSet {
fieldMap := map[string][]*structFieldSet{}
for _, field := range allFields {
fieldMap[field.key] = append(fieldMap[field.key], field)
}
duplicatedFieldMap := map[string]struct{}{}
for k, sets := range fieldMap {
sets = filterFieldSets(sets)
if len(sets) != 1 {
duplicatedFieldMap[k] = struct{}{}
}
}
filtered := make([]*structFieldSet, 0, len(allFields))
for _, field := range allFields {
if _, exists := duplicatedFieldMap[field.key]; exists {
continue
}
filtered = append(filtered, field)
}
return filtered
}
func filterFieldSets(sets []*structFieldSet) []*structFieldSet {
if len(sets) == 1 {
return sets
}
filtered := make([]*structFieldSet, 0, len(sets))
for _, set := range sets {
if set.isTaggedKey {
filtered = append(filtered, set)
}
}
return filtered
}
func implementsUnmarshalJSONType(typ *runtime.Type) bool { func implementsUnmarshalJSONType(typ *runtime.Type) bool {
return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType) return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType)
} }

View File

@ -1,4 +1,3 @@
//go:build !race
// +build !race // +build !race
package decoder package decoder

View File

@ -1,4 +1,3 @@
//go:build race
// +build race // +build race
package decoder package decoder

View File

@ -156,15 +156,3 @@ func (d *floatDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
d.op(p, f64) d.op(p, f64)
return cursor, nil return cursor, nil
} }
func (d *floatDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
buf := ctx.Buf
bytes, c, err := d.decodeByte(buf, cursor)
if err != nil {
return nil, 0, err
}
if bytes == nil {
return [][]byte{nullbytes}, c, nil
}
return [][]byte{bytes}, c, nil
}

View File

@ -2,7 +2,6 @@ package decoder
import ( import (
"bytes" "bytes"
"fmt"
"unsafe" "unsafe"
"github.com/goccy/go-json/internal/errors" "github.com/goccy/go-json/internal/errors"
@ -140,7 +139,3 @@ func (d *funcDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.
} }
return cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor) return cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor)
} }
func (d *funcDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: func decoder does not support decode path")
}

View File

@ -192,15 +192,15 @@ func (d *intDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) erro
} }
switch d.kind { switch d.kind {
case reflect.Int8: case reflect.Int8:
if i64 < -1*(1<<7) || (1<<7) <= i64 { if i64 <= -1*(1<<7) || (1<<7) <= i64 {
return d.typeError(bytes, s.totalOffset()) return d.typeError(bytes, s.totalOffset())
} }
case reflect.Int16: case reflect.Int16:
if i64 < -1*(1<<15) || (1<<15) <= i64 { if i64 <= -1*(1<<15) || (1<<15) <= i64 {
return d.typeError(bytes, s.totalOffset()) return d.typeError(bytes, s.totalOffset())
} }
case reflect.Int32: case reflect.Int32:
if i64 < -1*(1<<31) || (1<<31) <= i64 { if i64 <= -1*(1<<31) || (1<<31) <= i64 {
return d.typeError(bytes, s.totalOffset()) return d.typeError(bytes, s.totalOffset())
} }
} }
@ -225,22 +225,18 @@ func (d *intDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P
} }
switch d.kind { switch d.kind {
case reflect.Int8: case reflect.Int8:
if i64 < -1*(1<<7) || (1<<7) <= i64 { if i64 <= -1*(1<<7) || (1<<7) <= i64 {
return 0, d.typeError(bytes, cursor) return 0, d.typeError(bytes, cursor)
} }
case reflect.Int16: case reflect.Int16:
if i64 < -1*(1<<15) || (1<<15) <= i64 { if i64 <= -1*(1<<15) || (1<<15) <= i64 {
return 0, d.typeError(bytes, cursor) return 0, d.typeError(bytes, cursor)
} }
case reflect.Int32: case reflect.Int32:
if i64 < -1*(1<<31) || (1<<31) <= i64 { if i64 <= -1*(1<<31) || (1<<31) <= i64 {
return 0, d.typeError(bytes, cursor) return 0, d.typeError(bytes, cursor)
} }
} }
d.op(p, i64) d.op(p, i64)
return cursor, nil return cursor, nil
} }
func (d *intDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: int decoder does not support decode path")
}

View File

@ -94,7 +94,6 @@ func (d *interfaceDecoder) numDecoder(s *Stream) Decoder {
var ( var (
emptyInterfaceType = runtime.Type2RType(reflect.TypeOf((*interface{})(nil)).Elem()) emptyInterfaceType = runtime.Type2RType(reflect.TypeOf((*interface{})(nil)).Elem())
EmptyInterfaceType = emptyInterfaceType
interfaceMapType = runtime.Type2RType( interfaceMapType = runtime.Type2RType(
reflect.TypeOf((*map[string]interface{})(nil)).Elem(), reflect.TypeOf((*map[string]interface{})(nil)).Elem(),
) )
@ -457,72 +456,3 @@ func (d *interfaceDecoder) decodeEmptyInterface(ctx *RuntimeContext, cursor, dep
} }
return cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor) return cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor)
} }
func NewPathDecoder() Decoder {
ifaceDecoder := &interfaceDecoder{
typ: emptyInterfaceType,
structName: "",
fieldName: "",
floatDecoder: newFloatDecoder("", "", func(p unsafe.Pointer, v float64) {
*(*interface{})(p) = v
}),
numberDecoder: newNumberDecoder("", "", func(p unsafe.Pointer, v json.Number) {
*(*interface{})(p) = v
}),
stringDecoder: newStringDecoder("", ""),
}
ifaceDecoder.sliceDecoder = newSliceDecoder(
ifaceDecoder,
emptyInterfaceType,
emptyInterfaceType.Size(),
"", "",
)
ifaceDecoder.mapDecoder = newMapDecoder(
interfaceMapType,
stringType,
ifaceDecoder.stringDecoder,
interfaceMapType.Elem(),
ifaceDecoder,
"", "",
)
return ifaceDecoder
}
var (
truebytes = []byte("true")
falsebytes = []byte("false")
)
func (d *interfaceDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
buf := ctx.Buf
cursor = skipWhiteSpace(buf, cursor)
switch buf[cursor] {
case '{':
return d.mapDecoder.DecodePath(ctx, cursor, depth)
case '[':
return d.sliceDecoder.DecodePath(ctx, cursor, depth)
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
return d.floatDecoder.DecodePath(ctx, cursor, depth)
case '"':
return d.stringDecoder.DecodePath(ctx, cursor, depth)
case 't':
if err := validateTrue(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return [][]byte{truebytes}, cursor, nil
case 'f':
if err := validateFalse(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 5
return [][]byte{falsebytes}, cursor, nil
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return [][]byte{nullbytes}, cursor, nil
}
return nil, cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor)
}

View File

@ -1,55 +0,0 @@
package decoder
import (
"reflect"
"unsafe"
"github.com/goccy/go-json/internal/errors"
"github.com/goccy/go-json/internal/runtime"
)
type invalidDecoder struct {
typ *runtime.Type
kind reflect.Kind
structName string
fieldName string
}
func newInvalidDecoder(typ *runtime.Type, structName, fieldName string) *invalidDecoder {
return &invalidDecoder{
typ: typ,
kind: typ.Kind(),
structName: structName,
fieldName: fieldName,
}
}
func (d *invalidDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error {
return &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(d.typ),
Offset: s.totalOffset(),
Struct: d.structName,
Field: d.fieldName,
}
}
func (d *invalidDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) (int64, error) {
return 0, &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(d.typ),
Offset: cursor,
Struct: d.structName,
Field: d.fieldName,
}
}
func (d *invalidDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, &errors.UnmarshalTypeError{
Value: "object",
Type: runtime.RType2Type(d.typ),
Offset: cursor,
Struct: d.structName,
Field: d.fieldName,
}
}

View File

@ -87,13 +87,13 @@ func (d *mapDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) erro
if mapValue == nil { if mapValue == nil {
mapValue = makemap(d.mapType, 0) mapValue = makemap(d.mapType, 0)
} }
s.cursor++ if s.buf[s.cursor+1] == '}' {
if s.skipWhiteSpace() == '}' {
*(*unsafe.Pointer)(p) = mapValue *(*unsafe.Pointer)(p) = mapValue
s.cursor++ s.cursor += 2
return nil return nil
} }
for { for {
s.cursor++
k := unsafe_New(d.keyType) k := unsafe_New(d.keyType)
if err := d.keyDecoder.DecodeStream(s, depth, k); err != nil { if err := d.keyDecoder.DecodeStream(s, depth, k); err != nil {
return err return err
@ -117,7 +117,6 @@ func (d *mapDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) erro
if !s.equalChar(',') { if !s.equalChar(',') {
return errors.ErrExpected("comma after object value", s.totalOffset()) return errors.ErrExpected("comma after object value", s.totalOffset())
} }
s.cursor++
} }
} }
@ -185,96 +184,3 @@ func (d *mapDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P
cursor++ cursor++
} }
} }
func (d *mapDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
buf := ctx.Buf
depth++
if depth > maxDecodeNestingDepth {
return nil, 0, errors.ErrExceededMaxDepth(buf[cursor], cursor)
}
cursor = skipWhiteSpace(buf, cursor)
buflen := int64(len(buf))
if buflen < 2 {
return nil, 0, errors.ErrExpected("{} for map", cursor)
}
switch buf[cursor] {
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return [][]byte{nullbytes}, cursor, nil
case '{':
default:
return nil, 0, errors.ErrExpected("{ character for map value", cursor)
}
cursor++
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == '}' {
cursor++
return nil, cursor, nil
}
keyDecoder, ok := d.keyDecoder.(*stringDecoder)
if !ok {
return nil, 0, &errors.UnmarshalTypeError{
Value: "string",
Type: reflect.TypeOf(""),
Offset: cursor,
Struct: d.structName,
Field: d.fieldName,
}
}
ret := [][]byte{}
for {
key, keyCursor, err := keyDecoder.decodeByte(buf, cursor)
if err != nil {
return nil, 0, err
}
cursor = skipWhiteSpace(buf, keyCursor)
if buf[cursor] != ':' {
return nil, 0, errors.ErrExpected("colon after object key", cursor)
}
cursor++
child, found, err := ctx.Option.Path.Field(string(key))
if err != nil {
return nil, 0, err
}
if found {
if child != nil {
oldPath := ctx.Option.Path.node
ctx.Option.Path.node = child
paths, c, err := d.valueDecoder.DecodePath(ctx, cursor, depth)
if err != nil {
return nil, 0, err
}
ctx.Option.Path.node = oldPath
ret = append(ret, paths...)
cursor = c
} else {
start := cursor
end, err := skipValue(buf, cursor, depth)
if err != nil {
return nil, 0, err
}
ret = append(ret, buf[start:end])
cursor = end
}
} else {
c, err := skipValue(buf, cursor, depth)
if err != nil {
return nil, 0, err
}
cursor = c
}
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == '}' {
cursor++
return ret, cursor, nil
}
if buf[cursor] != ',' {
return nil, 0, errors.ErrExpected("comma after object value", cursor)
}
cursor++
}
}

View File

@ -51,17 +51,6 @@ func (d *numberDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
return cursor, nil return cursor, nil
} }
func (d *numberDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
bytes, c, err := d.decodeByte(ctx.Buf, cursor)
if err != nil {
return nil, 0, err
}
if bytes == nil {
return [][]byte{nullbytes}, c, nil
}
return [][]byte{bytes}, c, nil
}
func (d *numberDecoder) decodeStreamByte(s *Stream) ([]byte, error) { func (d *numberDecoder) decodeStreamByte(s *Stream) ([]byte, error) {
start := s.cursor start := s.cursor
for { for {

View File

@ -7,11 +7,9 @@ type OptionFlags uint8
const ( const (
FirstWinOption OptionFlags = 1 << iota FirstWinOption OptionFlags = 1 << iota
ContextOption ContextOption
PathOption
) )
type Option struct { type Option struct {
Flags OptionFlags Flags OptionFlags
Context context.Context Context context.Context
Path *Path
} }

View File

@ -1,670 +0,0 @@
package decoder
import (
"fmt"
"reflect"
"strconv"
"github.com/goccy/go-json/internal/errors"
"github.com/goccy/go-json/internal/runtime"
)
type PathString string
func (s PathString) Build() (*Path, error) {
builder := new(PathBuilder)
return builder.Build([]rune(s))
}
type PathBuilder struct {
root PathNode
node PathNode
singleQuotePathSelector bool
doubleQuotePathSelector bool
}
func (b *PathBuilder) Build(buf []rune) (*Path, error) {
node, err := b.build(buf)
if err != nil {
return nil, err
}
return &Path{
node: node,
RootSelectorOnly: node == nil,
SingleQuotePathSelector: b.singleQuotePathSelector,
DoubleQuotePathSelector: b.doubleQuotePathSelector,
}, nil
}
func (b *PathBuilder) build(buf []rune) (PathNode, error) {
if len(buf) == 0 {
return nil, errors.ErrEmptyPath()
}
if buf[0] != '$' {
return nil, errors.ErrInvalidPath("JSON Path must start with a $ character")
}
if len(buf) == 1 {
return nil, nil
}
buf = buf[1:]
offset, err := b.buildNext(buf)
if err != nil {
return nil, err
}
if len(buf) > offset {
return nil, errors.ErrInvalidPath("remain invalid path %q", buf[offset:])
}
return b.root, nil
}
func (b *PathBuilder) buildNextCharIfExists(buf []rune, cursor int) (int, error) {
if len(buf) > cursor {
offset, err := b.buildNext(buf[cursor:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
}
return cursor, nil
}
func (b *PathBuilder) buildNext(buf []rune) (int, error) {
switch buf[0] {
case '.':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with dot character")
}
offset, err := b.buildSelector(buf[1:])
if err != nil {
return 0, err
}
return offset + 1, nil
case '[':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character")
}
offset, err := b.buildIndex(buf[1:])
if err != nil {
return 0, err
}
return offset + 1, nil
default:
return 0, errors.ErrInvalidPath("expect dot or left bracket character. but found %c character", buf[0])
}
}
func (b *PathBuilder) buildSelector(buf []rune) (int, error) {
switch buf[0] {
case '.':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with double dot character")
}
offset, err := b.buildPathRecursive(buf[1:])
if err != nil {
return 0, err
}
return 1 + offset, nil
case '[', ']', '$', '*':
return 0, errors.ErrInvalidPath("found invalid path character %c after dot", buf[0])
}
for cursor := 0; cursor < len(buf); cursor++ {
switch buf[cursor] {
case '$', '*', ']':
return 0, errors.ErrInvalidPath("found %c character in field selector context", buf[cursor])
case '.':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with dot character")
}
selector := buf[:cursor]
b.addSelectorNode(string(selector))
offset, err := b.buildSelector(buf[cursor+1:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
case '[':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character")
}
selector := buf[:cursor]
b.addSelectorNode(string(selector))
offset, err := b.buildIndex(buf[cursor+1:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
case '"':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with double quote character")
}
offset, err := b.buildQuoteSelector(buf[cursor+1:], DoubleQuotePathSelector)
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
}
}
b.addSelectorNode(string(buf))
return len(buf), nil
}
func (b *PathBuilder) buildQuoteSelector(buf []rune, sel QuotePathSelector) (int, error) {
switch buf[0] {
case '[', ']', '$', '.', '*', '\'', '"':
return 0, errors.ErrInvalidPath("found invalid path character %c after quote", buf[0])
}
for cursor := 0; cursor < len(buf); cursor++ {
switch buf[cursor] {
case '\'':
if sel != SingleQuotePathSelector {
return 0, errors.ErrInvalidPath("found double quote character in field selector with single quote context")
}
if len(buf) <= cursor+1 {
return 0, errors.ErrInvalidPath("JSON Path ends with single quote character in field selector context")
}
if buf[cursor+1] != ']' {
return 0, errors.ErrInvalidPath("expect right bracket for field selector with single quote but found %c", buf[cursor+1])
}
selector := buf[:cursor]
b.addSelectorNode(string(selector))
b.singleQuotePathSelector = true
return b.buildNextCharIfExists(buf, cursor+2)
case '"':
if sel != DoubleQuotePathSelector {
return 0, errors.ErrInvalidPath("found single quote character in field selector with double quote context")
}
selector := buf[:cursor]
b.addSelectorNode(string(selector))
b.doubleQuotePathSelector = true
return b.buildNextCharIfExists(buf, cursor+1)
}
}
return 0, errors.ErrInvalidPath("couldn't find quote character in selector quote path context")
}
func (b *PathBuilder) buildPathRecursive(buf []rune) (int, error) {
switch buf[0] {
case '.', '[', ']', '$', '*':
return 0, errors.ErrInvalidPath("found invalid path character %c after double dot", buf[0])
}
for cursor := 0; cursor < len(buf); cursor++ {
switch buf[cursor] {
case '$', '*', ']':
return 0, errors.ErrInvalidPath("found %c character in field selector context", buf[cursor])
case '.':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with dot character")
}
selector := buf[:cursor]
b.addRecursiveNode(string(selector))
offset, err := b.buildSelector(buf[cursor+1:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
case '[':
if cursor+1 >= len(buf) {
return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character")
}
selector := buf[:cursor]
b.addRecursiveNode(string(selector))
offset, err := b.buildIndex(buf[cursor+1:])
if err != nil {
return 0, err
}
return cursor + 1 + offset, nil
}
}
b.addRecursiveNode(string(buf))
return len(buf), nil
}
func (b *PathBuilder) buildIndex(buf []rune) (int, error) {
switch buf[0] {
case '.', '[', ']', '$':
return 0, errors.ErrInvalidPath("found invalid path character %c after left bracket", buf[0])
case '\'':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with single quote character")
}
offset, err := b.buildQuoteSelector(buf[1:], SingleQuotePathSelector)
if err != nil {
return 0, err
}
return 1 + offset, nil
case '*':
if len(buf) == 1 {
return 0, errors.ErrInvalidPath("JSON Path ends with star character")
}
if buf[1] != ']' {
return 0, errors.ErrInvalidPath("expect right bracket character for index all path but found %c character", buf[1])
}
b.addIndexAllNode()
offset := len("*]")
if len(buf) > 2 {
buildOffset, err := b.buildNext(buf[2:])
if err != nil {
return 0, err
}
return offset + buildOffset, nil
}
return offset, nil
}
for cursor := 0; cursor < len(buf); cursor++ {
switch buf[cursor] {
case ']':
index, err := strconv.ParseInt(string(buf[:cursor]), 10, 64)
if err != nil {
return 0, errors.ErrInvalidPath("%q is unexpected index path", buf[:cursor])
}
b.addIndexNode(int(index))
return b.buildNextCharIfExists(buf, cursor+1)
}
}
return 0, errors.ErrInvalidPath("couldn't find right bracket character in index path context")
}
func (b *PathBuilder) addIndexAllNode() {
node := newPathIndexAllNode()
if b.root == nil {
b.root = node
b.node = node
} else {
b.node = b.node.chain(node)
}
}
func (b *PathBuilder) addRecursiveNode(selector string) {
node := newPathRecursiveNode(selector)
if b.root == nil {
b.root = node
b.node = node
} else {
b.node = b.node.chain(node)
}
}
func (b *PathBuilder) addSelectorNode(name string) {
node := newPathSelectorNode(name)
if b.root == nil {
b.root = node
b.node = node
} else {
b.node = b.node.chain(node)
}
}
func (b *PathBuilder) addIndexNode(idx int) {
node := newPathIndexNode(idx)
if b.root == nil {
b.root = node
b.node = node
} else {
b.node = b.node.chain(node)
}
}
type QuotePathSelector int
const (
SingleQuotePathSelector QuotePathSelector = 1
DoubleQuotePathSelector QuotePathSelector = 2
)
type Path struct {
node PathNode
RootSelectorOnly bool
SingleQuotePathSelector bool
DoubleQuotePathSelector bool
}
func (p *Path) Field(sel string) (PathNode, bool, error) {
if p.node == nil {
return nil, false, nil
}
return p.node.Field(sel)
}
func (p *Path) Get(src, dst reflect.Value) error {
if p.node == nil {
return nil
}
return p.node.Get(src, dst)
}
func (p *Path) String() string {
if p.node == nil {
return "$"
}
return p.node.String()
}
type PathNode interface {
fmt.Stringer
Index(idx int) (PathNode, bool, error)
Field(fieldName string) (PathNode, bool, error)
Get(src, dst reflect.Value) error
chain(PathNode) PathNode
target() bool
single() bool
}
type BasePathNode struct {
child PathNode
}
func (n *BasePathNode) chain(node PathNode) PathNode {
n.child = node
return node
}
func (n *BasePathNode) target() bool {
return n.child == nil
}
func (n *BasePathNode) single() bool {
return true
}
type PathSelectorNode struct {
*BasePathNode
selector string
}
func newPathSelectorNode(selector string) *PathSelectorNode {
return &PathSelectorNode{
BasePathNode: &BasePathNode{},
selector: selector,
}
}
func (n *PathSelectorNode) Index(idx int) (PathNode, bool, error) {
return nil, false, &errors.PathError{}
}
func (n *PathSelectorNode) Field(fieldName string) (PathNode, bool, error) {
if n.selector == fieldName {
return n.child, true, nil
}
return nil, false, nil
}
func (n *PathSelectorNode) Get(src, dst reflect.Value) error {
switch src.Type().Kind() {
case reflect.Map:
iter := src.MapRange()
for iter.Next() {
key, ok := iter.Key().Interface().(string)
if !ok {
return fmt.Errorf("invalid map key type %T", src.Type().Key())
}
child, found, err := n.Field(key)
if err != nil {
return err
}
if found {
if child != nil {
return child.Get(iter.Value(), dst)
}
return AssignValue(iter.Value(), dst)
}
}
case reflect.Struct:
typ := src.Type()
for i := 0; i < typ.Len(); i++ {
tag := runtime.StructTagFromField(typ.Field(i))
child, found, err := n.Field(tag.Key)
if err != nil {
return err
}
if found {
if child != nil {
return child.Get(src.Field(i), dst)
}
return AssignValue(src.Field(i), dst)
}
}
case reflect.Ptr:
return n.Get(src.Elem(), dst)
case reflect.Interface:
return n.Get(reflect.ValueOf(src.Interface()), dst)
case reflect.Float64, reflect.String, reflect.Bool:
return AssignValue(src, dst)
}
return fmt.Errorf("failed to get %s value from %s", n.selector, src.Type())
}
func (n *PathSelectorNode) String() string {
s := fmt.Sprintf(".%s", n.selector)
if n.child != nil {
s += n.child.String()
}
return s
}
type PathIndexNode struct {
*BasePathNode
selector int
}
func newPathIndexNode(selector int) *PathIndexNode {
return &PathIndexNode{
BasePathNode: &BasePathNode{},
selector: selector,
}
}
func (n *PathIndexNode) Index(idx int) (PathNode, bool, error) {
if n.selector == idx {
return n.child, true, nil
}
return nil, false, nil
}
func (n *PathIndexNode) Field(fieldName string) (PathNode, bool, error) {
return nil, false, &errors.PathError{}
}
func (n *PathIndexNode) Get(src, dst reflect.Value) error {
switch src.Type().Kind() {
case reflect.Array, reflect.Slice:
if src.Len() > n.selector {
if n.child != nil {
return n.child.Get(src.Index(n.selector), dst)
}
return AssignValue(src.Index(n.selector), dst)
}
case reflect.Ptr:
return n.Get(src.Elem(), dst)
case reflect.Interface:
return n.Get(reflect.ValueOf(src.Interface()), dst)
}
return fmt.Errorf("failed to get [%d] value from %s", n.selector, src.Type())
}
func (n *PathIndexNode) String() string {
s := fmt.Sprintf("[%d]", n.selector)
if n.child != nil {
s += n.child.String()
}
return s
}
type PathIndexAllNode struct {
*BasePathNode
}
func newPathIndexAllNode() *PathIndexAllNode {
return &PathIndexAllNode{
BasePathNode: &BasePathNode{},
}
}
func (n *PathIndexAllNode) Index(idx int) (PathNode, bool, error) {
return n.child, true, nil
}
func (n *PathIndexAllNode) Field(fieldName string) (PathNode, bool, error) {
return nil, false, &errors.PathError{}
}
func (n *PathIndexAllNode) Get(src, dst reflect.Value) error {
switch src.Type().Kind() {
case reflect.Array, reflect.Slice:
var arr []interface{}
for i := 0; i < src.Len(); i++ {
var v interface{}
rv := reflect.ValueOf(&v)
if n.child != nil {
if err := n.child.Get(src.Index(i), rv); err != nil {
return err
}
} else {
if err := AssignValue(src.Index(i), rv); err != nil {
return err
}
}
arr = append(arr, v)
}
if err := AssignValue(reflect.ValueOf(arr), dst); err != nil {
return err
}
return nil
case reflect.Ptr:
return n.Get(src.Elem(), dst)
case reflect.Interface:
return n.Get(reflect.ValueOf(src.Interface()), dst)
}
return fmt.Errorf("failed to get all value from %s", src.Type())
}
func (n *PathIndexAllNode) String() string {
s := "[*]"
if n.child != nil {
s += n.child.String()
}
return s
}
type PathRecursiveNode struct {
*BasePathNode
selector string
}
func newPathRecursiveNode(selector string) *PathRecursiveNode {
node := newPathSelectorNode(selector)
return &PathRecursiveNode{
BasePathNode: &BasePathNode{
child: node,
},
selector: selector,
}
}
func (n *PathRecursiveNode) Field(fieldName string) (PathNode, bool, error) {
if n.selector == fieldName {
return n.child, true, nil
}
return nil, false, nil
}
func (n *PathRecursiveNode) Index(_ int) (PathNode, bool, error) {
return n, true, nil
}
func valueToSliceValue(v interface{}) []interface{} {
rv := reflect.ValueOf(v)
ret := []interface{}{}
if rv.Type().Kind() == reflect.Slice || rv.Type().Kind() == reflect.Array {
for i := 0; i < rv.Len(); i++ {
ret = append(ret, rv.Index(i).Interface())
}
return ret
}
return []interface{}{v}
}
func (n *PathRecursiveNode) Get(src, dst reflect.Value) error {
if n.child == nil {
return fmt.Errorf("failed to get by recursive path ..%s", n.selector)
}
var arr []interface{}
switch src.Type().Kind() {
case reflect.Map:
iter := src.MapRange()
for iter.Next() {
key, ok := iter.Key().Interface().(string)
if !ok {
return fmt.Errorf("invalid map key type %T", src.Type().Key())
}
child, found, err := n.Field(key)
if err != nil {
return err
}
if found {
var v interface{}
rv := reflect.ValueOf(&v)
_ = child.Get(iter.Value(), rv)
arr = append(arr, valueToSliceValue(v)...)
} else {
var v interface{}
rv := reflect.ValueOf(&v)
_ = n.Get(iter.Value(), rv)
if v != nil {
arr = append(arr, valueToSliceValue(v)...)
}
}
}
_ = AssignValue(reflect.ValueOf(arr), dst)
return nil
case reflect.Struct:
typ := src.Type()
for i := 0; i < typ.Len(); i++ {
tag := runtime.StructTagFromField(typ.Field(i))
child, found, err := n.Field(tag.Key)
if err != nil {
return err
}
if found {
var v interface{}
rv := reflect.ValueOf(&v)
_ = child.Get(src.Field(i), rv)
arr = append(arr, valueToSliceValue(v)...)
} else {
var v interface{}
rv := reflect.ValueOf(&v)
_ = n.Get(src.Field(i), rv)
if v != nil {
arr = append(arr, valueToSliceValue(v)...)
}
}
}
_ = AssignValue(reflect.ValueOf(arr), dst)
return nil
case reflect.Array, reflect.Slice:
for i := 0; i < src.Len(); i++ {
var v interface{}
rv := reflect.ValueOf(&v)
_ = n.Get(src.Index(i), rv)
if v != nil {
arr = append(arr, valueToSliceValue(v)...)
}
}
_ = AssignValue(reflect.ValueOf(arr), dst)
return nil
case reflect.Ptr:
return n.Get(src.Elem(), dst)
case reflect.Interface:
return n.Get(reflect.ValueOf(src.Interface()), dst)
}
return fmt.Errorf("failed to get %s value from %s", n.selector, src.Type())
}
func (n *PathRecursiveNode) String() string {
s := fmt.Sprintf("..%s", n.selector)
if n.child != nil {
s += n.child.String()
}
return s
}

View File

@ -1,7 +1,6 @@
package decoder package decoder
import ( import (
"fmt"
"unsafe" "unsafe"
"github.com/goccy/go-json/internal/runtime" "github.com/goccy/go-json/internal/runtime"
@ -35,10 +34,6 @@ func (d *ptrDecoder) contentDecoder() Decoder {
//go:linkname unsafe_New reflect.unsafe_New //go:linkname unsafe_New reflect.unsafe_New
func unsafe_New(*runtime.Type) unsafe.Pointer func unsafe_New(*runtime.Type) unsafe.Pointer
func UnsafeNew(t *runtime.Type) unsafe.Pointer {
return unsafe_New(t)
}
func (d *ptrDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error { func (d *ptrDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error {
if s.skipWhiteSpace() == nul { if s.skipWhiteSpace() == nul {
s.read() s.read()
@ -85,13 +80,8 @@ func (d *ptrDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P
} }
c, err := d.dec.Decode(ctx, cursor, depth, newptr) c, err := d.dec.Decode(ctx, cursor, depth, newptr)
if err != nil { if err != nil {
*(*unsafe.Pointer)(p) = nil
return 0, err return 0, err
} }
cursor = c cursor = c
return cursor, nil return cursor, nil
} }
func (d *ptrDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: ptr decoder does not support decode path")
}

View File

@ -299,82 +299,3 @@ func (d *sliceDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
} }
} }
} }
func (d *sliceDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
buf := ctx.Buf
depth++
if depth > maxDecodeNestingDepth {
return nil, 0, errors.ErrExceededMaxDepth(buf[cursor], cursor)
}
ret := [][]byte{}
for {
switch buf[cursor] {
case ' ', '\n', '\t', '\r':
cursor++
continue
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return [][]byte{nullbytes}, cursor, nil
case '[':
cursor++
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == ']' {
cursor++
return ret, cursor, nil
}
idx := 0
for {
child, found, err := ctx.Option.Path.node.Index(idx)
if err != nil {
return nil, 0, err
}
if found {
if child != nil {
oldPath := ctx.Option.Path.node
ctx.Option.Path.node = child
paths, c, err := d.valueDecoder.DecodePath(ctx, cursor, depth)
if err != nil {
return nil, 0, err
}
ctx.Option.Path.node = oldPath
ret = append(ret, paths...)
cursor = c
} else {
start := cursor
end, err := skipValue(buf, cursor, depth)
if err != nil {
return nil, 0, err
}
ret = append(ret, buf[start:end])
cursor = end
}
} else {
c, err := skipValue(buf, cursor, depth)
if err != nil {
return nil, 0, err
}
cursor = c
}
cursor = skipWhiteSpace(buf, cursor)
switch buf[cursor] {
case ']':
cursor++
return ret, cursor, nil
case ',':
idx++
default:
return nil, 0, errors.ErrInvalidCharacter(buf[cursor], "slice", cursor)
}
cursor++
}
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
return nil, 0, d.errNumber(cursor)
default:
return nil, 0, errors.ErrUnexpectedEndOfJSON("slice", cursor)
}
}
}

View File

@ -103,7 +103,7 @@ func (s *Stream) statForRetry() ([]byte, int64, unsafe.Pointer) {
func (s *Stream) Reset() { func (s *Stream) Reset() {
s.reset() s.reset()
s.bufSize = int64(len(s.buf)) s.bufSize = initBufSize
} }
func (s *Stream) More() bool { func (s *Stream) More() bool {
@ -138,11 +138,8 @@ func (s *Stream) Token() (interface{}, error) {
s.cursor++ s.cursor++
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
bytes := floatBytes(s) bytes := floatBytes(s)
str := *(*string)(unsafe.Pointer(&bytes)) s := *(*string)(unsafe.Pointer(&bytes))
if s.UseNumber { f64, err := strconv.ParseFloat(s, 64)
return json.Number(str), nil
}
f64, err := strconv.ParseFloat(str, 64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -280,7 +277,7 @@ func (s *Stream) skipObject(depth int64) error {
if char(p, cursor) == nul { if char(p, cursor) == nul {
s.cursor = cursor s.cursor = cursor
if s.read() { if s.read() {
_, cursor, p = s.stat() _, cursor, p = s.statForRetry()
continue continue
} }
return errors.ErrUnexpectedEndOfJSON("string of object", cursor) return errors.ErrUnexpectedEndOfJSON("string of object", cursor)
@ -343,7 +340,7 @@ func (s *Stream) skipArray(depth int64) error {
if char(p, cursor) == nul { if char(p, cursor) == nul {
s.cursor = cursor s.cursor = cursor
if s.read() { if s.read() {
_, cursor, p = s.stat() _, cursor, p = s.statForRetry()
continue continue
} }
return errors.ErrUnexpectedEndOfJSON("string of object", cursor) return errors.ErrUnexpectedEndOfJSON("string of object", cursor)
@ -401,7 +398,7 @@ func (s *Stream) skipValue(depth int64) error {
if char(p, cursor) == nul { if char(p, cursor) == nul {
s.cursor = cursor s.cursor = cursor
if s.read() { if s.read() {
_, cursor, p = s.stat() _, cursor, p = s.statForRetry()
continue continue
} }
return errors.ErrUnexpectedEndOfJSON("value of string", s.totalOffset()) return errors.ErrUnexpectedEndOfJSON("value of string", s.totalOffset())
@ -426,6 +423,7 @@ func (s *Stream) skipValue(depth int64) error {
continue continue
} else if c == nul { } else if c == nul {
if s.read() { if s.read() {
s.cursor-- // for retry current character
_, cursor, p = s.stat() _, cursor, p = s.stat()
continue continue
} }

View File

@ -1,8 +1,6 @@
package decoder package decoder
import ( import (
"bytes"
"fmt"
"reflect" "reflect"
"unicode" "unicode"
"unicode/utf16" "unicode/utf16"
@ -60,17 +58,6 @@ func (d *stringDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
return cursor, nil return cursor, nil
} }
func (d *stringDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
bytes, c, err := d.decodeByte(ctx.Buf, cursor)
if err != nil {
return nil, 0, err
}
if bytes == nil {
return [][]byte{nullbytes}, c, nil
}
return [][]byte{bytes}, c, nil
}
var ( var (
hexToInt = [256]int{ hexToInt = [256]int{
'0': 0, '0': 0,
@ -106,30 +93,24 @@ func unicodeToRune(code []byte) rune {
return r return r
} }
func readAtLeast(s *Stream, n int64, p *unsafe.Pointer) bool {
for s.cursor+n >= s.length {
if !s.read() {
return false
}
*p = s.bufptr()
}
return true
}
func decodeUnicodeRune(s *Stream, p unsafe.Pointer) (rune, int64, unsafe.Pointer, error) { func decodeUnicodeRune(s *Stream, p unsafe.Pointer) (rune, int64, unsafe.Pointer, error) {
const defaultOffset = 5 const defaultOffset = 5
const surrogateOffset = 11 const surrogateOffset = 11
if !readAtLeast(s, defaultOffset, &p) { if s.cursor+defaultOffset >= s.length {
return rune(0), 0, nil, errors.ErrInvalidCharacter(s.char(), "escaped string", s.totalOffset()) if !s.read() {
return rune(0), 0, nil, errors.ErrInvalidCharacter(s.char(), "escaped string", s.totalOffset())
}
p = s.bufptr()
} }
r := unicodeToRune(s.buf[s.cursor+1 : s.cursor+defaultOffset]) r := unicodeToRune(s.buf[s.cursor+1 : s.cursor+defaultOffset])
if utf16.IsSurrogate(r) { if utf16.IsSurrogate(r) {
if !readAtLeast(s, surrogateOffset, &p) { if s.cursor+surrogateOffset >= s.length {
return unicode.ReplacementChar, defaultOffset, p, nil s.read()
p = s.bufptr()
} }
if s.buf[s.cursor+defaultOffset] != '\\' || s.buf[s.cursor+defaultOffset+1] != 'u' { if s.cursor+surrogateOffset >= s.length || s.buf[s.cursor+defaultOffset] != '\\' || s.buf[s.cursor+defaultOffset+1] != 'u' {
return unicode.ReplacementChar, defaultOffset, p, nil return unicode.ReplacementChar, defaultOffset, p, nil
} }
r2 := unicodeToRune(s.buf[s.cursor+defaultOffset+2 : s.cursor+surrogateOffset]) r2 := unicodeToRune(s.buf[s.cursor+defaultOffset+2 : s.cursor+surrogateOffset])
@ -182,7 +163,6 @@ RETRY:
if !s.read() { if !s.read() {
return nil, errors.ErrInvalidCharacter(s.char(), "escaped string", s.totalOffset()) return nil, errors.ErrInvalidCharacter(s.char(), "escaped string", s.totalOffset())
} }
p = s.bufptr()
goto RETRY goto RETRY
default: default:
return nil, errors.ErrUnexpectedEndOfJSON("string", s.totalOffset()) return nil, errors.ErrUnexpectedEndOfJSON("string", s.totalOffset())
@ -328,36 +308,49 @@ func (d *stringDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, err
cursor++ cursor++
start := cursor start := cursor
b := (*sliceHeader)(unsafe.Pointer(&buf)).data b := (*sliceHeader)(unsafe.Pointer(&buf)).data
escaped := 0
for { for {
switch char(b, cursor) { switch char(b, cursor) {
case '\\': case '\\':
escaped++
cursor++ cursor++
switch char(b, cursor) { switch char(b, cursor) {
case '"', '\\', '/', 'b', 'f', 'n', 'r', 't': case '"':
cursor++ buf[cursor] = '"'
buf = append(buf[:cursor-1], buf[cursor:]...)
case '\\':
buf[cursor] = '\\'
buf = append(buf[:cursor-1], buf[cursor:]...)
case '/':
buf[cursor] = '/'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'b':
buf[cursor] = '\b'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'f':
buf[cursor] = '\f'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'n':
buf[cursor] = '\n'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'r':
buf[cursor] = '\r'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 't':
buf[cursor] = '\t'
buf = append(buf[:cursor-1], buf[cursor:]...)
case 'u': case 'u':
buflen := int64(len(buf)) buflen := int64(len(buf))
if cursor+5 >= buflen { if cursor+5 >= buflen {
return nil, 0, errors.ErrUnexpectedEndOfJSON("escaped string", cursor) return nil, 0, errors.ErrUnexpectedEndOfJSON("escaped string", cursor)
} }
for i := int64(1); i <= 4; i++ { code := unicodeToRune(buf[cursor+1 : cursor+5])
c := char(b, cursor+i) unicode := []byte(string(code))
if !(('0' <= c && c <= '9') || ('a' <= c && c <= 'f') || ('A' <= c && c <= 'F')) { buf = append(append(buf[:cursor-1], unicode...), buf[cursor+5:]...)
return nil, 0, errors.ErrSyntax(fmt.Sprintf("json: invalid character %c in \\u hexadecimal character escape", c), cursor+i)
}
}
cursor += 5
default: default:
return nil, 0, errors.ErrUnexpectedEndOfJSON("escaped string", cursor) return nil, 0, errors.ErrUnexpectedEndOfJSON("escaped string", cursor)
} }
continue continue
case '"': case '"':
literal := buf[start:cursor] literal := buf[start:cursor]
if escaped > 0 {
literal = literal[:unescapeString(literal)]
}
cursor++ cursor++
return literal, cursor, nil return literal, cursor, nil
case nul: case nul:
@ -376,77 +369,3 @@ func (d *stringDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, err
} }
} }
} }
var unescapeMap = [256]byte{
'"': '"',
'\\': '\\',
'/': '/',
'b': '\b',
'f': '\f',
'n': '\n',
'r': '\r',
't': '\t',
}
func unsafeAdd(ptr unsafe.Pointer, offset int) unsafe.Pointer {
return unsafe.Pointer(uintptr(ptr) + uintptr(offset))
}
func unescapeString(buf []byte) int {
p := (*sliceHeader)(unsafe.Pointer(&buf)).data
end := unsafeAdd(p, len(buf))
src := unsafeAdd(p, bytes.IndexByte(buf, '\\'))
dst := src
for src != end {
c := char(src, 0)
if c == '\\' {
escapeChar := char(src, 1)
if escapeChar != 'u' {
*(*byte)(dst) = unescapeMap[escapeChar]
src = unsafeAdd(src, 2)
dst = unsafeAdd(dst, 1)
} else {
v1 := hexToInt[char(src, 2)]
v2 := hexToInt[char(src, 3)]
v3 := hexToInt[char(src, 4)]
v4 := hexToInt[char(src, 5)]
code := rune((v1 << 12) | (v2 << 8) | (v3 << 4) | v4)
if code >= 0xd800 && code < 0xdc00 && uintptr(unsafeAdd(src, 11)) < uintptr(end) {
if char(src, 6) == '\\' && char(src, 7) == 'u' {
v1 := hexToInt[char(src, 8)]
v2 := hexToInt[char(src, 9)]
v3 := hexToInt[char(src, 10)]
v4 := hexToInt[char(src, 11)]
lo := rune((v1 << 12) | (v2 << 8) | (v3 << 4) | v4)
if lo >= 0xdc00 && lo < 0xe000 {
code = (code-0xd800)<<10 | (lo - 0xdc00) + 0x10000
src = unsafeAdd(src, 6)
}
}
}
var b [utf8.UTFMax]byte
n := utf8.EncodeRune(b[:], code)
switch n {
case 4:
*(*byte)(unsafeAdd(dst, 3)) = b[3]
fallthrough
case 3:
*(*byte)(unsafeAdd(dst, 2)) = b[2]
fallthrough
case 2:
*(*byte)(unsafeAdd(dst, 1)) = b[1]
fallthrough
case 1:
*(*byte)(unsafeAdd(dst, 0)) = b[0]
}
src = unsafeAdd(src, 6)
dst = unsafeAdd(dst, n)
}
} else {
*(*byte)(dst) = c
src = unsafeAdd(src, 1)
dst = unsafeAdd(dst, 1)
}
}
return int(uintptr(dst) - uintptr(p))
}

View File

@ -51,14 +51,6 @@ func init() {
} }
} }
func toASCIILower(s string) string {
b := []byte(s)
for i := range b {
b[i] = largeToSmallTable[b[i]]
}
return string(b)
}
func newStructDecoder(structName, fieldName string, fieldMap map[string]*structFieldSet) *structDecoder { func newStructDecoder(structName, fieldName string, fieldMap map[string]*structFieldSet) *structDecoder {
return &structDecoder{ return &structDecoder{
fieldMap: fieldMap, fieldMap: fieldMap,
@ -99,10 +91,6 @@ func (d *structDecoder) tryOptimize() {
for k, v := range d.fieldMap { for k, v := range d.fieldMap {
key := strings.ToLower(k) key := strings.ToLower(k)
if key != k { if key != k {
if key != toASCIILower(k) {
d.isTriedOptimize = true
return
}
// already exists same key (e.g. Hello and HELLO has same lower case key // already exists same key (e.g. Hello and HELLO has same lower case key
if _, exists := conflicted[key]; exists { if _, exists := conflicted[key]; exists {
d.isTriedOptimize = true d.isTriedOptimize = true
@ -170,53 +158,49 @@ func (d *structDecoder) tryOptimize() {
} }
// decode from '\uXXXX' // decode from '\uXXXX'
func decodeKeyCharByUnicodeRune(buf []byte, cursor int64) ([]byte, int64, error) { func decodeKeyCharByUnicodeRune(buf []byte, cursor int64) ([]byte, int64) {
const defaultOffset = 4 const defaultOffset = 4
const surrogateOffset = 6 const surrogateOffset = 6
if cursor+defaultOffset >= int64(len(buf)) {
return nil, 0, errors.ErrUnexpectedEndOfJSON("escaped string", cursor)
}
r := unicodeToRune(buf[cursor : cursor+defaultOffset]) r := unicodeToRune(buf[cursor : cursor+defaultOffset])
if utf16.IsSurrogate(r) { if utf16.IsSurrogate(r) {
cursor += defaultOffset cursor += defaultOffset
if cursor+surrogateOffset >= int64(len(buf)) || buf[cursor] != '\\' || buf[cursor+1] != 'u' { if cursor+surrogateOffset >= int64(len(buf)) || buf[cursor] != '\\' || buf[cursor+1] != 'u' {
return []byte(string(unicode.ReplacementChar)), cursor + defaultOffset - 1, nil return []byte(string(unicode.ReplacementChar)), cursor + defaultOffset - 1
} }
cursor += 2 cursor += 2
r2 := unicodeToRune(buf[cursor : cursor+defaultOffset]) r2 := unicodeToRune(buf[cursor : cursor+defaultOffset])
if r := utf16.DecodeRune(r, r2); r != unicode.ReplacementChar { if r := utf16.DecodeRune(r, r2); r != unicode.ReplacementChar {
return []byte(string(r)), cursor + defaultOffset - 1, nil return []byte(string(r)), cursor + defaultOffset - 1
} }
} }
return []byte(string(r)), cursor + defaultOffset - 1, nil return []byte(string(r)), cursor + defaultOffset - 1
} }
func decodeKeyCharByEscapedChar(buf []byte, cursor int64) ([]byte, int64, error) { func decodeKeyCharByEscapedChar(buf []byte, cursor int64) ([]byte, int64) {
c := buf[cursor] c := buf[cursor]
cursor++ cursor++
switch c { switch c {
case '"': case '"':
return []byte{'"'}, cursor, nil return []byte{'"'}, cursor
case '\\': case '\\':
return []byte{'\\'}, cursor, nil return []byte{'\\'}, cursor
case '/': case '/':
return []byte{'/'}, cursor, nil return []byte{'/'}, cursor
case 'b': case 'b':
return []byte{'\b'}, cursor, nil return []byte{'\b'}, cursor
case 'f': case 'f':
return []byte{'\f'}, cursor, nil return []byte{'\f'}, cursor
case 'n': case 'n':
return []byte{'\n'}, cursor, nil return []byte{'\n'}, cursor
case 'r': case 'r':
return []byte{'\r'}, cursor, nil return []byte{'\r'}, cursor
case 't': case 't':
return []byte{'\t'}, cursor, nil return []byte{'\t'}, cursor
case 'u': case 'u':
return decodeKeyCharByUnicodeRune(buf, cursor) return decodeKeyCharByUnicodeRune(buf, cursor)
} }
return nil, cursor, nil return nil, cursor
} }
func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64, *structFieldSet, error) { func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64, *structFieldSet, error) {
@ -258,10 +242,7 @@ func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64,
return 0, nil, errors.ErrUnexpectedEndOfJSON("string", cursor) return 0, nil, errors.ErrUnexpectedEndOfJSON("string", cursor)
case '\\': case '\\':
cursor++ cursor++
chars, nextCursor, err := decodeKeyCharByEscapedChar(buf, cursor) chars, nextCursor := decodeKeyCharByEscapedChar(buf, cursor)
if err != nil {
return 0, nil, err
}
for _, c := range chars { for _, c := range chars {
curBit &= bitmap[keyIdx][largeToSmallTable[c]] curBit &= bitmap[keyIdx][largeToSmallTable[c]]
if curBit == 0 { if curBit == 0 {
@ -324,10 +305,7 @@ func decodeKeyByBitmapUint16(d *structDecoder, buf []byte, cursor int64) (int64,
return 0, nil, errors.ErrUnexpectedEndOfJSON("string", cursor) return 0, nil, errors.ErrUnexpectedEndOfJSON("string", cursor)
case '\\': case '\\':
cursor++ cursor++
chars, nextCursor, err := decodeKeyCharByEscapedChar(buf, cursor) chars, nextCursor := decodeKeyCharByEscapedChar(buf, cursor)
if err != nil {
return 0, nil, err
}
for _, c := range chars { for _, c := range chars {
curBit &= bitmap[keyIdx][largeToSmallTable[c]] curBit &= bitmap[keyIdx][largeToSmallTable[c]]
if curBit == 0 { if curBit == 0 {
@ -839,7 +817,3 @@ func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
cursor++ cursor++
} }
} }
func (d *structDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: struct decoder does not support decode path")
}

View File

@ -10,7 +10,6 @@ import (
type Decoder interface { type Decoder interface {
Decode(*RuntimeContext, int64, int64, unsafe.Pointer) (int64, error) Decode(*RuntimeContext, int64, int64, unsafe.Pointer) (int64, error)
DecodePath(*RuntimeContext, int64, int64) ([][]byte, int64, error)
DecodeStream(*Stream, int64, unsafe.Pointer) error DecodeStream(*Stream, int64, unsafe.Pointer) error
} }

View File

@ -188,7 +188,3 @@ func (d *uintDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.
d.op(p, u64) d.op(p, u64)
return cursor, nil return cursor, nil
} }
func (d *uintDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: uint decoder does not support decode path")
}

View File

@ -1,9 +1,7 @@
package decoder package decoder
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt"
"unsafe" "unsafe"
"github.com/goccy/go-json/internal/errors" "github.com/goccy/go-json/internal/errors"
@ -48,20 +46,13 @@ func (d *unmarshalJSONDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Poi
typ: d.typ, typ: d.typ,
ptr: p, ptr: p,
})) }))
switch v := v.(type) { if (s.Option.Flags & ContextOption) != 0 {
case unmarshalerContext: if err := v.(unmarshalerContext).UnmarshalJSON(s.Option.Context, dst); err != nil {
var ctx context.Context
if (s.Option.Flags & ContextOption) != 0 {
ctx = s.Option.Context
} else {
ctx = context.Background()
}
if err := v.UnmarshalJSON(ctx, dst); err != nil {
d.annotateError(s.cursor, err) d.annotateError(s.cursor, err)
return err return err
} }
case json.Unmarshaler: } else {
if err := v.UnmarshalJSON(dst); err != nil { if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil {
d.annotateError(s.cursor, err) d.annotateError(s.cursor, err)
return err return err
} }
@ -98,7 +89,3 @@ func (d *unmarshalJSONDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
} }
return end, nil return end, nil
} }
func (d *unmarshalJSONDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: unmarshal json decoder does not support decode path")
}

View File

@ -3,7 +3,6 @@ package decoder
import ( import (
"bytes" "bytes"
"encoding" "encoding"
"fmt"
"unicode" "unicode"
"unicode/utf16" "unicode/utf16"
"unicode/utf8" "unicode/utf8"
@ -143,11 +142,7 @@ func (d *unmarshalTextDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
return end, nil return end, nil
} }
func (d *unmarshalTextDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { func unquoteBytes(s []byte) (t []byte, ok bool) {
return nil, 0, fmt.Errorf("json: unmarshal text decoder does not support decode path")
}
func unquoteBytes(s []byte) (t []byte, ok bool) { //nolint: nonamedreturns
length := len(s) length := len(s)
if length < 2 || s[0] != '"' || s[length-1] != '"' { if length < 2 || s[0] != '"' || s[length-1] != '"' {
return return

View File

@ -1,7 +1,6 @@
package decoder package decoder
import ( import (
"fmt"
"reflect" "reflect"
"unsafe" "unsafe"
@ -67,7 +66,3 @@ func (d *wrappedStringDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
ctx.Buf = oldBuf ctx.Buf = oldBuf
return c, nil return c, nil
} }
func (d *wrappedStringDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) {
return nil, 0, fmt.Errorf("json: wrapped string decoder does not support decode path")
}

Some files were not shown because too many files have changed in this diff Show More