picopyn.pool
1import asyncio 2import time 3import json 4import sys 5from collections import deque 6from typing import Optional, Callable 7import random 8 9from urllib.parse import urlparse 10 11from .connection import Connection 12 13class Pool: 14 """A connection pool. 15 16 Connection pool can be used to manage a set of connections to the database. 17 Connections are first acquired from the pool, then used, and then released 18 back to the pool 19 20 :param dsn (str): The data source name (e.g., "postgresql://user:pass@host:port") for the cluster. 21 :param balance_strategy (callable, optional): A custom strategy function to select a connection 22 from the pool. If None, round-robin strategy is used. 23 :param max_size (int): Maximum number of connections in the pool. Must be at least 1. 24 :param enable_discovery (bool): If True, the pool will automatically discover available 25 picodata instances. If False, only the given `dsn` will be used. 26 :param balance_strategy (callable, optional): A function that selects a connection from the pool. 27 If None, a default round-robin strategy will be used. 28 """ 29 def __init__( 30 self, 31 dsn: Optional[str] = None, 32 max_size: int = 10, 33 enable_discovery: bool = False, 34 balance_strategy=None, 35 **connect_kwargs 36 ): 37 if max_size < 1: 38 raise ValueError("max_size must be at least 1") 39 40 self._dsn = dsn 41 self._connect_kwargs = connect_kwargs 42 self._max_size = max_size 43 self._pool = deque() 44 self._used = set() 45 self._lock = asyncio.Lock() 46 self._default_acquire_timeout_sec = 5 47 # node discovery mode 48 # if disabled, pool will be filled with given address connections 49 # if enabled, pool will be filled with available picodata instances 50 self.enable_discovery = enable_discovery 51 # load balancing strategy: 52 # if None, a simple round-robin strategy will be used. 53 # otherwise, the provided callable will be used to select connections. 54 if balance_strategy is not None and not callable(balance_strategy): 55 raise ValueError("balance_strategy must be callable or None") 56 self._balance_strategy = balance_strategy 57 58 async def connect(self): 59 """ 60 Prepares the pool by opening up to `max_size` connections. 61 62 This should be called before using the pool to ensure connections are available. 63 """ 64 async with self._lock: 65 if len(self._pool) == self._max_size: 66 return 67 68 # if node discovery is enabled, then connect to all alive picodata instances 69 # (if they fit within the max_size limit) 70 if self.enable_discovery: 71 try: 72 instance_addrs = await self._discover_instances() 73 except Exception as e: 74 raise RuntimeError(f"Failed to discover instances using DSN {self._dsn}: {e}") 75 76 parsed_url = urlparse(self._dsn) 77 78 addr_index = 0 79 # fill the connection pool with connections to all available nodes, up to the max_size. 80 # this ensures the pool is evenly populated across all nodes. 81 # if a node fails to connect, it will be skipped and removed from the list. 82 # the loop will exit early if no nodes remain to avoid an infinite loop. 83 while len(self._pool) < self._max_size and instance_addrs: 84 address = instance_addrs[addr_index % len(instance_addrs)] 85 dsn = f"{parsed_url.scheme}://{parsed_url.username}:{parsed_url.password}@{address}" 86 87 try: 88 conn = Connection(dsn, **self._connect_kwargs) 89 await conn.connect() 90 self._pool.append(conn) 91 except Exception as e: 92 print(f"Could not connect to node {address} for pool: {e}") 93 instance_addrs.remove(address) 94 if not instance_addrs: 95 break 96 continue 97 98 addr_index += 1 99 100 # then fill the connection pool up to max_size with main mode connections 101 while len(self._pool) < self._max_size: 102 main_node_conn = Connection(self._dsn, **self._connect_kwargs) 103 try: 104 await main_node_conn.connect() 105 self._pool.append(main_node_conn) 106 except Exception as e: 107 raise RuntimeError(f"Could not connect to main node {self._dsn} for pool: {e}") 108 109 # rotate the pool to randomize the order of connections. 110 # this helps to distribute the initial load more evenly across nodes 111 # when using round-robin or when multiple clients start simultaneously. 112 shift = random.randint(0, len(self._pool) - 1) 113 self._pool.rotate(shift) 114 115 return 116 117 async def _discover_instances(self): 118 # make temporary connection 119 temp_conn = Connection(self._dsn, **self._connect_kwargs) 120 121 try: 122 await temp_conn.connect() 123 124 # all instance addresses excluding connected node 125 alive_instances_info = await temp_conn.fetch(""" 126 WITH my_uuid AS (SELECT instance_uuid() AS uuid) 127 SELECT i.name, i.raft_id, i.current_state, p.address 128 FROM _pico_instance i 129 JOIN _pico_peer_address p ON i.raft_id = p.raft_id 130 JOIN my_uuid u ON 1 = 1 131 WHERE p.connection_type = 'pgproto' AND i.uuid != u.uuid; 132 """) 133 134 online_addresses = [] 135 # place connected node as first node to be sure that 136 # it will be in the pool independ on pool size 137 parsed_url = urlparse(self._dsn) 138 online_addresses.append(f"{parsed_url.hostname}:{parsed_url.port}") 139 for r in alive_instances_info: 140 if not r.get('current_state', None): 141 continue 142 143 try: 144 current_state = json.loads(r.get('current_state', None)) 145 except json.JSONDecodeError: 146 print(f"Failed to decode current state of picodata instance {r.get('current_state', None)}") 147 continue 148 149 if 'Online' in current_state: 150 online_addresses.append(r['address']) 151 152 return online_addresses 153 finally: 154 await temp_conn.close() 155 156 async def acquire(self, timeout: Optional[float] = None) -> Connection: 157 """ 158 Acquire a connection from the pool. 159 160 If no connections are available, this method will wait until one is released. 161 162 :return: A database connection. 163 """ 164 start_time = time.monotonic() 165 effective_timeout = timeout if timeout is not None else self._default_acquire_timeout_sec 166 167 while True: 168 async with self._lock: 169 # сheck if there are any available connections in the pool 170 if self._pool: 171 # round-robin strategy 172 if self._balance_strategy is None: 173 conn = self._pool.popleft() 174 # custom strategy 175 else: 176 try: 177 conn = self._balance_strategy(list(self._pool)) 178 except Exception as e: 179 raise RuntimeError(f"balance_strategy raised an exception: {e}") 180 181 if conn not in self._pool: 182 raise RuntimeError("balance_strategy returned a connection not in pool") 183 self._pool.remove(conn) 184 185 # mark it as currently in use 186 self._used.add(conn) 187 return conn 188 189 if (time.monotonic() - start_time) >= effective_timeout: 190 raise TimeoutError("Timed out waiting for a free connection in the pool") 191 192 # if no connections are available, wait briefly before retrying 193 # this gives other coroutines (like `release`) a chance to return a connection to the pool 194 await asyncio.sleep(0.1) 195 196 async def release(self, conn: Connection): 197 """ 198 Release a previously acquired connection back to the pool. 199 200 :param conn: The connection to release. 201 """ 202 async with self._lock: 203 if conn in self._used: 204 self._used.remove(conn) 205 self._pool.append(conn) 206 207 async def close(self): 208 """ 209 Closes all connections in the pool. 210 211 This should be called during application shutdown to clean up resources. 212 """ 213 async with self._lock: 214 while self._pool: 215 conn = self._pool.popleft() 216 await conn.close() 217 for conn in self._used: 218 await conn.close() 219 self._used.clear() 220 221 async def execute(self, query: str, *args): 222 """ 223 Executes a query that does not return rows (e.g. INSERT, UPDATE, DELETE). 224 225 :param query: The SQL query string. 226 :param args: Optional parameters for the SQL query. 227 :return: The result of the query execution. 228 """ 229 conn = await self.acquire() 230 try: 231 return await conn.execute(query, *args) 232 finally: 233 await self.release(conn) 234 235 async def fetch(self, query: str, *args): 236 """ 237 Executes a query and fetches all resulting rows. 238 239 :param query: The SQL query string. 240 :param args: Optional parameters for the SQL query. 241 :return: A list of rows returned by the query. 242 """ 243 conn = await self.acquire() 244 try: 245 return await conn.fetch(query, *args) 246 finally: 247 await self.release(conn) 248 249 async def fetchrow(self, query: str, *args): 250 """ 251 Executes a query and fetches a single row (first row). 252 253 :param query: The SQL query string. 254 :param args: Optional parameters for the SQL query. 255 :return: A single row returned by the query. 256 """ 257 conn = await self.acquire() 258 try: 259 return await conn.fetchrow(query, *args) 260 finally: 261 await self.release(conn)
14class Pool: 15 """A connection pool. 16 17 Connection pool can be used to manage a set of connections to the database. 18 Connections are first acquired from the pool, then used, and then released 19 back to the pool 20 21 :param dsn (str): The data source name (e.g., "postgresql://user:pass@host:port") for the cluster. 22 :param balance_strategy (callable, optional): A custom strategy function to select a connection 23 from the pool. If None, round-robin strategy is used. 24 :param max_size (int): Maximum number of connections in the pool. Must be at least 1. 25 :param enable_discovery (bool): If True, the pool will automatically discover available 26 picodata instances. If False, only the given `dsn` will be used. 27 :param balance_strategy (callable, optional): A function that selects a connection from the pool. 28 If None, a default round-robin strategy will be used. 29 """ 30 def __init__( 31 self, 32 dsn: Optional[str] = None, 33 max_size: int = 10, 34 enable_discovery: bool = False, 35 balance_strategy=None, 36 **connect_kwargs 37 ): 38 if max_size < 1: 39 raise ValueError("max_size must be at least 1") 40 41 self._dsn = dsn 42 self._connect_kwargs = connect_kwargs 43 self._max_size = max_size 44 self._pool = deque() 45 self._used = set() 46 self._lock = asyncio.Lock() 47 self._default_acquire_timeout_sec = 5 48 # node discovery mode 49 # if disabled, pool will be filled with given address connections 50 # if enabled, pool will be filled with available picodata instances 51 self.enable_discovery = enable_discovery 52 # load balancing strategy: 53 # if None, a simple round-robin strategy will be used. 54 # otherwise, the provided callable will be used to select connections. 55 if balance_strategy is not None and not callable(balance_strategy): 56 raise ValueError("balance_strategy must be callable or None") 57 self._balance_strategy = balance_strategy 58 59 async def connect(self): 60 """ 61 Prepares the pool by opening up to `max_size` connections. 62 63 This should be called before using the pool to ensure connections are available. 64 """ 65 async with self._lock: 66 if len(self._pool) == self._max_size: 67 return 68 69 # if node discovery is enabled, then connect to all alive picodata instances 70 # (if they fit within the max_size limit) 71 if self.enable_discovery: 72 try: 73 instance_addrs = await self._discover_instances() 74 except Exception as e: 75 raise RuntimeError(f"Failed to discover instances using DSN {self._dsn}: {e}") 76 77 parsed_url = urlparse(self._dsn) 78 79 addr_index = 0 80 # fill the connection pool with connections to all available nodes, up to the max_size. 81 # this ensures the pool is evenly populated across all nodes. 82 # if a node fails to connect, it will be skipped and removed from the list. 83 # the loop will exit early if no nodes remain to avoid an infinite loop. 84 while len(self._pool) < self._max_size and instance_addrs: 85 address = instance_addrs[addr_index % len(instance_addrs)] 86 dsn = f"{parsed_url.scheme}://{parsed_url.username}:{parsed_url.password}@{address}" 87 88 try: 89 conn = Connection(dsn, **self._connect_kwargs) 90 await conn.connect() 91 self._pool.append(conn) 92 except Exception as e: 93 print(f"Could not connect to node {address} for pool: {e}") 94 instance_addrs.remove(address) 95 if not instance_addrs: 96 break 97 continue 98 99 addr_index += 1 100 101 # then fill the connection pool up to max_size with main mode connections 102 while len(self._pool) < self._max_size: 103 main_node_conn = Connection(self._dsn, **self._connect_kwargs) 104 try: 105 await main_node_conn.connect() 106 self._pool.append(main_node_conn) 107 except Exception as e: 108 raise RuntimeError(f"Could not connect to main node {self._dsn} for pool: {e}") 109 110 # rotate the pool to randomize the order of connections. 111 # this helps to distribute the initial load more evenly across nodes 112 # when using round-robin or when multiple clients start simultaneously. 113 shift = random.randint(0, len(self._pool) - 1) 114 self._pool.rotate(shift) 115 116 return 117 118 async def _discover_instances(self): 119 # make temporary connection 120 temp_conn = Connection(self._dsn, **self._connect_kwargs) 121 122 try: 123 await temp_conn.connect() 124 125 # all instance addresses excluding connected node 126 alive_instances_info = await temp_conn.fetch(""" 127 WITH my_uuid AS (SELECT instance_uuid() AS uuid) 128 SELECT i.name, i.raft_id, i.current_state, p.address 129 FROM _pico_instance i 130 JOIN _pico_peer_address p ON i.raft_id = p.raft_id 131 JOIN my_uuid u ON 1 = 1 132 WHERE p.connection_type = 'pgproto' AND i.uuid != u.uuid; 133 """) 134 135 online_addresses = [] 136 # place connected node as first node to be sure that 137 # it will be in the pool independ on pool size 138 parsed_url = urlparse(self._dsn) 139 online_addresses.append(f"{parsed_url.hostname}:{parsed_url.port}") 140 for r in alive_instances_info: 141 if not r.get('current_state', None): 142 continue 143 144 try: 145 current_state = json.loads(r.get('current_state', None)) 146 except json.JSONDecodeError: 147 print(f"Failed to decode current state of picodata instance {r.get('current_state', None)}") 148 continue 149 150 if 'Online' in current_state: 151 online_addresses.append(r['address']) 152 153 return online_addresses 154 finally: 155 await temp_conn.close() 156 157 async def acquire(self, timeout: Optional[float] = None) -> Connection: 158 """ 159 Acquire a connection from the pool. 160 161 If no connections are available, this method will wait until one is released. 162 163 :return: A database connection. 164 """ 165 start_time = time.monotonic() 166 effective_timeout = timeout if timeout is not None else self._default_acquire_timeout_sec 167 168 while True: 169 async with self._lock: 170 # сheck if there are any available connections in the pool 171 if self._pool: 172 # round-robin strategy 173 if self._balance_strategy is None: 174 conn = self._pool.popleft() 175 # custom strategy 176 else: 177 try: 178 conn = self._balance_strategy(list(self._pool)) 179 except Exception as e: 180 raise RuntimeError(f"balance_strategy raised an exception: {e}") 181 182 if conn not in self._pool: 183 raise RuntimeError("balance_strategy returned a connection not in pool") 184 self._pool.remove(conn) 185 186 # mark it as currently in use 187 self._used.add(conn) 188 return conn 189 190 if (time.monotonic() - start_time) >= effective_timeout: 191 raise TimeoutError("Timed out waiting for a free connection in the pool") 192 193 # if no connections are available, wait briefly before retrying 194 # this gives other coroutines (like `release`) a chance to return a connection to the pool 195 await asyncio.sleep(0.1) 196 197 async def release(self, conn: Connection): 198 """ 199 Release a previously acquired connection back to the pool. 200 201 :param conn: The connection to release. 202 """ 203 async with self._lock: 204 if conn in self._used: 205 self._used.remove(conn) 206 self._pool.append(conn) 207 208 async def close(self): 209 """ 210 Closes all connections in the pool. 211 212 This should be called during application shutdown to clean up resources. 213 """ 214 async with self._lock: 215 while self._pool: 216 conn = self._pool.popleft() 217 await conn.close() 218 for conn in self._used: 219 await conn.close() 220 self._used.clear() 221 222 async def execute(self, query: str, *args): 223 """ 224 Executes a query that does not return rows (e.g. INSERT, UPDATE, DELETE). 225 226 :param query: The SQL query string. 227 :param args: Optional parameters for the SQL query. 228 :return: The result of the query execution. 229 """ 230 conn = await self.acquire() 231 try: 232 return await conn.execute(query, *args) 233 finally: 234 await self.release(conn) 235 236 async def fetch(self, query: str, *args): 237 """ 238 Executes a query and fetches all resulting rows. 239 240 :param query: The SQL query string. 241 :param args: Optional parameters for the SQL query. 242 :return: A list of rows returned by the query. 243 """ 244 conn = await self.acquire() 245 try: 246 return await conn.fetch(query, *args) 247 finally: 248 await self.release(conn) 249 250 async def fetchrow(self, query: str, *args): 251 """ 252 Executes a query and fetches a single row (first row). 253 254 :param query: The SQL query string. 255 :param args: Optional parameters for the SQL query. 256 :return: A single row returned by the query. 257 """ 258 conn = await self.acquire() 259 try: 260 return await conn.fetchrow(query, *args) 261 finally: 262 await self.release(conn)
A connection pool.
Connection pool can be used to manage a set of connections to the database. Connections are first acquired from the pool, then used, and then released back to the pool
Parameters
- dsn (str): The data source name (e.g., "postgresql://user:pass@host: port") for the cluster.
- balance_strategy (callable, optional): A custom strategy function to select a connection from the pool. If None, round-robin strategy is used.
- max_size (int): Maximum number of connections in the pool. Must be at least 1.
- enable_discovery (bool): If True, the pool will automatically discover available
picodata instances. If False, only the given
dsnwill be used. - balance_strategy (callable, optional): A function that selects a connection from the pool. If None, a default round-robin strategy will be used.
30 def __init__( 31 self, 32 dsn: Optional[str] = None, 33 max_size: int = 10, 34 enable_discovery: bool = False, 35 balance_strategy=None, 36 **connect_kwargs 37 ): 38 if max_size < 1: 39 raise ValueError("max_size must be at least 1") 40 41 self._dsn = dsn 42 self._connect_kwargs = connect_kwargs 43 self._max_size = max_size 44 self._pool = deque() 45 self._used = set() 46 self._lock = asyncio.Lock() 47 self._default_acquire_timeout_sec = 5 48 # node discovery mode 49 # if disabled, pool will be filled with given address connections 50 # if enabled, pool will be filled with available picodata instances 51 self.enable_discovery = enable_discovery 52 # load balancing strategy: 53 # if None, a simple round-robin strategy will be used. 54 # otherwise, the provided callable will be used to select connections. 55 if balance_strategy is not None and not callable(balance_strategy): 56 raise ValueError("balance_strategy must be callable or None") 57 self._balance_strategy = balance_strategy
59 async def connect(self): 60 """ 61 Prepares the pool by opening up to `max_size` connections. 62 63 This should be called before using the pool to ensure connections are available. 64 """ 65 async with self._lock: 66 if len(self._pool) == self._max_size: 67 return 68 69 # if node discovery is enabled, then connect to all alive picodata instances 70 # (if they fit within the max_size limit) 71 if self.enable_discovery: 72 try: 73 instance_addrs = await self._discover_instances() 74 except Exception as e: 75 raise RuntimeError(f"Failed to discover instances using DSN {self._dsn}: {e}") 76 77 parsed_url = urlparse(self._dsn) 78 79 addr_index = 0 80 # fill the connection pool with connections to all available nodes, up to the max_size. 81 # this ensures the pool is evenly populated across all nodes. 82 # if a node fails to connect, it will be skipped and removed from the list. 83 # the loop will exit early if no nodes remain to avoid an infinite loop. 84 while len(self._pool) < self._max_size and instance_addrs: 85 address = instance_addrs[addr_index % len(instance_addrs)] 86 dsn = f"{parsed_url.scheme}://{parsed_url.username}:{parsed_url.password}@{address}" 87 88 try: 89 conn = Connection(dsn, **self._connect_kwargs) 90 await conn.connect() 91 self._pool.append(conn) 92 except Exception as e: 93 print(f"Could not connect to node {address} for pool: {e}") 94 instance_addrs.remove(address) 95 if not instance_addrs: 96 break 97 continue 98 99 addr_index += 1 100 101 # then fill the connection pool up to max_size with main mode connections 102 while len(self._pool) < self._max_size: 103 main_node_conn = Connection(self._dsn, **self._connect_kwargs) 104 try: 105 await main_node_conn.connect() 106 self._pool.append(main_node_conn) 107 except Exception as e: 108 raise RuntimeError(f"Could not connect to main node {self._dsn} for pool: {e}") 109 110 # rotate the pool to randomize the order of connections. 111 # this helps to distribute the initial load more evenly across nodes 112 # when using round-robin or when multiple clients start simultaneously. 113 shift = random.randint(0, len(self._pool) - 1) 114 self._pool.rotate(shift) 115 116 return
Prepares the pool by opening up to max_size connections.
This should be called before using the pool to ensure connections are available.
157 async def acquire(self, timeout: Optional[float] = None) -> Connection: 158 """ 159 Acquire a connection from the pool. 160 161 If no connections are available, this method will wait until one is released. 162 163 :return: A database connection. 164 """ 165 start_time = time.monotonic() 166 effective_timeout = timeout if timeout is not None else self._default_acquire_timeout_sec 167 168 while True: 169 async with self._lock: 170 # сheck if there are any available connections in the pool 171 if self._pool: 172 # round-robin strategy 173 if self._balance_strategy is None: 174 conn = self._pool.popleft() 175 # custom strategy 176 else: 177 try: 178 conn = self._balance_strategy(list(self._pool)) 179 except Exception as e: 180 raise RuntimeError(f"balance_strategy raised an exception: {e}") 181 182 if conn not in self._pool: 183 raise RuntimeError("balance_strategy returned a connection not in pool") 184 self._pool.remove(conn) 185 186 # mark it as currently in use 187 self._used.add(conn) 188 return conn 189 190 if (time.monotonic() - start_time) >= effective_timeout: 191 raise TimeoutError("Timed out waiting for a free connection in the pool") 192 193 # if no connections are available, wait briefly before retrying 194 # this gives other coroutines (like `release`) a chance to return a connection to the pool 195 await asyncio.sleep(0.1)
Acquire a connection from the pool.
If no connections are available, this method will wait until one is released.
Returns
A database connection.
197 async def release(self, conn: Connection): 198 """ 199 Release a previously acquired connection back to the pool. 200 201 :param conn: The connection to release. 202 """ 203 async with self._lock: 204 if conn in self._used: 205 self._used.remove(conn) 206 self._pool.append(conn)
Release a previously acquired connection back to the pool.
Parameters
- conn: The connection to release.
208 async def close(self): 209 """ 210 Closes all connections in the pool. 211 212 This should be called during application shutdown to clean up resources. 213 """ 214 async with self._lock: 215 while self._pool: 216 conn = self._pool.popleft() 217 await conn.close() 218 for conn in self._used: 219 await conn.close() 220 self._used.clear()
Closes all connections in the pool.
This should be called during application shutdown to clean up resources.
222 async def execute(self, query: str, *args): 223 """ 224 Executes a query that does not return rows (e.g. INSERT, UPDATE, DELETE). 225 226 :param query: The SQL query string. 227 :param args: Optional parameters for the SQL query. 228 :return: The result of the query execution. 229 """ 230 conn = await self.acquire() 231 try: 232 return await conn.execute(query, *args) 233 finally: 234 await self.release(conn)
Executes a query that does not return rows (e.g. INSERT, UPDATE, DELETE).
Parameters
- query: The SQL query string.
- args: Optional parameters for the SQL query.
Returns
The result of the query execution.
236 async def fetch(self, query: str, *args): 237 """ 238 Executes a query and fetches all resulting rows. 239 240 :param query: The SQL query string. 241 :param args: Optional parameters for the SQL query. 242 :return: A list of rows returned by the query. 243 """ 244 conn = await self.acquire() 245 try: 246 return await conn.fetch(query, *args) 247 finally: 248 await self.release(conn)
Executes a query and fetches all resulting rows.
Parameters
- query: The SQL query string.
- args: Optional parameters for the SQL query.
Returns
A list of rows returned by the query.
250 async def fetchrow(self, query: str, *args): 251 """ 252 Executes a query and fetches a single row (first row). 253 254 :param query: The SQL query string. 255 :param args: Optional parameters for the SQL query. 256 :return: A single row returned by the query. 257 """ 258 conn = await self.acquire() 259 try: 260 return await conn.fetchrow(query, *args) 261 finally: 262 await self.release(conn)
Executes a query and fetches a single row (first row).
Parameters
- query: The SQL query string.
- args: Optional parameters for the SQL query.
Returns
A single row returned by the query.