I am trying to implement a QAbstractProxyModel
that maps an SqlTableModel
to a tree-like data structure. The table has a column called parent_id
, whose value is added to the createIndex
call as a third argument. The question is similar to this user's post, only that I am working in Python not in C++.
The TreeView loads correctly:
But when I try to expand an item, the application crashes. Debugging tells me that there seems to be an infinite loop of index
, rowCount
and mapToSource
being called.
I am at the end of my wits. Do you have any ideas? See the MWE below.
from __future__ import annotations
from PySide6.QtWidgets import QGridLayout
from PySide6.QtWidgets import QTreeView
from PySide6.QtWidgets import QApplication
from PySide6.QtWidgets import QMainWindow
from PySide6.QtWidgets import QWidget
from PySide6.QtCore import QModelIndex
from PySide6.QtCore import Qt
from PySide6.QtCore import Slot
from PySide6.QtCore import QAbstractProxyModel
from PySide6.QtSql import QSqlDatabase
from PySide6.QtSql import QSqlQuery
from PySide6.QtSql import QSqlTableModel
class CustomTreeModel(QAbstractProxyModel):
def __init__(self, database: str, parent: QWidget = None):
QAbstractProxyModel.__init__(self, parent)
sourceModel = QSqlTableModel(parent, database)
sourceModel.setTable('test')
sourceModel.select()
self.setSourceModel(sourceModel)
def flags(self, proxyIndex: QModelIndex) -> Qt.ItemFlags:
return Qt.ItemIsEnabled | Qt.ItemIsEditable
def data(self, proxyIndex: QModelIndex, role: int):
print("data")
if proxyIndex.isValid:
sourceIndex = self.mapToSource(proxyIndex)
return sourceIndex.data(role)
return None
def index(
self,
row: int,
column: int,
parentIndex: QModelIndex
) -> QModelIndex:
print("index")
if row < 0 and column < 0:
return QModelIndex()
parentId = parentIndex.internalPointer()
return self.createIndex(row, column, parentId)
def mapFromSource(self, sourceIndex: QModelIndex) -> QModelIndex:
print("mapFromSource")
if self.isRootItem(sourceIndex):
return QModelIndex()
if sourceIndex.column() == 0:
sourceId = sourceIndex.data()
else:
sourceId = sourceIndex.siblingAtColumn(0).data()
parentId = self.getParentId(sourceId)
childIds = self.getChildIds(parentId)
row = childIds.index(sourceId)
column = sourceIndex.column()
proxyIndex = self.createIndex(row, column, parentId)
return proxyIndex
def mapToSource(self, proxyIndex: QModelIndex) -> QModelIndex:
print("mapToSource")
if self.isRootItem(proxyIndex):
return QModelIndex()
parentId = proxyIndex.internalPointer()
childIds = self.getChildIds(parentId)
rowId = childIds[proxyIndex.row()]
rowIds = self.getAllIds()
sourceRow = rowIds.index(rowId)
sourceColumn = proxyIndex.column()
sourceIndex = self.sourceModel().index(sourceRow, sourceColumn)
return sourceIndex
def rowCount(self, parentIndex: QModelIndex) -> int:
print("rowCount")
if parentIndex.column() > 0:
return 0
parentId = parentIndex.internalPointer()
childIds = self.getChildIds(parentId)
return len(childIds)
def columnCount(self, parentIndex: QModelIndex) -> int:
print("columnCount")
if parentIndex.column() > 0:
return 0
numColumns = self.sourceModel().columnCount(parentIndex)
return numColumns
def parent(self, childIndex: QModelIndex) -> QModelIndex:
print("parent")
if childIndex.column() > 0:
return QModelIndex()
sourceIndex = self.mapToSource(childIndex)
childId = sourceIndex.siblingAtColumn(0).data()
parentId = self.getParentId(childId)
if not parentId:
return QModelIndex()
parentParentId = self.getParentId(parentId)
parentIds = self.getChildIds(parentParentId)
parentRow = parentIds.index(parentId)
parentIndex = self.createIndex(parentRow, 0, parentId)
return parentIndex
def getParentId(self, childId: str) -> str | None:
table = self.sourceModel().tableName()
query = QSqlQuery()
query.prepare(f"""
SELECT parent_id
FROM {table}
WHERE id=?
""")
query.addBindValue(childId)
query.exec_()
if query.first():
parentId = query.value(0)
return parentId if parentId else None
return None
def hasChildren(self, parentIndex: QModelIndex) -> bool:
if parentIndex.column() > 0:
return False
parentId = parentIndex.internalPointer()
childIds = self.getChildIds(parentId)
return len(childIds) > 0
def getAllIds(self) -> list[str]:
table = self.sourceModel().tableName()
query = QSqlQuery()
query.prepare(f"""
SELECT id
FROM {table}
""")
query.exec_()
ids = []
while query.next():
ids.append(query.value(0))
return ids
def getChildIds(self, parentId: str | None) -> list[str]:
table = self.sourceModel().tableName()
query = QSqlQuery()
if not parentId or parentId == '':
query.prepare(f"""
SELECT id
FROM {table}
WHERE parent_id IS NULL OR parent_id=''
""")
else:
query.prepare(f"""
SELECT id
FROM {table}
WHERE parent_id=?""")
query.addBindValue(parentId)
query.exec_()
childIds = []
while query.next():
childIds.append(query.value(0))
return childIds
def isRootItem(self, index: QModelIndex):
return index.row() == -1 and index.column() == -1
class CustomTreeWidget(QWidget):
def __init__(self, parent: QWidget = None):
QWidget.__init__(self, parent)
self.model: CustomTreeModel
self.view = QTreeView(self)
layout = QGridLayout(self)
layout.addWidget(self.view)
self.setLayout(layout)
@Slot()
def setDatabase(self):
database = QSqlDatabase.database()
model = CustomTreeModel(database, self)
self.view.setModel(model)
self.model = model
def initTestDatabase():
query = QSqlQuery()
query.prepare("""
CREATE TABLE test (
"id" TEXT,
"text" TEXT,
"parent_id" TEXT,
PRIMARY KEY("id")
);
""")
query.exec_()
query = QSqlQuery()
query.prepare("""
INSERT INTO test (
id, text, parent_id)
VALUES
(?, ?, ?),
(?, ?, ?),
(?, ?, ?),
(?, ?, ?);
""")
query.addBindValue("ID101")
query.addBindValue("Text")
query.addBindValue(None)
query.addBindValue("ID102")
query.addBindValue("Text")
query.addBindValue("ID101")
query.addBindValue("ID103")
query.addBindValue("Text")
query.addBindValue("ID101")
query.addBindValue("ID104")
query.addBindValue("Text")
query.addBindValue(None)
query.exec_()
if __name__ == "__main__":
projectDb = QSqlDatabase.addDatabase("QSQLITE")
projectDb.setDatabaseName(":memory:")
projectDb.open()
initTestDatabase()
app = QApplication()
mainWindow = QMainWindow()
widget = CustomTreeWidget(mainWindow)
widget.setDatabase()
mainWindow.setCentralWidget(widget)
mainWindow.showMaximized()
app.exec_()