matterbridge/vendor/go.mau.fi/whatsmeow/download.go

214 lines
7.0 KiB
Go

// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package whatsmeow
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"net/http"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/util/cbcutil"
"go.mau.fi/whatsmeow/util/hkdfutil"
)
// MediaType represents a type of uploaded file on WhatsApp.
// The value is the key which is used as a part of generating the encryption keys.
type MediaType string
// The known media types
const (
MediaImage MediaType = "WhatsApp Image Keys"
MediaVideo MediaType = "WhatsApp Video Keys"
MediaAudio MediaType = "WhatsApp Audio Keys"
MediaDocument MediaType = "WhatsApp Document Keys"
MediaHistory MediaType = "WhatsApp History Keys"
MediaAppState MediaType = "WhatsApp App State Keys"
)
// DownloadableMessage represents a protobuf message that contains attachment info.
type DownloadableMessage interface {
proto.Message
GetDirectPath() string
GetMediaKey() []byte
GetFileSha256() []byte
GetFileEncSha256() []byte
}
// All the message types that are intended to be downloadable
var (
_ DownloadableMessage = (*waProto.ImageMessage)(nil)
_ DownloadableMessage = (*waProto.AudioMessage)(nil)
_ DownloadableMessage = (*waProto.VideoMessage)(nil)
_ DownloadableMessage = (*waProto.DocumentMessage)(nil)
_ DownloadableMessage = (*waProto.StickerMessage)(nil)
_ DownloadableMessage = (*waProto.HistorySyncNotification)(nil)
_ DownloadableMessage = (*waProto.ExternalBlobReference)(nil)
)
type downloadableMessageWithLength interface {
DownloadableMessage
GetFileLength() uint64
}
type downloadableMessageWithSizeBytes interface {
DownloadableMessage
GetFileSizeBytes() uint64
}
type downloadableMessageWithURL interface {
DownloadableMessage
GetUrl() string
}
var classToMediaType = map[protoreflect.Name]MediaType{
"ImageMessage": MediaImage,
"AudioMessage": MediaAudio,
"VideoMessage": MediaVideo,
"DocumentMessage": MediaDocument,
"StickerMessage": MediaImage,
"HistorySyncNotification": MediaHistory,
"ExternalBlobReference": MediaAppState,
}
var mediaTypeToMMSType = map[MediaType]string{
MediaImage: "image",
MediaAudio: "audio",
MediaVideo: "video",
MediaDocument: "document",
MediaHistory: "md-msg-hist",
MediaAppState: "md-app-state",
}
// DownloadAny loops through the downloadable parts of the given message and downloads the first non-nil item.
func (cli *Client) DownloadAny(msg *waProto.Message) (data []byte, err error) {
downloadables := []DownloadableMessage{msg.GetImageMessage(), msg.GetAudioMessage(), msg.GetVideoMessage(), msg.GetDocumentMessage(), msg.GetStickerMessage()}
for _, downloadable := range downloadables {
if downloadable != nil {
return cli.Download(downloadable)
}
}
return nil, ErrNothingDownloadableFound
}
func getSize(msg DownloadableMessage) int {
switch sized := msg.(type) {
case downloadableMessageWithLength:
return int(sized.GetFileLength())
case downloadableMessageWithSizeBytes:
return int(sized.GetFileSizeBytes())
default:
return -1
}
}
// Download downloads the attachment from the given protobuf message.
func (cli *Client) Download(msg DownloadableMessage) (data []byte, err error) {
mediaType, ok := classToMediaType[msg.ProtoReflect().Descriptor().Name()]
if !ok {
return nil, fmt.Errorf("%w '%s'", ErrUnknownMediaType, string(msg.ProtoReflect().Descriptor().Name()))
}
urlable, ok := msg.(downloadableMessageWithURL)
if ok && len(urlable.GetUrl()) > 0 {
return downloadAndDecrypt(urlable.GetUrl(), msg.GetMediaKey(), mediaType, getSize(msg), msg.GetFileEncSha256(), msg.GetFileSha256())
} else if len(msg.GetDirectPath()) > 0 {
return cli.downloadMediaWithPath(msg.GetDirectPath(), msg.GetFileEncSha256(), msg.GetFileSha256(), msg.GetMediaKey(), getSize(msg), mediaType, mediaTypeToMMSType[mediaType])
} else {
return nil, ErrNoURLPresent
}
}
func (cli *Client) downloadMediaWithPath(directPath string, encFileHash, fileHash, mediaKey []byte, fileLength int, mediaType MediaType, mmsType string) (data []byte, err error) {
err = cli.refreshMediaConn(false)
if err != nil {
return nil, fmt.Errorf("failed to refresh media connections: %w", err)
}
for i, host := range cli.mediaConn.Hosts {
mediaURL := fmt.Sprintf("https://%s%s&hash=%s&mms-type=%s&__wa-mms=", host.Hostname, directPath, base64.URLEncoding.EncodeToString(encFileHash), mmsType)
data, err = downloadAndDecrypt(mediaURL, mediaKey, mediaType, fileLength, encFileHash, fileHash)
// TODO there are probably some errors that shouldn't retry
if err != nil {
if i >= len(cli.mediaConn.Hosts)-1 {
return nil, fmt.Errorf("failed to download media from last host: %w", err)
}
cli.Log.Warnf("Failed to download media: %s, trying with next host...", err)
}
}
return
}
func downloadAndDecrypt(url string, mediaKey []byte, appInfo MediaType, fileLength int, fileEncSha256, fileSha256 []byte) (data []byte, err error) {
iv, cipherKey, macKey, _ := getMediaKeys(mediaKey, appInfo)
var ciphertext, mac []byte
if ciphertext, mac, err = downloadEncryptedMedia(url, fileEncSha256); err != nil {
} else if err = validateMedia(iv, ciphertext, macKey, mac); err != nil {
} else if data, err = cbcutil.Decrypt(cipherKey, iv, ciphertext); err != nil {
err = fmt.Errorf("failed to decrypt file: %w", err)
} else if fileLength >= 0 && len(data) != fileLength {
err = fmt.Errorf("%w: expected %d, got %d", ErrFileLengthMismatch, fileLength, len(data))
} else if len(fileSha256) == 32 && sha256.Sum256(data) != *(*[32]byte)(fileSha256) {
err = ErrInvalidMediaSHA256
}
return
}
func getMediaKeys(mediaKey []byte, appInfo MediaType) (iv, cipherKey, macKey, refKey []byte) {
mediaKeyExpanded := hkdfutil.SHA256(mediaKey, nil, []byte(appInfo), 112)
return mediaKeyExpanded[:16], mediaKeyExpanded[16:48], mediaKeyExpanded[48:80], mediaKeyExpanded[80:]
}
func downloadEncryptedMedia(url string, checksum []byte) (file, mac []byte, err error) {
var resp *http.Response
resp, err = http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusNotFound {
err = ErrMediaDownloadFailedWith404
} else if resp.StatusCode == http.StatusGone {
err = ErrMediaDownloadFailedWith410
} else {
err = fmt.Errorf("download failed with status code %d", resp.StatusCode)
}
return
}
var data []byte
data, err = io.ReadAll(resp.Body)
if err != nil {
return
} else if len(data) <= 10 {
err = ErrTooShortFile
return
}
file, mac = data[:len(data)-10], data[len(data)-10:]
if len(checksum) == 32 && sha256.Sum256(data) != *(*[32]byte)(checksum) {
err = ErrInvalidMediaEncSHA256
}
return
}
func validateMedia(iv, file, macKey, mac []byte) error {
h := hmac.New(sha256.New, macKey)
h.Write(iv)
h.Write(file)
if !hmac.Equal(h.Sum(nil)[:10], mac) {
return ErrInvalidMediaHMAC
}
return nil
}