diff --git a/.gitignore b/.gitignore index 66fd13c..4ccad48 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,12 @@ # Dependency directories (remove the comment below to include it) # vendor/ + +_*/ +_* +.idea/ + +*.bak +*.env + +fmutex \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d33e7aa --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/bry00/fmutex + +go 1.16 diff --git a/main.go b/main.go new file mode 100644 index 0000000..e5c90f3 --- /dev/null +++ b/main.go @@ -0,0 +1,199 @@ +package main + +import ( + "flag" + "fmt" + mutex "github.com/bry00/fmutex/mutex" + "io/ioutil" + "log" + "os" + "path" + "strings" + "time" +) + +const ( + FLAG_ROOT = "root" + ENV_ROOT = "FMUTEX_ROOT" + FLAG_ID = "id" + FLAG_SILENT = "s" + FLAG_PULSE = "pulse" + FLAG_REFRESH = "refresh" + FLAG_LIMIT = "limit" + FLAG_TIMEOUT = "timeout" +) + +var cmn = struct { // Common flags + Root string + Id string + Silent bool +}{ + Root: ifEmptyStr(os.Getenv(ENV_ROOT), os.TempDir()), + Silent: false, +} + +var lck = struct { // Lock flags + Pulse time.Duration + Refresh time.Duration + Limit time.Duration + Timeout time.Duration +}{ + Pulse: mutex.DefaultPulse, + Refresh: mutex.DefaultRefresh, + Limit: mutex.DefaultDeadTimeout, +} + +const ( + CMD_LOCK = "lock" + CMD_RELEASE = "release" + CMD_UNLOCK = "unlock" // An alias to CMD_RELEASE + CMD_TEST = "test" +) + +var ( + cmdLock *flag.FlagSet + cmdRelease *flag.FlagSet + cmdTest *flag.FlagSet + cmdAll []*flag.FlagSet + cmdNames []string +) + +func init() { + log.SetFlags(0) + log.SetPrefix(fmt.Sprintf("%s: ", getProg(os.Args))) + + flag.Usage = usage + flag.StringVar(&cmn.Root, FLAG_ROOT, cmn.Root, "root directory for mutex(es)") + flag.StringVar(&cmn.Id, FLAG_ID, cmn.Id, "mutex id") + flag.BoolVar(&cmn.Silent, FLAG_SILENT, cmn.Silent, "silent execution") + + cmdLock = flag.NewFlagSet(CMD_LOCK, flag.ExitOnError) + cmdLock.DurationVar(&lck.Pulse, FLAG_PULSE, lck.Pulse, "determines frequency of locking attempts") + cmdLock.DurationVar(&lck.Refresh, FLAG_REFRESH, lck.Refresh, "determines frequency of saving current timestamp in a locking file") + cmdLock.DurationVar(&lck.Limit, FLAG_LIMIT, lck.Limit, "determines how long takes to consider given mutex as \"dead\"") + cmdLock.DurationVar(&lck.Timeout, FLAG_TIMEOUT, lck.Timeout, "locking timeout (if > 0)") + + cmdRelease = flag.NewFlagSet(CMD_RELEASE, flag.ExitOnError) + cmdTest = flag.NewFlagSet(CMD_TEST, flag.ExitOnError) + + cmdAll, cmdNames = mkCommands(cmdLock, cmdRelease, cmdTest) + + flag.Parse() +} + +func main() { + + if isEmptyStr(cmn.Id) { + log.Fatalf("Flag -%s is required.", FLAG_ID) + } + + if flag.NArg() < 1 { + log.Fatalf("Parameter error - expected command, one of: %s", strings.Join(cmdNames, ", ")) + } + + if cmn.Silent { + log.SetOutput(ioutil.Discard) + } + switch flag.Arg(0) { + case CMD_LOCK: + cmdLock.Parse(flag.Args()[1:]) + doLock() + if !cmn.Silent { + fmt.Println("LOCKED") + } + case CMD_RELEASE, CMD_UNLOCK: + cmdRelease.Parse(flag.Args()[1:]) + doUnlock() + if !cmn.Silent { + fmt.Println("RELEASED") + } + case CMD_TEST: + cmdTest.Parse(flag.Args()[1:]) + os.Exit(doTest()) + + default: + log.Fatalf("Fatal parameter error - unknown command \"%s\", valid commands are: %s", flag.Arg(0), + strings.Join(cmdNames, ", ")) + } +} + +func doTest() int { + m := newMutex() + lockPath := m.LockPath() + if tm := m.When(); tm.IsZero() { + log.Printf("Mutex \"%s\" (%s) is unlocked", m.Id(), lockPath) + return 1 + } else { + log.Printf("Mutex \"%s\" (%s) is locked: %s", m.Id(), lockPath, tm.Format(time.RFC3339)) + } + return 0 +} + +func doLock() { + m := newMutex() + if err := m.TryLock(lck.Timeout); err != nil { + log.Fatalf("Cannot lock mutex \"%s\": %v", m.Id(), err) + } +} + +func doUnlock() { + m := newMutex() + if err := m.TryUnlock(); err != nil { + log.Fatalf("Cannot unlock mutex \"%s\": %v", m.Id(), err) + } +} + +func newMutex() *mutex.Mutex { + result, err := mutex.NewMutexExt(cmn.Root, cmn.Id, lck.Pulse, lck.Refresh, lck.Limit) + if err != nil { + log.Fatalf("Cannot create mutex \"%s\": %v", result.Id(), err) + } + return result +} + +func ifEmptyStr(str string, defaultStr string) string { + if isEmptyStr(str) { + return defaultStr + } + return str +} + +func isEmptyStr(str string) bool { + return strings.TrimSpace(str) == "" +} +func getProg(args []string) string { + base := path.Base(args[0]) + if i := strings.LastIndex(base, "."); i < 0 { + return base + } else { + return base[0:i] + } +} + +func mkCommands(cmds ...*flag.FlagSet) ([]*flag.FlagSet, []string) { + var result []string + for _, c := range cmds { + result = append(result, c.Name()) + } + return cmds, result +} + +func usage() { + prog := getProg(os.Args) + fmt.Fprintf(os.Stderr, "Program %s s designated to lock/unlock file-based mutexes.\n"+ + "Usage:\n"+ + "\t%s [options] {%s} [command-specific options]\n\n"+ + "options:\n", + prog, prog, strings.Join(cmdNames, ", ")) + flag.PrintDefaults() + + for _, c := range cmdAll { + var options = 0 + c.VisitAll(func(_ *flag.Flag) { options++ }) + if options > 0 { + fmt.Fprintf(os.Stderr, "\n%s's options:\n", c.Name()) + c.PrintDefaults() + } + } + fmt.Fprintln(os.Stderr) +} diff --git a/mutex/mutex.go b/mutex/mutex.go new file mode 100644 index 0000000..80f7021 --- /dev/null +++ b/mutex/mutex.go @@ -0,0 +1,205 @@ +package mutex + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" +) + +// A Mutex is a mutual exclusion lock based on filesystem primitives. +type Mutex struct { + id string + directory string + deadAgeRecovery time.Duration + pulse time.Duration + refresh time.Duration +} + +// DefaultPulse determines default frequency of locking attempts, i.e. defines delay between subsequent locking attempts. +const DefaultPulse = 500 * time.Millisecond + +// DefaultRefresh determines default frequency of saving current timestamp in a locking file. +const DefaultRefresh = 10 * time.Second + +// DefaultDeadTimeout determines how long takes to consider given mutex as "dead". +// "Dead" mutexes are removed during locking attempts. +const DefaultDeadTimeout = 60 * time.Minute + +// A lockCandidateTemplate defines locking candidate file name template. +const lockCandidateTemplate = "%s-candidate-*.tmp" + +// A lockTemplate defines locking file name template. +const lockTemplate = "%s-mutex.lck" + +// Id return given Mutex id. +func (m *Mutex) Id() string { + return m.id +} + +// Lock locks given Mutex. Panics in case of any error. Conforms to the sync.Locker interface. +func (m *Mutex) Lock() { + if err := m.TryLock(0); err != nil { + panic(err) + } +} + +// Unlock unlocks given Mutex. Panics in case of any error. Conforms to the sync.Locker interface. +func (m *Mutex) Unlock() { + if err := m.TryUnlock(); err != nil { + panic(err) + } +} + +// TryLock tries to lock given Mutex and returns error in case of failure. +// If timeout is greater than 0, the unsuccessful lock attempt is failed after timeout. +func (m *Mutex) TryLock(timeout time.Duration) error { + ctx := context.Background() + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + return m.LockWithContext(ctx) +} + +// TryUnlock unlocks given Mutex or returns error in case of failure. +func (m *Mutex) TryUnlock() error { + return os.Remove(m.LockPath()) +} + +// LockWithContext waits indefinitely to acquire given Mutex with timeout governed by passed context +// or returns error in case of failure. +func (m *Mutex) LockWithContext(ctx context.Context) error { + candidateLock, err := ioutil.TempFile(m.directory, fmt.Sprintf(lockCandidateTemplate, m.id)) + if err != nil { + return fmt.Errorf("cannot create candidate lock %s: %w", m.id, err) + } + candidateLock.Close() + candidate := candidateLock.Name() + defer os.Remove(candidate) // clean up + + target := m.LockPath() + + var lastTimestamp int64 = 0 + for { + if lastTimestamp == 0 || now()-lastTimestamp > millis(m.refresh) { + if f, err := os.Create(candidateLock.Name()); err == nil { + if lastTimestamp, err = writeCurrentTimestamp(f); err != nil { + return fmt.Errorf("cannot write current timestamp for candidate lock %s: %w", m.id, err) + } + } + if m.deadAgeRecovery >= 0 { + if otherTimestamp := readTimestamp(target); otherTimestamp > 0 { + if now()-otherTimestamp > millis(m.deadAgeRecovery) { + os.Remove(target) + time.Sleep(m.pulse * 2) + } + } + } + } + if err := os.Link(candidate, target); err == nil { + if now()-lastTimestamp > millis(m.refresh) { + if f, err := os.Create(target); err == nil { + _, err = writeCurrentTimestamp(f) + } + if err != nil { + return fmt.Errorf("cannot write current timestamp for target lock %s: %w", m.id, err) + } + } + return nil + } + if sleepOrDone(ctx, m.pulse) { + return errors.New("expired") + } + } +} + +func NewMutex(root string, lockId string) (*Mutex, error) { + return NewMutexExt(root, lockId, DefaultPulse, DefaultRefresh, DefaultDeadTimeout) +} + +func NewMutexExt(root string, lockId string, pulse time.Duration, refresh time.Duration, deadTimeout time.Duration) (*Mutex, error) { + if !filepath.IsAbs(root) { + var err error + //return nil, fmt.Errorf("root (%s) is NOT an absolute absolute path", root) + if root, err = filepath.Abs(root); err != nil { + return nil, err + } + } + dir := path.Join(root, lockId) + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, fmt.Errorf("cannot create directory (%s): %w", root, err) + } + if pulse <= 0 { + pulse = DefaultPulse + } + if refresh <= 0 { + refresh = DefaultRefresh + } + return &Mutex{ + id: strings.ToLower(lockId), + directory: dir, + deadAgeRecovery: deadTimeout, + pulse: pulse, + refresh: refresh, + }, nil +} + +// LockPath returns the path of the lock file +func (m *Mutex) LockPath() string { + return path.Join(m.directory, fmt.Sprintf(lockTemplate, m.id)) +} + +// When returns time of when a given mutex has been created or "zero time" if mutext is in unlocked state +func (m *Mutex) When() time.Time { + if tm := readTimestamp(m.LockPath()); tm != 0 { + return time.Unix(0, tm*int64(time.Millisecond)) + } + return time.Time{} +} + +func sleepOrDone(ctx context.Context, delay time.Duration) bool { + select { + case <-ctx.Done(): + return true + case <-time.After(delay): + } + return false +} + +func nano2Millis(v int64) int64 { + return v / 1000000 +} + +func millis(d time.Duration) int64 { + return nano2Millis(int64(d)) +} + +func now() int64 { + return nano2Millis(time.Now().UnixNano()) +} + +func readTimestamp(fileName string) int64 { + if b, err := ioutil.ReadFile(fileName); err == nil { + if value, err := strconv.ParseInt(strings.TrimSpace(string(b)), 10, 64); err == nil { + return value + } + } + return 0 +} + +func writeCurrentTimestamp(f *os.File) (int64, error) { + defer f.Close() + timestamp := now() + if _, err := f.Write([]byte(fmt.Sprintf("%d\n", timestamp))); err != nil { + return timestamp, err + } + return timestamp, nil +} diff --git a/mutex/mutex_test.go b/mutex/mutex_test.go new file mode 100644 index 0000000..6826e70 --- /dev/null +++ b/mutex/mutex_test.go @@ -0,0 +1,80 @@ +package mutex + +import ( + "os" + "sync" + "testing" + "time" +) + +func temporaryCatalog(t *testing.T) string { + tempDir, err := os.MkdirTemp("", "temp-*.dir") + if err != nil { + t.Fatalf("error creating temporary directory: %v", err) + } + t.Cleanup(func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Errorf("error removing temporary directory: %v", err) + } + }) + return tempDir +} + +func TestSimpleMutex(t *testing.T) { + const mutexId = "simple-test-mutex" + mutexRoot := temporaryCatalog(t) + mx, err := NewMutex(mutexRoot, mutexId) + if err != nil { + t.Fatalf("cannot create the mutex: %v", err) + } + value := 0 + mx.Lock() + go func(v *int) { + mx.Lock() + defer mx.Unlock() + want := 33 + if *v != want { + t.Fatalf("wrong value %d instead of %d", *v, want) + } + }(&value) + value = 33 + mx.Unlock() +} + +func TestSimpleMutexN(t *testing.T) { + const mutexId = "simple-test-mutex" + var wg sync.WaitGroup + + mutexRoot := temporaryCatalog(t) + value := 100 + + mx, err := NewMutex(mutexRoot, mutexId) + if err != nil { + t.Fatalf("cannot create the mutex: %v", err) + } + mx.Lock() + for i := 0; i < 100; i++ { + wg.Add(1) + go func(wg *sync.WaitGroup, v *int) { + defer wg.Done() + lmx, err := NewMutex(mutexRoot, mutexId) + if err != nil { + t.Fatalf("cannot create the mutex: %v", err) + } + lmx.Lock() + defer lmx.Unlock() + *v += 1 + }(&wg, &value) + } + time.Sleep(10 * time.Millisecond) + want := 100 + if value != want { + t.Fatalf("wrong value %d instead of %d", value, want) + } + mx.Unlock() + wg.Wait() + want = 200 + if value != want { + t.Fatalf("wrong value %d instead of %d", value, want) + } +}