Coverage for providers/src/airflow/providers/teradata/triggers/teradata_compute_cluster.py: 77%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2024-12-27 08:27 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import asyncio 

20from collections.abc import AsyncIterator 

21from typing import Any 

22 

23from airflow.exceptions import AirflowException 

24from airflow.providers.common.sql.hooks.sql import fetch_one_handler 

25from airflow.providers.teradata.hooks.teradata import TeradataHook 

26from airflow.providers.teradata.utils.constants import Constants 

27from airflow.triggers.base import BaseTrigger, TriggerEvent 

28 

29 

30class TeradataComputeClusterSyncTrigger(BaseTrigger): 

31 """ 

32 Fetch the status of the suspend or resume operation for the specified compute cluster. 

33 

34 :param teradata_conn_id: The :ref:`Teradata connection id <howto/connection:teradata>` 

35 reference to a specific Teradata database. 

36 :param compute_profile_name: Name of the Compute Profile to manage. 

37 :param compute_group_name: Name of compute group to which compute profile belongs. 

38 :param opr_type: Compute cluster operation - SUSPEND/RESUME 

39 :param poll_interval: polling period in minutes to check for the status 

40 """ 

41 

42 def __init__( 

43 self, 

44 teradata_conn_id: str, 

45 compute_profile_name: str, 

46 compute_group_name: str | None = None, 

47 operation_type: str | None = None, 

48 poll_interval: float | None = None, 

49 ): 

50 super().__init__() 

51 self.teradata_conn_id = teradata_conn_id 

52 self.compute_profile_name = compute_profile_name 

53 self.compute_group_name = compute_group_name 

54 self.operation_type = operation_type 

55 self.poll_interval = poll_interval 

56 

57 def serialize(self) -> tuple[str, dict[str, Any]]: 

58 """Serialize TeradataComputeClusterSyncTrigger arguments and classpath.""" 

59 return ( 

60 "airflow.providers.teradata.triggers.teradata_compute_cluster.TeradataComputeClusterSyncTrigger", 

61 { 

62 "teradata_conn_id": self.teradata_conn_id, 

63 "compute_profile_name": self.compute_profile_name, 

64 "compute_group_name": self.compute_group_name, 

65 "operation_type": self.operation_type, 

66 "poll_interval": self.poll_interval, 

67 }, 

68 ) 

69 

70 async def run(self) -> AsyncIterator[TriggerEvent]: 

71 """Wait for Compute Cluster operation to complete.""" 

72 try: 

73 while True: 

74 status = await self.get_status() 

75 if status is None or len(status) == 0: 

76 self.log.info(Constants.CC_GRP_PRP_NON_EXISTS_MSG) 

77 raise AirflowException(Constants.CC_GRP_PRP_NON_EXISTS_MSG) 

78 if ( 

79 self.operation_type == Constants.CC_SUSPEND_OPR 

80 or self.operation_type == Constants.CC_CREATE_SUSPEND_OPR 

81 ): 

82 if status == Constants.CC_SUSPEND_DB_STATUS: 82 ↛ 90line 82 didn't jump to line 90 because the condition on line 82 was always true

83 break 

84 elif ( 84 ↛ 90line 84 didn't jump to line 90 because the condition on line 84 was always true

85 self.operation_type == Constants.CC_RESUME_OPR 

86 or self.operation_type == Constants.CC_CREATE_OPR 

87 ): 

88 if status == Constants.CC_RESUME_DB_STATUS: 88 ↛ 90line 88 didn't jump to line 90 because the condition on line 88 was always true

89 break 

90 if self.poll_interval is not None: 

91 self.poll_interval = float(self.poll_interval) 

92 else: 

93 self.poll_interval = float(Constants.CC_POLL_INTERVAL) 

94 await asyncio.sleep(self.poll_interval) 

95 if ( 

96 self.operation_type == Constants.CC_SUSPEND_OPR 

97 or self.operation_type == Constants.CC_CREATE_SUSPEND_OPR 

98 ): 

99 if status == Constants.CC_SUSPEND_DB_STATUS: 99 ↛ 108line 99 didn't jump to line 108 because the condition on line 99 was always true

100 yield TriggerEvent( 

101 { 

102 "status": "success", 

103 "message": Constants.CC_OPR_SUCCESS_STATUS_MSG 

104 % (self.compute_profile_name, self.operation_type), 

105 } 

106 ) 

107 else: 

108 yield TriggerEvent( 

109 { 

110 "status": "error", 

111 "message": Constants.CC_OPR_FAILURE_STATUS_MSG 

112 % (self.compute_profile_name, self.operation_type), 

113 } 

114 ) 

115 elif ( 115 ↛ 136line 115 didn't jump to line 136 because the condition on line 115 was always true

116 self.operation_type == Constants.CC_RESUME_OPR 

117 or self.operation_type == Constants.CC_CREATE_OPR 

118 ): 

119 if status == Constants.CC_RESUME_DB_STATUS: 119 ↛ 128line 119 didn't jump to line 128 because the condition on line 119 was always true

120 yield TriggerEvent( 

121 { 

122 "status": "success", 

123 "message": Constants.CC_OPR_SUCCESS_STATUS_MSG 

124 % (self.compute_profile_name, self.operation_type), 

125 } 

126 ) 

127 else: 

128 yield TriggerEvent( 

129 { 

130 "status": "error", 

131 "message": Constants.CC_OPR_FAILURE_STATUS_MSG 

132 % (self.compute_profile_name, self.operation_type), 

133 } 

134 ) 

135 else: 

136 yield TriggerEvent({"status": "error", "message": "Invalid operation"}) 

137 except Exception as e: 

138 yield TriggerEvent({"status": "error", "message": str(e)}) 

139 except asyncio.CancelledError: 

140 self.log.error(Constants.CC_OPR_TIMEOUT_ERROR, self.operation_type) 

141 

142 async def get_status(self) -> str: 

143 """Return compute cluster SUSPEND/RESUME operation status.""" 

144 sql = ( 

145 "SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('" 

146 + self.compute_profile_name 

147 + "')" 

148 ) 

149 if self.compute_group_name: 

150 sql += " AND UPPER(ComputeGroupName) = UPPER('" + self.compute_group_name + "')" 

151 hook = TeradataHook(teradata_conn_id=self.teradata_conn_id) 

152 result_set = hook.run(sql, handler=fetch_one_handler) 

153 status = "" 

154 if isinstance(result_set, list) and isinstance(result_set[0], str): 154 ↛ 156line 154 didn't jump to line 156 because the condition on line 154 was always true

155 status = str(result_set[0]) 

156 return status