Coverage for providers/src/airflow/providers/teradata/triggers/teradata_compute_cluster.py: 77%
59 statements
« prev ^ index » next coverage.py v7.6.10, created at 2024-12-27 08:27 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2024-12-27 08:27 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import asyncio
20from collections.abc import AsyncIterator
21from typing import Any
23from airflow.exceptions import AirflowException
24from airflow.providers.common.sql.hooks.sql import fetch_one_handler
25from airflow.providers.teradata.hooks.teradata import TeradataHook
26from airflow.providers.teradata.utils.constants import Constants
27from airflow.triggers.base import BaseTrigger, TriggerEvent
30class TeradataComputeClusterSyncTrigger(BaseTrigger):
31 """
32 Fetch the status of the suspend or resume operation for the specified compute cluster.
34 :param teradata_conn_id: The :ref:`Teradata connection id <howto/connection:teradata>`
35 reference to a specific Teradata database.
36 :param compute_profile_name: Name of the Compute Profile to manage.
37 :param compute_group_name: Name of compute group to which compute profile belongs.
38 :param opr_type: Compute cluster operation - SUSPEND/RESUME
39 :param poll_interval: polling period in minutes to check for the status
40 """
42 def __init__(
43 self,
44 teradata_conn_id: str,
45 compute_profile_name: str,
46 compute_group_name: str | None = None,
47 operation_type: str | None = None,
48 poll_interval: float | None = None,
49 ):
50 super().__init__()
51 self.teradata_conn_id = teradata_conn_id
52 self.compute_profile_name = compute_profile_name
53 self.compute_group_name = compute_group_name
54 self.operation_type = operation_type
55 self.poll_interval = poll_interval
57 def serialize(self) -> tuple[str, dict[str, Any]]:
58 """Serialize TeradataComputeClusterSyncTrigger arguments and classpath."""
59 return (
60 "airflow.providers.teradata.triggers.teradata_compute_cluster.TeradataComputeClusterSyncTrigger",
61 {
62 "teradata_conn_id": self.teradata_conn_id,
63 "compute_profile_name": self.compute_profile_name,
64 "compute_group_name": self.compute_group_name,
65 "operation_type": self.operation_type,
66 "poll_interval": self.poll_interval,
67 },
68 )
70 async def run(self) -> AsyncIterator[TriggerEvent]:
71 """Wait for Compute Cluster operation to complete."""
72 try:
73 while True:
74 status = await self.get_status()
75 if status is None or len(status) == 0:
76 self.log.info(Constants.CC_GRP_PRP_NON_EXISTS_MSG)
77 raise AirflowException(Constants.CC_GRP_PRP_NON_EXISTS_MSG)
78 if (
79 self.operation_type == Constants.CC_SUSPEND_OPR
80 or self.operation_type == Constants.CC_CREATE_SUSPEND_OPR
81 ):
82 if status == Constants.CC_SUSPEND_DB_STATUS: 82 ↛ 90line 82 didn't jump to line 90 because the condition on line 82 was always true
83 break
84 elif ( 84 ↛ 90line 84 didn't jump to line 90 because the condition on line 84 was always true
85 self.operation_type == Constants.CC_RESUME_OPR
86 or self.operation_type == Constants.CC_CREATE_OPR
87 ):
88 if status == Constants.CC_RESUME_DB_STATUS: 88 ↛ 90line 88 didn't jump to line 90 because the condition on line 88 was always true
89 break
90 if self.poll_interval is not None:
91 self.poll_interval = float(self.poll_interval)
92 else:
93 self.poll_interval = float(Constants.CC_POLL_INTERVAL)
94 await asyncio.sleep(self.poll_interval)
95 if (
96 self.operation_type == Constants.CC_SUSPEND_OPR
97 or self.operation_type == Constants.CC_CREATE_SUSPEND_OPR
98 ):
99 if status == Constants.CC_SUSPEND_DB_STATUS: 99 ↛ 108line 99 didn't jump to line 108 because the condition on line 99 was always true
100 yield TriggerEvent(
101 {
102 "status": "success",
103 "message": Constants.CC_OPR_SUCCESS_STATUS_MSG
104 % (self.compute_profile_name, self.operation_type),
105 }
106 )
107 else:
108 yield TriggerEvent(
109 {
110 "status": "error",
111 "message": Constants.CC_OPR_FAILURE_STATUS_MSG
112 % (self.compute_profile_name, self.operation_type),
113 }
114 )
115 elif ( 115 ↛ 136line 115 didn't jump to line 136 because the condition on line 115 was always true
116 self.operation_type == Constants.CC_RESUME_OPR
117 or self.operation_type == Constants.CC_CREATE_OPR
118 ):
119 if status == Constants.CC_RESUME_DB_STATUS: 119 ↛ 128line 119 didn't jump to line 128 because the condition on line 119 was always true
120 yield TriggerEvent(
121 {
122 "status": "success",
123 "message": Constants.CC_OPR_SUCCESS_STATUS_MSG
124 % (self.compute_profile_name, self.operation_type),
125 }
126 )
127 else:
128 yield TriggerEvent(
129 {
130 "status": "error",
131 "message": Constants.CC_OPR_FAILURE_STATUS_MSG
132 % (self.compute_profile_name, self.operation_type),
133 }
134 )
135 else:
136 yield TriggerEvent({"status": "error", "message": "Invalid operation"})
137 except Exception as e:
138 yield TriggerEvent({"status": "error", "message": str(e)})
139 except asyncio.CancelledError:
140 self.log.error(Constants.CC_OPR_TIMEOUT_ERROR, self.operation_type)
142 async def get_status(self) -> str:
143 """Return compute cluster SUSPEND/RESUME operation status."""
144 sql = (
145 "SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('"
146 + self.compute_profile_name
147 + "')"
148 )
149 if self.compute_group_name:
150 sql += " AND UPPER(ComputeGroupName) = UPPER('" + self.compute_group_name + "')"
151 hook = TeradataHook(teradata_conn_id=self.teradata_conn_id)
152 result_set = hook.run(sql, handler=fetch_one_handler)
153 status = ""
154 if isinstance(result_set, list) and isinstance(result_set[0], str): 154 ↛ 156line 154 didn't jump to line 156 because the condition on line 154 was always true
155 status = str(result_set[0])
156 return status