Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions mongomock/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, host=None, port=None, document_class=dict,
else:
self.host, self.port = split_hosts(self.host, default_port=self.port)[0]

self.__default_datebase_name = dbase
self.__default_database_name = dbase

def __getitem__(self, db_name):
return self.get_database(db_name)
Expand Down Expand Up @@ -130,7 +130,11 @@ def drop_collections_for_db(_db):
def get_database(self, name=None, codec_options=None, read_preference=None,
write_concern=None):
if name is None:
db = self.get_default_database()
db = self.get_default_database(
codec_options=codec_options,
read_preference=read_preference,
write_concern=write_concern,
)
else:
db = self._database_accesses.get(name)
if db is None:
Expand All @@ -140,11 +144,13 @@ def get_database(self, name=None, codec_options=None, read_preference=None,
codec_options=codec_options or self._codec_options, _store=db_store)
return db

def get_default_database(self):
if self.__default_datebase_name is None:
raise ConfigurationError('No default database defined')
def get_default_database(self, default=None, **kwargs):
name = self.__default_database_name
name = name if name is not None else default
if name is None:
raise ConfigurationError('No default database name defined or provided.')

return self[self.__default_datebase_name]
return self.get_database(name=name, **kwargs)

def alive(self):
"""The original MongoConnection.alive method checks the status of the server.
Expand Down
20 changes: 20 additions & 0 deletions tests/test__mongomock.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from bson.objectid import ObjectId
import pymongo
from pymongo import MongoClient as PymongoClient
from pymongo.read_preferences import ReadPreference
_HAVE_PYMONGO = True
_PYMONGO_VERSION = version.LooseVersion(pymongo.version)
except ImportError:
Expand Down Expand Up @@ -216,6 +217,25 @@ def client(uri):
with self.assertRaises(ConfigurationError):
c.get_default_database()

def test__getting_default_database_with_default_parameter(self):
c = mongomock.MongoClient('mongodb://host1/')
self.assertIs(c.get_default_database('foo'), c['foo'])
self.assertIs(c.get_default_database(default='foo'), c['foo'])

def test__getting_default_database_ignoring_default_parameter(self):
c = mongomock.MongoClient('mongodb://host1/bar')
self.assertIs(c.get_default_database('foo'), c['bar'])
self.assertIs(c.get_default_database(default='foo'), c['bar'])

@skipIf(not _HAVE_PYMONGO, 'pymongo not installed')
def test__getting_default_database_preserves_options(self):
client = mongomock.MongoClient('mongodb://host1/foo')
db = client.get_database(read_preference=ReadPreference.NEAREST)

self.assertEqual(db.name, 'foo')
self.assertEqual(ReadPreference.NEAREST, db.read_preference)
self.assertEqual(ReadPreference.PRIMARY, client.read_preference)


class UTCPlus2(datetime.tzinfo):
def fromutc(self, dt):
Expand Down