227 lines
6.7 KiB
Go
227 lines
6.7 KiB
Go
// Copyright 2015 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// +build windows
|
|
|
|
package sspi
|
|
|
|
import (
|
|
"fmt"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
)
|
|
|
|
// TODO: add documentation
|
|
|
|
type PackageInfo struct {
|
|
Capabilities uint32
|
|
Version uint16
|
|
RPCID uint16
|
|
MaxToken uint32
|
|
Name string
|
|
Comment string
|
|
}
|
|
|
|
func QueryPackageInfo(pkgname string) (*PackageInfo, error) {
|
|
name, err := syscall.UTF16PtrFromString(pkgname)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var pi *SecPkgInfo
|
|
ret := QuerySecurityPackageInfo(name, &pi)
|
|
if ret != SEC_E_OK {
|
|
return nil, ret
|
|
}
|
|
defer FreeContextBuffer((*byte)(unsafe.Pointer(pi)))
|
|
|
|
return &PackageInfo{
|
|
Capabilities: pi.Capabilities,
|
|
Version: pi.Version,
|
|
RPCID: pi.RPCID,
|
|
MaxToken: pi.MaxToken,
|
|
Name: syscall.UTF16ToString((*[2 << 12]uint16)(unsafe.Pointer(pi.Name))[:]),
|
|
Comment: syscall.UTF16ToString((*[2 << 12]uint16)(unsafe.Pointer(pi.Comment))[:]),
|
|
}, nil
|
|
}
|
|
|
|
type Credentials struct {
|
|
Handle CredHandle
|
|
expiry syscall.Filetime
|
|
}
|
|
|
|
// AcquireCredentials calls the windows AcquireCredentialsHandle function and
|
|
// returns Credentials containing a security handle that can be used for
|
|
// InitializeSecurityContext or AcceptSecurityContext operations.
|
|
// As a special case, passing an empty string as the principal parameter will
|
|
// pass a null string to the underlying function.
|
|
func AcquireCredentials(principal string, pkgname string, creduse uint32, authdata *byte) (*Credentials, error) {
|
|
var principalName *uint16
|
|
if principal != "" {
|
|
var err error
|
|
principalName, err = syscall.UTF16PtrFromString(principal)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
name, err := syscall.UTF16PtrFromString(pkgname)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var c Credentials
|
|
ret := AcquireCredentialsHandle(principalName, name, creduse, nil, authdata, 0, 0, &c.Handle, &c.expiry)
|
|
if ret != SEC_E_OK {
|
|
return nil, ret
|
|
}
|
|
return &c, nil
|
|
}
|
|
|
|
func (c *Credentials) Release() error {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
ret := FreeCredentialsHandle(&c.Handle)
|
|
if ret != SEC_E_OK {
|
|
return ret
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Credentials) Expiry() time.Time {
|
|
return time.Unix(0, c.expiry.Nanoseconds())
|
|
}
|
|
|
|
// TODO: add functions to display and manage RequestedFlags and EstablishedFlags fields.
|
|
// TODO: maybe get rid of RequestedFlags and EstablishedFlags fields, and replace them with input parameter for New...Context and return value of Update (instead of current bool parameter).
|
|
|
|
type updateFunc func(c *Context, targname *uint16, h, newh *CtxtHandle, out, in *SecBufferDesc) syscall.Errno
|
|
|
|
type Context struct {
|
|
Cred *Credentials
|
|
Handle *CtxtHandle
|
|
handle CtxtHandle
|
|
updFn updateFunc
|
|
expiry syscall.Filetime
|
|
RequestedFlags uint32
|
|
EstablishedFlags uint32
|
|
}
|
|
|
|
func NewClientContext(cred *Credentials, flags uint32) *Context {
|
|
return &Context{
|
|
Cred: cred,
|
|
updFn: initialize,
|
|
RequestedFlags: flags,
|
|
}
|
|
}
|
|
|
|
func NewServerContext(cred *Credentials, flags uint32) *Context {
|
|
return &Context{
|
|
Cred: cred,
|
|
updFn: accept,
|
|
RequestedFlags: flags,
|
|
}
|
|
}
|
|
|
|
func initialize(c *Context, targname *uint16, h, newh *CtxtHandle, out, in *SecBufferDesc) syscall.Errno {
|
|
return InitializeSecurityContext(&c.Cred.Handle, h, targname, c.RequestedFlags,
|
|
0, SECURITY_NATIVE_DREP, in, 0, newh, out, &c.EstablishedFlags, &c.expiry)
|
|
}
|
|
|
|
func accept(c *Context, targname *uint16, h, newh *CtxtHandle, out, in *SecBufferDesc) syscall.Errno {
|
|
return AcceptSecurityContext(&c.Cred.Handle, h, in, c.RequestedFlags,
|
|
SECURITY_NATIVE_DREP, newh, out, &c.EstablishedFlags, &c.expiry)
|
|
}
|
|
|
|
func (c *Context) Update(targname *uint16, out, in *SecBufferDesc) syscall.Errno {
|
|
h := c.Handle
|
|
if c.Handle == nil {
|
|
c.Handle = &c.handle
|
|
}
|
|
return c.updFn(c, targname, h, c.Handle, out, in)
|
|
}
|
|
|
|
func (c *Context) Release() error {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
ret := DeleteSecurityContext(c.Handle)
|
|
if ret != SEC_E_OK {
|
|
return ret
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Context) Expiry() time.Time {
|
|
return time.Unix(0, c.expiry.Nanoseconds())
|
|
}
|
|
|
|
// TODO: add comment to function doco that this "impersonation" is applied to current OS thread.
|
|
func (c *Context) ImpersonateUser() error {
|
|
ret := ImpersonateSecurityContext(c.Handle)
|
|
if ret != SEC_E_OK {
|
|
return ret
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Context) RevertToSelf() error {
|
|
ret := RevertSecurityContext(c.Handle)
|
|
if ret != SEC_E_OK {
|
|
return ret
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Sizes queries the context for the sizes used in per-message functions.
|
|
// It returns the maximum token size used in authentication exchanges, the
|
|
// maximum signature size, the preferred integral size of messages, the
|
|
// size of any security trailer, and any error.
|
|
func (c *Context) Sizes() (uint32, uint32, uint32, uint32, error) {
|
|
var s _SecPkgContext_Sizes
|
|
ret := QueryContextAttributes(c.Handle, _SECPKG_ATTR_SIZES, (*byte)(unsafe.Pointer(&s)))
|
|
if ret != SEC_E_OK {
|
|
return 0, 0, 0, 0, ret
|
|
}
|
|
return s.MaxToken, s.MaxSignature, s.BlockSize, s.SecurityTrailer, nil
|
|
}
|
|
|
|
// VerifyFlags determines if all flags used to construct the context
|
|
// were honored (see NewClientContext). It should be called after c.Update.
|
|
func (c *Context) VerifyFlags() error {
|
|
return c.VerifySelectiveFlags(c.RequestedFlags)
|
|
}
|
|
|
|
// VerifySelectiveFlags determines if the given flags were honored (see NewClientContext).
|
|
// It should be called after c.Update.
|
|
func (c *Context) VerifySelectiveFlags(flags uint32) error {
|
|
if valid, missing, extra := verifySelectiveFlags(flags, c.RequestedFlags); !valid {
|
|
return fmt.Errorf("sspi: invalid flags check: desired=%b requested=%b missing=%b extra=%b", flags, c.RequestedFlags, missing, extra)
|
|
}
|
|
if valid, missing, extra := verifySelectiveFlags(flags, c.EstablishedFlags); !valid {
|
|
return fmt.Errorf("sspi: invalid flags: desired=%b established=%b missing=%b extra=%b", flags, c.EstablishedFlags, missing, extra)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifySelectiveFlags determines if all bits requested in flags are set in establishedFlags.
|
|
// missing represents the bits set in flags that are not set in establishedFlags.
|
|
// extra represents the bits set in establishedFlags that are not set in flags.
|
|
// valid is true and missing is zero when establishedFlags has all of the requested flags.
|
|
func verifySelectiveFlags(flags, establishedFlags uint32) (valid bool, missing, extra uint32) {
|
|
missing = flags&establishedFlags ^ flags
|
|
extra = flags | establishedFlags ^ flags
|
|
valid = missing == 0
|
|
return valid, missing, extra
|
|
}
|
|
|
|
// NewSecBufferDesc returns an initialized SecBufferDesc describing the
|
|
// provided SecBuffer.
|
|
func NewSecBufferDesc(b []SecBuffer) *SecBufferDesc {
|
|
return &SecBufferDesc{
|
|
Version: SECBUFFER_VERSION,
|
|
BuffersCount: uint32(len(b)),
|
|
Buffers: &b[0],
|
|
}
|
|
}
|