added functions to handle multiple auth servers
This commit is contained in:
parent
b995a26e42
commit
5d54b0289b
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,4 +1,5 @@
|
|||||||
*.ini
|
*.ini
|
||||||
/adradius
|
/adradius
|
||||||
*.pid
|
*.pid
|
||||||
*.log
|
*.log
|
||||||
|
*.pdf
|
12
README.md
12
README.md
@ -21,12 +21,12 @@ make
|
|||||||
|
|
||||||
```ini
|
```ini
|
||||||
[adradius]
|
[adradius]
|
||||||
server=localhost
|
servers=ad1,ad2 # comma separated list of authentication servers hostnames as a slice
|
||||||
port=389
|
port=389 # ldap port as an integer
|
||||||
basedn=dc=example,dc=com
|
basedn="dc=example,dc=com" # ldap basedn as string
|
||||||
secret=secret
|
tls=false # use tls for ldap as a bool
|
||||||
tls=true
|
secret=secret # radius secret
|
||||||
listen=localhost:1812
|
listen=localhost:1812 # listen params for radius server as a string
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
@ -3,6 +3,9 @@ package main
|
|||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
|
|
||||||
"git.paulbsd.com/paulbsd/adradius/src/adradius"
|
"git.paulbsd.com/paulbsd/adradius/src/adradius"
|
||||||
"git.paulbsd.com/paulbsd/adradius/src/config"
|
"git.paulbsd.com/paulbsd/adradius/src/config"
|
||||||
"github.com/sevlyar/go-daemon"
|
"github.com/sevlyar/go-daemon"
|
||||||
@ -42,8 +45,13 @@ func main() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||||
|
}()
|
||||||
|
|
||||||
adradius.RunServer(&cfg, ldapcfg)
|
adradius.RunServer(&cfg, ldapcfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -20,15 +20,13 @@ func SetADRadiusConfig(c *config.Config) (ldapconfig *auth.Config, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ldapconfig = &auth.Config{
|
ldapconfig = &auth.Config{
|
||||||
Server: c.Server,
|
Servers: c.Servers,
|
||||||
Port: c.Port,
|
Port: c.Port,
|
||||||
BaseDN: c.BaseDN,
|
BaseDN: c.BaseDN,
|
||||||
Security: security,
|
Security: security,
|
||||||
SkipVerify: c.SkipVerify,
|
SkipVerify: c.SkipVerify,
|
||||||
}
|
}
|
||||||
|
|
||||||
ldapconfig.Connect()
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,8 +54,10 @@ func RunServer(config *config.Config, ldapconfig *auth.Config) {
|
|||||||
|
|
||||||
if status {
|
if status {
|
||||||
code = radius.CodeAccessAccept
|
code = radius.CodeAccessAccept
|
||||||
|
log.Printf("Successful login for user %s", username)
|
||||||
} else {
|
} else {
|
||||||
code = radius.CodeAccessReject
|
code = radius.CodeAccessReject
|
||||||
|
log.Printf("Failed login for user %s", username)
|
||||||
}
|
}
|
||||||
log.Printf("Writing %v to %v", code, r.RemoteAddr)
|
log.Printf("Writing %v to %v", code, r.RemoteAddr)
|
||||||
w.Write(r.Response(code))
|
w.Write(r.Response(code))
|
||||||
|
@ -2,6 +2,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
auth "git.paulbsd.com/paulbsd/adradius/src/go-ad-auth"
|
auth "git.paulbsd.com/paulbsd/adradius/src/go-ad-auth"
|
||||||
"git.paulbsd.com/paulbsd/adradius/utils"
|
"git.paulbsd.com/paulbsd/adradius/utils"
|
||||||
@ -21,7 +22,10 @@ func (c *Config) GetConfig() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
adradiusSection := config.Section("adradius")
|
adradiusSection := config.Section("adradius")
|
||||||
c.Server = adradiusSection.Key("server").MustString("localhost")
|
c.Servers = adradiusSection.Key("servers").Strings(",")
|
||||||
|
if len(c.Servers) < 1 {
|
||||||
|
return fmt.Errorf("No servers provided in config")
|
||||||
|
}
|
||||||
c.Port = adradiusSection.Key("port").MustInt(389)
|
c.Port = adradiusSection.Key("port").MustInt(389)
|
||||||
c.BaseDN = adradiusSection.Key("basedn").MustString("dc=example,dc=com")
|
c.BaseDN = adradiusSection.Key("basedn").MustString("dc=example,dc=com")
|
||||||
c.TLS = adradiusSection.Key("tls").MustBool()
|
c.TLS = adradiusSection.Key("tls").MustBool()
|
||||||
@ -29,17 +33,13 @@ func (c *Config) GetConfig() (err error) {
|
|||||||
c.Secret = adradiusSection.Key("secret").MustString("secret")
|
c.Secret = adradiusSection.Key("secret").MustString("secret")
|
||||||
c.SkipVerify = adradiusSection.Key("skipverify").MustBool()
|
c.SkipVerify = adradiusSection.Key("skipverify").MustBool()
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config is the main configuration
|
// Config is the main configuration
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ConfigPath string
|
ConfigPath string
|
||||||
Server string
|
Servers []string
|
||||||
Port int
|
Port int
|
||||||
BaseDN string
|
BaseDN string
|
||||||
TLS bool
|
TLS bool
|
||||||
|
@ -19,7 +19,7 @@ const (
|
|||||||
|
|
||||||
//Config contains settings for connecting to an Active Directory server.
|
//Config contains settings for connecting to an Active Directory server.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server string
|
Servers []string
|
||||||
Port int
|
Port int
|
||||||
BaseDN string
|
BaseDN string
|
||||||
Security SecurityType
|
Security SecurityType
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
ldap "gopkg.in/ldap.v3"
|
ldap "gopkg.in/ldap.v3"
|
||||||
)
|
)
|
||||||
@ -16,40 +17,93 @@ type Conn struct {
|
|||||||
|
|
||||||
//Connect returns an open connection to an Active Directory server or an error if one occurred.
|
//Connect returns an open connection to an Active Directory server or an error if one occurred.
|
||||||
func (c *Config) Connect() (*Conn, error) {
|
func (c *Config) Connect() (*Conn, error) {
|
||||||
|
|
||||||
tlscfg := &tls.Config{ServerName: c.Server}
|
|
||||||
if c.SkipVerify {
|
|
||||||
tlscfg.InsecureSkipVerify = true
|
|
||||||
}
|
|
||||||
|
|
||||||
switch c.Security {
|
switch c.Security {
|
||||||
case SecurityNone:
|
case SecurityNone:
|
||||||
conn, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port))
|
conn, err := c.ConnectClear()
|
||||||
if err != nil {
|
return conn, err
|
||||||
return nil, fmt.Errorf("Connection error: %v", err)
|
|
||||||
}
|
|
||||||
return &Conn{Conn: conn, Config: c}, nil
|
|
||||||
case SecurityTLS:
|
case SecurityTLS:
|
||||||
conn, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port), tlscfg)
|
conn, err := c.ConnectTLS()
|
||||||
if err != nil {
|
return conn, err
|
||||||
return nil, fmt.Errorf("Connection error: %v", err)
|
|
||||||
}
|
|
||||||
return &Conn{Conn: conn, Config: c}, nil
|
|
||||||
case SecurityStartTLS:
|
case SecurityStartTLS:
|
||||||
conn, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port))
|
conn, err := c.ConnectStartTLS()
|
||||||
if err != nil {
|
return conn, err
|
||||||
return nil, fmt.Errorf("Connection error: %v", err)
|
|
||||||
}
|
|
||||||
err = conn.StartTLS(&tls.Config{ServerName: c.Server})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Connection error: %v", err)
|
|
||||||
}
|
|
||||||
return &Conn{Conn: conn, Config: c}, nil
|
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("Configuration error: invalid SecurityType")
|
return nil, errors.New("Configuration error: invalid SecurityType")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//ConnectClear is a function used to connect to ldap server using clear auth
|
||||||
|
func (c *Config) ConnectClear() (*Conn, error) {
|
||||||
|
var conn *ldap.Conn
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for index, server := range c.Servers {
|
||||||
|
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", server, c.Port))
|
||||||
|
if err != nil && index < len(c.Servers)-1 {
|
||||||
|
log.Println("Failed dial, trying next server")
|
||||||
|
} else if err != nil {
|
||||||
|
log.Println("Servers are all unavailable")
|
||||||
|
return nil, fmt.Errorf("Connection error: %v", err)
|
||||||
|
} else {
|
||||||
|
return &Conn{Conn: conn, Config: c}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Conn{Conn: conn, Config: c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//ConnectTLS is a function used to connect to ldap server using TLS
|
||||||
|
func (c *Config) ConnectTLS() (*Conn, error) {
|
||||||
|
var conn *ldap.Conn
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for index, server := range c.Servers {
|
||||||
|
tlscfg := &tls.Config{ServerName: server}
|
||||||
|
if c.SkipVerify {
|
||||||
|
tlscfg.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
conn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", server, c.Port), tlscfg)
|
||||||
|
if err != nil && index < len(c.Servers)-1 {
|
||||||
|
log.Println("Failed dial, trying next server")
|
||||||
|
} else if err != nil {
|
||||||
|
log.Println("Servers are all unavailable")
|
||||||
|
return nil, fmt.Errorf("Connection error: %v", err)
|
||||||
|
} else {
|
||||||
|
return &Conn{Conn: conn, Config: c}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Conn{Conn: conn, Config: c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//ConnectStartTLS is a function used to connect to ldap server using TLS
|
||||||
|
func (c *Config) ConnectStartTLS() (*Conn, error) {
|
||||||
|
var conn *ldap.Conn
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for index, server := range c.Servers {
|
||||||
|
tlscfg := &tls.Config{ServerName: server}
|
||||||
|
if c.SkipVerify {
|
||||||
|
tlscfg.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", server, c.Port))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Connection error: %v", err)
|
||||||
|
}
|
||||||
|
err = conn.StartTLS(tlscfg)
|
||||||
|
if err != nil && index < len(c.Servers)-1 {
|
||||||
|
log.Println("Failed dial, trying next server")
|
||||||
|
} else if err != nil {
|
||||||
|
log.Println("Servers are all unavailable")
|
||||||
|
return nil, fmt.Errorf("Connection error: %v", err)
|
||||||
|
} else {
|
||||||
|
return &Conn{Conn: conn, Config: c}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Conn{Conn: conn, Config: c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
//Bind authenticates the connection with the given userPrincipalName and password
|
//Bind authenticates the connection with the given userPrincipalName and password
|
||||||
//and returns the result or an error if one occurred.
|
//and returns the result or an error if one occurred.
|
||||||
func (c *Conn) Bind(upn, password string) (bool, error) {
|
func (c *Conn) Bind(upn, password string) (bool, error) {
|
||||||
|
Loading…
Reference in New Issue
Block a user