diff --git a/cmd_set.go b/cmd_set.go index 056a3093..3484f71e 100644 --- a/cmd_set.go +++ b/cmd_set.go @@ -20,6 +20,7 @@ func commandsSet(m *Miniredis) { m.srv.Register("SINTERSTORE", m.cmdSinterstore) m.srv.Register("SISMEMBER", m.cmdSismember) m.srv.Register("SMEMBERS", m.cmdSmembers) + m.srv.Register("SMISMEMBER", m.cmdSmismember) m.srv.Register("SMOVE", m.cmdSmove) m.srv.Register("SPOP", m.cmdSpop) m.srv.Register("SRANDMEMBER", m.cmdSrandmember) @@ -293,6 +294,47 @@ func (m *Miniredis) cmdSmembers(c *server.Peer, cmd string, args []string) { }) } +// SMISMEMBER +func (m *Miniredis) cmdSmismember(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + key, values := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteInt(0) + return + } + + if db.t(key) != "set" { + c.WriteError(ErrWrongType.Error()) + return + } + + c.WriteLen(len(values)) + for _, value := range values { + if db.setIsMember(key, value) { + c.WriteInt(1) + } else { + c.WriteInt(0) + } + } + return + }) +} + // SMOVE func (m *Miniredis) cmdSmove(c *server.Peer, cmd string, args []string) { if len(args) != 3 { diff --git a/cmd_set_test.go b/cmd_set_test.go index 84cae3c6..4d61650f 100644 --- a/cmd_set_test.go +++ b/cmd_set_test.go @@ -148,6 +148,36 @@ func TestSismember(t *testing.T) { }) } +// Test SMISMEMBER +func TestSmismember(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := proto.Dial(s.Addr()) + ok(t, err) + defer c.Close() + + s.SetAdd("s", "aap", "noot", "mies") + + mustDo(t, c, "SMISMEMBER", "s", "aap", "nosuch", "mies", proto.Ints(1, 0, 1)) + + t.Run("errors", func(t *testing.T) { + mustOK(t, c, "SET", "str", "value") + mustDo(t, c, + "SMISMEMBER", "str", "foo", + proto.Error(msgWrongType), + ) + mustDo(t, c, + "SMISMEMBER", + proto.Error(errWrongNumber("smismember")), + ) + mustDo(t, c, + "SMISMEMBER", "set", + proto.Error(errWrongNumber("smismember")), + ) + }) +} + // Test SREM func TestSrem(t *testing.T) { s, err := Run() diff --git a/integration/set_test.go b/integration/set_test.go index 375d8646..c52e7b1c 100644 --- a/integration/set_test.go +++ b/integration/set_test.go @@ -16,9 +16,11 @@ func TestSet(t *testing.T) { c.DoSorted("SMEMBERS", "nosuch") c.Do("SISMEMBER", "s", "aap") c.Do("SISMEMBER", "s", "nosuch") + c.Do("SMISMEMBER", "q", "aap", "noot", "nosuch") c.Do("SCARD", "nosuch") c.Do("SISMEMBER", "nosuch", "nosuch") + c.Do("SMISMEMBER", "nosuch", "nosuch", "nosuch") // failure cases c.Error("wrong number", "SADD") @@ -30,11 +32,14 @@ func TestSet(t *testing.T) { c.Error("wrong number", "SISMEMBER") c.Error("wrong number", "SISMEMBER", "few") c.Error("wrong number", "SISMEMBER", "too", "many", "arguments") + c.Error("wrong number", "SMISMEMBER") + c.Error("wrong number", "SMISMEMBER", "few") // Wrong type c.Do("SET", "str", "I am a string") c.Error("wrong kind", "SADD", "str", "noot", "mies") c.Error("wrong kind", "SMEMBERS", "str") c.Error("wrong kind", "SISMEMBER", "str", "noot") + c.Error("wrong kind", "SMISMEMBER", "str", "noot") c.Error("wrong kind", "SCARD", "str") }) @@ -44,6 +49,7 @@ func TestSet(t *testing.T) { c.Do("SMEMBERS", "q") c.Do("SISMEMBER", "q", "aap") c.Do("SISMEMBER", "q", "noot") + c.Do("SMISMEMBER", "q", "aap", "noot", "nosuch") }) }