import base64 from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import and_, select, delete, update, asc from database.models import User, ChatMessage, Wallet, MemoryVector, Payment class UserRepository: def __init__(self, session: AsyncSession): self.session = session async def get_by_telegram_id(self, telegram_id: int): return await self.session.get(User, telegram_id) async def create_if_not_exists(self, telegram_id: int, **kwargs): user = await self.get_by_telegram_id(telegram_id) if not user: user = User(telegram_id=telegram_id, **kwargs) self.session.add(user) await self.session.commit() return user async def update(self, user: User, **kwargs): if 'balance_credits' in kwargs: kwargs['balance_credits'] = user.balance_credits - kwargs['balance_credits'] await self.session.execute( update(User).where(User.telegram_id == user.telegram_id).values(**kwargs) ) await self.session.commit() async def delete_chat_messages(self, user: User): await self.session.execute(delete(ChatMessage).where(ChatMessage.user_id == user.telegram_id)) await self.session.commit() async def get_wallet(self, user_id: int): wallet = await self.session.scalar(select(Wallet.encrypted_private_key).where(Wallet.user_id == user_id)) if wallet: base64_bytes = wallet.encode('utf-8') text_bytes = base64.b64decode(base64_bytes) text = text_bytes.decode('utf-8') return text return None async def get_messags(self, user_id: int): return (await self.session.scalars(select(ChatMessage). where(ChatMessage.user_id == user_id). order_by(asc(ChatMessage.id) ) ) ).fetchall() async def get_memory_vector(self, user_id: int): return await self.session.scalar(select(MemoryVector).where(MemoryVector.user_id == user_id)) async def add_memory_vector(self, user_id: int, vector_store_id: int): memory_vector = MemoryVector(user_id=user_id, id_vector=vector_store_id) self.session.add(memory_vector) await self.session.commit() async def delete_memory_vector(self, user_id: int): await self.session.execute(delete(MemoryVector).where(MemoryVector.user_id == user_id)) await self.session.commit() async def add_context(self, user_id: int, role: str, content: str): chat_message = ChatMessage(user_id=user_id, role=role, content=content) self.session.add(chat_message) await self.session.commit() return chat_message.id async def delete_wallet_key(self, user_id: int): await self.session.execute(delete(Wallet).where(Wallet.user_id == user_id)) await self.session.commit() async def add_wallet_key(self, user_id: int, key: str): await self.delete_wallet_key(user_id=user_id) text_bytes = key.encode('utf-8') base64_bytes = base64.b64encode(text_bytes) base64_string = base64_bytes.decode('utf-8') wallet = Wallet(user_id=user_id, encrypted_private_key=base64_string) self.session.add(wallet) await self.session.commit() async def add_payment(self, user_id: int, amount: int, crypto_amount: str, crypto_currency: str, random_suffix: str): payment = Payment(user_id=user_id, amount_usd=amount, crypto_amount=crypto_amount, crypto_currency=crypto_currency, random_suffix=random_suffix) self.session.add(payment) await self.session.commit() return payment.id async def add_user_credits(self, user_id: int, balance_credits: int): await self.session.execute(update(User).where(User.telegram_id == user_id). values(balance_credits=User.balance_credits + balance_credits)) await self.session.commit() async def get_row_for_md(self, row_id: int): return await self.session.scalar(select(ChatMessage).where(ChatMessage.id == row_id))