Skip to content

Commit

Permalink
Support binds on abstract models (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
singingwolfboy authored and davidism committed Apr 9, 2016
1 parent 095f19a commit 88892a8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
4 changes: 2 additions & 2 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,9 @@ def _join(match):
return DeclarativeMeta.__new__(cls, name, bases, d)

def __init__(self, name, bases, d):
bind_key = d.pop('__bind_key__', None)
bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None)
DeclarativeMeta.__init__(self, name, bases, d)
if bind_key is not None:
if bind_key is not None and hasattr(self, '__table__'):
self.__table__.info['bind_key'] = bind_key


Expand Down
42 changes: 40 additions & 2 deletions test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import with_statement

import atexit
import tempfile
import os
import unittest
from datetime import datetime
import flask
Expand Down Expand Up @@ -384,12 +386,10 @@ def index():
class BindsTestCase(unittest.TestCase):

def test_basic_binds(self):
import tempfile
_, db1 = tempfile.mkstemp()
_, db2 = tempfile.mkstemp()

def _remove_files():
import os
try:
os.remove(db1)
os.remove(db2)
Expand Down Expand Up @@ -456,6 +456,44 @@ class Baz(db.Model):
Baz.__table__: db.get_engine(app, None)
})

def test_abstract_binds(self):
_, db1 = tempfile.mkstemp()
_, db2 = tempfile.mkstemp()

def _remove_files():
try:
os.remove(db1)
os.remove(db2)
except IOError:
pass
atexit.register(_remove_files)

app = flask.Flask(__name__)
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
app.config['SQLALCHEMY_BINDS'] = {
'foo': 'sqlite:///' + db1,
'bar': 'sqlite:///' + db2
}
db = sqlalchemy.SQLAlchemy(app)

class AbstractFooBoundModel(db.Model):
__abstract__ = True
__bind_key__ = 'foo'

class FooBoundModel(AbstractFooBoundModel):
id = db.Column(db.Integer, primary_key=True)

db.create_all()

# does the model have the correct engines?
self.assertEqual(db.metadata.tables['foo_bound_model'].info['bind_key'], 'foo')

# see the tables created in an engine
metadata = db.MetaData()
metadata.reflect(bind=db.get_engine(app, 'foo'))
self.assertEqual(len(metadata.tables), 1)
self.assertTrue('foo_bound_model' in metadata.tables)


class DefaultQueryClassTestCase(unittest.TestCase):

Expand Down

0 comments on commit 88892a8

Please sign in to comment.