added functions to handle multiple auth servers

This commit is contained in:
Paul 2020-03-04 12:00:45 +01:00
parent b995a26e42
commit 5d54b0289b
7 changed files with 105 additions and 42 deletions

3
.gitignore vendored
View File

@ -1,4 +1,5 @@
*.ini *.ini
/adradius /adradius
*.pid *.pid
*.log *.log
*.pdf

View File

@ -21,12 +21,12 @@ make
```ini ```ini
[adradius] [adradius]
server=localhost servers=ad1,ad2 # comma separated list of authentication servers hostnames as a slice
port=389 port=389 # ldap port as an integer
basedn=dc=example,dc=com basedn="dc=example,dc=com" # ldap basedn as string
secret=secret tls=false # use tls for ldap as a bool
tls=true secret=secret # radius secret
listen=localhost:1812 listen=localhost:1812 # listen params for radius server as a string
``` ```
### Run ### Run

View File

@ -3,6 +3,9 @@ package main
import ( import (
"log" "log"
"net/http"
_ "net/http/pprof"
"git.paulbsd.com/paulbsd/adradius/src/adradius" "git.paulbsd.com/paulbsd/adradius/src/adradius"
"git.paulbsd.com/paulbsd/adradius/src/config" "git.paulbsd.com/paulbsd/adradius/src/config"
"github.com/sevlyar/go-daemon" "github.com/sevlyar/go-daemon"
@ -42,8 +45,13 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
adradius.RunServer(&cfg, ldapcfg) adradius.RunServer(&cfg, ldapcfg)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }

View File

@ -20,15 +20,13 @@ func SetADRadiusConfig(c *config.Config) (ldapconfig *auth.Config, err error) {
} }
ldapconfig = &auth.Config{ ldapconfig = &auth.Config{
Server: c.Server, Servers: c.Servers,
Port: c.Port, Port: c.Port,
BaseDN: c.BaseDN, BaseDN: c.BaseDN,
Security: security, Security: security,
SkipVerify: c.SkipVerify, SkipVerify: c.SkipVerify,
} }
ldapconfig.Connect()
return return
} }
@ -56,8 +54,10 @@ func RunServer(config *config.Config, ldapconfig *auth.Config) {
if status { if status {
code = radius.CodeAccessAccept code = radius.CodeAccessAccept
log.Printf("Successful login for user %s", username)
} else { } else {
code = radius.CodeAccessReject code = radius.CodeAccessReject
log.Printf("Failed login for user %s", username)
} }
log.Printf("Writing %v to %v", code, r.RemoteAddr) log.Printf("Writing %v to %v", code, r.RemoteAddr)
w.Write(r.Response(code)) w.Write(r.Response(code))

View File

@ -2,6 +2,7 @@ package config
import ( import (
"flag" "flag"
"fmt"
auth "git.paulbsd.com/paulbsd/adradius/src/go-ad-auth" auth "git.paulbsd.com/paulbsd/adradius/src/go-ad-auth"
"git.paulbsd.com/paulbsd/adradius/utils" "git.paulbsd.com/paulbsd/adradius/utils"
@ -21,7 +22,10 @@ func (c *Config) GetConfig() (err error) {
} }
adradiusSection := config.Section("adradius") adradiusSection := config.Section("adradius")
c.Server = adradiusSection.Key("server").MustString("localhost") c.Servers = adradiusSection.Key("servers").Strings(",")
if len(c.Servers) < 1 {
return fmt.Errorf("No servers provided in config")
}
c.Port = adradiusSection.Key("port").MustInt(389) c.Port = adradiusSection.Key("port").MustInt(389)
c.BaseDN = adradiusSection.Key("basedn").MustString("dc=example,dc=com") c.BaseDN = adradiusSection.Key("basedn").MustString("dc=example,dc=com")
c.TLS = adradiusSection.Key("tls").MustBool() c.TLS = adradiusSection.Key("tls").MustBool()
@ -29,17 +33,13 @@ func (c *Config) GetConfig() (err error) {
c.Secret = adradiusSection.Key("secret").MustString("secret") c.Secret = adradiusSection.Key("secret").MustString("secret")
c.SkipVerify = adradiusSection.Key("skipverify").MustBool() c.SkipVerify = adradiusSection.Key("skipverify").MustBool()
if err != nil {
return
}
return nil return nil
} }
// Config is the main configuration // Config is the main configuration
type Config struct { type Config struct {
ConfigPath string ConfigPath string
Server string Servers []string
Port int Port int
BaseDN string BaseDN string
TLS bool TLS bool

View File

@ -19,7 +19,7 @@ const (
//Config contains settings for connecting to an Active Directory server. //Config contains settings for connecting to an Active Directory server.
type Config struct { type Config struct {
Server string Servers []string
Port int Port int
BaseDN string BaseDN string
Security SecurityType Security SecurityType

View File

@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"log"
ldap "gopkg.in/ldap.v3" ldap "gopkg.in/ldap.v3"
) )
@ -16,40 +17,93 @@ type Conn struct {
//Connect returns an open connection to an Active Directory server or an error if one occurred. //Connect returns an open connection to an Active Directory server or an error if one occurred.
func (c *Config) Connect() (*Conn, error) { func (c *Config) Connect() (*Conn, error) {
tlscfg := &tls.Config{ServerName: c.Server}
if c.SkipVerify {
tlscfg.InsecureSkipVerify = true
}
switch c.Security { switch c.Security {
case SecurityNone: case SecurityNone:
conn, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port)) conn, err := c.ConnectClear()
if err != nil { return conn, err
return nil, fmt.Errorf("Connection error: %v", err)
}
return &Conn{Conn: conn, Config: c}, nil
case SecurityTLS: case SecurityTLS:
conn, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port), tlscfg) conn, err := c.ConnectTLS()
if err != nil { return conn, err
return nil, fmt.Errorf("Connection error: %v", err)
}
return &Conn{Conn: conn, Config: c}, nil
case SecurityStartTLS: case SecurityStartTLS:
conn, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port)) conn, err := c.ConnectStartTLS()
if err != nil { return conn, err
return nil, fmt.Errorf("Connection error: %v", err)
}
err = conn.StartTLS(&tls.Config{ServerName: c.Server})
if err != nil {
return nil, fmt.Errorf("Connection error: %v", err)
}
return &Conn{Conn: conn, Config: c}, nil
default: default:
return nil, errors.New("Configuration error: invalid SecurityType") return nil, errors.New("Configuration error: invalid SecurityType")
} }
} }
//ConnectClear is a function used to connect to ldap server using clear auth
func (c *Config) ConnectClear() (*Conn, error) {
var conn *ldap.Conn
var err error
for index, server := range c.Servers {
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", server, c.Port))
if err != nil && index < len(c.Servers)-1 {
log.Println("Failed dial, trying next server")
} else if err != nil {
log.Println("Servers are all unavailable")
return nil, fmt.Errorf("Connection error: %v", err)
} else {
return &Conn{Conn: conn, Config: c}, nil
}
}
return &Conn{Conn: conn, Config: c}, nil
}
//ConnectTLS is a function used to connect to ldap server using TLS
func (c *Config) ConnectTLS() (*Conn, error) {
var conn *ldap.Conn
var err error
for index, server := range c.Servers {
tlscfg := &tls.Config{ServerName: server}
if c.SkipVerify {
tlscfg.InsecureSkipVerify = true
}
conn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", server, c.Port), tlscfg)
if err != nil && index < len(c.Servers)-1 {
log.Println("Failed dial, trying next server")
} else if err != nil {
log.Println("Servers are all unavailable")
return nil, fmt.Errorf("Connection error: %v", err)
} else {
return &Conn{Conn: conn, Config: c}, nil
}
}
return &Conn{Conn: conn, Config: c}, nil
}
//ConnectStartTLS is a function used to connect to ldap server using TLS
func (c *Config) ConnectStartTLS() (*Conn, error) {
var conn *ldap.Conn
var err error
for index, server := range c.Servers {
tlscfg := &tls.Config{ServerName: server}
if c.SkipVerify {
tlscfg.InsecureSkipVerify = true
}
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", server, c.Port))
if err != nil {
return nil, fmt.Errorf("Connection error: %v", err)
}
err = conn.StartTLS(tlscfg)
if err != nil && index < len(c.Servers)-1 {
log.Println("Failed dial, trying next server")
} else if err != nil {
log.Println("Servers are all unavailable")
return nil, fmt.Errorf("Connection error: %v", err)
} else {
return &Conn{Conn: conn, Config: c}, nil
}
}
return &Conn{Conn: conn, Config: c}, nil
}
//Bind authenticates the connection with the given userPrincipalName and password //Bind authenticates the connection with the given userPrincipalName and password
//and returns the result or an error if one occurred. //and returns the result or an error if one occurred.
func (c *Conn) Bind(upn, password string) (bool, error) { func (c *Conn) Bind(upn, password string) (bool, error) {