Fix Windows ACL flakiness issue (Windows error 1336). (#221)

* Add getHomePath(), getConfigPath() and getCachePath() implementations.
Fix ACL flakiness issue.
Add tests.

* Add getTempPath().
Normalize path endings for all xxPath() functions.

* Fix 2.0/devel compilation errors.
This commit is contained in:
Eugene Kabanov 2024-06-17 10:04:14 +03:00 committed by GitHub
parent 28743363ff
commit bb086e69da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 286 additions and 63 deletions

View File

@ -55,6 +55,13 @@ when defined(windows):
FileBasicInfoClass = 0'u32 FileBasicInfoClass = 0'u32
CSIDL_APPDATA = 0x001a'u32
# <user name>\Application Data
CSIDL_PROFILE = 0x0028'u32
# <user name>
CSIDL_LOCAL_APPDATA = 0x001c'u32
# <user name>\Local Settings\Applicaiton Data (non roaming)
type type
IoErrorCode* = distinct uint32 IoErrorCode* = distinct uint32
IoHandle* = distinct uint IoHandle* = distinct uint
@ -184,6 +191,12 @@ when defined(windows):
nNumberOfBytesToLockLow, nNumberOfBytesToLockHigh: uint32, nNumberOfBytesToLockLow, nNumberOfBytesToLockHigh: uint32,
lpOverlapped: pointer): uint32 {. lpOverlapped: pointer): uint32 {.
importc: "UnlockFileEx", dynlib: "kernel32", stdcall, sideEffect.} importc: "UnlockFileEx", dynlib: "kernel32", stdcall, sideEffect.}
proc shGetSpecialFolderPathW(hwnd: uint, pszPath: WideCString, csidl: uint32,
fCreate: uint32):uint32 {.
importc: "SHGetSpecialFolderPathW", dynlib: "shell32", stdcall,
sideEffect.}
proc getTempPathW(nBufferLength: uint32, lpBuffer: WideCString): uint32 {.
importc: "GetTempPathW", dynlib: "kernel32", stdcall, sideEffect.}
const const
NO_ERROR = IoErrorCode(0) NO_ERROR = IoErrorCode(0)
@ -245,6 +258,8 @@ elif defined(posix):
var errno {.importc, header: "<errno.h>".}: cint var errno {.importc, header: "<errno.h>".}: cint
proc c_getenv(env: cstring): cstring {.
importc: "getenv", header: "<stdlib.h>", sideEffect.}
proc write(a1: cint, a2: pointer, a3: csize_t): int {. proc write(a1: cint, a2: pointer, a3: csize_t): int {.
importc, header: "<unistd.h>", sideEffect.} importc, header: "<unistd.h>", sideEffect.}
proc read(a1: cint, a2: pointer, a3: csize_t): int {. proc read(a1: cint, a2: pointer, a3: csize_t): int {.
@ -1547,3 +1562,97 @@ proc unlockFile*(lock: IoLockHandle): IoResult[void] =
err(res.error()) err(res.error())
else: else:
ok() ok()
when defined(windows):
proc getSpecialFolderPath(code: uint32): IoResult[string] =
var path: array[MAX_PATH, Utf16Char]
let
wpath = cast[WideCString](addr path[0])
res = shGetSpecialFolderPathW(0'u, cast[WideCString](addr path[0]),
code, 0'u32)
if res == 0'u32:
err(ioLastError())
else:
var strpath = `$`(wpath, len(path))
normPathEnd(strpath, true)
ok(strpath)
proc getHomePath*(): IoResult[string] =
## Returns path to user's home directory.
when defined(windows):
getSpecialFolderPath(CSIDL_PROFILE)
else:
let res = c_getenv("HOME")
var path = if isNil(res): "" else: $res
normPathEnd(path, true)
ok(path)
proc getConfigPath*(): IoResult[string] =
## Returns path to application's configuration directory.
when defined(windows):
getSpecialFolderPath(CSIDL_APPDATA)
else:
let
subpath = ".config"
xres = c_getenv("XDG_CONFIG_HOME")
var
path =
if isNil(xres):
let hres = c_getenv("HOME")
if isNil(hres): subpath else: $hres & DirSep & subpath
else:
$xres
normPathEnd(path, true)
ok(path)
proc getCachePath*(): IoResult[string] =
## Returns path to application's cache directory.
when defined(windows):
getSpecialFolderPath(CSIDL_LOCAL_APPDATA)
else:
let subpath =
when defined(macos) or defined(macosx) or defined(osx):
"Library/Caches"
else:
".cache"
let
xres = c_getenv("XDG_CACHE_HOME")
var
path =
if isNil(xres):
let hres = c_getenv("HOME")
if isNil(hres): subpath else: $hres & DirSep & subpath
else:
$xres
normPathEnd(path, true)
ok(path)
proc getTempPath*(): IoResult[string] =
## Returns path to OS temporary directory.
when defined(windows):
var path: array[MAX_PATH + 1, Utf16Char]
let
wpath = cast[WideCString](addr path[0])
res = getTempPathW(uint32(MAX_PATH), wpath)
if res == 0'u32:
err(ioLastError())
else:
var strpath = `$`(wpath, len(path))
normPathEnd(strpath, true)
ok(strpath)
else:
for name in ["TMP", "TEMP", "TMPDIR", "TEMPDIR"]:
let res = c_getenv(cstring(name))
if not(isNil(res)) and isDir($res):
var path = $res
normPathEnd(path, true)
return ok(path)
var defaultDir =
when defined(android):
"/data/local/tmp"
else:
"/tmp"
if isDir(defaultDir):
normPathEnd(defaultDir, true)
return ok(defaultDir)
err(IoErrorCode(2)) # ENOENT

View File

@ -33,8 +33,11 @@ const
SECURITY_DESCRIPTOR_REVISION = 1'u32 SECURITY_DESCRIPTOR_REVISION = 1'u32
ACCESS_ALLOWED_ACE_TYPE = 0x00'u8 ACCESS_ALLOWED_ACE_TYPE = 0x00'u8
SE_DACL_PROTECTED = 0x1000'u16 SE_DACL_PROTECTED = 0x1000'u16
LPTR = 0x0040'u32
type type
LocalMemPtr = distinct pointer
ACL {.pure, final.} = object ACL {.pure, final.} = object
aclRevision: uint8 aclRevision: uint8
sbz1: uint8 sbz1: uint8
@ -45,11 +48,11 @@ type
PACL* = ptr ACL PACL* = ptr ACL
SID* = object SID* = object
data: seq[byte] data: LocalMemPtr
SD* = object SD* = object
sddata: seq[byte] sddata: LocalMemPtr
acldata: seq[byte] acldata: LocalMemPtr
SID_AND_ATTRIBUTES {.pure, final.} = object SID_AND_ATTRIBUTES {.pure, final.} = object
sid: pointer sid: pointer
@ -73,6 +76,8 @@ type
proc closeHandle(hobj: uint): int32 {. proc closeHandle(hobj: uint): int32 {.
importc: "CloseHandle", dynlib: "kernel32", stdcall, sideEffect.} importc: "CloseHandle", dynlib: "kernel32", stdcall, sideEffect.}
proc localAlloc(uFlags: uint32, ubytes: uint): pointer {.
importc: "LocalAlloc", stdcall, dynlib: "kernel32".}
proc localFree(p: pointer): uint {. proc localFree(p: pointer): uint {.
importc: "LocalFree", stdcall, dynlib: "kernel32".} importc: "LocalFree", stdcall, dynlib: "kernel32".}
proc getCurrentProcess(): uint {. proc getCurrentProcess(): uint {.
@ -128,33 +133,53 @@ proc setSecurityDescriptorControl(pSD: pointer, bitsOfInterest: uint16,
importc: "SetSecurityDescriptorControl", dynlib: "advapi32", stdcall, importc: "SetSecurityDescriptorControl", dynlib: "advapi32", stdcall,
sideEffect.} sideEffect.}
proc len*(sid: SID): int = len(sid.data) proc len*(sid: SID): int =
int(getLengthSid(cast[pointer](sid.data)))
proc free(mem: LocalMemPtr): uint =
localFree(cast[pointer](mem))
proc free*(sd: var SD) =
## Free memory occupied by security descriptor.
discard sd.sddata.free()
discard sd.acldata.free()
sd.sddata = LocalMemPtr(nil)
sd.acldata = LocalMemPtr(nil)
proc free*(sid: var SID) =
## Free memory occupied by security identifier.
discard sid.data.free()
sid.data = LocalMemPtr(nil)
proc getTokenInformation(token: uint, proc getTokenInformation(token: uint,
information: uint32): IoResult[seq[byte]] = information: uint32): IoResult[LocalMemPtr] =
var tlength: uint32 var
var buffer = newSeq[byte]() tlength: uint32 = 0'u32
localMem: pointer
while true: while true:
let res = let res =
if len(buffer) == 0: if tlength == 0'u32:
getTokenInformation(token, information, nil, 0, tlength) getTokenInformation(token, information, nil, 0, tlength)
else: else:
getTokenInformation(token, information, cast[pointer](addr buffer[0]), getTokenInformation(token, information, localMem, tlength, tlength)
uint32(len(buffer)), tlength)
if res != 0: if res != 0:
return ok(buffer) return ok(LocalMemPtr(localMem))
else: else:
let errCode = ioLastError() let errorCode = ioLastError()
if errCode == ERROR_INSUFFICIENT_BUFFER: if errorCode == ERROR_INSUFFICIENT_BUFFER:
when sizeof(int) == 8: when sizeof(int) == 8:
buffer.setLen(int(tlength)) localMem = localAlloc(LPTR, uint(tlength))
if isNil(localMem):
return err(ioLastError())
elif sizeof(int) == 4: elif sizeof(int) == 4:
if tlength > uint32(high(int)): if tlength > uint32(high(int)):
return err(errCode) return err(errorCode)
else: else:
buffer.setLen(int(tlength)) localMem = localAlloc(LPTR, uint(tlength))
if isNil(localMem):
return err(ioLastError())
else: else:
return err(errCode) return err(errorCode)
proc getCurrentUserSid*(): IoResult[SID] = proc getCurrentUserSid*(): IoResult[SID] =
## Returns current process user's security identifier (SID). ## Returns current process user's security identifier (SID).
@ -163,45 +188,66 @@ proc getCurrentUserSid*(): IoResult[SID] =
if ores == 0: if ores == 0:
err(ioLastError()) err(ioLastError())
else: else:
let tres = getTokenInformation(token, 1'u32) let localMem = getTokenInformation(token, 1'u32).valueOr:
if tres.isErr():
discard closeHandle(token) discard closeHandle(token)
err(tres.error) return err(error)
else: var utoken = cast[ptr TOKEN_USER](localMem)
var buffer = tres.get()
var utoken = cast[ptr TOKEN_USER](addr buffer[0])
let psid = utoken[].user.sid let psid = utoken[].user.sid
if isValidSid(psid) != 0: if isValidSid(psid) != 0:
var ssid = newSeq[byte](getLengthSid(psid)) let length = getLengthSid(psid)
if copySid(uint32(len(ssid)), addr ssid[0], psid) != 0: var ssid = localAlloc(LPTR, length)
if isNil(ssid):
return err(ioLastError())
if copySid(uint32(length), ssid, psid) != 0:
if closeHandle(token) != 0: if closeHandle(token) != 0:
ok(SID(data: ssid)) if free(localMem) != 0'u:
let errorCode = ioLastError()
discard localFree(ssid)
err(errorCode)
else: else:
err(ioLastError()) ok(SID(data: LocalMemPtr(ssid)))
else: else:
let errCode = ioLastError() let errorCode = ioLastError()
discard localFree(ssid)
err(errorCode)
else:
let errorCode = ioLastError()
discard closeHandle(token) discard closeHandle(token)
err(errCode) discard free(localMem)
discard localFree(ssid)
err(errorCode)
else: else:
let errCode = ioLastError() let errorCode = ioLastError()
discard closeHandle(token) discard closeHandle(token)
err(errCode) discard free(localMem)
err(errorCode)
template getAddr*(sid: SID): pointer = template getAddr*(sid: SID): pointer =
## Obtain Windows specific SID pointer. ## Obtain Windows specific SID pointer.
unsafeAddr sid.data[0] cast[pointer](sid.data)
proc createCurrentUserOnlyAcl(kind: SecDescriptorKind): IoResult[seq[byte]] = template getAddr*(mem: LocalMemPtr): pointer =
cast[pointer](mem)
proc createCurrentUserOnlyAcl(kind: SecDescriptorKind): IoResult[LocalMemPtr] =
let aceMask = FILE_ALL_ACCESS let aceMask = FILE_ALL_ACCESS
var userSid = ? getCurrentUserSid() var userSid = ? getCurrentUserSid()
let size = let size =
((sizeof(ACL) + sizeof(ACCESS_ALLOWED_ACE) + len(userSid)) + (uint32(sizeof(ACL) + sizeof(ACCESS_ALLOWED_ACE) + len(userSid)) +
(sizeof(uint32) - 1)) and 0xFFFF_FFFC uint32(sizeof(uint32) - 1)) and 0xFFFF_FFFC'u32
var buffer = newSeq[byte](size) var localMem = localAlloc(LPTR, uint(size))
var pdacl = cast[PACL](addr buffer[0]) if isNil(localMem):
let errorCode = ioLastError()
free(userSid)
return err(errorCode)
var pdacl = cast[PACL](localMem)
if initializeAcl(pdacl, uint32(size), ACL_REVISION) == 0: if initializeAcl(pdacl, uint32(size), ACL_REVISION) == 0:
err(ioLastError()) let errorCode = ioLastError()
discard localFree(localMem)
free(userSid)
err(errorCode)
else: else:
let aceFlags = let aceFlags =
if kind == Folder: if kind == Folder:
@ -210,9 +256,12 @@ proc createCurrentUserOnlyAcl(kind: SecDescriptorKind): IoResult[seq[byte]] =
0'u32 0'u32
if addAccessAllowedAceEx(pdacl, ACL_REVISION, aceFlags, if addAccessAllowedAceEx(pdacl, ACL_REVISION, aceFlags,
aceMask, userSid.getAddr()) == 0: aceMask, userSid.getAddr()) == 0:
err(ioLastError()) let errorCode = ioLastError()
discard localFree(localMem)
free userSid
err(errorCode)
else: else:
ok(buffer) ok(LocalMemPtr(localMem))
proc setCurrentUserOnlyAccess*(path: string): IoResult[void] = proc setCurrentUserOnlyAccess*(path: string): IoResult[void] =
## Set file or folder with path ``path`` to be accessed only by current ## Set file or folder with path ``path`` to be accessed only by current
@ -227,33 +276,50 @@ proc setCurrentUserOnlyAccess*(path: string): IoResult[void] =
else: else:
File File
var buffer = ? createCurrentUserOnlyAcl(descriptorKind) let
var pdacl = cast[PACL](addr buffer[0]) pacl = ? createCurrentUserOnlyAcl(descriptorKind)
pdacl = cast[PACL](pacl)
let dflags = DACL_SECURITY_INFORMATION or dflags = DACL_SECURITY_INFORMATION or
PROTECTED_DACL_SECURITY_INFORMATION PROTECTED_DACL_SECURITY_INFORMATION
let sres = setNamedSecurityInfo(newWideCString(path), SE_FILE_OBJECT, sres = setNamedSecurityInfo(newWideCString(path), SE_FILE_OBJECT,
dflags, nil, nil, pdacl, nil) dflags, nil, nil, pdacl, nil)
if free(pacl) != 0'u:
return err(ioLastError())
if sres != ERROR_SUCCESS: if sres != ERROR_SUCCESS:
err(IoErrorCode(sres)) err(IoErrorCode(sres))
else: else:
ok() ok()
proc createUserOnlySecurityDescriptor(kind: SecDescriptorKind): IoResult[SD] = proc createUserOnlySecurityDescriptor(kind: SecDescriptorKind): IoResult[SD] =
var dacl = ? createCurrentUserOnlyAcl(kind) let
var buffer = newSeq[byte](SECURITY_DESCRIPTOR_MIN_LENGTH) dacl = ? createCurrentUserOnlyAcl(kind)
if initializeSecurityDescriptor(addr buffer[0], localMem = localAlloc(LPTR, SECURITY_DESCRIPTOR_MIN_LENGTH)
SECURITY_DESCRIPTOR_REVISION) == 0:
err(ioLastError()) if isNil(localMem):
discard free(dacl)
return err(ioLastError())
if initializeSecurityDescriptor(localMem, SECURITY_DESCRIPTOR_REVISION) == 0:
let errorCode = ioLastError()
discard free(dacl)
discard localFree(localMem)
err(errorCode)
else: else:
var res = SD(sddata: buffer, acldata: dacl) var res = SD(sddata: cast[LocalMemPtr](localMem), acldata: dacl)
let bits = SE_DACL_PROTECTED let bits = SE_DACL_PROTECTED
if setSecurityDescriptorControl(addr res.sddata[0], bits, bits) == 0: if setSecurityDescriptorControl(localMem, bits, bits) == 0:
err(ioLastError()) let errorCode = ioLastError()
discard free(dacl)
discard localFree(localMem)
err(errorCode)
else: else:
if setSecurityDescriptorDacl(addr res.sddata[0], 1'i32, if setSecurityDescriptorDacl(localMem, 1'i32,
addr res.acldata[0], 0'i32) == 0: res.acldata.getAddr(), 0'i32) == 0:
err(ioLastError()) let errorCode = ioLastError()
discard free(dacl)
discard localFree(localMem)
err(errorCode)
else: else:
ok(res) ok(res)
@ -269,11 +335,11 @@ proc createFilesUserOnlySecurityDescriptor*(): IoResult[SD] {.inline.} =
proc isEmpty*(sd: SD): bool = proc isEmpty*(sd: SD): bool =
## Returns ``true`` is security descriptor ``sd`` is not initialized. ## Returns ``true`` is security descriptor ``sd`` is not initialized.
(len(sd.sddata) == 0) or (len(sd.acldata) == 0) isNil(sd.sddata.getAddr()) or isNil(sd.acldata.getAddr())
template getDescriptor*(sd: SD): pointer = template getDescriptor*(sd: SD): pointer =
## Returns pointer to Windows specific security descriptor. ## Returns pointer to Windows specific security descriptor.
cast[pointer](unsafeAddr sd.sddata[0]) sd.sddata.getAddr()
proc checkCurrentUserOnlyACL*(path: string): IoResult[bool] = proc checkCurrentUserOnlyACL*(path: string): IoResult[bool] =
## Check if specified file or folder ``path`` can be accessed and modified ## Check if specified file or folder ``path`` can be accessed and modified

View File

@ -768,3 +768,27 @@ suite "OS Input/Output procedures test suite":
isDir(destDir) == false isDir(destDir) == false
isDir(firstDir) == false isDir(firstDir) == false
isFile(destFile) == false isFile(destFile) == false
test "getHomePath() test":
let res = getHomePath()
check:
res.isOk()
len(res.get()) > 0
test "getConfigPath() test":
let res = getConfigPath()
check:
res.isOk()
len(res.get()) > 0
test "getCachePath() test":
let res = getCachePath()
check:
res.isOk()
len(res.get()) > 0
test "getTempPath() test":
let res = getTempPath()
check:
res.isOk()
len(res.get()) > 0

View File

@ -10,6 +10,7 @@ import unittest2
when defined(windows): when defined(windows):
import ../stew/windows/acl import ../stew/windows/acl
const TestsCount = 50
suite "Windows security descriptor tests suite": suite "Windows security descriptor tests suite":
test "File/Folder user-only ACL create/verify test": test "File/Folder user-only ACL create/verify test":
@ -35,6 +36,8 @@ suite "Windows security descriptor tests suite":
? removeDir(path3) ? removeDir(path3)
? removeFile(path2) ? removeFile(path2)
? removeDir(path1) ? removeDir(path1)
free(sdd)
free(sdf)
if res1 and res2 and res3 and res4: if res1 and res2 and res3 and res4:
ok(true) ok(true)
else: else:
@ -43,3 +46,24 @@ suite "Windows security descriptor tests suite":
performTest("testblob14", "testblob15").isOk() performTest("testblob14", "testblob15").isOk()
else: else:
skip() skip()
test "Create/Verify multiple folders in user/config home directory":
when defined(windows):
proc performTest(directory: string): IoResult[void] =
var sdd = ? createFoldersUserOnlySecurityDescriptor()
var results = newSeq[bool](TestsCount)
for i in 0 ..< TestsCount:
let path = directory & "\\" & "ACLTEST" & $i
? createPath(path, secDescriptor = sdd.getDescriptor())
results[i] = ? checkCurrentUserOnlyACL(path)
? removeDir(path)
free(sdd)
for chk in results:
if not(chk):
return err(IoErrorCode(UserErrorCode))
ok()
check:
performTest(getHomePath().get()).isOk()
performTest(getConfigPath().get()).isOk()
else:
skip()