# The COPYRIGHT file at the top level of this repository contains the full # copyright notices and license terms. from trytond.model import fields from trytond.modules.edw.connector import BackendConnector from trytond.modules.edw.tools import evaluate_sql_domain import pymssql INSERT_QUERY = ( """INSERT INTO [%(tablename)s] """ + """(%(fields)s) """ + """VALUES (%(values)s);""") TRUNCATE_QUERY = ( """TRUNCATE TABLE [%(tablename)s]; """) CREATE_QUERY = ( """IF NOT EXISTS (select * from sysobjects """ + """ where name='%(tablename)s' and xtype='U') """ + """ CREATE TABLE [%(tablename)s] (%(fields)s);""") DROP_QUERY = ( """IF EXISTS (select * from sysobjects """ + """ where name='%(tablename)s' and xtype='U') """ + """ DROP TABLE [%(tablename)s]; """) DELETE_QUERY = ( """ DELETE FROM \"%(tablename)s\" WHERE %(where)s; """) class BackendConnectorMSsql(BackendConnector): """MS SQL Connector""" def connect(self): url_server = self.uri.split("//")[1] server = url_server.split("/")[0] database = url_server.split("/")[1] conn = pymssql.connect(server=server, user=self.username, password=self.password, database=database) return conn def create(self, fields, tablename): query = self._get_create_query(fields, tablename) self.execute_query(query) def fill(self, results, tablename): query = self._get_insert_query(tablename, results) self.execute_query(query, [tuple(r.values()) for r in results]) def clean(self, tablename, domain=None): if not domain: query = self._get_truncate_query(tablename) else: query = self._get_delete_query(tablename, domain) self.execute_query(query) def drop(self, tablename): query = self._get_drop_query(tablename) self.execute_query(query) def _get_create_query(self, fields, tablename): fields_sql = [] for key, value in fields.items(): fields_sql.append(self._format_object_name(key) + ' ' + value) fields_sql = ','.join(fields_sql) return CREATE_QUERY % {'tablename': tablename, 'fields': fields_sql} def _get_insert_query(self, tablename, results): result = results[0] fields = [self._format_object_name(key) for key in list(result.keys())] fields = ','.join(fields) values = ('%s,' * len(result))[:-1] return INSERT_QUERY % {'tablename': tablename, 'fields': fields, 'values': values} def _format_object_name(self, name): return '[%s]' % name def _get_truncate_query(self, tablename): return TRUNCATE_QUERY % {'tablename': tablename} def _get_drop_query(self, tablename): return DROP_QUERY % {'tablename': tablename} def _get_delete_query(self, tablename, where): return DELETE_QUERY % { 'tablename': tablename, 'where': where or '1=1' } def get_mapped_types(self): return { fields.Integer: 'int', fields.Many2One: 'int', fields.Char: 'nvarchar(max)', fields.Text: 'nvarchar(max)', fields.Selection: 'nvarchar(max)', fields.Date: 'date', fields.Numeric: 'numeric(32, 18)', fields.TimeDelta: 'nvarchar(32)', fields.Timestamp: 'datetime', fields.Float: 'numeric(32, 18)', fields.Reference: 'nvarchar(max)', fields.Boolean: 'bit', fields.DateTime: 'datetime', fields.Time: 'nvarchar(32)', fields.One2One: 'int' } def execute_query(self, query, results=None): with self.connect() as connection: cursor = connection.cursor() try: if results: cursor.executemany(query, results) else: cursor.execute(query) connection.commit() except Exception as e: connection.rollback() raise e finally: connection.close() def evaluate_domain(self, model, domain): return evaluate_sql_domain(model, domain)