Coverage for providers/src/airflow/providers/teradata/hooks/teradata.py: 87%

106 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-12-27 08:27 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""An Airflow Hook for interacting with Teradata SQL Server.""" 

19 

20from __future__ import annotations 

21 

22import re 

23from typing import TYPE_CHECKING, Any 

24 

25import sqlalchemy 

26import teradatasql 

27from teradatasql import TeradataConnection 

28 

29from airflow.providers.common.sql.hooks.sql import DbApiHook 

30 

31if TYPE_CHECKING: 

32 from airflow.models.connection import Connection 

33 

34PARAM_TYPES = {bool, float, int, str} 

35 

36 

37def _map_param(value): 

38 if value in PARAM_TYPES: 38 ↛ 42line 38 didn't jump to line 42 because the condition on line 38 was never true

39 # In this branch, value is a Python type; calling it produces 

40 # an instance of the type which is understood by the Teradata driver 

41 # in the out parameter mapping mechanism. 

42 value = value() 

43 return value 

44 

45 

46def _handle_user_query_band_text(query_band_text) -> str: 

47 """Validate given query_band and append if required values missed in query_band.""" 

48 # Ensures 'appname=airflow' and 'org=teradata-internal-telem' are in query_band_text. 

49 if query_band_text is not None: 

50 # checking org doesn't exist in query_band, appending 'org=teradata-internal-telem' 

51 # If it exists, user might have set some value of their own, so doing nothing in that case 

52 pattern = r"org\s*=\s*([^;]*)" 

53 match = re.search(pattern, query_band_text) 

54 if not match: 

55 if not query_band_text.endswith(";"): 

56 query_band_text += ";" 

57 query_band_text += "org=teradata-internal-telem;" 

58 # Making sure appname in query_band contains 'airflow' 

59 pattern = r"appname\s*=\s*([^;]*)" 

60 # Search for the pattern in the query_band_text 

61 match = re.search(pattern, query_band_text) 

62 if match: 

63 appname_value = match.group(1).strip() 

64 # if appname exists and airflow not exists in appname then appending 'airflow' to existing 

65 # appname value 

66 if "airflow" not in appname_value.lower(): 

67 new_appname_value = appname_value + "_airflow" 

68 # Optionally, you can replace the original value in the query_band_text 

69 updated_query_band_text = re.sub(pattern, f"appname={new_appname_value}", query_band_text) 

70 query_band_text = updated_query_band_text 

71 else: 

72 # if appname doesn't exist in query_band, adding 'appname=airflow' 

73 if len(query_band_text.strip()) > 0 and not query_band_text.endswith(";"): 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 query_band_text += ";" 

75 query_band_text += "appname=airflow;" 

76 else: 

77 query_band_text = "org=teradata-internal-telem;appname=airflow;" 

78 

79 return query_band_text 

80 

81 

82class TeradataHook(DbApiHook): 

83 """ 

84 General hook for interacting with Teradata SQL Database. 

85 

86 This module contains basic APIs to connect to and interact with Teradata SQL Database. It uses teradatasql 

87 client internally as a database driver for connecting to Teradata database. The config parameters like 

88 Teradata DB Server URL, username, password and database name are fetched from the predefined connection 

89 config connection_id. It raises an airflow error if the given connection id doesn't exist. 

90 

91 You can also specify ssl parameters in the extra field of your connection 

92 as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``. 

93 

94 .. seealso:: 

95 - :ref:`Teradata API connection <howto/connection:teradata>` 

96 

97 :param args: passed to DbApiHook 

98 :param database: The Teradata database to connect to. 

99 :param kwargs: passed to DbApiHook 

100 """ 

101 

102 # Override to provide the connection name. 

103 conn_name_attr = "teradata_conn_id" 

104 

105 # Override to have a default connection id for a particular dbHook 

106 default_conn_name = "teradata_default" 

107 

108 # Override if this db supports autocommit. 

109 supports_autocommit = True 

110 

111 # Override if this db supports executemany. 

112 supports_executemany = True 

113 

114 # Override this for hook to have a custom name in the UI selection 

115 conn_type = "teradata" 

116 

117 # Override hook name to give descriptive name for hook 

118 hook_name = "Teradata" 

119 

120 # Override with the Teradata specific placeholder parameter string used for insert queries 

121 placeholder: str = "?" 

122 

123 # Override SQL query to be used for testing database connection 

124 _test_connection_sql = "select 1" 

125 

126 def __init__( 

127 self, 

128 *args, 

129 database: str | None = None, 

130 **kwargs, 

131 ) -> None: 

132 super().__init__(*args, schema=database, **kwargs) 

133 

134 def get_conn(self) -> TeradataConnection: 

135 """ 

136 Create and return a Teradata Connection object using teradatasql client. 

137 

138 Establishes connection to a Teradata SQL database using config corresponding to teradata_conn_id. 

139 

140 :return: a Teradata connection object 

141 """ 

142 teradata_conn_config: dict = self._get_conn_config_teradatasql() 

143 query_band_text = None 

144 if "query_band" in teradata_conn_config: 

145 query_band_text = teradata_conn_config.pop("query_band") 

146 teradata_conn = teradatasql.connect(**teradata_conn_config) 

147 # setting query band 

148 self.set_query_band(query_band_text, teradata_conn) 

149 return teradata_conn 

150 

151 def set_query_band(self, query_band_text, teradata_conn): 

152 """Set SESSION Query Band for each connection session.""" 

153 try: 

154 query_band_text = _handle_user_query_band_text(query_band_text) 

155 set_query_band_sql = f"SET QUERY_BAND='{query_band_text}' FOR SESSION" 

156 with teradata_conn.cursor() as cur: 

157 cur.execute(set_query_band_sql) 

158 except Exception as ex: 

159 self.log.error("Error occurred while setting session query band: %s ", str(ex)) 

160 

161 def _get_conn_config_teradatasql(self) -> dict[str, Any]: 

162 """Return set of config params required for connecting to Teradata DB using teradatasql client.""" 

163 conn: Connection = self.get_connection(self.get_conn_id()) 

164 conn_config = { 

165 "host": conn.host or "localhost", 

166 "dbs_port": conn.port or "1025", 

167 "database": conn.schema or "", 

168 "user": conn.login or "dbc", 

169 "password": conn.password or "dbc", 

170 } 

171 

172 if conn.extra_dejson.get("tmode", False): 

173 conn_config["tmode"] = conn.extra_dejson["tmode"] 

174 

175 # Handling SSL connection parameters 

176 

177 if conn.extra_dejson.get("sslmode", False): 

178 conn_config["sslmode"] = conn.extra_dejson["sslmode"] 

179 if "verify" in conn_config["sslmode"]: 

180 if conn.extra_dejson.get("sslca", False): 180 ↛ 182line 180 didn't jump to line 182 because the condition on line 180 was always true

181 conn_config["sslca"] = conn.extra_dejson["sslca"] 

182 if conn.extra_dejson.get("sslcapath", False): 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true

183 conn_config["sslcapath"] = conn.extra_dejson["sslcapath"] 

184 if conn.extra_dejson.get("sslcipher", False): 

185 conn_config["sslcipher"] = conn.extra_dejson["sslcipher"] 

186 if conn.extra_dejson.get("sslcrc", False): 

187 conn_config["sslcrc"] = conn.extra_dejson["sslcrc"] 

188 if conn.extra_dejson.get("sslprotocol", False): 

189 conn_config["sslprotocol"] = conn.extra_dejson["sslprotocol"] 

190 if conn.extra_dejson.get("query_band", False): 

191 conn_config["query_band"] = conn.extra_dejson["query_band"] 

192 

193 return conn_config 

194 

195 def get_sqlalchemy_engine(self, engine_kwargs=None): 

196 """Return a connection object using sqlalchemy.""" 

197 conn: Connection = self.get_connection(self.get_conn_id()) 

198 link = f"teradatasql://{conn.login}:{conn.password}@{conn.host}" 

199 connection = sqlalchemy.create_engine(link) 

200 return connection 

201 

202 @staticmethod 

203 def get_ui_field_behaviour() -> dict: 

204 """Return custom field behaviour.""" 

205 import json 

206 

207 return { 

208 "hidden_fields": ["port"], 

209 "relabeling": { 

210 "host": "Database Server URL", 

211 "schema": "Database Name", 

212 "login": "Username", 

213 }, 

214 "placeholders": { 

215 "extra": json.dumps( 

216 {"tmode": "TERA", "sslmode": "verify-ca", "sslca": "/tmp/server-ca.pem"}, indent=4 

217 ), 

218 "login": "dbc", 

219 "password": "dbc", 

220 }, 

221 } 

222 

223 def callproc( 

224 self, 

225 identifier: str, 

226 autocommit: bool = False, 

227 parameters: list | dict | None = None, 

228 ) -> list | dict | tuple | None: 

229 """ 

230 Call the stored procedure identified by the provided string. 

231 

232 Any OUT parameters must be provided with a value of either the 

233 expected Python type (e.g., `int`) or an instance of that type. 

234 

235 :param identifier: stored procedure name 

236 :param autocommit: What to set the connection's autocommit setting to 

237 before executing the query. 

238 :param parameters: The `IN`, `OUT` and `INOUT` parameters for Teradata 

239 stored procedure 

240 

241 The return value is a list or mapping that includes parameters in 

242 both directions; the actual return type depends on the type of the 

243 provided `parameters` argument. 

244 

245 """ 

246 if parameters is None: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true

247 parameters = [] 

248 

249 args = ",".join("?" for name in parameters) 

250 

251 sql = f"{{CALL {identifier}({(args)})}}" 

252 

253 def handler(cursor): 

254 records = cursor.fetchall() 

255 

256 if records is None: 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true

257 return 

258 if isinstance(records, list): 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true

259 return [row for row in records] 

260 

261 if isinstance(records, dict): 261 ↛ 263line 261 didn't jump to line 263 because the condition on line 261 was always true

262 return {n: v for (n, v) in records.items()} 

263 raise TypeError(f"Unexpected results: {records}") 

264 

265 result = self.run( 

266 sql, 

267 autocommit=autocommit, 

268 parameters=( 

269 [_map_param(value) for (name, value) in parameters.items()] 

270 if isinstance(parameters, dict) 

271 else [_map_param(value) for value in parameters] 

272 ), 

273 handler=handler, 

274 ) 

275 

276 return result