Code

Ticket #5062: pymssql.diff

File pymssql.diff, 3.6 KB (added by mamcx, 7 years ago)

Fixed pymssql. Install from site and then apply this

Line 
1Index: pymssql.py
2===================================================================
3--- pymssql.py  (revision 190)
4+++ pymssql.py  (working copy)
5@@ -133,23 +133,30 @@
6 
7                # first try to execute all queries
8                totrows = 0
9-               sql = ""
10-               try:
11-                       for params in param_seq:
12-                               if params != None:
13-                                       sql = _quoteparams(operation, params)
14+               #import pdb
15+               #pdb.set_trace()
16+               #Respect GO terminator
17+               for sql in operation.split('\nGO'):
18+                       if sql=='':
19+                               continue
20+                       try:
21+                               for params in param_seq:
22+                                       if params != None:
23+                                               sql = _quoteparams(sql, params)
24+
25+                                       #print sql
26+                                       ret = self.__source.query(sql)
27+                                       if ret == 1:
28+                                               self._result = self.__source.fetch_array()
29+                                               totrows = totrows + self._result[self.__resultpos][1]
30+                                       else:
31+                                           self._result = None
32+                                           raise DatabaseError, "error: %s" % self.__source.errmsg()
33+                       except Exception,e:
34+                               if self.__source.errmsg() == None:
35+                                       raise e
36                                else:
37-                                       sql = operation
38-                               #print sql
39-                               ret = self.__source.query(sql)
40-                               if ret == 1:
41-                                       self._result = self.__source.fetch_array()
42-                                       totrows = totrows + self._result[self.__resultpos][1]
43-                               else:
44-                                   self._result = None
45-                                   raise DatabaseError, "error: %s" % self.__source.errmsg()
46-               except:
47-                       raise DatabaseError, "internal error: %s" % self.__source.errmsg()
48+                                       raise DatabaseError, "internal error: %s" % self.__source.errmsg()
49 
50                # then initialize result raw count and description
51                if len(self._result[self.__resultpos][0]) > 0:
52@@ -220,6 +227,8 @@
53        # alternative quoting by Luciano Pacheco <lucmult@gmail.com>
54        #elif hasattr(x, 'timetuple'):
55        #       x = time.strftime('\'%Y%m%d %H:%M:%S\'', x.timetuple())
56+       elif type(x) == types.BooleanType:
57+               x = x and 1 or 0
58        else:
59                #print "didn't like " + x + " " + str(type(x))
60                raise InterfaceError, 'do not know how to handle type %s' % type(x)
61@@ -244,8 +253,9 @@
62 
63        def __init__(self, cnx):
64                self.__cnx = cnx
65+               self.__autocommit = False
66                try:
67-                       self.__cnx.query("begin tran")
68+                       self.__cnx.query("IF @@TRANCOUNT>0 begin tran")
69                        self.__cnx.fetch_array()
70                except:
71                        raise OperationalError, "invalid connection."
72@@ -259,10 +269,14 @@
73        def commit(self):
74                if self.__cnx == None:
75                        raise OperationalError, "invalid connection."
76+
77+               if self.__autocommit == True:
78+                       return
79+
80                try:
81-                       self.__cnx.query("commit tran")
82+                       self.__cnx.query("IF @@TRANCOUNT>0 commit tran")
83                        self.__cnx.fetch_array()
84-                       self.__cnx.query("begin tran")
85+                       self.__cnx.query("IF @@TRANCOUNT>0 begin tran")
86                        self.__cnx.fetch_array()
87                except:
88                        raise OperationalError, "can't commit."
89@@ -270,14 +284,30 @@
90        def rollback(self):
91                if self.__cnx == None:
92                        raise OperationalError, "invalid connection."
93+
94+               if self.__autocommit == True:
95+                       return
96+
97                try:
98-                       self.__cnx.query("rollback tran")
99+                       self.__cnx.query("IF @@TRANCOUNT>0 rollback tran")
100                        self.__cnx.fetch_array()
101-                       self.__cnx.query("begin tran")
102+                       self.__cnx.query("IF @@TRANCOUNT>0 begin tran")
103                        self.__cnx.fetch_array()
104                except:
105                        raise OperationalError, "can't rollback."
106 
107+       def autocommit(self,status):
108+               if status:
109+                       if self.__autocommit == False:
110+                               self.__cnx.query("IF @@TRANCOUNT>0 rollback tran")
111+                               self.__cnx.fetch_array()
112+                               self.__autocommit = True
113+                       else:
114+                               if self.__autocommit == True:
115+                                       self.__cnx.query("IF @@TRANCOUNT>0 begin tran")
116+                                       self.__cnx.fetch_array()
117+                                       self.__autocommit = False
118+
119        def cursor(self):
120                if self.__cnx == None:
121                        raise OperationalError, "invalid connection."