added functions to handle multiple auth servers

This commit is contained in:
Paul 2020-03-04 12:00:45 +01:00
parent b995a26e42
commit 5d54b0289b
7 changed files with 105 additions and 42 deletions

3
.gitignore vendored
View File

@ -1,4 +1,5 @@
*.ini
/adradius
*.pid
*.log
*.log
*.pdf

View File

@ -21,12 +21,12 @@ make
```ini
[adradius]
server=localhost
port=389
basedn=dc=example,dc=com
secret=secret
tls=true
listen=localhost:1812
servers=ad1,ad2 # comma separated list of authentication servers hostnames as a slice
port=389 # ldap port as an integer
basedn="dc=example,dc=com" # ldap basedn as string
tls=false # use tls for ldap as a bool
secret=secret # radius secret
listen=localhost:1812 # listen params for radius server as a string
```
### Run

View File

@ -3,6 +3,9 @@ package main
import (
"log"
"net/http"
_ "net/http/pprof"
"git.paulbsd.com/paulbsd/adradius/src/adradius"
"git.paulbsd.com/paulbsd/adradius/src/config"
"github.com/sevlyar/go-daemon"
@ -42,8 +45,13 @@ func main() {
log.Fatal(err)
}
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
adradius.RunServer(&cfg, ldapcfg)
if err != nil {
log.Fatal(err)
}
}

View File

@ -20,15 +20,13 @@ func SetADRadiusConfig(c *config.Config) (ldapconfig *auth.Config, err error) {
}
ldapconfig = &auth.Config{
Server: c.Server,
Servers: c.Servers,
Port: c.Port,
BaseDN: c.BaseDN,
Security: security,
SkipVerify: c.SkipVerify,
}
ldapconfig.Connect()
return
}
@ -56,8 +54,10 @@ func RunServer(config *config.Config, ldapconfig *auth.Config) {
if status {
code = radius.CodeAccessAccept
log.Printf("Successful login for user %s", username)
} else {
code = radius.CodeAccessReject
log.Printf("Failed login for user %s", username)
}
log.Printf("Writing %v to %v", code, r.RemoteAddr)
w.Write(r.Response(code))

View File

@ -2,6 +2,7 @@ package config
import (
"flag"
"fmt"
auth "git.paulbsd.com/paulbsd/adradius/src/go-ad-auth"
"git.paulbsd.com/paulbsd/adradius/utils"
@ -21,7 +22,10 @@ func (c *Config) GetConfig() (err error) {
}
adradiusSection := config.Section("adradius")
c.Server = adradiusSection.Key("server").MustString("localhost")
c.Servers = adradiusSection.Key("servers").Strings(",")
if len(c.Servers) < 1 {
return fmt.Errorf("No servers provided in config")
}
c.Port = adradiusSection.Key("port").MustInt(389)
c.BaseDN = adradiusSection.Key("basedn").MustString("dc=example,dc=com")
c.TLS = adradiusSection.Key("tls").MustBool()
@ -29,17 +33,13 @@ func (c *Config) GetConfig() (err error) {
c.Secret = adradiusSection.Key("secret").MustString("secret")
c.SkipVerify = adradiusSection.Key("skipverify").MustBool()
if err != nil {
return
}
return nil
}
// Config is the main configuration
type Config struct {
ConfigPath string
Server string
Servers []string
Port int
BaseDN string
TLS bool

View File

@ -19,7 +19,7 @@ const (
//Config contains settings for connecting to an Active Directory server.
type Config struct {
Server string
Servers []string
Port int
BaseDN string
Security SecurityType

View File

@ -4,6 +4,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"log"
ldap "gopkg.in/ldap.v3"
)
@ -16,40 +17,93 @@ type Conn struct {
//Connect returns an open connection to an Active Directory server or an error if one occurred.
func (c *Config) Connect() (*Conn, error) {
tlscfg := &tls.Config{ServerName: c.Server}
if c.SkipVerify {
tlscfg.InsecureSkipVerify = true
}
switch c.Security {
case SecurityNone:
conn, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port))
if err != nil {
return nil, fmt.Errorf("Connection error: %v", err)
}
return &Conn{Conn: conn, Config: c}, nil
conn, err := c.ConnectClear()
return conn, err
case SecurityTLS:
conn, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port), tlscfg)
if err != nil {
return nil, fmt.Errorf("Connection error: %v", err)
}
return &Conn{Conn: conn, Config: c}, nil
conn, err := c.ConnectTLS()
return conn, err
case SecurityStartTLS:
conn, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port))
if err != nil {
return nil, fmt.Errorf("Connection error: %v", err)
}
err = conn.StartTLS(&tls.Config{ServerName: c.Server})
if err != nil {
return nil, fmt.Errorf("Connection error: %v", err)
}
return &Conn{Conn: conn, Config: c}, nil
conn, err := c.ConnectStartTLS()
return conn, err
default:
return nil, errors.New("Configuration error: invalid SecurityType")
}
}
//ConnectClear is a function used to connect to ldap server using clear auth
func (c *Config) ConnectClear() (*Conn, error) {
var conn *ldap.Conn
var err error
for index, server := range c.Servers {
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", server, c.Port))
if err != nil && index < len(c.Servers)-1 {
log.Println("Failed dial, trying next server")
} else if err != nil {
log.Println("Servers are all unavailable")
return nil, fmt.Errorf("Connection error: %v", err)
} else {
return &Conn{Conn: conn, Config: c}, nil
}
}
return &Conn{Conn: conn, Config: c}, nil
}
//ConnectTLS is a function used to connect to ldap server using TLS
func (c *Config) ConnectTLS() (*Conn, error) {
var conn *ldap.Conn
var err error
for index, server := range c.Servers {
tlscfg := &tls.Config{ServerName: server}
if c.SkipVerify {
tlscfg.InsecureSkipVerify = true
}
conn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", server, c.Port), tlscfg)
if err != nil && index < len(c.Servers)-1 {
log.Println("Failed dial, trying next server")
} else if err != nil {
log.Println("Servers are all unavailable")
return nil, fmt.Errorf("Connection error: %v", err)
} else {
return &Conn{Conn: conn, Config: c}, nil
}
}
return &Conn{Conn: conn, Config: c}, nil
}
//ConnectStartTLS is a function used to connect to ldap server using TLS
func (c *Config) ConnectStartTLS() (*Conn, error) {
var conn *ldap.Conn
var err error
for index, server := range c.Servers {
tlscfg := &tls.Config{ServerName: server}
if c.SkipVerify {
tlscfg.InsecureSkipVerify = true
}
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", server, c.Port))
if err != nil {
return nil, fmt.Errorf("Connection error: %v", err)
}
err = conn.StartTLS(tlscfg)
if err != nil && index < len(c.Servers)-1 {
log.Println("Failed dial, trying next server")
} else if err != nil {
log.Println("Servers are all unavailable")
return nil, fmt.Errorf("Connection error: %v", err)
} else {
return &Conn{Conn: conn, Config: c}, nil
}
}
return &Conn{Conn: conn, Config: c}, nil
}
//Bind authenticates the connection with the given userPrincipalName and password
//and returns the result or an error if one occurred.
func (c *Conn) Bind(upn, password string) (bool, error) {