From bb086e69da967ad235ed6c31247769e75b318e61 Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Mon, 17 Jun 2024 10:04:14 +0300 Subject: [PATCH] Fix Windows ACL flakiness issue (Windows error 1336). (#221) * Add getHomePath(), getConfigPath() and getCachePath() implementations. Fix ACL flakiness issue. Add tests. * Add getTempPath(). Normalize path endings for all xxPath() functions. * Fix 2.0/devel compilation errors. --- stew/io2.nim | 109 ++++++++++++++++++++++++ stew/windows/acl.nim | 192 ++++++++++++++++++++++++++++-------------- tests/test_io2.nim | 24 ++++++ tests/test_winacl.nim | 24 ++++++ 4 files changed, 286 insertions(+), 63 deletions(-) diff --git a/stew/io2.nim b/stew/io2.nim index be5be90..6cbc135 100644 --- a/stew/io2.nim +++ b/stew/io2.nim @@ -55,6 +55,13 @@ when defined(windows): FileBasicInfoClass = 0'u32 + CSIDL_APPDATA = 0x001a'u32 + # \Application Data + CSIDL_PROFILE = 0x0028'u32 + # + CSIDL_LOCAL_APPDATA = 0x001c'u32 + # \Local Settings\Applicaiton Data (non roaming) + type IoErrorCode* = distinct uint32 IoHandle* = distinct uint @@ -184,6 +191,12 @@ when defined(windows): nNumberOfBytesToLockLow, nNumberOfBytesToLockHigh: uint32, lpOverlapped: pointer): uint32 {. importc: "UnlockFileEx", dynlib: "kernel32", stdcall, sideEffect.} + proc shGetSpecialFolderPathW(hwnd: uint, pszPath: WideCString, csidl: uint32, + fCreate: uint32):uint32 {. + importc: "SHGetSpecialFolderPathW", dynlib: "shell32", stdcall, + sideEffect.} + proc getTempPathW(nBufferLength: uint32, lpBuffer: WideCString): uint32 {. + importc: "GetTempPathW", dynlib: "kernel32", stdcall, sideEffect.} const NO_ERROR = IoErrorCode(0) @@ -245,6 +258,8 @@ elif defined(posix): var errno {.importc, header: "".}: cint + proc c_getenv(env: cstring): cstring {. + importc: "getenv", header: "", sideEffect.} proc write(a1: cint, a2: pointer, a3: csize_t): int {. importc, header: "", sideEffect.} proc read(a1: cint, a2: pointer, a3: csize_t): int {. @@ -1547,3 +1562,97 @@ proc unlockFile*(lock: IoLockHandle): IoResult[void] = err(res.error()) else: ok() + +when defined(windows): + proc getSpecialFolderPath(code: uint32): IoResult[string] = + var path: array[MAX_PATH, Utf16Char] + let + wpath = cast[WideCString](addr path[0]) + res = shGetSpecialFolderPathW(0'u, cast[WideCString](addr path[0]), + code, 0'u32) + if res == 0'u32: + err(ioLastError()) + else: + var strpath = `$`(wpath, len(path)) + normPathEnd(strpath, true) + ok(strpath) + +proc getHomePath*(): IoResult[string] = + ## Returns path to user's home directory. + when defined(windows): + getSpecialFolderPath(CSIDL_PROFILE) + else: + let res = c_getenv("HOME") + var path = if isNil(res): "" else: $res + normPathEnd(path, true) + ok(path) + +proc getConfigPath*(): IoResult[string] = + ## Returns path to application's configuration directory. + when defined(windows): + getSpecialFolderPath(CSIDL_APPDATA) + else: + let + subpath = ".config" + xres = c_getenv("XDG_CONFIG_HOME") + var + path = + if isNil(xres): + let hres = c_getenv("HOME") + if isNil(hres): subpath else: $hres & DirSep & subpath + else: + $xres + normPathEnd(path, true) + ok(path) + +proc getCachePath*(): IoResult[string] = + ## Returns path to application's cache directory. + when defined(windows): + getSpecialFolderPath(CSIDL_LOCAL_APPDATA) + else: + let subpath = + when defined(macos) or defined(macosx) or defined(osx): + "Library/Caches" + else: + ".cache" + let + xres = c_getenv("XDG_CACHE_HOME") + var + path = + if isNil(xres): + let hres = c_getenv("HOME") + if isNil(hres): subpath else: $hres & DirSep & subpath + else: + $xres + normPathEnd(path, true) + ok(path) + +proc getTempPath*(): IoResult[string] = + ## Returns path to OS temporary directory. + when defined(windows): + var path: array[MAX_PATH + 1, Utf16Char] + let + wpath = cast[WideCString](addr path[0]) + res = getTempPathW(uint32(MAX_PATH), wpath) + if res == 0'u32: + err(ioLastError()) + else: + var strpath = `$`(wpath, len(path)) + normPathEnd(strpath, true) + ok(strpath) + else: + for name in ["TMP", "TEMP", "TMPDIR", "TEMPDIR"]: + let res = c_getenv(cstring(name)) + if not(isNil(res)) and isDir($res): + var path = $res + normPathEnd(path, true) + return ok(path) + var defaultDir = + when defined(android): + "/data/local/tmp" + else: + "/tmp" + if isDir(defaultDir): + normPathEnd(defaultDir, true) + return ok(defaultDir) + err(IoErrorCode(2)) # ENOENT diff --git a/stew/windows/acl.nim b/stew/windows/acl.nim index 096e49f..a3615b6 100644 --- a/stew/windows/acl.nim +++ b/stew/windows/acl.nim @@ -33,8 +33,11 @@ const SECURITY_DESCRIPTOR_REVISION = 1'u32 ACCESS_ALLOWED_ACE_TYPE = 0x00'u8 SE_DACL_PROTECTED = 0x1000'u16 + LPTR = 0x0040'u32 type + LocalMemPtr = distinct pointer + ACL {.pure, final.} = object aclRevision: uint8 sbz1: uint8 @@ -45,11 +48,11 @@ type PACL* = ptr ACL SID* = object - data: seq[byte] + data: LocalMemPtr SD* = object - sddata: seq[byte] - acldata: seq[byte] + sddata: LocalMemPtr + acldata: LocalMemPtr SID_AND_ATTRIBUTES {.pure, final.} = object sid: pointer @@ -73,6 +76,8 @@ type proc closeHandle(hobj: uint): int32 {. importc: "CloseHandle", dynlib: "kernel32", stdcall, sideEffect.} +proc localAlloc(uFlags: uint32, ubytes: uint): pointer {. + importc: "LocalAlloc", stdcall, dynlib: "kernel32".} proc localFree(p: pointer): uint {. importc: "LocalFree", stdcall, dynlib: "kernel32".} proc getCurrentProcess(): uint {. @@ -128,33 +133,53 @@ proc setSecurityDescriptorControl(pSD: pointer, bitsOfInterest: uint16, importc: "SetSecurityDescriptorControl", dynlib: "advapi32", stdcall, sideEffect.} -proc len*(sid: SID): int = len(sid.data) +proc len*(sid: SID): int = + int(getLengthSid(cast[pointer](sid.data))) + +proc free(mem: LocalMemPtr): uint = + localFree(cast[pointer](mem)) + +proc free*(sd: var SD) = + ## Free memory occupied by security descriptor. + discard sd.sddata.free() + discard sd.acldata.free() + sd.sddata = LocalMemPtr(nil) + sd.acldata = LocalMemPtr(nil) + +proc free*(sid: var SID) = + ## Free memory occupied by security identifier. + discard sid.data.free() + sid.data = LocalMemPtr(nil) proc getTokenInformation(token: uint, - information: uint32): IoResult[seq[byte]] = - var tlength: uint32 - var buffer = newSeq[byte]() + information: uint32): IoResult[LocalMemPtr] = + var + tlength: uint32 = 0'u32 + localMem: pointer while true: let res = - if len(buffer) == 0: + if tlength == 0'u32: getTokenInformation(token, information, nil, 0, tlength) else: - getTokenInformation(token, information, cast[pointer](addr buffer[0]), - uint32(len(buffer)), tlength) + getTokenInformation(token, information, localMem, tlength, tlength) if res != 0: - return ok(buffer) + return ok(LocalMemPtr(localMem)) else: - let errCode = ioLastError() - if errCode == ERROR_INSUFFICIENT_BUFFER: + let errorCode = ioLastError() + if errorCode == ERROR_INSUFFICIENT_BUFFER: when sizeof(int) == 8: - buffer.setLen(int(tlength)) + localMem = localAlloc(LPTR, uint(tlength)) + if isNil(localMem): + return err(ioLastError()) elif sizeof(int) == 4: if tlength > uint32(high(int)): - return err(errCode) + return err(errorCode) else: - buffer.setLen(int(tlength)) + localMem = localAlloc(LPTR, uint(tlength)) + if isNil(localMem): + return err(ioLastError()) else: - return err(errCode) + return err(errorCode) proc getCurrentUserSid*(): IoResult[SID] = ## Returns current process user's security identifier (SID). @@ -163,45 +188,66 @@ proc getCurrentUserSid*(): IoResult[SID] = if ores == 0: err(ioLastError()) else: - let tres = getTokenInformation(token, 1'u32) - if tres.isErr(): + let localMem = getTokenInformation(token, 1'u32).valueOr: discard closeHandle(token) - err(tres.error) - else: - var buffer = tres.get() - var utoken = cast[ptr TOKEN_USER](addr buffer[0]) - let psid = utoken[].user.sid - if isValidSid(psid) != 0: - var ssid = newSeq[byte](getLengthSid(psid)) - if copySid(uint32(len(ssid)), addr ssid[0], psid) != 0: - if closeHandle(token) != 0: - ok(SID(data: ssid)) + return err(error) + var utoken = cast[ptr TOKEN_USER](localMem) + let psid = utoken[].user.sid + if isValidSid(psid) != 0: + let length = getLengthSid(psid) + var ssid = localAlloc(LPTR, length) + if isNil(ssid): + return err(ioLastError()) + if copySid(uint32(length), ssid, psid) != 0: + if closeHandle(token) != 0: + if free(localMem) != 0'u: + let errorCode = ioLastError() + discard localFree(ssid) + err(errorCode) else: - err(ioLastError()) + ok(SID(data: LocalMemPtr(ssid))) else: - let errCode = ioLastError() - discard closeHandle(token) - err(errCode) + let errorCode = ioLastError() + discard localFree(ssid) + err(errorCode) else: - let errCode = ioLastError() + let errorCode = ioLastError() discard closeHandle(token) - err(errCode) + discard free(localMem) + discard localFree(ssid) + err(errorCode) + else: + let errorCode = ioLastError() + discard closeHandle(token) + discard free(localMem) + err(errorCode) template getAddr*(sid: SID): pointer = ## Obtain Windows specific SID pointer. - unsafeAddr sid.data[0] + cast[pointer](sid.data) -proc createCurrentUserOnlyAcl(kind: SecDescriptorKind): IoResult[seq[byte]] = +template getAddr*(mem: LocalMemPtr): pointer = + cast[pointer](mem) + +proc createCurrentUserOnlyAcl(kind: SecDescriptorKind): IoResult[LocalMemPtr] = let aceMask = FILE_ALL_ACCESS var userSid = ? getCurrentUserSid() let size = - ((sizeof(ACL) + sizeof(ACCESS_ALLOWED_ACE) + len(userSid)) + - (sizeof(uint32) - 1)) and 0xFFFF_FFFC + (uint32(sizeof(ACL) + sizeof(ACCESS_ALLOWED_ACE) + len(userSid)) + + uint32(sizeof(uint32) - 1)) and 0xFFFF_FFFC'u32 - var buffer = newSeq[byte](size) - var pdacl = cast[PACL](addr buffer[0]) + var localMem = localAlloc(LPTR, uint(size)) + if isNil(localMem): + let errorCode = ioLastError() + free(userSid) + return err(errorCode) + + var pdacl = cast[PACL](localMem) if initializeAcl(pdacl, uint32(size), ACL_REVISION) == 0: - err(ioLastError()) + let errorCode = ioLastError() + discard localFree(localMem) + free(userSid) + err(errorCode) else: let aceFlags = if kind == Folder: @@ -210,9 +256,12 @@ proc createCurrentUserOnlyAcl(kind: SecDescriptorKind): IoResult[seq[byte]] = 0'u32 if addAccessAllowedAceEx(pdacl, ACL_REVISION, aceFlags, aceMask, userSid.getAddr()) == 0: - err(ioLastError()) + let errorCode = ioLastError() + discard localFree(localMem) + free userSid + err(errorCode) else: - ok(buffer) + ok(LocalMemPtr(localMem)) proc setCurrentUserOnlyAccess*(path: string): IoResult[void] = ## Set file or folder with path ``path`` to be accessed only by current @@ -227,33 +276,50 @@ proc setCurrentUserOnlyAccess*(path: string): IoResult[void] = else: File - var buffer = ? createCurrentUserOnlyAcl(descriptorKind) - var pdacl = cast[PACL](addr buffer[0]) + let + pacl = ? createCurrentUserOnlyAcl(descriptorKind) + pdacl = cast[PACL](pacl) + dflags = DACL_SECURITY_INFORMATION or + PROTECTED_DACL_SECURITY_INFORMATION + sres = setNamedSecurityInfo(newWideCString(path), SE_FILE_OBJECT, + dflags, nil, nil, pdacl, nil) - let dflags = DACL_SECURITY_INFORMATION or - PROTECTED_DACL_SECURITY_INFORMATION - let sres = setNamedSecurityInfo(newWideCString(path), SE_FILE_OBJECT, - dflags, nil, nil, pdacl, nil) + if free(pacl) != 0'u: + return err(ioLastError()) if sres != ERROR_SUCCESS: err(IoErrorCode(sres)) else: ok() proc createUserOnlySecurityDescriptor(kind: SecDescriptorKind): IoResult[SD] = - var dacl = ? createCurrentUserOnlyAcl(kind) - var buffer = newSeq[byte](SECURITY_DESCRIPTOR_MIN_LENGTH) - if initializeSecurityDescriptor(addr buffer[0], - SECURITY_DESCRIPTOR_REVISION) == 0: - err(ioLastError()) + let + dacl = ? createCurrentUserOnlyAcl(kind) + localMem = localAlloc(LPTR, SECURITY_DESCRIPTOR_MIN_LENGTH) + + if isNil(localMem): + discard free(dacl) + return err(ioLastError()) + + if initializeSecurityDescriptor(localMem, SECURITY_DESCRIPTOR_REVISION) == 0: + let errorCode = ioLastError() + discard free(dacl) + discard localFree(localMem) + err(errorCode) else: - var res = SD(sddata: buffer, acldata: dacl) + var res = SD(sddata: cast[LocalMemPtr](localMem), acldata: dacl) let bits = SE_DACL_PROTECTED - if setSecurityDescriptorControl(addr res.sddata[0], bits, bits) == 0: - err(ioLastError()) + if setSecurityDescriptorControl(localMem, bits, bits) == 0: + let errorCode = ioLastError() + discard free(dacl) + discard localFree(localMem) + err(errorCode) else: - if setSecurityDescriptorDacl(addr res.sddata[0], 1'i32, - addr res.acldata[0], 0'i32) == 0: - err(ioLastError()) + if setSecurityDescriptorDacl(localMem, 1'i32, + res.acldata.getAddr(), 0'i32) == 0: + let errorCode = ioLastError() + discard free(dacl) + discard localFree(localMem) + err(errorCode) else: ok(res) @@ -269,11 +335,11 @@ proc createFilesUserOnlySecurityDescriptor*(): IoResult[SD] {.inline.} = proc isEmpty*(sd: SD): bool = ## Returns ``true`` is security descriptor ``sd`` is not initialized. - (len(sd.sddata) == 0) or (len(sd.acldata) == 0) + isNil(sd.sddata.getAddr()) or isNil(sd.acldata.getAddr()) template getDescriptor*(sd: SD): pointer = ## Returns pointer to Windows specific security descriptor. - cast[pointer](unsafeAddr sd.sddata[0]) + sd.sddata.getAddr() proc checkCurrentUserOnlyACL*(path: string): IoResult[bool] = ## Check if specified file or folder ``path`` can be accessed and modified diff --git a/tests/test_io2.nim b/tests/test_io2.nim index e56d052..29bad78 100644 --- a/tests/test_io2.nim +++ b/tests/test_io2.nim @@ -768,3 +768,27 @@ suite "OS Input/Output procedures test suite": isDir(destDir) == false isDir(firstDir) == false isFile(destFile) == false + + test "getHomePath() test": + let res = getHomePath() + check: + res.isOk() + len(res.get()) > 0 + + test "getConfigPath() test": + let res = getConfigPath() + check: + res.isOk() + len(res.get()) > 0 + + test "getCachePath() test": + let res = getCachePath() + check: + res.isOk() + len(res.get()) > 0 + + test "getTempPath() test": + let res = getTempPath() + check: + res.isOk() + len(res.get()) > 0 diff --git a/tests/test_winacl.nim b/tests/test_winacl.nim index 612a758..21fc255 100644 --- a/tests/test_winacl.nim +++ b/tests/test_winacl.nim @@ -10,6 +10,7 @@ import unittest2 when defined(windows): import ../stew/windows/acl + const TestsCount = 50 suite "Windows security descriptor tests suite": test "File/Folder user-only ACL create/verify test": @@ -35,6 +36,8 @@ suite "Windows security descriptor tests suite": ? removeDir(path3) ? removeFile(path2) ? removeDir(path1) + free(sdd) + free(sdf) if res1 and res2 and res3 and res4: ok(true) else: @@ -43,3 +46,24 @@ suite "Windows security descriptor tests suite": performTest("testblob14", "testblob15").isOk() else: skip() + + test "Create/Verify multiple folders in user/config home directory": + when defined(windows): + proc performTest(directory: string): IoResult[void] = + var sdd = ? createFoldersUserOnlySecurityDescriptor() + var results = newSeq[bool](TestsCount) + for i in 0 ..< TestsCount: + let path = directory & "\\" & "ACLTEST" & $i + ? createPath(path, secDescriptor = sdd.getDescriptor()) + results[i] = ? checkCurrentUserOnlyACL(path) + ? removeDir(path) + free(sdd) + for chk in results: + if not(chk): + return err(IoErrorCode(UserErrorCode)) + ok() + check: + performTest(getHomePath().get()).isOk() + performTest(getConfigPath().get()).isOk() + else: + skip()