From 5e334761dbf8f68a5f09a710d832060088afbcac Mon Sep 17 00:00:00 2001 From: Ben Davies Date: Fri, 18 Nov 2016 12:13:16 -0400 Subject: [PATCH] Sockets: Go rewrite WIP Fixes #2943 --- package.json | 9 +- rooms.js | 10 +- sockets.go | 513 ++++++++++++++++++++++++++ sockets.js | 944 +++++++++++++++++++++++++++--------------------- sockets_test.go | 286 +++++++++++++++ users.js | 3 +- 6 files changed, 1338 insertions(+), 427 deletions(-) create mode 100644 sockets.go create mode 100644 sockets_test.go diff --git a/package.json b/package.json index fe81ad1ff8883..b32f3f559a9cf 100644 --- a/package.json +++ b/package.json @@ -3,17 +3,10 @@ "preferGlobal": true, "description": "The server for the Pokémon Showdown battle simulator", "version": "0.10.2", - "dependencies": { - "sockjs": "0.3.18" - }, "optionalDependencies": { "cloud-env": "0.1.1", "http-proxy": "0.10.0", - "nodemailer": "1.4.0", - "node-static": "0.7.7" - }, - "nonDefaultDependencies": { - "ofe": "0.1.2" + "nodemailer": "1.4.0" }, "engines": { "node": ">=6.0.0" diff --git a/rooms.js b/rooms.js index 7881493a7c6c8..9c18e6a451c0b 100644 --- a/rooms.js +++ b/rooms.js @@ -702,11 +702,13 @@ class GlobalRoom { } } onConnect(user, connection) { + console.log('connected:', user.name, this.id, connection.socketid); let initdata = '|updateuser|' + user.name + '|' + (user.named ? '1' : '0') + '|' + user.avatar + '\n'; connection.send(initdata + this.formatListText); if (this.chatRooms.length > 2) connection.send('|queryresponse|rooms|null'); // should display room list } onJoin(user, connection) { + console.log('joined:', user.name, this.id, connection.socketid); if (!user) return false; // ??? if (this.users[user.userid]) return user; @@ -1249,10 +1251,12 @@ class BattleRoom extends Room { } } onConnect(user, connection) { + console.log('connected:', user.name, this.id, connection.socketid); this.sendUser(connection, '|init|battle\n|title|' + this.title + '\n' + this.getLogForUser(user).join('\n')); if (this.game && this.game.onConnect) this.game.onConnect(user, connection); } onJoin(user, connection) { + console.log('joined:', user.name, this.id, connection.socketid); if (!user) return false; if (this.users[user.userid]) return user; @@ -1546,6 +1550,7 @@ class ChatRoom extends Room { return message; } onConnect(user, connection) { + console.log('connected:', user.name, this.id, connection.socketid); let userList = this.userList ? this.userList : this.getUserList(); this.sendUser(connection, '|init|chat\n|title|' + this.title + '\n' + userList + '\n' + this.getLogSlice(-100).join('\n') + this.getIntroMessage(user)); if (this.poll) this.poll.onConnect(user, connection); @@ -1556,7 +1561,10 @@ class ChatRoom extends Room { if (this.users[user.userid]) return user; if (user.named) { - this.reportJoin('j', user.getIdentity(this.id)); + // Prevents a race condition where this message would send before + // Connection#joinRoom has a chance to finish, preventing it from + // reaching users joining empty rooms. + process.nextTick(() => this.reportJoin('j', user.getIdentity(this.id))); } this.users[user.userid] = user; diff --git a/sockets.go b/sockets.go new file mode 100644 index 0000000000000..8b1fb0f0d2674 --- /dev/null +++ b/sockets.go @@ -0,0 +1,513 @@ +package main + +import ( + "bufio" + "encoding/json" + "fmt" + "log" + "io" + "net" + "net/http" + "os" + "path/filepath" + // "reflect" + "regexp" + "strings" + "sync" + // "unsafe" + + // TODO: use the stable version of sockjs-go once it includes + // sockjs.Session.Request(). + "github.com/gorilla/mux" + "github.com/igm/sockjs-go/sockjs" +) + +// This silences IPC when unit testing. +var production bool = true + +// IPC delimiter +const EOT byte = 3 + +type SSLOptions struct { + Key byte `json:"key"` + Cert byte `json:"cert"` +} + +type SSL struct { + Port int `json:"port"` + Options SSLOptions `json:"options"` +} + +type Config struct { + Workers int `json:"Workers"` + Port string `json:"Port"` + BindAddress string `json:"BindAddress"` + SSL SSL `json:"SSL"` +} + +func NewConfig(envVar string) (c Config, err error) { + configEnv := os.Getenv(envVar) + err = json.Unmarshal([]byte(configEnv), &c) + return +} + +type Payload struct { + Command string + Params []string +} + +func NewPayload(command string, params ...string) Payload { + p := Payload{ + Command: command, + Params: params} + return p +} + +func (p Payload) WriteToStdout() { + output, _ := json.Marshal(p) + if (production) { + os.Stdout.Write(append(output, EOT)) + } +} + +type Job struct { + Payload Payload +} + +var JobQueue = make(chan Job) + +type Worker struct { + WorkerPool chan chan Job + JobChannel chan Job + quit chan bool +} + +func NewWorker(workerPool chan chan Job) Worker { + return Worker{ + WorkerPool: workerPool, + JobChannel: make(chan Job), + quit: make(chan bool)} +} + +func (w Worker) Start() { + go func() { + for { + w.WorkerPool <- w.JobChannel + select { + case job := <-w.JobChannel: + params := job.Payload.Params + fmt.Printf("sockets: received job: %v %v", job.Payload.Command, params) + switch job.Payload.Command { + case ">": + if err := sm.SocketSend(params[0], params[1]); err != nil { + log.Fatal(err) + } + case "!": + if err := sm.SocketRemove(params[0], true); err != nil { + log.Fatal(err) + } + case "+": + sm.ChannelAdd(params[0], params[1]) + case "-": + if err := sm.ChannelRemove(params[0], params[1]); err != nil { + log.Fatal(err) + } + case "#": + if err := sm.ChannelSend(params[0], params[1]); err != nil { + log.Fatal(err) + } + case ".": + sm.SubchannelMove(params[0], params[1], params[2]) + case ":": + if err := sm.SubchannelSend(params[0], params[1]); err != nil { + log.Fatal(err) + } + } + case <-w.quit: + return + } + } + }() +} + +func (w Worker) Stop() { + go func() { + w.quit <- true + }() +} + +type Dispatcher struct { + WorkerPool chan chan Job + MaxWorkers int +} + +func NewDispatcher(maxWorkers int) *Dispatcher { + return &Dispatcher{ + WorkerPool: make(chan chan Job, maxWorkers), + MaxWorkers: maxWorkers} +} + +func (d *Dispatcher) Run() { + for i := 0; i < d.MaxWorkers; i++ { + worker := NewWorker(d.WorkerPool) + worker.Start() + } + + go d.dispatch() +} + +func (d *Dispatcher) dispatch() { + for { + select { + case job := <-JobQueue: + go func(job Job) { + jobChannel := <-d.WorkerPool + jobChannel <- job + }(job) + } + } +} + +// socketMultiplexer acts as a wrapper for sockjs.handler, exposing its map of +// SockJS sessions to allow the parent process to be able to send messages to +// all users in a room via channels, or either side of a battle and the +// users in its audience via subchannels. socketMultiplexer's methods each +// correspond to a method of global.Sockets in the parent process. +type socketMultiplexer struct { + smux sync.Mutex + sockets map[string]*sockjs.Session + + cmux sync.Mutex + channels map[string]map[string]bool + + scmux sync.Mutex + scre *regexp.Regexp + subchannels map[string]map[string]string +} + +func newSocketMultiplexer() *socketMultiplexer { + return &socketMultiplexer{ + sockets: make(map[string]*sockjs.Session), + channels: make(map[string]map[string]bool), + scre: regexp.MustCompile("\n|split\n([^\n]*)\n([^\n]*)\n([^\n]*)\n[^\n]*"), + subchannels: make(map[string]map[string]string)} +} + +func (sm *socketMultiplexer) SocketAdd(s *sockjs.Session) error { + sm.smux.Lock() + defer sm.smux.Unlock() + + id := (*s).ID() + if _, ok := sm.sockets[id]; ok { + return fmt.Errorf("sockets: error adding socket: collision at ID %v", id) + } + + sm.sockets[id] = s + + // FIXME: payload params is missing the socket's protocol! Screw with + // reflect and unsafe to get it, since it's a private field. + req := (*s).Request() + ip, _, _ := net.SplitHostPort(req.RemoteAddr) + ips := req.Header.Get("X-Forwarded-For") + NewPayload("*", id, ip, ips).WriteToStdout() + return nil +} + +func (sm *socketMultiplexer) SocketRemove(sid string, forced bool) error { + sm.smux.Lock() + if s, ok := sm.sockets[sid]; ok { + if forced { + (*s).Close(2010, "Normal closure") + } + } else { + sm.smux.Unlock() + return fmt.Errorf("sockets: failed to remove socket of ID %v: does not exist", sid) + } + + delete((*sm).sockets, sid) + sm.smux.Unlock() + + sm.cmux.Lock() + for cid, c := range sm.channels { + if _, ok := c[sid]; ok { + delete(c, sid) + if (len(c) == 0) { + delete((*sm).channels, cid) + } + } + } + sm.cmux.Unlock() + + sm.scmux.Lock() + for cid, sc := range sm.subchannels { + if _, ok := sc[sid]; ok { + delete(sc, sid) + if (len(sc) == 0) { + delete((*sm).subchannels, cid) + } + } + } + sm.scmux.Unlock() + + if !forced { + // The parent process doesn't know that the socket was closed. Poke it + // so it can clean up any relevant connections. + NewPayload("!", sid).WriteToStdout() + } + + return nil +} + +func (sm *socketMultiplexer) SocketSend(sid string, msg string) (err error) { + sm.smux.Lock() + defer sm.smux.Unlock() + + s, ok := sm.sockets[sid] + if ok { + err = (*s).Send(msg) + } else { + err = fmt.Errorf("sockets: error sending to socket of ID %v: does not exist", sid) + } + + return +} + +func (sm *socketMultiplexer) ChannelAdd(cid string, sid string) error { + sm.cmux.Lock() + defer sm.cmux.Unlock() + + if strings.HasPrefix(cid, "battle-") { + sm.scmux.Lock() + sc, ok := sm.subchannels[cid] + if !ok { + sc = make(map[string]string) + } + + if _, ok := sc[sid]; !ok { + // This subchannel ID should never actually be 0, but let's humour the + // parent process if it tries to send something to this channel's + // subchannels before moving the users in each of the battle's sides to + // the appropriate subchannel ID. + sc[sid] = "0" + } + + sm.subchannels[cid] = sc + sm.scmux.Unlock() + } + + c, ok := sm.channels[cid] + if !ok { + c = make(map[string]bool) + sm.channels[cid] = c + } + + c[sid] = true + return nil +} + +func (sm *socketMultiplexer) ChannelRemove(cid string, sid string) error { + sm.cmux.Lock() + defer sm.cmux.Unlock() + + if strings.HasPrefix(cid, "battle-") { + sm.scmux.Lock() + if _, ok := sm.subchannels[cid]; !ok { + sm.scmux.Unlock() + return fmt.Errorf("sockets: failed to remove socket of ID %v from channel of ID %v: subchannel map does not exist", sid, cid) + } + + delete((*sm).subchannels, cid) + sm.scmux.Unlock() + } + + c, ok := sm.channels[cid] + if ok { + if _, ok := c[sid]; !ok { + return fmt.Errorf("sockets: failed to remove socket of ID %v from channel of ID %v: socket does not exist", sid, cid) + } + } else { + return fmt.Errorf("sockets: failed to remove socket of ID %v from channel of ID %v: channel does not exist", sid, cid) + } + + delete(c, sid) + if len(c) == 0 { + delete((*sm).channels, cid) + } + + return nil +} + +func (sm *socketMultiplexer) ChannelSend(cid string, msg string) error { + sm.cmux.Lock() + c, ok := sm.channels[cid] + if !ok { + sm.cmux.Unlock() + return fmt.Errorf("sockets: failed to send to channel of ID %v: does not exist", cid) + } + sm.cmux.Unlock() + + sm.smux.Lock() + defer sm.smux.Unlock() + + for sid, _ := range c { + s, ok := sm.sockets[sid] + if !ok { + delete(c, sid) + continue + } + + if err := (*s).Send(msg); err != nil { + return fmt.Errorf("sockets: failed to send to channel of ID %v: %v", cid, err) + } + } + + return nil +} + +func (sm *socketMultiplexer) SubchannelMove(cid string, scid string, sid string) { + sm.scmux.Lock() + defer sm.scmux.Unlock() + + sc, ok := sm.subchannels[cid] + if !ok { + sc = make(map[string]string) + sm.subchannels[cid] = sc + } + + sc[sid] = scid +} + +func (sm *socketMultiplexer) SubchannelSend(cid string, msg string) error { + sm.cmux.Lock() + if _, ok := sm.channels[cid]; !ok { + sm.cmux.Unlock() + return fmt.Errorf("sockets: failed to broadcast to subchannels in channel of ID %v: channel does not exist", cid) + } + sm.cmux.Unlock() + + sm.scmux.Lock() + defer sm.scmux.Unlock() + + sc, ok := sm.subchannels[cid] + if !ok { + return fmt.Errorf("sockets: failed to broadcast to subchannels in channel of ID %v: subchannel map does not exist", cid) + } + + var scmsgs [3][]string + msgs := sm.scre.FindStringSubmatch(msg) + for i := 0; i < len(msgs); i += 1 { + switch i % 3 { + case 0: + scmsgs[0] = append(scmsgs[0], "\n" + msgs[i]) + case 1: + scmsgs[1] = append(scmsgs[1], "\n" + msgs[i]) + case 2: + scmsgs[2] = append(scmsgs[2], "\n" + msgs[i]) + } + } + + for sid, scid := range sc { + s, ok := sm.sockets[sid] + if !ok { + return fmt.Errorf("sockets: failed to broadcast to subchannels in channel of ID %v: socket of ID %v in subchannel of ID %v does not exist", cid, sid, scid) + } + + switch scid { + case "0": + for _, scmsg := range scmsgs[0] { + (*s).Send(scmsg) + } + case "1": + for _, scmsg := range scmsgs[1] { + (*s).Send(scmsg) + } + case "2": + for _, scmsg := range scmsgs[2] { + (*s).Send(scmsg) + } + default: + return fmt.Errorf("sockets: failed to broadcast to subchannels in channel of ID %v: socket of ID %v has unknown subchannel ID: %v", cid, sid, scid) + } + } + + return nil +} + +var sm *socketMultiplexer + +func SockJSHandler(s sockjs.Session) { + if err := sm.SocketAdd(&s); err != nil { + panic(err) + } + + id := s.ID() + for { + if msg, err := s.Recv(); err == nil { + NewPayload("<", id, msg).WriteToStdout() + continue + } + + break + } + + sm.SocketRemove(id, false) +} + +func main() { + // Create config struct from PS_CONFIG env variable as defined by the + // parent process from relevant settings in config.js. + config, err := NewConfig("PS_CONFIG") + if err != nil { + log.Fatal("sockets: failed to read config from $PS_CONFIG") + } + + // Spawn goroutine workers. + dispatcher := NewDispatcher(config.Workers) + dispatcher.Run() + + // Initialize socket multiplexer. + sm = newSocketMultiplexer() + + // Spawn the SockJS and static servers, begin serving over HTTP (and HTTPS + // if so configured). + go func() { + opts := sockjs.Options{ + SockJSURL: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", + Websocket: true, + HeartbeatDelay: sockjs.DefaultOptions.HeartbeatDelay, + DisconnectDelay: sockjs.DefaultOptions.DisconnectDelay, + JSessionID: sockjs.DefaultOptions.JSessionID} + + r := mux.NewRouter() + staticDir, _ := filepath.Abs("./static") + avatarDir, _ := filepath.Abs("./config/avatars") + r.Handle("/", http.FileServer(http.Dir(staticDir))) + r.PathPrefix("/avatars/"). + Handler(http.FileServer(http.Dir(avatarDir))) + r.PathPrefix("/showdown"). + Handler(sockjs.NewHandler("/showdown", opts, SockJSHandler)) + fmt.Printf("sockets: now serving on http://%v%v/", config.BindAddress, config.Port) + log.Fatal(http.ListenAndServe(config.Port, r)) + }() + + // Finally, listen for any messages passed over stdin by the parent + // process until either process is killed. + reader := bufio.NewReader(os.Stdin) + for { + input, err := reader.ReadString(EOT) + if err != nil { + fmt.Printf("sockets: error reading IPC input: %v", err) + if err == io.EOF { + return + } + + continue + } + + var p Payload + input = input[:len(input) - 1] + json.Unmarshal([]byte(input), &p) + job := Job{Payload: p} + JobQueue <- job + } +} diff --git a/sockets.js b/sockets.js index 729ce6fd55365..7b3000297aca5 100644 --- a/sockets.js +++ b/sockets.js @@ -13,488 +13,598 @@ 'use strict'; -const cluster = require('cluster'); -global.Config = require('./config/config'); - -if (cluster.isMaster) { - cluster.setupMaster({ - exec: require('path').resolve(__dirname, 'sockets'), - }); - - let workers = exports.workers = {}; - - let spawnWorker = exports.spawnWorker = function () { - let worker = cluster.fork({PSPORT: Config.port, PSBINDADDR: Config.bindaddress || '', PSNOSSL: Config.ssl ? 0 : 1}); - let id = worker.id; - workers[id] = worker; - worker.on('message', data => { - // console.log('master received: ' + data); - switch (data.charAt(0)) { - case '*': { - // *socketid, ip, protocol - // connect - let nlPos = data.indexOf('\n'); - let nlPos2 = data.indexOf('\n', nlPos + 1); - Users.socketConnect(worker, id, data.slice(1, nlPos), data.slice(nlPos + 1, nlPos2), data.slice(nlPos2 + 1)); - break; +const {spawn} = require('child_process'); +const EventEmitter = require('events'); + +const isTrustedProxyIp = Dnsbl.checker(Config.proxyip); + +// IPC delimiter +const EOT = '\u0003'; + +// Let's pretend that we still uses a cluster for dealing with the SockJS +// server in this process. This creates a mock of a cluster worker, which acts +// as a wrapper for handling I/O with the Go child process, which does the real +// work with the SockJS and static servers. +const workers = exports.workers = new Map(); + +class Worker extends EventEmitter { + constructor({id, port, bindAddress, ssl, workerCount}) { + super(); + + this.id = id; + this.queue = []; + + // Note: this crashes on node v7.1.0+ using Bash on Windows due to a + // bug in how it handles Unix domain sockets. + // See https://github.com/Microsoft/BashOnWindows/issues/1354 + this.process = spawn( + 'go', + ['run', 'sockets.go'], { + env: { + GOPATH: process.env.GOPATH || '', + GOROOT: process.env.GOROOT || '', + PS_CONFIG: JSON.stringify({ + Workers: workerCount || 1, + Port: `:${port || 8000}`, + BindAddress: bindAddress || '0.0.0.0', + SSL: ssl || null, + }), + }, + stdio: 'pipe', + shell: true, } + ); + this.process.once('exit', (code, signal) => { + // Respawn if the child process killed itself. + if (!signal) process.nextTick(() => Worker.spawn()); + }); - case '!': { - // !socketid - // disconnect - Users.socketDisconnect(worker, id, data.substr(1)); - break; + this.process.stdin.on('error', () => {}); + this.process.stdin.on('drain', () => { + while (this.queue.length) { + let args = this.queue[0]; + let res = this.send.apply(this, args); + if (res) { + this.queue.shift(); + } else { + // Wait for next drain event before continuing. + break; + } } + }); - case '<': { - // { + let payloads = data.split(EOT); + for (let payload of payloads) { + if (!payload) continue; + + let idx = payload.indexOf('{'); + if (idx < 0) { + // Data received wasn't a proper payload, so it was + // probably intended to be logged to console instead. + if (isNaN(payload)) console.log(payload); + continue; + } - default: - // unhandled + if (idx > 0) { + // For whatever reason, messages written to stdout from the + // child process are prefixed by the number of messages + // written to it since it was spawned. + payload = payload.substr(idx); + } + + payload = JSON.parse(payload); + let command = payload.Command; + let params = payload.Params; + switch (command) { + case '*': + this.onSocketConnect(...params); + break; + case '!': + this.onSocketDisconnect(...params); + break; + case '<': + this.onSocketReceive(...params); + break; + default: + console.error(`sockets: parent process received job with unknown command type ${command} from child process: ${params}`); + break; + } } }); - }; - cluster.on('disconnect', worker => { - // worker crashed, try our best to clean up - require('./crashlogger')(new Error("Worker " + worker.id + " abruptly died"), "The main process"); + this.process.stderr.setEncoding('utf8'); + this.process.stderr.once('data', data => { + require('./crashlogger')(new Error(data), `Worker ${this.id}`); - // this could get called during cleanup; prevent it from crashing - worker.send = () => {}; + let {id} = this; + let count = 0; + Users.connections.forEach(connection => { + if (connection.worker === this) { + Users.socketDisconnect(this, id, connection.socketid); + count++; + } + }); + console.error(`${count} connections were lost.`); - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; - } + // Leave the worker in the workers map so it can be investigated + // later, and try respawning it once the process finishes exiting. }); - console.error("" + count + " connections were lost."); + } - // don't delete the worker, so we can investigate it if necessary. + onSocketConnect(socketid, remoteAddress, header, protocol) { +// console.log(`sockets: socket connect (${socketid}, ${remoteAddress}, ${header}, ${protocol})`); - // attempt to recover - spawnWorker(); - }); + let ip = remoteAddress; + if (header && isTrustedProxyIp(remoteAddress)) { + let ips = header.split(','); + for (let i = ips.length; i--;) { + ip = ips[i].trim(); + if (!isTrustedProxyIp(ip)) break; + } + } + + Users.socketConnect(this, this.id, socketid, ip, protocol); + } + + onSocketDisconnect(socketid) { +// console.log(`sockets: socket disconnect (${socketid})`); + Users.socketDisconnect(this, this.id, socketid); + } + + onSocketReceive(socketid, message) { +// console.log(`sockets: socket receive (${socketid}, ${message})`); + Users.socketReceive(this, this.id, socketid, message); + } - exports.listen = function (port, bindAddress, workerCount) { + isDead() { + return this.process.exitCode !== null || this.process.signalCode !== null; + } + + kill(signal = 'SIGTERM') { + this.process.kill(signal); + } + + send(command, ...params) { + if (this.isDead()) return false; + + let payload = `${JSON.stringify({Command: command, Params: params})}${EOT}`; + let res = this.process.stdin.write(payload); + if (!res) this.queue.push([command, params]); + return res; + } + + static spawn(port = Config.port, bindAddress = Config.bindaddress, workerCount = Config.workers) { + // Don't spawn another child process if one is already alive -- it will + // always crash when it launches its servers otherwise! + for (let [id, worker] of workers) { // eslint-disable-line no-unused-vars + if (!worker.isDead()) return false; + } + + let {ssl} = Config; + let id = workers.size; if (port !== undefined && !isNaN(port)) { - Config.port = port; - Config.ssl = null; - } else { port = Config.port; - // Autoconfigure the app when running in cloud hosting environments: + } else { try { let cloudenv = require('cloud-env'); - bindAddress = cloudenv.get('IP', bindAddress); port = cloudenv.get('PORT', port); + bindAddress = cloudenv.get('IP', bindAddress); } catch (e) {} } - if (bindAddress !== undefined) { - Config.bindaddress = bindAddress; - } - if (workerCount === undefined) { - workerCount = (Config.workers !== undefined ? Config.workers : 1); - } - for (let i = 0; i < workerCount; i++) { - spawnWorker(); - } - }; - - exports.killWorker = function (worker) { - let idd = worker.id + '-'; - let count = 0; - Users.connections.forEach((connection, connectionid) => { - if (connectionid.substr(idd.length) === idd) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; - } - }); - try { - worker.kill(); - } catch (e) {} - delete workers[worker.id]; - return count; - }; - exports.killPid = function (pid) { - pid = '' + pid; - for (let id in workers) { - let worker = workers[id]; - if (pid === '' + worker.process.pid) { - return this.killWorker(worker); + if (ssl && typeof ssl === 'object' && !Array.isArray(ssl)) { + try { + ssl = JSON.stringify(ssl); + } catch (e) { + ssl = null; } + } else { + ssl = null; } - return false; - }; - exports.socketSend = function (worker, socketid, message) { - worker.send('>' + socketid + '\n' + message); - }; - exports.socketDisconnect = function (worker, socketid) { - worker.send('!' + socketid); - }; + if (workerCount === undefined) workerCount = 1; - exports.channelBroadcast = function (channelid, message) { - for (let workerid in workers) { - workers[workerid].send('#' + channelid + '\n' + message); - } - }; - exports.channelSend = function (worker, channelid, message) { - worker.send('#' + channelid + '\n' + message); - }; - exports.channelAdd = function (worker, channelid, socketid) { - worker.send('+' + channelid + '\n' + socketid); - }; - exports.channelRemove = function (worker, channelid, socketid) { - worker.send('-' + channelid + '\n' + socketid); - }; + let worker = new Worker({id, port, bindAddress, ssl, workerCount}); + return worker; + } +} - exports.subchannelBroadcast = function (channelid, message) { - for (let workerid in workers) { - workers[workerid].send(':' + channelid + '\n' + message); +exports.Worker = Worker; + +exports.listen = (...args) => { + let worker = Worker.spawn(...args); + if (worker) workers.set(worker.id, worker); +}; +exports.spawnWorker = () => { + let worker = Worker.spawn(); + if (worker) workers.set(worker.id, worker); +}; +exports.killWorker = worker => { + let {id} = worker; + let count = 0; + Users.connections.forEach(connection => { + if (connection.worker === worker) { + Users.socketDisconnect(worker, id, connection.socketid); + count++; } - }; - exports.subchannelMove = function (worker, channelid, subchannelid, socketid) { - worker.send('.' + channelid + '\n' + subchannelid + '\n' + socketid); - }; -} else { - // is worker - - if (process.env.PSPORT) Config.port = +process.env.PSPORT; - if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; - if (+process.env.PSNOSSL) Config.ssl = null; - - // ofe is optional - // if installed, it will heap dump if the process runs out of memory - try { - require('ofe').call(); - } catch (e) {} - - // Static HTTP server - - // This handles the custom CSS and custom avatar features, and also - // redirects yourserver:8001 to yourserver-8001.psim.us - - // It's optional if you don't need these features. - - global.Dnsbl = require('./dnsbl'); + }); - if (Config.crashguard) { - // graceful crash - process.on('uncaughtException', err => { - require('./crashlogger')(err, 'Socket process ' + cluster.worker.id + ' (' + process.pid + ')', true); - }); - } + worker.kill(); + workers.delete(id); + return count; +}; + +exports.socketSend = (worker, socketid, message) => { +// console.log(`sockets: sending to socket of ID ${socketid}: ${message}`); + worker.send('>', socketid, message); +}; +exports.socketDisconnect = (worker, socketid) => { +// console.log(`sockets: disconnecting socket of ID ${socketid}`); + worker.send('!', socketid); +}; + +exports.channelSend = (worker, channelid, message) => { +// console.log(`sockets: sending to channel of ID ${channelid}: ${message}`); + worker.send('#', channelid, message); +}; +exports.channelBroadcast = (channelid, message) => { +// console.log(`sockets: broadcasting to channel of ID ${channelid}: ${message}`); + let worker = workers.get(workers.size - 1); + if (worker) worker.send('#', channelid, message); +}; +exports.channelAdd = (worker, channelid, socketid) => { +// console.log(`sockets: adding socket of ID ${socketid} to channel of ID ${channelid}`); + worker.send('+', channelid, socketid); +}; +exports.channelRemove = (worker, channelid, socketid) => { +// console.log(`sockets: removing socket of ID ${socketid} from channel of ID ${channelid}`); + worker.send('-', channelid, socketid); +}; + +exports.subchannelBroadcast = (channelid, message) => { +// console.log(`sockets: broadcasting to subchannels in channel of ID ${channelid}: ${message}`); + let worker = workers.get(workers.size - 1); + if (worker) worker.send(':', channelid, message); +}; +exports.subchannelMove = (worker, channelid, subchannelid, socketid) => { +// console.log(`sockets: moving socketid of ${socketid} to subchannel ${subchannelid} in channel of ID ${channelid}`); + worker.send('.', channelid, subchannelid, socketid); +}; + +/* +// is worker + +if (process.env.PSPORT) Config.port = +process.env.PSPORT; +if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; +if (+process.env.PSNOSSL) Config.ssl = null; + +// ofe is optional +// if installed, it will heap dump if the process runs out of memory +try { + require('ofe').call(); +} catch (e) {} + +// Static HTTP server + +// This handles the custom CSS and custom avatar features, and also +// redirects yourserver:8001 to yourserver-8001.psim.us + +// It's optional if you don't need these features. + +global.Dnsbl = require('./dnsbl'); + +if (Config.crashguard) { + // graceful crash + process.on('uncaughtException', err => { + require('./crashlogger')(err, 'Socket process ' + cluster.worker.id + ' (' + process.pid + ')', true); + }); +} - let app = require('http').createServer(); - let appssl; - if (Config.ssl) { - appssl = require('https').createServer(Config.ssl.options); - } - try { - let nodestatic = require('node-static'); - let cssserver = new nodestatic.Server('./config'); - let avatarserver = new nodestatic.Server('./config/avatars'); - let staticserver = new nodestatic.Server('./static'); - let staticRequestHandler = (request, response) => { - // console.log("static rq: " + request.socket.remoteAddress + ":" + request.socket.remotePort + " -> " + request.socket.localAddress + ":" + request.socket.localPort + " - " + request.method + " " + request.url + " " + request.httpVersion + " - " + request.rawHeaders.join('|')); - request.resume(); - request.addListener('end', () => { - if (Config.customhttpresponse && - Config.customhttpresponse(request, response)) { - return; +let app = require('http').createServer(); +let appssl; +if (Config.ssl) { + appssl = require('https').createServer(Config.ssl.options); +} +try { + let nodestatic = require('node-static'); + let cssserver = new nodestatic.Server('./config'); + let avatarserver = new nodestatic.Server('./config/avatars'); + let staticserver = new nodestatic.Server('./static'); + let staticRequestHandler = (request, response) => { + // console.log("static rq: " + request.socket.remoteAddress + ":" + request.socket.remotePort + " -> " + request.socket.localAddress + ":" + request.socket.localPort + " - " + request.method + " " + request.url + " " + request.httpVersion + " - " + request.rawHeaders.join('|')); + request.resume(); + request.addListener('end', () => { + if (Config.customhttpresponse && + Config.customhttpresponse(request, response)) { + return; + } + let server; + if (request.url === '/custom.css') { + server = cssserver; + } else if (request.url.substr(0, 9) === '/avatars/') { + request.url = request.url.substr(8); + server = avatarserver; + } else { + if (/^\/([A-Za-z0-9][A-Za-z0-9-]*)\/?$/.test(request.url)) { + request.url = '/'; } - let server; - if (request.url === '/custom.css') { - server = cssserver; - } else if (request.url.substr(0, 9) === '/avatars/') { - request.url = request.url.substr(8); - server = avatarserver; - } else { - if (/^\/([A-Za-z0-9][A-Za-z0-9-]*)\/?$/.test(request.url)) { - request.url = '/'; - } - server = staticserver; + server = staticserver; + } + server.serve(request, response, (e, res) => { + if (e && (e.status === 404)) { + staticserver.serveFile('404.html', 404, {}, request, response); } - server.serve(request, response, (e, res) => { - if (e && (e.status === 404)) { - staticserver.serveFile('404.html', 404, {}, request, response); - } - }); }); - }; - app.on('request', staticRequestHandler); - if (appssl) { - appssl.on('request', staticRequestHandler); - } - } catch (e) { - console.log('Could not start node-static - try `npm install` if you want to use it'); + }); + }; + app.on('request', staticRequestHandler); + if (appssl) { + appssl.on('request', staticRequestHandler); } +} catch (e) { + console.log('Could not start node-static - try `npm install` if you want to use it'); +} - // SockJS server +// SockJS server - // This is the main server that handles users connecting to our server - // and doing things on our server. +// This is the main server that handles users connecting to our server +// and doing things on our server. - let sockjs = require('sockjs'); +let sockjs = require('sockjs'); - let server = sockjs.createServer({ - sockjs_url: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", - log: (severity, message) => { - if (severity === 'error') console.log('ERROR: ' + message); - }, - prefix: '/showdown', - }); +let server = sockjs.createServer({ + sockjs_url: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", + log: (severity, message) => { + if (severity === 'error') console.log('ERROR: ' + message); + }, + prefix: '/showdown', +}); - let sockets = {}; - let channels = {}; - let subchannels = {}; - - // Deal with phantom connections. - let sweepClosedSockets = function () { - for (let s in sockets) { - if (sockets[s].protocol === 'xhr-streaming' && - sockets[s]._session && - sockets[s]._session.recv) { - sockets[s]._session.recv.didClose(); - } +let sockets = {}; +let channels = {}; +let subchannels = {}; - // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while - // it is an object for normal users. Under normal circumstances, those properties should only be - // `null` when the timeout has already been called, but somehow it's not happening for some connections. - // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually - // on those connections kills those connections. For a bit of background, this timeout is the timeout - // that sockjs sets to wait for users to reconnect within that time to continue their session. - if (sockets[s]._session && - sockets[s]._session.to_tref && - !sockets[s]._session.to_tref._idlePrev) { - sockets[s]._session.timeout_cb(); - } +// Deal with phantom connections. +let sweepClosedSockets = function () { + for (let s in sockets) { + if (sockets[s].protocol === 'xhr-streaming' && + sockets[s]._session && + sockets[s]._session.recv) { + sockets[s]._session.recv.didClose(); } - }; - let interval = setInterval(sweepClosedSockets, 1000 * 60 * 10); // eslint-disable-line no-unused-vars - process.on('message', data => { - // console.log('worker received: ' + data); - let socket = null, socketid = ''; - let channel = null, channelid = ''; - let subchannel = null, subchannelid = ''; + // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while + // it is an object for normal users. Under normal circumstances, those properties should only be + // `null` when the timeout has already been called, but somehow it's not happening for some connections. + // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually + // on those connections kills those connections. For a bit of background, this timeout is the timeout + // that sockjs sets to wait for users to reconnect within that time to continue their session. + if (sockets[s]._session && + sockets[s]._session.to_tref && + !sockets[s]._session.to_tref._idlePrev) { + sockets[s]._session.timeout_cb(); + } + } +}; +let interval = setInterval(sweepClosedSockets, 1000 * 60 * 10); // eslint-disable-line no-unused-vars + +process.on('message', data => { + // console.log('worker received: ' + data); + let socket = null, socketid = ''; + let channel = null, channelid = ''; + let subchannel = null, subchannelid = ''; + + switch (data.charAt(0)) { + case '$': // $code + eval(data.substr(1)); + break; + + case '!': // !socketid + // destroy + socketid = data.substr(1); + socket = sockets[socketid]; + if (!socket) return; + socket.end(); + // After sending the FIN packet, we make sure the I/O is totally blocked for this socket + socket.destroy(); + delete sockets[socketid]; + for (channelid in channels) { + delete channels[channelid][socketid]; + } + break; + + case '>': { + // >socketid, message + // message + let nlLoc = data.indexOf('\n'); + socket = sockets[data.substr(1, nlLoc - 1)]; + if (!socket) return; + socket.write(data.substr(nlLoc + 1)); + break; + } - switch (data.charAt(0)) { - case '$': // $code - eval(data.substr(1)); - break; + case '#': { + // #channelid, message + // message to channel + let nlLoc = data.indexOf('\n'); + channel = channels[data.substr(1, nlLoc - 1)]; + let message = data.substr(nlLoc + 1); + for (socketid in channel) { + channel[socketid].write(message); + } + break; + } - case '!': // !socketid - // destroy - socketid = data.substr(1); - socket = sockets[socketid]; - if (!socket) return; - socket.end(); - // After sending the FIN packet, we make sure the I/O is totally blocked for this socket - socket.destroy(); - delete sockets[socketid]; - for (channelid in channels) { - delete channels[channelid][socketid]; - } - break; + case '+': { + // +channelid, socketid + // add to channel + let nlLoc = data.indexOf('\n'); + socketid = data.substr(nlLoc + 1); + socket = sockets[socketid]; + if (!socket) return; + channelid = data.substr(1, nlLoc - 1); + channel = channels[channelid]; + if (!channel) channel = channels[channelid] = Object.create(null); + channel[socketid] = socket; + break; + } - case '>': { - // >socketid, message - // message - let nlLoc = data.indexOf('\n'); - socket = sockets[data.substr(1, nlLoc - 1)]; - if (!socket) return; - socket.write(data.substr(nlLoc + 1)); + case '-': { + // -channelid, socketid + // remove from channel + let nlLoc = data.indexOf('\n'); + channelid = data.slice(1, nlLoc); + channel = channels[channelid]; + if (!channel) return; + socketid = data.slice(nlLoc + 1); + delete channel[socketid]; + if (subchannels[channelid]) delete subchannels[channelid][socketid]; + let isEmpty = true; + for (let socketid in channel) { // eslint-disable-line no-unused-vars + isEmpty = false; break; } - - case '#': { - // #channelid, message - // message to channel - let nlLoc = data.indexOf('\n'); - channel = channels[data.substr(1, nlLoc - 1)]; - let message = data.substr(nlLoc + 1); - for (socketid in channel) { - channel[socketid].write(message); - } - break; + if (isEmpty) { + delete channels[channelid]; + delete subchannels[channelid]; } + break; + } - case '+': { - // +channelid, socketid - // add to channel - let nlLoc = data.indexOf('\n'); - socketid = data.substr(nlLoc + 1); - socket = sockets[socketid]; - if (!socket) return; - channelid = data.substr(1, nlLoc - 1); - channel = channels[channelid]; - if (!channel) channel = channels[channelid] = Object.create(null); - channel[socketid] = socket; - break; + case '.': { + // .channelid, subchannelid, socketid + // move subchannel + let nlLoc = data.indexOf('\n'); + channelid = data.slice(1, nlLoc); + let nlLoc2 = data.indexOf('\n', nlLoc + 1); + subchannelid = data.slice(nlLoc + 1, nlLoc2); + socketid = data.slice(nlLoc2 + 1); + + subchannel = subchannels[channelid]; + if (!subchannel) subchannel = subchannels[channelid] = Object.create(null); + if (subchannelid === '0') { + delete subchannel[socketid]; + } else { + subchannel[socketid] = subchannelid; } + break; + } - case '-': { - // -channelid, socketid - // remove from channel - let nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels[channelid]; - if (!channel) return; - socketid = data.slice(nlLoc + 1); - delete channel[socketid]; - if (subchannels[channelid]) delete subchannels[channelid][socketid]; - let isEmpty = true; - for (let socketid in channel) { // eslint-disable-line no-unused-vars - isEmpty = false; + case ':': { + // :channelid, message + // message to subchannel + let nlLoc = data.indexOf('\n'); + channelid = data.slice(1, nlLoc); + channel = channels[channelid]; + subchannel = subchannels[channelid]; + let message = data.substr(nlLoc + 1); + let messages = [null, null, null]; + for (socketid in channel) { + switch (subchannel ? subchannel[socketid] : '0') { + case '1': + if (!messages[1]) {*/ +// messages[1] = message.replace(/\n\|split\n[^\n]*\n([^\n]*)\n[^\n]*\n[^\n]*/g, '\n$1'); +/* } + channel[socketid].write(messages[1]); + break; + case '2': + if (!messages[2]) {*/ +// messages[2] = message.replace(/\n\|split\n[^\n]*\n[^\n]*\n([^\n]*)\n[^\n]*/g, '\n$1'); +/* } + channel[socketid].write(messages[2]); + break; + default: + if (!messages[0]) {*/ +// messages[0] = message.replace(/\n\|split\n([^\n]*)\n[^\n]*\n[^\n]*\n[^\n]*/g, '\n$1'); +/* } + channel[socketid].write(messages[0]); break; } - if (isEmpty) { - delete channels[channelid]; - delete subchannels[channelid]; - } - break; } + break; + } - case '.': { - // .channelid, subchannelid, socketid - // move subchannel - let nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - let nlLoc2 = data.indexOf('\n', nlLoc + 1); - subchannelid = data.slice(nlLoc + 1, nlLoc2); - socketid = data.slice(nlLoc2 + 1); - - subchannel = subchannels[channelid]; - if (!subchannel) subchannel = subchannels[channelid] = Object.create(null); - if (subchannelid === '0') { - delete subchannel[socketid]; - } else { - subchannel[socketid] = subchannelid; - } - break; - } + default: + } +}); + +process.on('disconnect', () => { + process.exit(); +}); + +// this is global so it can be hotpatched if necessary +let isTrustedProxyIp = Dnsbl.checker(Config.proxyip); +let socketCounter = 0; +server.on('connection', socket => { + if (!socket) { + // For reasons that are not entirely clear, SockJS sometimes triggers + // this event with a null `socket` argument. + return; + } else if (!socket.remoteAddress) { + // This condition occurs several times per day. It may be a SockJS bug. + try { + socket.end(); + } catch (e) {} + return; + } + let socketid = socket.id = (++socketCounter); - case ':': { - // :channelid, message - // message to subchannel - let nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels[channelid]; - subchannel = subchannels[channelid]; - let message = data.substr(nlLoc + 1); - let messages = [null, null, null]; - for (socketid in channel) { - switch (subchannel ? subchannel[socketid] : '0') { - case '1': - if (!messages[1]) { - messages[1] = message.replace(/\n\|split\n[^\n]*\n([^\n]*)\n[^\n]*\n[^\n]*/g, '\n$1'); - } - channel[socketid].write(messages[1]); - break; - case '2': - if (!messages[2]) { - messages[2] = message.replace(/\n\|split\n[^\n]*\n[^\n]*\n([^\n]*)\n[^\n]*/g, '\n$1'); - } - channel[socketid].write(messages[2]); - break; - default: - if (!messages[0]) { - messages[0] = message.replace(/\n\|split\n([^\n]*)\n[^\n]*\n[^\n]*\n[^\n]*/g, '\n$1'); - } - channel[socketid].write(messages[0]); - break; - } - } - break; - } + sockets[socket.id] = socket; - default: + if (isTrustedProxyIp(socket.remoteAddress)) { + let ips = (socket.headers['x-forwarded-for'] || '').split(','); + let ip; + while ((ip = ips.pop())) { + ip = ip.trim(); + if (!isTrustedProxyIp(ip)) { + socket.remoteAddress = ip; + break; + } } - }); + } - process.on('disconnect', () => { - process.exit(); - }); + process.send('*' + socketid + '\n' + socket.remoteAddress + '\n' + socket.protocol); - // this is global so it can be hotpatched if necessary - let isTrustedProxyIp = Dnsbl.checker(Config.proxyip); - let socketCounter = 0; - server.on('connection', socket => { - if (!socket) { - // For reasons that are not entirely clear, SockJS sometimes triggers - // this event with a null `socket` argument. - return; - } else if (!socket.remoteAddress) { - // This condition occurs several times per day. It may be a SockJS bug. - try { - socket.end(); - } catch (e) {} + socket.on('data', message => { + // drop empty messages (DDoS?) + if (!message) return; + // drop messages over 100KB + if (message.length > 100000) { + console.log("Dropping client message " + (message.length / 1024) + " KB..."); + console.log(message.slice(0, 160)); return; } - let socketid = socket.id = (++socketCounter); + // drop legacy JSON messages + if (typeof message !== 'string' || message.charAt(0) === '{') return; + // drop blank messages (DDoS?) + let pipeIndex = message.indexOf('|'); + if (pipeIndex < 0 || pipeIndex === message.length - 1) return; - sockets[socket.id] = socket; + process.send('<' + socketid + '\n' + message); + }); - if (isTrustedProxyIp(socket.remoteAddress)) { - let ips = (socket.headers['x-forwarded-for'] || '').split(','); - let ip; - while ((ip = ips.pop())) { - ip = ip.trim(); - if (!isTrustedProxyIp(ip)) { - socket.remoteAddress = ip; - break; - } - } + socket.on('close', () => { + process.send('!' + socketid); + delete sockets[socketid]; + for (let channelid in channels) { + delete channels[channelid][socketid]; } - - process.send('*' + socketid + '\n' + socket.remoteAddress + '\n' + socket.protocol); - - socket.on('data', message => { - // drop empty messages (DDoS?) - if (!message) return; - // drop messages over 100KB - if (message.length > 100000) { - console.log("Dropping client message " + (message.length / 1024) + " KB..."); - console.log(message.slice(0, 160)); - return; - } - // drop legacy JSON messages - if (typeof message !== 'string' || message.charAt(0) === '{') return; - // drop blank messages (DDoS?) - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0 || pipeIndex === message.length - 1) return; - - process.send('<' + socketid + '\n' + message); - }); - - socket.on('close', () => { - process.send('!' + socketid); - delete sockets[socketid]; - for (let channelid in channels) { - delete channels[channelid][socketid]; - } - }); }); - server.installHandlers(app, {}); - if (!Config.bindaddress) Config.bindaddress = '0.0.0.0'; - app.listen(Config.port, Config.bindaddress); - console.log('Worker ' + cluster.worker.id + ' now listening on ' + Config.bindaddress + ':' + Config.port); - - if (appssl) { - server.installHandlers(appssl, {}); - appssl.listen(Config.ssl.port, Config.bindaddress); - console.log('Worker ' + cluster.worker.id + ' now listening for SSL on port ' + Config.ssl.port); - } +}); +server.installHandlers(app, {}); +if (!Config.bindaddress) Config.bindaddress = '0.0.0.0'; +app.listen(Config.port, Config.bindaddress); +console.log('Worker ' + cluster.worker.id + ' now listening on ' + Config.bindaddress + ':' + Config.port); + +if (appssl) { + server.installHandlers(appssl, {}); + appssl.listen(Config.ssl.port, Config.bindaddress); + console.log('Worker ' + cluster.worker.id + ' now listening for SSL on port ' + Config.ssl.port); +} - console.log('Test your server at http://' + (Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress) + ':' + Config.port); +console.log('Test your server at http://' + (Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress) + ':' + Config.port); - require('./repl').start('sockets-', cluster.worker.id + '-' + process.pid, cmd => eval(cmd)); -} +require('./repl').start('sockets-', cluster.worker.id + '-' + process.pid, cmd => eval(cmd)); +*/ diff --git a/sockets_test.go b/sockets_test.go new file mode 100644 index 0000000000000..e709876cf3bec --- /dev/null +++ b/sockets_test.go @@ -0,0 +1,286 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/igm/sockjs-go/sockjs" +) + +// Mock of sockjs.Session for socketsMultiplexer tests. +type MockSession struct { + id string + req *http.Request + state sockjs.SessionState +} + +func NewMockSession(id string) *MockSession { + ms := MockSession{ + id: id, + req: httptest.NewRequest("GET", "http://localhost:8000/showdown", nil), + state: sockjs.SessionActive} + return &ms +} + +func (ms *MockSession) ID() string { + return ms.id +} + +func (ms *MockSession) Request() *http.Request { + return ms.req +} + +func (ms *MockSession) Recv() (string, error) { + return "", nil +} + +func (ms *MockSession) Send(msg string) error { + return nil +} + +func (ms *MockSession) Close(status uint32, reason string) error { + ms.state = sockjs.SessionClosed + return nil +} + +func (ms *MockSession) GetSessionState() sockjs.SessionState { + return ms.state +} + +func (ms *MockSession) ServeHTTP(rw http.ResponseWriter, req *http.Request) {} + +func pipeErr(ech chan error, err error) { + go func() { + ech <- err + }() +} + +func scrubSM() { + for sid, _ := range sm.sockets { + delete((*sm).sockets, sid) + } + for cid, _ := range sm.channels { + delete((*sm).channels, cid) + } + for cid, _ := range sm.subchannels { + delete((*sm).subchannels, cid) + } +} + +func Test_newSocketMultiplexer(t *testing.T) { + sm = newSocketMultiplexer() + if sm.sockets == nil { + t.Errorf("SM sockets map does not exist") + } + if sm.channels == nil { + t.Errorf("SM channels map does not exist") + } + if sm.subchannels == nil { + t.Errorf("SM subchannels map does not exist") + } +} + +func Test_socketMultiplexer_SocketAdd(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sp := &s + sm.SocketAdd(sp) + if _, ok := sm.sockets[s.ID()]; !ok { + t.Errorf("SM SocketAdd failed to add the session to sockets map") + } + + scrubSM() +} + +func Test_socketMultiplexer_SocketSend(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sp := &s + sm.SocketAdd(sp) + if err := sm.SocketSend(s.ID(), ""); err != nil { + t.Errorf("SM SocketSend failed: %v", err) + } + + scrubSM() +} + +func Test_socketMultiplexer_SocketRemove(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sp := &s + sm.SocketAdd(sp) + sm.SocketRemove(s.ID(), true) + if _, ok := sm.sockets[s.ID()]; ok { + t.Errorf("SM SocketRemove failed to remove the session from the sockets map") + } + + // Check for race conditions. + ech := make(chan error) + for i := 0; i < 100; i += 1 { + id := fmt.Sprint(i) + digits := 8 - len(id) + id = strings.Repeat("a", digits) + id + ms := NewMockSession(id) + s := sockjs.Session(ms) + sp := &s + go func() { + pipeErr(ech, sm.SocketAdd(sp)) + pipeErr(ech, sm.SocketSend(s.ID(), "")) + pipeErr(ech, sm.SocketRemove(s.ID(), true)) + }() + } + + for i := 0; i < 300; i += 1 { + if err := <-ech; err != nil { + t.Errorf("SM sockets race condition in add/remove/send: %v", err) + break + } + } + + scrubSM() +} + +func Test_socketMultiplexer_ChannelAdd(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sp := &s + sm.SocketAdd(sp) + sm.ChannelAdd("global", s.ID()) + if _, ok := sm.channels["global"]; !ok { + t.Errorf("SM ChannelAdd failed to add channel to channels map") + } + if _, ok := sm.channels["global"][s.ID()]; !ok { + t.Errorf("SM ChannelAdd failed to add socket to new channel") + } + if _, ok := sm.subchannels["global"]; ok { + t.Errorf("SM ChannelAdd added a subchannel map for a non-battle room") + } + + sm.ChannelAdd("battle-ou-1", s.ID()) + if _, ok := sm.subchannels["battle-ou-1"]; !ok { + t.Errorf("SM ChannelAdd failed to add subchannel map for battle room's channel") + } + if _, ok := sm.subchannels["battle-ou-1"][s.ID()]; !ok { + t.Errorf("SM ChannelAdd failed to add socket ID to subchannel map") + } + + scrubSM() +} + +func Test_socketMultiplexer_ChannelSend(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sp := &s + sm.SocketAdd(sp) + sm.ChannelAdd("global", s.ID()) + if err := sm.ChannelSend("global", ""); err != nil { + t.Errorf("SM ChannelSend failed to send message: %v", err) + } + + scrubSM() +} + +func Test_socketMultiplexer_ChannelRemove(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + ms2 := NewMockSession("aaaaaaa1") + s := sockjs.Session(ms) + s2 := sockjs.Session(ms2) + sp := &s + sp2 := &s2 + sm.SocketAdd(sp) + sm.SocketAdd(sp2) + sm.ChannelAdd("global", s.ID()) + sm.ChannelAdd("global", s2.ID()) + sm.ChannelRemove("global", s2.ID()) + if _, ok := sm.channels["global"][s2.ID()]; ok { + t.Errorf("SM ChannelRemove failed to remove socket from channel") + } + + sm.ChannelRemove("global", s.ID()) + if _, ok := sm.channels["global"]; ok { + t.Errorf("SM ChannelRemove failed to remove channel when removing its last socket") + } + + sm.ChannelAdd("battle-ou-1", s.ID()) + sm.ChannelRemove("battle-ou-1", s.ID()) + if _, ok := sm.subchannels["battle-ou-1"]; ok { + t.Errorf("SM ChannelRemove failed to remove subchannel map") + } + + // Check for race conditions. + ech := make(chan error) + for i := 0; i < 100; i += 1 { + id := "battle-ou-" + fmt.Sprint(i) + go func() { + pipeErr(ech, sm.ChannelAdd(id, s.ID())) + pipeErr(ech, sm.ChannelSend(id, "")) + pipeErr(ech, sm.ChannelRemove(id, s.ID())) + }() + } + + for i := 0; i < 300; i += 1 { + err := <-ech + if err != nil { + t.Errorf("SM channel race condition in add/remove/send: %v", err) + break + } + } + + scrubSM() +} + +func Test_socketMultiplexer_SubchannelMove(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sp := &s + sm.SocketAdd(sp) + sm.ChannelAdd("battle-ou-1", s.ID()) + sm.SubchannelMove("battle-ou-1", "1", s.ID()) + if scid, _ := sm.subchannels["battle-ou-1"][s.ID()]; scid != "1" { + t.Errorf("SM SubchannelMove failed to move socket to new subchannel") + } + + scrubSM() +} + +func Test_socketMultiplexer_SubchannelSend(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sp := &s + sm.SocketAdd(sp) + sm.ChannelAdd("battle-ou-1", s.ID()) + if err := sm.SubchannelSend("battle-ou-1", ""); err != nil { + t.Errorf("SM SubchannelSend failed to send to subchannel: %v", err) + } + + // Check for race conditions. + ech := make(chan error) + for i := 0; i < 100; i += 1 { + id := "battle-ou-" + fmt.Sprint(i) + go func() { + fmt.Print("adding " + s.ID()) + pipeErr(ech, sm.ChannelAdd(id, s.ID())) + fmt.Print("sending " + s.ID()) + pipeErr(ech, sm.SubchannelSend(id, "\n|split\n|c|@Morfent|sup0\n|c|@Morfent|sup1\n|c|@Morfent|sup2\n")) + fmt.Print("removing " + s.ID()) + pipeErr(ech, sm.ChannelRemove(id, s.ID())) + }() + } + + for i := 0; i < 300; i += 1 { + err := <-ech + if err != nil { + t.Errorf("SM channel race condition in add/remove/send: %v", err) + break + } + } + + scrubSM() +} + +func init() { + production = false +} diff --git a/users.js b/users.js index c768b3d3cfe29..f1deed864ac21 100644 --- a/users.js +++ b/users.js @@ -1471,7 +1471,8 @@ Users.socketConnect = function (worker, workerid, socketid, ip, protocol) { let banned = Punishments.checkIpBanned(connection); if (banned) { - return connection.destroy(); + setImmediate(() => connection.destroy()); + return; } // Emergency mode connections logging if (Config.emergency) {