#!/usr/bin/python  
 | 
# -*- coding: GBK -*-  
 | 
#-------------------------------------------------------------------------------  
 | 
#  
 | 
#-------------------------------------------------------------------------------  
 | 
#  
 | 
  
 | 
import pymongo  
 | 
from pymongo.son_manipulator import SONManipulator  
 | 
import base64  
 | 
from functools import wraps  
 | 
from time import (sleep)  
 | 
  
 | 
import CommFunc  
 | 
  
 | 
#Ä£ÄâSQLµÄIDENT  
 | 
def seq(db, collectionName, fieldName, feed, increment):  
 | 
    try:  
 | 
        result = 0  
 | 
        collection = db['%s_seq'%collectionName]  
 | 
        resultObj = collection.find_and_modify(query={'_id':fieldName}, update={'$inc':{'seq':increment}}, new=True)  
 | 
        if resultObj:  
 | 
            result = resultObj['seq']  
 | 
        else:  
 | 
            resultObj = collection.find_and_modify(query={'_id':fieldName}, update={'$set':{'seq':feed}}, new=True,  
 | 
                                                   upsert=True)  
 | 
            if resultObj:  
 | 
                result = resultObj['seq']  
 | 
            else:  
 | 
                return False, None  
 | 
    except Exception, e:  
 | 
        return False, None  
 | 
    return True, result  
 | 
  
 | 
class ObjectIdRemover(SONManipulator):  
 | 
    def transform_outgoing(self, son, collection):  
 | 
        if '_id' in son:  
 | 
            del son['_id']  
 | 
        return son  
 | 
  
 | 
class EncodeStringManipulator(SONManipulator):  
 | 
    def __init__(self, encoding):  
 | 
        self.encoding = encoding  
 | 
    def transform_incoming(self, son, collection):  
 | 
          
 | 
        def transform_value(value):  
 | 
            if isinstance(value, dict):  
 | 
                return transform_dict(value)  
 | 
            elif isinstance(value, list):  
 | 
                return [transform_value(v) for v in value]  
 | 
            elif isinstance(value, basestring):  
 | 
                result, value = CommFunc.EncodingToUnicode(self.encoding, value)  
 | 
                return value  
 | 
            return value  
 | 
          
 | 
        def transform_dict(object):  
 | 
            for (key, value) in object.items():  
 | 
                object[key] = transform_value(value)  
 | 
            return object  
 | 
          
 | 
        def transform_list(container):  
 | 
            for item in container:  
 | 
                transform_dict(item)  
 | 
            return container  
 | 
          
 | 
        if isinstance(son, dict):  
 | 
            return transform_dict(son)  
 | 
        elif isinstance(son, list):  
 | 
            return transform_list(son)  
 | 
        return son  
 | 
      
 | 
    def transform_outgoing(self, son, collection):  
 | 
          
 | 
        def transform_value(value):  
 | 
            if isinstance(value, dict):  
 | 
                return transform_dict(value)  
 | 
            elif isinstance(value, list):  
 | 
                return [transform_value(v) for v in value]  
 | 
            elif isinstance(value, basestring):  
 | 
                result, value =CommFunc.UnicodeToEncoding(self.encoding, value)  
 | 
                return value  
 | 
            return value  
 | 
          
 | 
        def transform_dict(object):  
 | 
            for (key, value) in object.items():  
 | 
                object[key] = transform_value(value)  
 | 
            return object  
 | 
          
 | 
        def transform_list(container):  
 | 
            for item in container:  
 | 
                transform_dict(item)  
 | 
            return container  
 | 
          
 | 
        if isinstance(son, dict):  
 | 
            return transform_dict(son)  
 | 
        elif isinstance(son, list):  
 | 
            return transform_list(son)  
 | 
        return son  
 | 
      
 | 
class Base64StringManipulator(SONManipulator):  
 | 
      
 | 
    def transform_incoming(self, son, collection):  
 | 
          
 | 
        def transform_value(value):  
 | 
            if isinstance(value, dict):  
 | 
                return transform_dict(value)  
 | 
            elif isinstance(value, list):  
 | 
                return [transform_value(v) for v in value]  
 | 
            elif isinstance(value, basestring):  
 | 
                return base64.b64encode(value)  
 | 
            return value  
 | 
          
 | 
        def transform_dict(object):  
 | 
            for (key, value) in object.items():  
 | 
                object[key] = transform_value(value)  
 | 
            return object  
 | 
          
 | 
        def transform_list(container):  
 | 
            for item in container:  
 | 
                transform_dict(item)  
 | 
            return container  
 | 
          
 | 
        if isinstance(son, dict):  
 | 
            return transform_dict(son)  
 | 
        elif isinstance(son, list):  
 | 
            return transform_list(son)  
 | 
        return son  
 | 
      
 | 
    def transform_outgoing(self, son, collection):  
 | 
          
 | 
        def transform_value(value):  
 | 
            if isinstance(value, dict):  
 | 
                return transform_dict(value)  
 | 
            elif isinstance(value, list):  
 | 
                return [transform_value(v) for v in value]  
 | 
            elif isinstance(value, basestring):  
 | 
                return base64.b64decode(value)  
 | 
            return value  
 | 
          
 | 
        def transform_dict(object):  
 | 
            for (key, value) in object.items():  
 | 
                object[key] = transform_value(value)  
 | 
            return object  
 | 
          
 | 
        def transform_list(container):  
 | 
            for item in container:  
 | 
                transform_dict(item)  
 | 
            return container  
 | 
          
 | 
        if isinstance(son, dict):  
 | 
            return transform_dict(son)  
 | 
        elif isinstance(son, list):  
 | 
            return transform_list(son)  
 | 
        return son  
 | 
        
 | 
#ÓÃÓÚÐÞÊÎDBControllerµÄÊý¾Ý¿â²Ù×÷º¯Êý  
 | 
#¶ÏÏß×Ô¶¯ÖØÊÔ  
 | 
def reconnect_decorator(func):  
 | 
    @wraps(func)  
 | 
    def wrapper(*args, **kwds):  
 | 
        MAX_RECONNECT = 10  
 | 
        RECONNECT_INTERVAL = 0.1  
 | 
        failCnt = 0  
 | 
        while True:  
 | 
            try:  
 | 
                #È¥µôself  
 | 
                return func(*args, **kwds)  
 | 
            except pymongo.errors.AutoReconnect, e:  
 | 
                failCnt += 1  
 | 
                sleep(RECONNECT_INTERVAL)  
 | 
                if failCnt > MAX_RECONNECT:  
 | 
                    raise e  
 | 
      
 | 
    return wrapper  
 | 
  
 | 
class DBController:  
 | 
    def __init__(self, host, port, dbName, user, pwd, encoding):  
 | 
        self.host = host  
 | 
        self.port = port  
 | 
        self.dbName = dbName  
 | 
        self.user = user  
 | 
        self.pwd = pwd  
 | 
          
 | 
        self.connected = False  
 | 
        self.con = None  
 | 
        self.db = None  
 | 
        self.lastError = None  
 | 
        self.translator = None  
 | 
        #===========================================================================================  
 | 
        # if encoding == 'base64':  
 | 
        #    self.translator = Base64StringManipulator()  
 | 
        # else:  
 | 
        #    self.translator = EncodeStringManipulator(encoding)  
 | 
        #===========================================================================================  
 | 
        self.initialize()  
 | 
          
 | 
    def initialize(self):  
 | 
        if not self.connected:  
 | 
            if not self.doConnect(self.host, self.port):  
 | 
                return False  
 | 
            authResult = self.doAuthentication(self.dbName, self.user, self.pwd)  
 | 
            if self.db:  
 | 
                self.db.add_son_manipulator(ObjectIdRemover())  
 | 
            return authResult  
 | 
        return True  
 | 
          
 | 
    def doConnect(self, ip, port):  
 | 
        try:  
 | 
            self.con = pymongo.Connection(ip, port)  
 | 
        except TypeError, typeError:  
 | 
            raise  
 | 
        except pymongo.errors.ConnectionFailure, failure:  
 | 
            self.lastError = failure  
 | 
            return False  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return False  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return False  
 | 
        self.connected = True  
 | 
        return True  
 | 
      
 | 
    def doAuthentication(self, dbName, user, pwd):  
 | 
        if not self.connected or not self.con:  
 | 
            self.lastError = 'Not connected yet!'  
 | 
            return False  
 | 
        self.db = self.con[dbName]  
 | 
        authDB = self.con['admin']  
 | 
        try:  
 | 
            return authDB.authenticate(user, pwd)  
 | 
#            return self.db.authenticate(user, pwd)  
 | 
        except TypeError, typeError:  
 | 
            self.lastError = typeError  
 | 
            return False  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return False  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return False  
 | 
      
 | 
    def find_one(self, colName, spec, filter = None):  
 | 
        result, recList = self.find(colName, spec, filter, 1)  
 | 
        if not result:  
 | 
            return False, None  
 | 
        for rec in recList:  
 | 
            return True, rec  
 | 
        return True, None  
 | 
      
 | 
    @reconnect_decorator  
 | 
    def find(self, colName, spec = None, filter = None, maxCnt = 0, sortBy = None):  
 | 
        if not self.connected:  
 | 
            if not self.initialize():  
 | 
                return False, []  
 | 
              
 | 
        result = False  
 | 
        resultDictList = []  
 | 
        col = self.db[colName]  
 | 
        if self.translator:  
 | 
            spec = self.translator.transform_incoming(spec, None)  
 | 
        try:  
 | 
            resultCollection = col.find(spec, filter, limit = maxCnt, sort = sortBy)  
 | 
            if self.translator:  
 | 
                resultDictList = self.translator.transform_outgoing(list(resultCollection), None)  
 | 
            else:  
 | 
                resultDictList = list(resultCollection)  
 | 
            return True, resultDictList  
 | 
        except TypeError, typeError:  
 | 
            self.lastError = typeError  
 | 
            return result, resultDictList  
 | 
        except pymongo.errors.OperationFailure, err:  
 | 
            self.lastError = err  
 | 
            return result, resultDictList  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return result, resultDictList  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return result, resultDictList      
 | 
          
 | 
    @reconnect_decorator  
 | 
    def insert(self, colName, doc_or_docs, isSafe = True):  
 | 
        if not self.connected:  
 | 
            if not self.initialize():  
 | 
                return False  
 | 
              
 | 
        col = self.db[colName]  
 | 
        if self.translator:  
 | 
            doc_or_docs = self.translator.transform_incoming(doc_or_docs, None)  
 | 
        try:  
 | 
            col.insert(doc_or_docs, safe = isSafe)  
 | 
        except pymongo.errors.OperationFailure, err:  
 | 
            self.lastError = err  
 | 
            return False  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return False  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return False  
 | 
        return True  
 | 
      
 | 
    @reconnect_decorator  
 | 
    def update(self, colName, spec, doc, isUpsert = False, isSafe = True, isMulti = False):  
 | 
        if not self.connected:  
 | 
            if not self.initialize():  
 | 
                return False  
 | 
              
 | 
        col = self.db[colName]  
 | 
        #ÐèÒªÏȶÔdoc½øÐд¦Àí£¬µ«Óɲ»ÄÜ¿ªÆôcollection.updateµÄmanipulate,ÒòΪÄÇ»áÓ¦ÓÃËùÓд¦Àí  
 | 
        if self.translator:  
 | 
            spec = self.translator.transform_incoming(spec, None)  
 | 
            doc = self.translator.transform_incoming(doc, None)  
 | 
        try:  
 | 
            col.update(spec, doc, upsert = isUpsert, safe = isSafe, multi = isMulti)  
 | 
        except TypeError, typeError:  
 | 
            self.lastError = typeError  
 | 
            return False  
 | 
        except pymongo.errors.OperationFailure, err:  
 | 
            self.lastError = err  
 | 
            return False  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return False  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return False  
 | 
        return True  
 | 
      
 | 
    @reconnect_decorator  
 | 
    def save(self, colName, doc, isSafe = True):  
 | 
        if not self.connected:  
 | 
            if not self.initialize():  
 | 
                return False  
 | 
              
 | 
        col = self.db[colName]  
 | 
        if self.translator:  
 | 
            doc = self.translator.transform_incoming(doc, None)  
 | 
        try:  
 | 
            col.save(doc, safe = isSafe)  
 | 
        except TypeError, typeError:  
 | 
            self.lastError = typeError  
 | 
            return False  
 | 
        except pymongo.errors.OperationFailure, err:  
 | 
            self.lastError = err  
 | 
            return False  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return False  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return False  
 | 
        return True  
 | 
      
 | 
    @reconnect_decorator  
 | 
    def remove(self, colName, spec = None, isSafe = True):  
 | 
        if not self.connected:  
 | 
            if not self.initialize():  
 | 
                return False  
 | 
              
 | 
        col = self.db[colName]  
 | 
        if self.translator:  
 | 
            spec = self.translator.transform_incoming(spec, None)  
 | 
        try:  
 | 
            col.remove(spec, safe = isSafe)  
 | 
        except pymongo.errors.OperationFailure, err:  
 | 
            self.lastError = err  
 | 
            return False  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return False  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return False  
 | 
        return True  
 | 
      
 | 
    @reconnect_decorator  
 | 
    def drop(self, colName):  
 | 
        if not self.connected:  
 | 
            if not self.initialize():  
 | 
                return False  
 | 
     
 | 
        col = self.db[colName]  
 | 
        try:  
 | 
            col.drop()  
 | 
        except TypeError, typeError:  
 | 
            self.lastError = typeError  
 | 
            return False  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return False  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return False  
 | 
        return True  
 | 
      
 | 
    @reconnect_decorator  
 | 
    def count(self, colName):  
 | 
        if not self.connected:  
 | 
            if not self.initialize():  
 | 
                return False, 0  
 | 
              
 | 
        col = self.db[colName]  
 | 
        try:  
 | 
            cnt = col.count()  
 | 
        except pymongo.errors.OperationFailure, err:  
 | 
            self.lastError = err  
 | 
            return False, 0  
 | 
        except Exception, e:  
 | 
            self.lastError = e  
 | 
            return False, 0  
 | 
        except:  
 | 
            self.lastError = 'Unknown exception occur!'  
 | 
            return False, 0  
 | 
        return True, cnt  
 | 
        
 | 
def test_seq():  
 | 
    con = pymongo.Connection()  
 | 
    db = con.admin  
 | 
    if not db.authenticate('sa', 'sa'):  
 | 
        print 'auth failed!'  
 | 
        return  
 | 
    colName = 'tagSeqTest'  
 | 
    fieldName = 'ID'  
 | 
    db = con['test']  
 | 
    db.drop_collection(colName)  
 | 
    db.drop_collection('%s_seq'%colName)  
 | 
      
 | 
    result, ID = seq(db, colName, fieldName, 1, 1)  
 | 
    assert (result and ID == 1)  
 | 
    result, ID = seq(db, colName, fieldName, 1, 1)  
 | 
    assert (result and ID == 2)  
 | 
            
 | 
def test_StringManipulator():  
 | 
    translator = Base64StringManipulator()  
 | 
      
 | 
    son = []  
 | 
    result = translator.transform_incoming(son, None)  
 | 
    assert (son == result)  
 | 
    result = translator.transform_outgoing(son, None)  
 | 
    assert (son == result)  
 | 
      
 | 
    son = [{'a':1}]  
 | 
    result = translator.transform_incoming(son, None)  
 | 
    assert (son == result)  
 | 
    result = translator.transform_outgoing(son, None)  
 | 
    assert (son == result)  
 | 
      
 | 
    son = [{'a':'a'}]  
 | 
    result = translator.transform_incoming(son, None)  
 | 
    assert (result and result == [{'a':base64.b64encode('a')}])  
 | 
    result = translator.transform_outgoing(result, None)  
 | 
    assert (result and result == son)  
 | 
      
 | 
    son = [{'a':[{'b':'b'}, {'c':'c'}]}]  
 | 
    result = translator.transform_incoming(son, None)  
 | 
    assert (result and result == [{'a':[{'b':base64.b64encode('b')}, {'c':base64.b64encode('c')}]}])  
 | 
    result = translator.transform_outgoing(result, None)  
 | 
    assert (result and result == son)  
 | 
      
 | 
def test_DBController():  
 | 
    testColName = 'tagTestController'  
 | 
    dbController = DBController('localhost', 27017, 'test', 'test', '1')  
 | 
    result = dbController.drop(testColName)  
 | 
    assert result  
 | 
      
 | 
    result, cnt = dbController.count(testColName)  
 | 
    assert (result and cnt == 0)  
 | 
      
 | 
    doc = {'a':1}  
 | 
    result = dbController.insert(testColName, doc)  
 | 
    assert result  
 | 
      
 | 
    result, recs = dbController.find(testColName)  
 | 
    assert (result and len(recs) == 1)  
 | 
    rec = recs[0]  
 | 
#    del rec['_id']  
 | 
#    print 'rec = %s\r\ndoc = %s'%(rec, doc)  
 | 
    assert (rec == doc)  
 | 
      
 | 
    spec = {'a':1}  
 | 
    updateDoc = {'a':2}  
 | 
    updateDocWithModifier = {'$set':updateDoc}  
 | 
    result = dbController.update(testColName, spec, updateDocWithModifier)  
 | 
    assert result  
 | 
    result, recs = dbController.find(testColName)  
 | 
    assert (result and len(recs) == 1)  
 | 
    rec = recs[0]  
 | 
    del rec['_id']  
 | 
#    print 'rec = %s\r\nupdateDoc = %s'%(rec, updateDoc)  
 | 
    assert (rec == updateDoc)  
 | 
      
 | 
    result = dbController.remove(testColName)  
 | 
    assert result  
 | 
    result, recs = dbController.find(testColName)  
 | 
    assert (result and recs == [])  
 | 
      
 | 
    saveDoc = {'b':3}  
 | 
    result = dbController.save(testColName, saveDoc)  
 | 
    assert result  
 | 
    result, recs = dbController.find(testColName)  
 | 
    assert (result and len(recs) == 1)  
 | 
    rec = recs[0]  
 | 
#    del rec['_id']  
 | 
    assert (rec == saveDoc)  
 | 
      
 | 
def test():  
 | 
    test_seq()  
 | 
    test_StringManipulator()  
 | 
    test_DBController()  
 | 
    print 'test ok!'  
 | 
  
 | 
if __name__ == '__main__':  
 | 
    test()  
 | 
           
 |