BackGo/pkg/backup/smb.go

116 lines
2.7 KiB
Go
Raw Normal View History

package backup
import (
"fmt"
"io"
"net"
"os"
"github.com/hirochachacha/go-smb2"
)
// SMBClient wraps the SMB client and session
type SMBClient struct {
Share string
Session *smb2.Session
}
// NewSMBClient initializes and returns a new SMB client
func NewSMBClient(server, user, pass, domain, share string) (*SMBClient, error) {
conn, err := net.Dial("tcp", server)
if err != nil {
return nil, fmt.Errorf("failed to dial: %v", err)
}
defer conn.Close()
d := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: user,
Password: pass,
Domain: domain,
},
}
session, err := d.Dial(conn)
if err != nil {
return nil, fmt.Errorf("failed to start SMB session: %v", err)
}
return &SMBClient{
Share: share,
Session: session,
}, nil
}
// UploadFile uploads a file to the specified SMB share
func (c *SMBClient) UploadFile(remotePath, localPath string) error {
fs, err := c.Session.Mount(c.Share)
if err != nil {
return fmt.Errorf("failed to mount share: %v", err)
}
defer fs.Umount()
localFile, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("failed to open local file: %v", err)
}
defer localFile.Close()
remoteFile, err := fs.Create(remotePath)
if err != nil {
return fmt.Errorf("failed to create remote file: %v", err)
}
defer remoteFile.Close()
_, err = io.Copy(remoteFile, localFile)
if err != nil {
return fmt.Errorf("failed to write to remote file: %v", err)
}
return nil
}
// DeleteFile deletes a file from the specified SMB share
func (c *SMBClient) DeleteFile(remotePath string) error {
fs, err := c.Session.Mount(c.Share)
if err != nil {
return fmt.Errorf("failed to mount share: %v", err)
}
defer fs.Umount()
err = fs.Remove(remotePath)
if err != nil {
return fmt.Errorf("failed to delete remote file: %v", err)
}
return nil
}
// ListFiles lists all files (backups) in the specified SMB share
func (c *SMBClient) ListFiles(remoteDir string) ([]string, error) {
fs, err := c.Session.Mount(c.Share)
if err != nil {
return nil, fmt.Errorf("failed to mount share: %v", err)
}
defer fs.Umount()
entries, err := fs.ReadDir(remoteDir)
if err != nil {
return nil, fmt.Errorf("failed to read directory: %v", err)
}
var files []string
for _, entry := range entries {
if !entry.IsDir() {
files = append(files, entry.Name())
}
}
return files, nil
}
// SetShare sets the SMB share for the client
func (c *SMBClient) SetShare(share string) {
c.Share = share
}