Finish auth & pass test

This commit is contained in:
Wancat
2022-10-20 09:53:40 +08:00
parent b7bad122e6
commit c23c781c63
2 changed files with 65 additions and 10 deletions

View File

@@ -1,9 +1,11 @@
package auth
import (
"bufio"
"errors"
"io/ioutil"
"log"
"fmt"
"os"
"strings"
"golang.org/x/crypto/bcrypt"
)
@@ -20,11 +22,48 @@ type Htpasswd struct {
}
func NewHtpasswd(path string) (AuthStore, error) {
_, err := ioutil.ReadFile(path)
if err != nil {
return Htpasswd{}, err
s := Htpasswd{
filePath: path,
}
return Htpasswd{}, nil
err := s.read()
return s, err
}
func (s *Htpasswd) read() (err error) {
file, err := os.Open(s.filePath)
if err != nil {
return err
}
defer file.Close()
fileScanner := bufio.NewScanner(file)
fileScanner.Split(bufio.ScanLines)
s.accounts = make(map[string]string)
for fileScanner.Scan() {
arr := strings.SplitN(fileScanner.Text(), ":", 2)
if len(arr) < 2 {
return fmt.Errorf("invalid data %s", arr)
}
s.accounts[arr[0]] = arr[1]
}
return nil
}
func (s *Htpasswd) write() (err error) {
file, err := os.OpenFile(s.filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to open htpasswd file: %w", err)
}
defer file.Close()
for u, p := range s.accounts {
_, err = fmt.Fprintf(file, "%s:%s\n", u, p)
if err != nil {
return err
}
}
return nil
}
func (s Htpasswd) Register(user, pass string) (err error) {
@@ -32,16 +71,20 @@ func (s Htpasswd) Register(user, pass string) (err error) {
if err != nil {
return
}
log.Println(s.accounts[user])
return
return s.write()
}
func (s Htpasswd) Authenticate(user, pass string) (err error) {
return errors.New("work in progress")
hashed, ok := s.accounts[user]
if !ok {
return errors.New("user not found")
}
return bcrypt.CompareHashAndPassword([]byte(hashed), []byte(pass))
}
func (s Htpasswd) Remove(user string) (err error) {
return errors.New("work in progress")
delete(s.accounts, user)
return s.write()
}
func hash(pass string) (string, error) {

View File

@@ -51,4 +51,16 @@ func TestHtpasswdSuccess(t *testing.T) {
t.Errorf("%s not found in htpasswd file: %s", u.user, string(data))
}
}
err = store.Remove(user1.user)
if err != nil {
t.Error(err)
}
data, err = ioutil.ReadFile(path)
if err != nil {
t.Error(err)
}
if strings.Contains(string(data), user1.user) {
t.Errorf("%s is found in htpasswd file but should be removed: %s", user1.user, string(data))
}
}