az-dns/main.go
2025-10-20 14:33:53 +02:00

585 lines
14 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"os"
"strconv"
"strings"
"time"
)
const (
defaultTimeout = 10 * time.Second
hetznerAPIBase = "https://api.hetzner.cloud/v1"
ipv4DiscoverURL = "https://api.ipify.org"
ipv6DiscoverURL = "https://api6.ipify.org"
version = "1.0.0"
)
type logLevel int
const (
logLevelDebug logLevel = iota
logLevelInfo
logLevelWarn
)
type config struct {
Token string
Hostname string
Domain string
TTL int
DryRun bool
LogLevel logLevel
Logger *logger
}
type zone struct {
ID int `json:"id"`
Name string `json:"name"`
TTL int `json:"ttl"`
}
type zonesResponse struct {
Zones []zone `json:"zones"`
}
type rrsetRecord struct {
Value string `json:"value"`
}
type rrset struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
TTL *int `json:"ttl"`
Records []rrsetRecord `json:"records"`
}
type rrsetsResponse struct {
RRSets []rrset `json:"rrsets"`
}
func main() {
switch {
case showHelp(os.Args[1:]):
printHelp()
return
case showVersion(os.Args[1:]):
fmt.Println(version)
return
}
cfg, err := loadConfig()
if err != nil {
fail("configuration error: %v", err)
}
client := &http.Client{Timeout: defaultTimeout}
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
cfg.Logger.Infof("Starting az-dns (dry-run=%t, hostname=%q, domain=%q, ttl=%d)", cfg.DryRun, cfg.Hostname, cfg.Domain, cfg.TTL)
ipv4, err := discoverIP(ctx, client, ipv4DiscoverURL)
if err != nil {
cfg.Logger.Warnf("IPv4 lookup failed: %v", err)
} else {
cfg.Logger.Debugf("Detected IPv4 %s", ipv4)
}
ipv6, err := discoverIP(ctx, client, ipv6DiscoverURL)
if err != nil {
cfg.Logger.Warnf("IPv6 lookup failed: %v", err)
} else {
cfg.Logger.Debugf("Detected IPv6 %s", ipv6)
}
if ipv4 == "" && ipv6 == "" {
fail("unable to determine public IPv4 or IPv6 address")
}
z, err := lookupZone(ctx, client, cfg, cfg.Domain)
if err != nil {
fail("zone lookup failed: %v", err)
}
cfg.Logger.Infof("Using zone %s (id=%d)", z.Name, z.ID)
if ipv4 != "" {
if err := upsertRecord(ctx, client, cfg, z, cfg.Hostname, "A", ipv4); err != nil {
fail("updating A record failed: %v", err)
}
cfg.Logger.Infof("A record for %s set to %s", cfg.fqdn(), ipv4)
}
if ipv6 != "" {
if err := upsertRecord(ctx, client, cfg, z, cfg.Hostname, "AAAA", ipv6); err != nil {
fail("updating AAAA record failed: %v", err)
}
cfg.Logger.Infof("AAAA record for %s set to %s", cfg.fqdn(), ipv6)
}
}
func loadConfig() (config, error) {
token := strings.TrimSpace(os.Getenv("HETZNER_TOKEN"))
if token == "" {
return config{}, errors.New("HETZNER_TOKEN is not set")
}
host := strings.TrimSpace(os.Getenv("HOSTNAME"))
if host == "" {
return config{}, errors.New("HOSTNAME is not set")
}
domain := strings.TrimSpace(os.Getenv("DOMAIN"))
if domain == "" {
return config{}, errors.New("DOMAIN is not set")
}
domain = normalizeDomain(domain)
ttlStr := strings.TrimSpace(os.Getenv("TTL"))
if ttlStr == "" {
return config{}, errors.New("TTL is not set")
}
ttl, err := strconv.Atoi(ttlStr)
if err != nil || ttl <= 0 {
return config{}, fmt.Errorf("invalid TTL: %q", ttlStr)
}
host = normalizeHostname(host, domain)
dryRun := parseBoolEnv(os.Getenv("DRY_RUN"))
level, err := parseLogLevel(strings.TrimSpace(os.Getenv("LOG_LEVEL")))
if err != nil {
return config{}, err
}
return config{
Token: token,
Hostname: host,
Domain: domain,
TTL: ttl,
DryRun: dryRun,
LogLevel: level,
Logger: newLogger(level),
}, nil
}
func normalizeHostname(host, domain string) string {
host = strings.TrimSpace(host)
host = strings.TrimSuffix(host, ".")
domain = normalizeDomain(domain)
if host == "" || host == "@" {
return "@"
}
hostLower := strings.ToLower(host)
if strings.HasSuffix(hostLower, domain) {
hostLower = strings.TrimSuffix(hostLower, domain)
hostLower = strings.TrimSuffix(hostLower, ".")
if hostLower == "" {
return "@"
}
}
return hostLower
}
func discoverIP(ctx context.Context, client *http.Client, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
q := req.URL.Query()
q.Set("format", "text")
req.URL.RawQuery = q.Encode()
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return "", fmt.Errorf("unexpected status %s: %s", resp.Status, string(body))
}
data, err := io.ReadAll(io.LimitReader(resp.Body, 128))
if err != nil {
return "", err
}
ip := strings.TrimSpace(string(data))
if ip == "" {
return "", errors.New("empty response")
}
if _, err := netip.ParseAddr(ip); err != nil {
return "", fmt.Errorf("invalid IP response: %q", ip)
}
return ip, nil
}
func lookupZone(ctx context.Context, client *http.Client, cfg config, domain string) (zone, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, hetznerAPIBase+"/zones", nil)
if err != nil {
return zone{}, err
}
req.Header.Set("Authorization", "Bearer "+cfg.Token)
q := req.URL.Query()
q.Set("name", normalizeDomain(domain))
req.URL.RawQuery = q.Encode()
resp, err := client.Do(req)
if err != nil {
return zone{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return zone{}, fmt.Errorf("zones query failed with %s: %s", resp.Status, readErrorBody(resp.Body))
}
var zr zonesResponse
if err := json.NewDecoder(resp.Body).Decode(&zr); err != nil {
return zone{}, err
}
for _, z := range zr.Zones {
if strings.EqualFold(z.Name, normalizeDomain(domain)) {
return z, nil
}
}
return zone{}, fmt.Errorf("zone %q not found", domain)
}
func upsertRecord(ctx context.Context, client *http.Client, cfg config, z zone, host, recordType, value string) error {
name := rrsetName(host)
target := joinFQDN(name, cfg.Domain)
current, err := getRRSet(ctx, client, cfg, z, name, recordType)
switch {
case errors.Is(err, errRRSetNotFound):
if cfg.DryRun {
cfg.Logger.Infof("Dry-run: would create RRSet %s type %s with value %s (ttl=%d)", target, recordType, value, cfg.TTL)
return nil
}
cfg.Logger.Infof("Creating RRSet %s type %s with value %s (ttl=%d)", target, recordType, value, cfg.TTL)
return createRRSet(ctx, client, cfg, z, name, recordType, value, cfg.TTL)
case err != nil:
return err
}
currentTTL := z.TTL
if current.TTL != nil {
currentTTL = *current.TTL
}
if cfg.DryRun {
cfg.Logger.Infof("Dry-run: would update RRSet %s type %s to value %s (ttl=%d)", target, recordType, value, cfg.TTL)
if !ttlMatches(current.TTL, cfg.TTL, z.TTL) {
cfg.Logger.Infof("Dry-run: would change TTL for %s type %s from %d to %d", target, recordType, currentTTL, cfg.TTL)
}
return nil
}
cfg.Logger.Infof("Updating RRSet %s type %s to value %s", target, recordType, value)
if err := setRRSetRecords(ctx, client, cfg, z, name, recordType, value); err != nil {
return err
}
if !ttlMatches(current.TTL, cfg.TTL, z.TTL) {
cfg.Logger.Infof("Changing TTL for %s type %s from %d to %d", target, recordType, currentTTL, cfg.TTL)
if err := changeRRSetTTL(ctx, client, cfg, z, name, recordType, cfg.TTL); err != nil {
return err
}
} else {
cfg.Logger.Debugf("TTL for %s type %s unchanged (%d)", target, recordType, currentTTL)
}
return nil
}
var errRRSetNotFound = errors.New("rrset not found")
func getRRSet(ctx context.Context, client *http.Client, cfg config, z zone, name, recordType string) (rrset, error) {
endpoint := fmt.Sprintf("%s/zones/%d/rrsets", hetznerAPIBase, z.ID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return rrset{}, err
}
req.Header.Set("Authorization", "Bearer "+cfg.Token)
q := req.URL.Query()
q.Set("name", name)
q.Add("type", recordType)
req.URL.RawQuery = q.Encode()
resp, err := client.Do(req)
if err != nil {
return rrset{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return rrset{}, fmt.Errorf("rrset lookup failed with %s: %s", resp.Status, readErrorBody(resp.Body))
}
var result rrsetsResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return rrset{}, err
}
if len(result.RRSets) == 0 {
return rrset{}, errRRSetNotFound
}
return result.RRSets[0], nil
}
func createRRSet(ctx context.Context, client *http.Client, cfg config, z zone, name, recordType, value string, ttl int) error {
payload := struct {
Name string `json:"name"`
Type string `json:"type"`
TTL int `json:"ttl"`
Records []rrsetRecord `json:"records"`
}{
Name: name,
Type: recordType,
TTL: ttl,
Records: []rrsetRecord{{Value: value}},
}
body, err := json.Marshal(payload)
if err != nil {
return err
}
endpoint := fmt.Sprintf("%s/zones/%d/rrsets", hetznerAPIBase, z.ID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+cfg.Token)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
return fmt.Errorf("create rrset failed with %s: %s", resp.Status, readErrorBody(resp.Body))
}
return nil
}
func setRRSetRecords(ctx context.Context, client *http.Client, cfg config, z zone, name, recordType, value string) error {
payload := struct {
Records []rrsetRecord `json:"records"`
}{
Records: []rrsetRecord{{Value: value}},
}
body, err := json.Marshal(payload)
if err != nil {
return err
}
endpoint := fmt.Sprintf("%s/zones/%d/rrsets/%s/%s/actions/set_records", hetznerAPIBase, z.ID, url.PathEscape(name), url.PathEscape(recordType))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+cfg.Token)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("set rrset records failed with %s: %s", resp.Status, readErrorBody(resp.Body))
}
return nil
}
func changeRRSetTTL(ctx context.Context, client *http.Client, cfg config, z zone, name, recordType string, ttl int) error {
payload := struct {
TTL int `json:"ttl"`
}{
TTL: ttl,
}
body, err := json.Marshal(payload)
if err != nil {
return err
}
endpoint := fmt.Sprintf("%s/zones/%d/rrsets/%s/%s/actions/change_ttl", hetznerAPIBase, z.ID, url.PathEscape(name), url.PathEscape(recordType))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+cfg.Token)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("change rrset ttl failed with %s: %s", resp.Status, readErrorBody(resp.Body))
}
return nil
}
func fail(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, format+"\n", args...)
os.Exit(1)
}
func (c config) fqdn() string {
switch c.Hostname {
case "@":
return c.Domain
default:
return fmt.Sprintf("%s.%s", c.Hostname, c.Domain)
}
}
func normalizeDomain(domain string) string {
d := strings.TrimSuffix(domain, ".")
return strings.ToLower(d)
}
func rrsetName(host string) string {
if host == "@" {
return host
}
return strings.ToLower(host)
}
func joinFQDN(host, domain string) string {
if host == "" || host == "@" {
return domain
}
return fmt.Sprintf("%s.%s", host, domain)
}
func ttlMatches(current *int, desired, zoneDefault int) bool {
if current == nil {
return desired == zoneDefault
}
return *current == desired
}
func readErrorBody(r io.Reader) string {
data, err := io.ReadAll(io.LimitReader(r, 2048))
if err != nil {
return fmt.Sprintf("unable to read error body: %v", err)
}
return strings.TrimSpace(string(data))
}
func parseBoolEnv(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func parseLogLevel(value string) (logLevel, error) {
if value == "" {
return logLevelWarn, nil
}
switch strings.ToLower(value) {
case "debug":
return logLevelDebug, nil
case "info":
return logLevelInfo, nil
case "warn", "warning":
return logLevelWarn, nil
default:
return logLevelWarn, fmt.Errorf("invalid LOG_LEVEL %q (allowed: debug, info, warn)", value)
}
}
type logger struct {
level logLevel
}
func newLogger(level logLevel) *logger {
return &logger{level: level}
}
func (l *logger) logf(level logLevel, prefix string, format string, args ...interface{}) {
if l == nil || level < l.level {
return
}
msg := fmt.Sprintf(format, args...)
switch level {
case logLevelWarn:
fmt.Fprintf(os.Stderr, "[%s] %s\n", prefix, msg)
default:
fmt.Fprintf(os.Stdout, "[%s] %s\n", prefix, msg)
}
}
func (l *logger) Debugf(format string, args ...interface{}) {
l.logf(logLevelDebug, "DEBUG", format, args...)
}
func (l *logger) Infof(format string, args ...interface{}) {
l.logf(logLevelInfo, "INFO", format, args...)
}
func (l *logger) Warnf(format string, args ...interface{}) {
l.logf(logLevelWarn, "WARN", format, args...)
}
func showHelp(args []string) bool {
for _, a := range args {
if a == "-h" || a == "--help" {
return true
}
}
return false
}
func showVersion(args []string) bool {
for _, a := range args {
if a == "-v" || a == "--version" {
return true
}
}
return false
}
func printHelp() {
fmt.Println(`Usage: az-dns [options]
Environment variables (required):
HETZNER_TOKEN Hetzner Cloud API Token mit DNS-Schreibrechten
HOSTNAME Hostname innerhalb der Zone (z. B. "@", "vpn")
DOMAIN Zonendomain (z. B. "example.com")
TTL Time-to-live der Records in Sekunden
Optionale Environment-Variablen:
LOG_LEVEL "debug", "info" oder "warn" (Standard: "warn")
DRY_RUN "true"/"1" für Trockenlauf ohne tatsächliche API-Calls
Optionen:
-h, --help Diese Hilfe anzeigen
-v, --version Version ausgeben
`)
}