Coverage for providers/src/airflow/providers/teradata/hooks/teradata.py: 87%
106 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-12-27 08:27 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-12-27 08:27 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""An Airflow Hook for interacting with Teradata SQL Server."""
20from __future__ import annotations
22import re
23from typing import TYPE_CHECKING, Any
25import sqlalchemy
26import teradatasql
27from teradatasql import TeradataConnection
29from airflow.providers.common.sql.hooks.sql import DbApiHook
31if TYPE_CHECKING:
32 from airflow.models.connection import Connection
34PARAM_TYPES = {bool, float, int, str}
37def _map_param(value):
38 if value in PARAM_TYPES: 38 ↛ 42line 38 didn't jump to line 42 because the condition on line 38 was never true
39 # In this branch, value is a Python type; calling it produces
40 # an instance of the type which is understood by the Teradata driver
41 # in the out parameter mapping mechanism.
42 value = value()
43 return value
46def _handle_user_query_band_text(query_band_text) -> str:
47 """Validate given query_band and append if required values missed in query_band."""
48 # Ensures 'appname=airflow' and 'org=teradata-internal-telem' are in query_band_text.
49 if query_band_text is not None:
50 # checking org doesn't exist in query_band, appending 'org=teradata-internal-telem'
51 # If it exists, user might have set some value of their own, so doing nothing in that case
52 pattern = r"org\s*=\s*([^;]*)"
53 match = re.search(pattern, query_band_text)
54 if not match:
55 if not query_band_text.endswith(";"):
56 query_band_text += ";"
57 query_band_text += "org=teradata-internal-telem;"
58 # Making sure appname in query_band contains 'airflow'
59 pattern = r"appname\s*=\s*([^;]*)"
60 # Search for the pattern in the query_band_text
61 match = re.search(pattern, query_band_text)
62 if match:
63 appname_value = match.group(1).strip()
64 # if appname exists and airflow not exists in appname then appending 'airflow' to existing
65 # appname value
66 if "airflow" not in appname_value.lower():
67 new_appname_value = appname_value + "_airflow"
68 # Optionally, you can replace the original value in the query_band_text
69 updated_query_band_text = re.sub(pattern, f"appname={new_appname_value}", query_band_text)
70 query_band_text = updated_query_band_text
71 else:
72 # if appname doesn't exist in query_band, adding 'appname=airflow'
73 if len(query_band_text.strip()) > 0 and not query_band_text.endswith(";"): 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true
74 query_band_text += ";"
75 query_band_text += "appname=airflow;"
76 else:
77 query_band_text = "org=teradata-internal-telem;appname=airflow;"
79 return query_band_text
82class TeradataHook(DbApiHook):
83 """
84 General hook for interacting with Teradata SQL Database.
86 This module contains basic APIs to connect to and interact with Teradata SQL Database. It uses teradatasql
87 client internally as a database driver for connecting to Teradata database. The config parameters like
88 Teradata DB Server URL, username, password and database name are fetched from the predefined connection
89 config connection_id. It raises an airflow error if the given connection id doesn't exist.
91 You can also specify ssl parameters in the extra field of your connection
92 as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``.
94 .. seealso::
95 - :ref:`Teradata API connection <howto/connection:teradata>`
97 :param args: passed to DbApiHook
98 :param database: The Teradata database to connect to.
99 :param kwargs: passed to DbApiHook
100 """
102 # Override to provide the connection name.
103 conn_name_attr = "teradata_conn_id"
105 # Override to have a default connection id for a particular dbHook
106 default_conn_name = "teradata_default"
108 # Override if this db supports autocommit.
109 supports_autocommit = True
111 # Override if this db supports executemany.
112 supports_executemany = True
114 # Override this for hook to have a custom name in the UI selection
115 conn_type = "teradata"
117 # Override hook name to give descriptive name for hook
118 hook_name = "Teradata"
120 # Override with the Teradata specific placeholder parameter string used for insert queries
121 placeholder: str = "?"
123 # Override SQL query to be used for testing database connection
124 _test_connection_sql = "select 1"
126 def __init__(
127 self,
128 *args,
129 database: str | None = None,
130 **kwargs,
131 ) -> None:
132 super().__init__(*args, schema=database, **kwargs)
134 def get_conn(self) -> TeradataConnection:
135 """
136 Create and return a Teradata Connection object using teradatasql client.
138 Establishes connection to a Teradata SQL database using config corresponding to teradata_conn_id.
140 :return: a Teradata connection object
141 """
142 teradata_conn_config: dict = self._get_conn_config_teradatasql()
143 query_band_text = None
144 if "query_band" in teradata_conn_config:
145 query_band_text = teradata_conn_config.pop("query_band")
146 teradata_conn = teradatasql.connect(**teradata_conn_config)
147 # setting query band
148 self.set_query_band(query_band_text, teradata_conn)
149 return teradata_conn
151 def set_query_band(self, query_band_text, teradata_conn):
152 """Set SESSION Query Band for each connection session."""
153 try:
154 query_band_text = _handle_user_query_band_text(query_band_text)
155 set_query_band_sql = f"SET QUERY_BAND='{query_band_text}' FOR SESSION"
156 with teradata_conn.cursor() as cur:
157 cur.execute(set_query_band_sql)
158 except Exception as ex:
159 self.log.error("Error occurred while setting session query band: %s ", str(ex))
161 def _get_conn_config_teradatasql(self) -> dict[str, Any]:
162 """Return set of config params required for connecting to Teradata DB using teradatasql client."""
163 conn: Connection = self.get_connection(self.get_conn_id())
164 conn_config = {
165 "host": conn.host or "localhost",
166 "dbs_port": conn.port or "1025",
167 "database": conn.schema or "",
168 "user": conn.login or "dbc",
169 "password": conn.password or "dbc",
170 }
172 if conn.extra_dejson.get("tmode", False):
173 conn_config["tmode"] = conn.extra_dejson["tmode"]
175 # Handling SSL connection parameters
177 if conn.extra_dejson.get("sslmode", False):
178 conn_config["sslmode"] = conn.extra_dejson["sslmode"]
179 if "verify" in conn_config["sslmode"]:
180 if conn.extra_dejson.get("sslca", False): 180 ↛ 182line 180 didn't jump to line 182 because the condition on line 180 was always true
181 conn_config["sslca"] = conn.extra_dejson["sslca"]
182 if conn.extra_dejson.get("sslcapath", False): 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true
183 conn_config["sslcapath"] = conn.extra_dejson["sslcapath"]
184 if conn.extra_dejson.get("sslcipher", False):
185 conn_config["sslcipher"] = conn.extra_dejson["sslcipher"]
186 if conn.extra_dejson.get("sslcrc", False):
187 conn_config["sslcrc"] = conn.extra_dejson["sslcrc"]
188 if conn.extra_dejson.get("sslprotocol", False):
189 conn_config["sslprotocol"] = conn.extra_dejson["sslprotocol"]
190 if conn.extra_dejson.get("query_band", False):
191 conn_config["query_band"] = conn.extra_dejson["query_band"]
193 return conn_config
195 def get_sqlalchemy_engine(self, engine_kwargs=None):
196 """Return a connection object using sqlalchemy."""
197 conn: Connection = self.get_connection(self.get_conn_id())
198 link = f"teradatasql://{conn.login}:{conn.password}@{conn.host}"
199 connection = sqlalchemy.create_engine(link)
200 return connection
202 @staticmethod
203 def get_ui_field_behaviour() -> dict:
204 """Return custom field behaviour."""
205 import json
207 return {
208 "hidden_fields": ["port"],
209 "relabeling": {
210 "host": "Database Server URL",
211 "schema": "Database Name",
212 "login": "Username",
213 },
214 "placeholders": {
215 "extra": json.dumps(
216 {"tmode": "TERA", "sslmode": "verify-ca", "sslca": "/tmp/server-ca.pem"}, indent=4
217 ),
218 "login": "dbc",
219 "password": "dbc",
220 },
221 }
223 def callproc(
224 self,
225 identifier: str,
226 autocommit: bool = False,
227 parameters: list | dict | None = None,
228 ) -> list | dict | tuple | None:
229 """
230 Call the stored procedure identified by the provided string.
232 Any OUT parameters must be provided with a value of either the
233 expected Python type (e.g., `int`) or an instance of that type.
235 :param identifier: stored procedure name
236 :param autocommit: What to set the connection's autocommit setting to
237 before executing the query.
238 :param parameters: The `IN`, `OUT` and `INOUT` parameters for Teradata
239 stored procedure
241 The return value is a list or mapping that includes parameters in
242 both directions; the actual return type depends on the type of the
243 provided `parameters` argument.
245 """
246 if parameters is None: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true
247 parameters = []
249 args = ",".join("?" for name in parameters)
251 sql = f"{{CALL {identifier}({(args)})}}"
253 def handler(cursor):
254 records = cursor.fetchall()
256 if records is None: 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true
257 return
258 if isinstance(records, list): 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true
259 return [row for row in records]
261 if isinstance(records, dict): 261 ↛ 263line 261 didn't jump to line 263 because the condition on line 261 was always true
262 return {n: v for (n, v) in records.items()}
263 raise TypeError(f"Unexpected results: {records}")
265 result = self.run(
266 sql,
267 autocommit=autocommit,
268 parameters=(
269 [_map_param(value) for (name, value) in parameters.items()]
270 if isinstance(parameters, dict)
271 else [_map_param(value) for value in parameters]
272 ),
273 handler=handler,
274 )
276 return result