118 lines
3.8 KiB
Python
118 lines
3.8 KiB
Python
import json
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
|
|
from sqlalchemy import Engine, BLOB, text
|
|
from sqlmodel import SQLModel, Field, Session, select
|
|
from typing_extensions import override, Self, Optional
|
|
|
|
from app.repos.sql import dbRetry
|
|
from core import Utils
|
|
from core.domain.optimization.Optimization import Optimization
|
|
from core.domain.optimization.OptimizationPoint import OptimizationPoint
|
|
from core.domain.optimization.OptimizationResult import OptimizationResult
|
|
from core.domain.optimization.OptimizationRoute import OptimizationRoute
|
|
from core.repos.OptimizationResultRepo import OptimizationResultRepo
|
|
from core.types.Id import Id
|
|
|
|
|
|
@dataclass
|
|
class OptimizationResultSqlRepo(OptimizationResultRepo):
|
|
|
|
db: Engine
|
|
|
|
class Table(SQLModel, table=True):
|
|
__tablename__ = "optimization_result"
|
|
|
|
id: str = Field(primary_key=True)
|
|
optimization_id: str = Field(foreign_key="optimization.id")
|
|
routes: str = Field(sa_type=BLOB)
|
|
unvisited: str = Field(sa_type=BLOB)
|
|
created_at: int
|
|
info: str
|
|
authorized_by_user_id: str
|
|
parent: str
|
|
|
|
def toDomain(self) -> OptimizationResult:
|
|
routes = [OptimizationRoute.fromJson(**x) for x in json.loads(self.routes.decode('utf-8'))]
|
|
unvisited = [OptimizationPoint.fromJson(**x) for x in json.loads(self.unvisited.decode('utf-8'))] if self.unvisited is not None else []
|
|
return OptimizationResult(
|
|
optimizationId=Id(value=uuid.UUID(self.optimization_id)),
|
|
routes=routes,
|
|
unvisited=unvisited,
|
|
createdAt=self.created_at,
|
|
info=self.info,
|
|
authorizedByUserId=self.authorized_by_user_id,
|
|
parent=Id(value=uuid.UUID(self.parent)) if self.parent is not None else None,
|
|
id=Id(value=uuid.UUID(self.id)),
|
|
)
|
|
|
|
@classmethod
|
|
def toRow(cls, obj: OptimizationResult) -> Self:
|
|
return cls(
|
|
optimization_id=obj.optimizationId.value,
|
|
routes=Utils.json_dumps(obj.routes).encode('ascii'),
|
|
unvisited=Utils.json_dumps(obj.unvisited).encode('ascii') if obj.unvisited is not None else None,
|
|
created_at=obj.createdAt,
|
|
info=obj.info,
|
|
authorized_by_user_id=obj.authorizedByUserId,
|
|
parent=obj.parent.value if obj.parent is not None else None,
|
|
id=obj.id.value,
|
|
)
|
|
|
|
@override
|
|
def getAll(self) -> list[OptimizationResult]:
|
|
with Session(self.db) as conn:
|
|
query = select(self.Table)
|
|
return [row.toDomain() for row in conn.exec(query).all()]
|
|
|
|
@override
|
|
def get(self, id: Id[OptimizationResult]) -> Optional[OptimizationResult]:
|
|
with Session(self.db) as conn:
|
|
query = select(self.Table).filter_by(id=id.value)
|
|
row = conn.exec(query).one_or_none()
|
|
return row.toDomain() if row is not None else None
|
|
|
|
@override
|
|
def getAllByOptimizationId(self, optimizationId: Id[Optimization]) -> list[OptimizationResult]:
|
|
with Session(self.db) as conn:
|
|
query = select(self.Table).filter_by(optimization_id=optimizationId.value)
|
|
return [row.toDomain() for row in conn.exec(query).all()]
|
|
|
|
@override
|
|
@dbRetry
|
|
def post(self, optimizationResult: OptimizationResult) -> OptimizationResult:
|
|
with Session(self.db) as conn:
|
|
conn.merge(self.Table.toRow(optimizationResult))
|
|
conn.commit()
|
|
return optimizationResult
|
|
|
|
@override
|
|
def getLatestByOptimizationId(self, optimizationId: Id[Optimization]) -> Optional[OptimizationResult]:
|
|
with Session(self.db) as conn:
|
|
query = select(
|
|
self.Table
|
|
).order_by(
|
|
self.Table.created_at.desc()
|
|
).limit(
|
|
1
|
|
).filter_by(optimization_id=optimizationId.value)
|
|
|
|
row = conn.exec(query).one_or_none()
|
|
if row is None:
|
|
return None
|
|
|
|
return row.toDomain()
|
|
|
|
@override
|
|
def getAllIds(self) -> list[Id[OptimizationResult]]:
|
|
query = text(f"""
|
|
select optimization_result.id from optimization_result
|
|
join optimization o on o.id = optimization_result.optimization_id
|
|
where state not in ('TEST', 'DELETED')
|
|
""")
|
|
|
|
with (Session(self.db) as conn):
|
|
results = conn.exec(query).all()
|
|
return [Id(value=row[0]) for row in results]
|