diff --git a/.gitignore b/.gitignore index 9d97eab..8ffc286 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.ini /adradius *.pid -*.log \ No newline at end of file +*.log +*.pdf \ No newline at end of file diff --git a/README.md b/README.md index 1123ce6..a67582b 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,12 @@ make ```ini [adradius] -server=localhost -port=389 -basedn=dc=example,dc=com -secret=secret -tls=true -listen=localhost:1812 +servers=ad1,ad2 # comma separated list of authentication servers hostnames as a slice +port=389 # ldap port as an integer +basedn="dc=example,dc=com" # ldap basedn as string +tls=false # use tls for ldap as a bool +secret=secret # radius secret +listen=localhost:1812 # listen params for radius server as a string ``` ### Run diff --git a/cmd/adradius/adradius.go b/cmd/adradius/adradius.go index d729120..b07a102 100644 --- a/cmd/adradius/adradius.go +++ b/cmd/adradius/adradius.go @@ -3,6 +3,9 @@ package main import ( "log" + "net/http" + _ "net/http/pprof" + "git.paulbsd.com/paulbsd/adradius/src/adradius" "git.paulbsd.com/paulbsd/adradius/src/config" "github.com/sevlyar/go-daemon" @@ -42,8 +45,13 @@ func main() { log.Fatal(err) } + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + adradius.RunServer(&cfg, ldapcfg) if err != nil { log.Fatal(err) } + } diff --git a/src/adradius/adradius.go b/src/adradius/adradius.go index 741736b..6207d21 100644 --- a/src/adradius/adradius.go +++ b/src/adradius/adradius.go @@ -20,15 +20,13 @@ func SetADRadiusConfig(c *config.Config) (ldapconfig *auth.Config, err error) { } ldapconfig = &auth.Config{ - Server: c.Server, + Servers: c.Servers, Port: c.Port, BaseDN: c.BaseDN, Security: security, SkipVerify: c.SkipVerify, } - ldapconfig.Connect() - return } @@ -56,8 +54,10 @@ func RunServer(config *config.Config, ldapconfig *auth.Config) { if status { code = radius.CodeAccessAccept + log.Printf("Successful login for user %s", username) } else { code = radius.CodeAccessReject + log.Printf("Failed login for user %s", username) } log.Printf("Writing %v to %v", code, r.RemoteAddr) w.Write(r.Response(code)) diff --git a/src/config/main.go b/src/config/main.go index ac6fe52..996e037 100644 --- a/src/config/main.go +++ b/src/config/main.go @@ -2,6 +2,7 @@ package config import ( "flag" + "fmt" auth "git.paulbsd.com/paulbsd/adradius/src/go-ad-auth" "git.paulbsd.com/paulbsd/adradius/utils" @@ -21,7 +22,10 @@ func (c *Config) GetConfig() (err error) { } adradiusSection := config.Section("adradius") - c.Server = adradiusSection.Key("server").MustString("localhost") + c.Servers = adradiusSection.Key("servers").Strings(",") + if len(c.Servers) < 1 { + return fmt.Errorf("No servers provided in config") + } c.Port = adradiusSection.Key("port").MustInt(389) c.BaseDN = adradiusSection.Key("basedn").MustString("dc=example,dc=com") c.TLS = adradiusSection.Key("tls").MustBool() @@ -29,17 +33,13 @@ func (c *Config) GetConfig() (err error) { c.Secret = adradiusSection.Key("secret").MustString("secret") c.SkipVerify = adradiusSection.Key("skipverify").MustBool() - if err != nil { - return - } - return nil } // Config is the main configuration type Config struct { ConfigPath string - Server string + Servers []string Port int BaseDN string TLS bool diff --git a/src/go-ad-auth/config.go b/src/go-ad-auth/config.go index c5c15f7..3ca4f55 100644 --- a/src/go-ad-auth/config.go +++ b/src/go-ad-auth/config.go @@ -19,7 +19,7 @@ const ( //Config contains settings for connecting to an Active Directory server. type Config struct { - Server string + Servers []string Port int BaseDN string Security SecurityType diff --git a/src/go-ad-auth/conn.go b/src/go-ad-auth/conn.go index 8bb38ef..aff958c 100644 --- a/src/go-ad-auth/conn.go +++ b/src/go-ad-auth/conn.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "errors" "fmt" + "log" ldap "gopkg.in/ldap.v3" ) @@ -16,40 +17,93 @@ type Conn struct { //Connect returns an open connection to an Active Directory server or an error if one occurred. func (c *Config) Connect() (*Conn, error) { - - tlscfg := &tls.Config{ServerName: c.Server} - if c.SkipVerify { - tlscfg.InsecureSkipVerify = true - } - switch c.Security { case SecurityNone: - conn, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port)) - if err != nil { - return nil, fmt.Errorf("Connection error: %v", err) - } - return &Conn{Conn: conn, Config: c}, nil + conn, err := c.ConnectClear() + return conn, err case SecurityTLS: - conn, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port), tlscfg) - if err != nil { - return nil, fmt.Errorf("Connection error: %v", err) - } - return &Conn{Conn: conn, Config: c}, nil + conn, err := c.ConnectTLS() + return conn, err case SecurityStartTLS: - conn, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", c.Server, c.Port)) - if err != nil { - return nil, fmt.Errorf("Connection error: %v", err) - } - err = conn.StartTLS(&tls.Config{ServerName: c.Server}) - if err != nil { - return nil, fmt.Errorf("Connection error: %v", err) - } - return &Conn{Conn: conn, Config: c}, nil + conn, err := c.ConnectStartTLS() + return conn, err default: return nil, errors.New("Configuration error: invalid SecurityType") } } +//ConnectClear is a function used to connect to ldap server using clear auth +func (c *Config) ConnectClear() (*Conn, error) { + var conn *ldap.Conn + var err error + + for index, server := range c.Servers { + conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", server, c.Port)) + if err != nil && index < len(c.Servers)-1 { + log.Println("Failed dial, trying next server") + } else if err != nil { + log.Println("Servers are all unavailable") + return nil, fmt.Errorf("Connection error: %v", err) + } else { + return &Conn{Conn: conn, Config: c}, nil + } + } + + return &Conn{Conn: conn, Config: c}, nil +} + +//ConnectTLS is a function used to connect to ldap server using TLS +func (c *Config) ConnectTLS() (*Conn, error) { + var conn *ldap.Conn + var err error + + for index, server := range c.Servers { + tlscfg := &tls.Config{ServerName: server} + if c.SkipVerify { + tlscfg.InsecureSkipVerify = true + } + conn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", server, c.Port), tlscfg) + if err != nil && index < len(c.Servers)-1 { + log.Println("Failed dial, trying next server") + } else if err != nil { + log.Println("Servers are all unavailable") + return nil, fmt.Errorf("Connection error: %v", err) + } else { + return &Conn{Conn: conn, Config: c}, nil + } + } + + return &Conn{Conn: conn, Config: c}, nil +} + +//ConnectStartTLS is a function used to connect to ldap server using TLS +func (c *Config) ConnectStartTLS() (*Conn, error) { + var conn *ldap.Conn + var err error + + for index, server := range c.Servers { + tlscfg := &tls.Config{ServerName: server} + if c.SkipVerify { + tlscfg.InsecureSkipVerify = true + } + conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", server, c.Port)) + if err != nil { + return nil, fmt.Errorf("Connection error: %v", err) + } + err = conn.StartTLS(tlscfg) + if err != nil && index < len(c.Servers)-1 { + log.Println("Failed dial, trying next server") + } else if err != nil { + log.Println("Servers are all unavailable") + return nil, fmt.Errorf("Connection error: %v", err) + } else { + return &Conn{Conn: conn, Config: c}, nil + } + } + + return &Conn{Conn: conn, Config: c}, nil +} + //Bind authenticates the connection with the given userPrincipalName and password //and returns the result or an error if one occurred. func (c *Conn) Bind(upn, password string) (bool, error) {