Support providing a list of activity types when counting unread notifications (#3141)

* Support providing a list of activity types when counting unread notifications

* Minor cleanup

* Test added

* Smaller fix

* Test small fix

* uint64
This commit is contained in:
Alexander 2023-01-30 20:43:13 +01:00 committed by GitHub
parent 7e1a894ab8
commit 2fba8c4591
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 11 deletions

View File

@ -794,27 +794,43 @@ func (db sqlitePersistence) MarkActivityCenterNotificationsUnread(ids []types.He
}
func buildActivityCenterNotificationsCountQuery(isAccepted bool) string {
func (db sqlitePersistence) buildActivityCenterNotificationsCountQuery(isAccepted bool, activityTypes []ActivityCenterType) *sql.Row {
var args []interface{}
var acceptedWhere string
if !isAccepted {
acceptedWhere = `AND NOT accepted`
}
return fmt.Sprintf(`SELECT COUNT(1) FROM activity_center_notifications WHERE NOT read AND NOT dismissed %s`, acceptedWhere)
var inTypeWhere string
if len(activityTypes) != 0 {
inVector := strings.Repeat("?, ", len(activityTypes)-1) + "?"
inTypeWhere = fmt.Sprintf(" AND notification_type IN (%s)", inVector)
for _, activityCenterType := range activityTypes {
args = append(args, activityCenterType)
}
}
query := fmt.Sprintf(`
SELECT COUNT(1)
FROM activity_center_notifications
WHERE NOT read AND NOT dismissed
%s
%s
`, acceptedWhere, inTypeWhere)
return db.db.QueryRow(query, args...)
}
func (db sqlitePersistence) UnreadActivityCenterNotificationsCount() (uint64, error) {
var count uint64
query := buildActivityCenterNotificationsCountQuery(false)
err := db.db.QueryRow(query).Scan(&count)
err := db.buildActivityCenterNotificationsCountQuery(false, []ActivityCenterType{}).Scan(&count)
return count, err
}
func (db sqlitePersistence) UnreadAndAcceptedActivityCenterNotificationsCount() (uint64, error) {
func (db sqlitePersistence) UnreadAndAcceptedActivityCenterNotificationsCount(activityTypes []ActivityCenterType) (uint64, error) {
var count uint64
query := buildActivityCenterNotificationsCountQuery(true)
err := db.db.QueryRow(query).Scan(&count)
err := db.buildActivityCenterNotificationsCountQuery(true, activityTypes).Scan(&count)
return count, err
}

View File

@ -16,8 +16,8 @@ func (m *Messenger) UnreadActivityCenterNotificationsCount() (uint64, error) {
return m.persistence.UnreadActivityCenterNotificationsCount()
}
func (m *Messenger) UnreadAndAcceptedActivityCenterNotificationsCount() (uint64, error) {
return m.persistence.UnreadAndAcceptedActivityCenterNotificationsCount()
func (m *Messenger) UnreadAndAcceptedActivityCenterNotificationsCount(activityTypes []ActivityCenterType) (uint64, error) {
return m.persistence.UnreadAndAcceptedActivityCenterNotificationsCount(activityTypes)
}
func toHexBytes(b [][]byte) []types.HexBytes {

View File

@ -1507,6 +1507,80 @@ func TestActivityCenterReadUnread(t *testing.T) {
require.Equal(t, nID2, notifications[0].ID)
}
func TestUnreadAndAcceptedActivityCenterNotificationsCount(t *testing.T) {
db, err := openTestDB()
require.NoError(t, err)
p := newSQLitePersistence(db)
chat := CreatePublicChat("test-chat", &testTimeSource{})
message := &common.Message{}
message.Text = "sample text"
chat.LastMessage = message
err = p.SaveChat(*chat)
require.NoError(t, err)
allNotifications := []*ActivityCenterNotification{
{
ID: types.HexBytes("1"),
Type: ActivityCenterNotificationTypeMention,
ChatID: chat.ID,
Timestamp: 1,
},
{
ID: types.HexBytes("2"),
Type: ActivityCenterNotificationTypeNewOneToOne,
ChatID: chat.ID,
Timestamp: 1,
},
{
ID: types.HexBytes("3"),
Type: ActivityCenterNotificationTypeMention,
ChatID: chat.ID,
Timestamp: 1,
},
{
ID: types.HexBytes("4"),
Type: ActivityCenterNotificationTypeMention,
ChatID: chat.ID,
Timestamp: 1,
},
{
ID: types.HexBytes("5"),
Type: ActivityCenterNotificationTypeContactRequest,
ChatID: chat.ID,
Timestamp: 1,
},
}
for _, notification := range allNotifications {
err = p.SaveActivityCenterNotification(notification)
require.NoError(t, err)
}
notificationCount, err := p.UnreadAndAcceptedActivityCenterNotificationsCount(
[]ActivityCenterType{},
)
require.NoError(t, err)
require.Equal(t, notificationCount, uint64(5))
notificationCount, err = p.UnreadAndAcceptedActivityCenterNotificationsCount(
[]ActivityCenterType{
ActivityCenterNotificationTypeNewOneToOne,
},
)
require.NoError(t, err)
require.Equal(t, notificationCount, uint64(1))
notificationCount, err = p.UnreadAndAcceptedActivityCenterNotificationsCount(
[]ActivityCenterType{
ActivityCenterNotificationTypeNewOneToOne,
ActivityCenterNotificationTypeContactRequest,
},
)
require.NoError(t, err)
require.Equal(t, notificationCount, uint64(2))
}
func TestActivityCenterReadUnreadFilterByTypes(t *testing.T) {
db, err := openTestDB()
require.NoError(t, err)

View File

@ -1082,8 +1082,8 @@ func (api *PublicAPI) UnreadActivityCenterNotificationsCount() (uint64, error) {
return api.service.messenger.UnreadActivityCenterNotificationsCount()
}
func (api *PublicAPI) UnreadAndAcceptedActivityCenterNotificationsCount() (uint64, error) {
return api.service.messenger.UnreadAndAcceptedActivityCenterNotificationsCount()
func (api *PublicAPI) UnreadAndAcceptedActivityCenterNotificationsCount(activityTypes []protocol.ActivityCenterType) (uint64, error) {
return api.service.messenger.UnreadAndAcceptedActivityCenterNotificationsCount(activityTypes)
}
func (api *PublicAPI) MarkAllActivityCenterNotificationsRead(ctx context.Context) error {